Skip to content

Commit afcc8e6

Browse files
NV-jptsft-managedglemaitre
authored
ENH duck-typing scikit-learn estimator instead of inheritance (#858)
Co-authored-by: sft-managed <u00ua78mcdbMxLjxr2357@rl-dgx-r12-u36-rapids-dgx107.raplab.nvidia.com> Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent b3a3dce commit afcc8e6

20 files changed

+410
-87
lines changed

build_tools/azure/install.sh

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ set -x
66
UNAMESTR=`uname`
77

88
make_conda() {
9+
conda update -yq conda
910
TO_INSTALL="$@"
1011
if [[ "$DISTRIB" == *"mamba"* ]]; then
1112
mamba create -n $VIRTUALENV --yes $TO_INSTALL

doc/developers_utils.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ which accepts arrays, matrices, or sparse matrices as arguments, the following
2929
should be used when applicable.
3030

3131
- :func:`check_neighbors_object`: Check the objects is consistent to be a NN.
32-
- :func:`check_target_type`: Check the target types to be conform to the current sam plers.
32+
- :func:`check_target_type`: Check the target types to be conform to the current
33+
samplers.
3334
- :func:`check_sampling_strategy`: Checks that sampling target is onsistent with
3435
the type and return a dictionary containing each targeted class with its
3536
corresponding number of pixel.

doc/whats_new/v0.10.rst

+8
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,11 @@ Version 0.10.0 (ongoing)
55

66
Changelog
77
---------
8+
9+
Enhancements
10+
............
11+
12+
- Add support to accept compatible `NearestNeighbors` objects by only
13+
duck-typing. For instance, it allows to accept cuML instances.
14+
:pr:`858` by :user:`NV-jpt <NV-jpt>` and
15+
:user:`Guillaume Lemaitre <glemaitre>`.

imblearn/over_sampling/_adasyn.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,17 @@ class ADASYN(BaseOverSampler):
3939
{random_state}
4040
4141
n_neighbors : int or estimator object, default=5
42-
If ``int``, number of nearest neighbours to used to construct synthetic
43-
samples. If object, an estimator that inherits from
44-
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
45-
find the k_neighbors.
42+
The nearest neighbors used to define the neighborhood of samples to use
43+
to generate the synthetic samples. You can pass:
44+
45+
- an `int` corresponding to the number of neighbors to use. A
46+
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this
47+
case.
48+
- an instance of a compatible nearest neighbors algorithm that should
49+
implement both methods `kneighbors` and `kneighbors_graph`. For
50+
instance, it could correspond to a
51+
:class:`~sklearn.neighbors.NearestNeighbors` but could be extended to
52+
any compatible class.
4653
4754
{n_jobs}
4855
@@ -124,7 +131,6 @@ def _validate_estimator(self):
124131
self.nn_ = check_neighbors_object(
125132
"n_neighbors", self.n_neighbors, additional_neighbor=1
126133
)
127-
self.nn_.set_params(**{"n_jobs": self.n_jobs})
128134

129135
def _fit_resample(self, X, y):
130136
self._validate_estimator()

imblearn/over_sampling/_smote/base.py

+33-12
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,17 @@ class SMOTE(BaseSMOTE):
224224
{random_state}
225225
226226
k_neighbors : int or object, default=5
227-
If ``int``, number of nearest neighbours to used to construct synthetic
228-
samples. If object, an estimator that inherits from
229-
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
230-
find the k_neighbors.
227+
The nearest neighbors used to define the neighborhood of samples to use
228+
to generate the synthetic samples. You can pass:
229+
230+
- an `int` corresponding to the number of neighbors to use. A
231+
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this
232+
case.
233+
- an instance of a compatible nearest neighbors algorithm that should
234+
implement both methods `kneighbors` and `kneighbors_graph`. For
235+
instance, it could correspond to a
236+
:class:`~sklearn.neighbors.NearestNeighbors` but could be extended to
237+
any compatible class.
231238
232239
{n_jobs}
233240
@@ -367,10 +374,17 @@ class SMOTENC(SMOTE):
367374
{random_state}
368375
369376
k_neighbors : int or object, default=5
370-
If ``int``, number of nearest neighbours to used to construct synthetic
371-
samples. If object, an estimator that inherits from
372-
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
373-
find the k_neighbors.
377+
The nearest neighbors used to define the neighborhood of samples to use
378+
to generate the synthetic samples. You can pass:
379+
380+
- an `int` corresponding to the number of neighbors to use. A
381+
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this
382+
case.
383+
- an instance of a compatible nearest neighbors algorithm that should
384+
implement both methods `kneighbors` and `kneighbors_graph`. For
385+
instance, it could correspond to a
386+
:class:`~sklearn.neighbors.NearestNeighbors` but could be extended to
387+
any compatible class.
374388
375389
{n_jobs}
376390
@@ -636,10 +650,17 @@ class SMOTEN(SMOTE):
636650
{random_state}
637651
638652
k_neighbors : int or object, default=5
639-
If ``int``, number of nearest neighbours to used to construct synthetic
640-
samples. If object, an estimator that inherits from
641-
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
642-
find the k_neighbors.
653+
The nearest neighbors used to define the neighborhood of samples to use
654+
to generate the synthetic samples. You can pass:
655+
656+
- an `int` corresponding to the number of neighbors to use. A
657+
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this
658+
case.
659+
- an instance of a compatible nearest neighbors algorithm that should
660+
implement both methods `kneighbors` and `kneighbors_graph`. For
661+
instance, it could correspond to a
662+
:class:`~sklearn.neighbors.NearestNeighbors` but could be extended to
663+
any compatible class.
643664
644665
{n_jobs}
645666

imblearn/over_sampling/_smote/cluster.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,17 @@ class KMeansSMOTE(BaseSMOTE):
4545
{random_state}
4646
4747
k_neighbors : int or object, default=2
48-
If ``int``, number of nearest neighbours to used to construct synthetic
49-
samples. If object, an estimator that inherits from
50-
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
51-
find the k_neighbors.
48+
The nearest neighbors used to define the neighborhood of samples to use
49+
to generate the synthetic samples. You can pass:
50+
51+
- an `int` corresponding to the number of neighbors to use. A
52+
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this
53+
case.
54+
- an instance of a compatible nearest neighbors algorithm that should
55+
implement both methods `kneighbors` and `kneighbors_graph`. For
56+
instance, it could correspond to a
57+
:class:`~sklearn.neighbors.NearestNeighbors` but could be extended to
58+
any compatible class.
5259
5360
{n_jobs}
5461

imblearn/over_sampling/_smote/filter.py

+53-22
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from sklearn.utils import _safe_indexing
1616

1717
from ..base import BaseOverSampler
18-
from ...exceptions import raise_isinstance_error
1918
from ...utils import check_neighbors_object
2019
from ...utils import Substitution
2120
from ...utils._docstring import _n_jobs_docstring
@@ -48,18 +47,32 @@ class BorderlineSMOTE(BaseSMOTE):
4847
{random_state}
4948
5049
k_neighbors : int or object, default=5
51-
If ``int``, number of nearest neighbours to used to construct synthetic
52-
samples. If object, an estimator that inherits from
53-
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
54-
find the k_neighbors.
50+
The nearest neighbors used to define the neighborhood of samples to use
51+
to generate the synthetic samples. You can pass:
52+
53+
- an `int` corresponding to the number of neighbors to use. A
54+
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this
55+
case.
56+
- an instance of a compatible nearest neighbors algorithm that should
57+
implement both methods `kneighbors` and `kneighbors_graph`. For
58+
instance, it could correspond to a
59+
:class:`~sklearn.neighbors.NearestNeighbors` but could be extended to
60+
any compatible class.
5561
5662
{n_jobs}
5763
5864
m_neighbors : int or object, default=10
59-
If int, number of nearest neighbours to use to determine if a minority
60-
sample is in danger. If object, an estimator that inherits
61-
from :class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used
62-
to find the m_neighbors.
65+
The nearest neighbors used to determine if a minority sample is in
66+
"danger". You can pass:
67+
68+
- an `int` corresponding to the number of neighbors to use. A
69+
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this
70+
case.
71+
- an instance of a compatible nearest neighbors algorithm that should
72+
implement both methods `kneighbors` and `kneighbors_graph`. For
73+
instance, it could correspond to a
74+
:class:`~sklearn.neighbors.NearestNeighbors` but could be extended to
75+
any compatible class.
6376
6477
kind : {{"borderline-1", "borderline-2"}}, default='borderline-1'
6578
The type of SMOTE algorithm to use one of the following options:
@@ -155,7 +168,6 @@ def _validate_estimator(self):
155168
self.nn_m_ = check_neighbors_object(
156169
"m_neighbors", self.m_neighbors, additional_neighbor=1
157170
)
158-
self.nn_m_.set_params(**{"n_jobs": self.n_jobs})
159171
if self.kind not in ("borderline-1", "borderline-2"):
160172
raise ValueError(
161173
f'The possible "kind" of algorithm are '
@@ -263,21 +275,37 @@ class SVMSMOTE(BaseSMOTE):
263275
{random_state}
264276
265277
k_neighbors : int or object, default=5
266-
If ``int``, number of nearest neighbours to used to construct synthetic
267-
samples. If object, an estimator that inherits from
268-
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
269-
find the k_neighbors.
278+
The nearest neighbors used to define the neighborhood of samples to use
279+
to generate the synthetic samples. You can pass:
280+
281+
- an `int` corresponding to the number of neighbors to use. A
282+
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this
283+
case.
284+
- an instance of a compatible nearest neighbors algorithm that should
285+
implement both methods `kneighbors` and `kneighbors_graph`. For
286+
instance, it could correspond to a
287+
:class:`~sklearn.neighbors.NearestNeighbors` but could be extended to
288+
any compatible class.
270289
271290
{n_jobs}
272291
273292
m_neighbors : int or object, default=10
274-
If int, number of nearest neighbours to use to determine if a minority
275-
sample is in danger. If object, an estimator that inherits from
276-
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
277-
find the m_neighbors.
293+
The nearest neighbors used to determine if a minority sample is in
294+
"danger". You can pass:
295+
296+
- an `int` corresponding to the number of neighbors to use. A
297+
`~sklearn.neighbors.NearestNeighbors` instance will be fitted in this
298+
case.
299+
- an instance of a compatible nearest neighbors algorithm that should
300+
implement both methods `kneighbors` and `kneighbors_graph`. For
301+
instance, it could correspond to a
302+
:class:`~sklearn.neighbors.NearestNeighbors` but could be extended to
303+
any compatible class.
278304
279305
svm_estimator : estimator object, default=SVC()
280306
A parametrized :class:`~sklearn.svm.SVC` classifier can be passed.
307+
A scikit-learn compatible estimator can be passed but it is required
308+
to expose a `support_` fitted attribute.
281309
282310
out_step : float, default=0.5
283311
Step size when extrapolating.
@@ -381,14 +409,11 @@ def _validate_estimator(self):
381409
self.nn_m_ = check_neighbors_object(
382410
"m_neighbors", self.m_neighbors, additional_neighbor=1
383411
)
384-
self.nn_m_.set_params(**{"n_jobs": self.n_jobs})
385412

386413
if self.svm_estimator is None:
387414
self.svm_estimator_ = SVC(gamma="scale", random_state=self.random_state)
388-
elif isinstance(self.svm_estimator, SVC):
389-
self.svm_estimator_ = clone(self.svm_estimator)
390415
else:
391-
raise_isinstance_error("svm_estimator", [SVC], self.svm_estimator)
416+
self.svm_estimator_ = clone(self.svm_estimator)
392417

393418
def _fit_resample(self, X, y):
394419
self._validate_estimator()
@@ -403,6 +428,12 @@ def _fit_resample(self, X, y):
403428
X_class = _safe_indexing(X, target_class_indices)
404429

405430
self.svm_estimator_.fit(X, y)
431+
if not hasattr(self.svm_estimator_, "support_"):
432+
raise RuntimeError(
433+
"`svm_estimator` is required to exposed a `support_` fitted "
434+
"attribute. Such estimator belongs to the familly of Support "
435+
"Vector Machine."
436+
)
406437
support_index = self.svm_estimator_.support_[
407438
y[self.svm_estimator_.support_] == class_sample
408439
]

imblearn/over_sampling/_smote/tests/test_smote.py

-14
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,12 @@
44
# License: MIT
55

66
import numpy as np
7-
import pytest
87

98
from sklearn.utils._testing import assert_allclose
109
from sklearn.utils._testing import assert_array_equal
1110
from sklearn.neighbors import NearestNeighbors
1211

1312
from imblearn.over_sampling import SMOTE
14-
from imblearn.over_sampling import SVMSMOTE
15-
from imblearn.over_sampling import BorderlineSMOTE
1613

1714

1815
RND_SEED = 0
@@ -153,14 +150,3 @@ def test_sample_regular_with_nn():
153150
)
154151
assert_allclose(X_resampled, X_gt, rtol=R_TOL)
155152
assert_array_equal(y_resampled, y_gt)
156-
157-
158-
@pytest.mark.parametrize(
159-
"smote", [BorderlineSMOTE(), SVMSMOTE()], ids=["borderline", "svm"]
160-
)
161-
def test_smote_m_neighbors(smote):
162-
# check that m_neighbors is properly set. Regression test for:
163-
# https://github.com/scikit-learn-contrib/imbalanced-learn/issues/568
164-
_ = smote.fit_resample(X, Y)
165-
assert smote.nn_k_.n_neighbors == 6
166-
assert smote.nn_m_.n_neighbors == 11

imblearn/over_sampling/_smote/tests/test_svm_smote.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
import numpy as np
33

4+
from sklearn.linear_model import LogisticRegression
45
from sklearn.neighbors import NearestNeighbors
56
from sklearn.svm import SVC
67

@@ -54,3 +55,12 @@ def test_svm_smote(data):
5455

5556
assert_allclose(X_res_1, X_res_2)
5657
assert_array_equal(y_res_1, y_res_2)
58+
59+
60+
def test_svm_smote_not_svm(data):
61+
"""Check that we raise a proper error if passing an estimator that does not
62+
expose a `support_` fitted attribute."""
63+
64+
err_msg = "`svm_estimator` is required to exposed a `support_` fitted attribute."
65+
with pytest.raises(RuntimeError, match=err_msg):
66+
SVMSMOTE(svm_estimator=LogisticRegression()).fit_resample(*data)

imblearn/over_sampling/tests/test_adasyn.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,11 @@ def test_ada_fit_resample_nn_obj():
131131
{"sampling_strategy": {0: 9, 1: 12}},
132132
"No samples will be generated.",
133133
),
134-
({"n_neighbors": "rnd"}, "has to be one of"),
134+
(
135+
{"n_neighbors": "rnd"},
136+
"n_neighbors must be an interger or an object compatible with the "
137+
"KNeighborsMixin API of scikit-learn",
138+
),
135139
],
136140
)
137141
def test_adasyn_error(adasyn_params, err_msg):

0 commit comments

Comments
 (0)