diff --git a/build_tools/azure/install.sh b/build_tools/azure/install.sh index 4fa68fbaa..932c5b150 100755 --- a/build_tools/azure/install.sh +++ b/build_tools/azure/install.sh @@ -6,6 +6,7 @@ set -x UNAMESTR=`uname` make_conda() { + conda update -yq conda TO_INSTALL="$@" if [[ "$DISTRIB" == *"mamba"* ]]; then mamba create -n $VIRTUALENV --yes $TO_INSTALL diff --git a/doc/developers_utils.rst b/doc/developers_utils.rst index e7b9bc478..5a1dc5559 100644 --- a/doc/developers_utils.rst +++ b/doc/developers_utils.rst @@ -29,7 +29,8 @@ which accepts arrays, matrices, or sparse matrices as arguments, the following should be used when applicable. - :func:`check_neighbors_object`: Check the objects is consistent to be a NN. -- :func:`check_target_type`: Check the target types to be conform to the current sam plers. +- :func:`check_target_type`: Check the target types to be conform to the current + samplers. - :func:`check_sampling_strategy`: Checks that sampling target is onsistent with the type and return a dictionary containing each targeted class with its corresponding number of pixel. diff --git a/doc/whats_new/v0.10.rst b/doc/whats_new/v0.10.rst index ea585d986..ff37d151a 100644 --- a/doc/whats_new/v0.10.rst +++ b/doc/whats_new/v0.10.rst @@ -5,3 +5,11 @@ Version 0.10.0 (ongoing) Changelog --------- + +Enhancements +............ + +- Add support to accept compatible `NearestNeighbors` objects by only + duck-typing. For instance, it allows to accept cuML instances. + :pr:`858` by :user:`NV-jpt ` and + :user:`Guillaume Lemaitre `. diff --git a/imblearn/over_sampling/_adasyn.py b/imblearn/over_sampling/_adasyn.py index cbfeeda22..bd50378c2 100644 --- a/imblearn/over_sampling/_adasyn.py +++ b/imblearn/over_sampling/_adasyn.py @@ -39,10 +39,17 @@ class ADASYN(BaseOverSampler): {random_state} n_neighbors : int or estimator object, default=5 - If ``int``, number of nearest neighbours to used to construct synthetic - samples. If object, an estimator that inherits from - :class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to - find the k_neighbors. + The nearest neighbors used to define the neighborhood of samples to use + to generate the synthetic samples. You can pass: + + - an `int` corresponding to the number of neighbors to use. A + `~sklearn.neighbors.NearestNeighbors` instance will be fitted in this + case. + - an instance of a compatible nearest neighbors algorithm that should + implement both methods `kneighbors` and `kneighbors_graph`. For + instance, it could correspond to a + :class:`~sklearn.neighbors.NearestNeighbors` but could be extended to + any compatible class. {n_jobs} @@ -124,7 +131,6 @@ def _validate_estimator(self): self.nn_ = check_neighbors_object( "n_neighbors", self.n_neighbors, additional_neighbor=1 ) - self.nn_.set_params(**{"n_jobs": self.n_jobs}) def _fit_resample(self, X, y): self._validate_estimator() diff --git a/imblearn/over_sampling/_smote/base.py b/imblearn/over_sampling/_smote/base.py index 23bff84c1..01540e19f 100644 --- a/imblearn/over_sampling/_smote/base.py +++ b/imblearn/over_sampling/_smote/base.py @@ -224,10 +224,17 @@ class SMOTE(BaseSMOTE): {random_state} k_neighbors : int or object, default=5 - If ``int``, number of nearest neighbours to used to construct synthetic - samples. If object, an estimator that inherits from - :class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to - find the k_neighbors. + The nearest neighbors used to define the neighborhood of samples to use + to generate the synthetic samples. You can pass: + + - an `int` corresponding to the number of neighbors to use. A + `~sklearn.neighbors.NearestNeighbors` instance will be fitted in this + case. + - an instance of a compatible nearest neighbors algorithm that should + implement both methods `kneighbors` and `kneighbors_graph`. For + instance, it could correspond to a + :class:`~sklearn.neighbors.NearestNeighbors` but could be extended to + any compatible class. {n_jobs} @@ -367,10 +374,17 @@ class SMOTENC(SMOTE): {random_state} k_neighbors : int or object, default=5 - If ``int``, number of nearest neighbours to used to construct synthetic - samples. If object, an estimator that inherits from - :class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to - find the k_neighbors. + The nearest neighbors used to define the neighborhood of samples to use + to generate the synthetic samples. You can pass: + + - an `int` corresponding to the number of neighbors to use. A + `~sklearn.neighbors.NearestNeighbors` instance will be fitted in this + case. + - an instance of a compatible nearest neighbors algorithm that should + implement both methods `kneighbors` and `kneighbors_graph`. For + instance, it could correspond to a + :class:`~sklearn.neighbors.NearestNeighbors` but could be extended to + any compatible class. {n_jobs} @@ -636,10 +650,17 @@ class SMOTEN(SMOTE): {random_state} k_neighbors : int or object, default=5 - If ``int``, number of nearest neighbours to used to construct synthetic - samples. If object, an estimator that inherits from - :class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to - find the k_neighbors. + The nearest neighbors used to define the neighborhood of samples to use + to generate the synthetic samples. You can pass: + + - an `int` corresponding to the number of neighbors to use. A + `~sklearn.neighbors.NearestNeighbors` instance will be fitted in this + case. + - an instance of a compatible nearest neighbors algorithm that should + implement both methods `kneighbors` and `kneighbors_graph`. For + instance, it could correspond to a + :class:`~sklearn.neighbors.NearestNeighbors` but could be extended to + any compatible class. {n_jobs} diff --git a/imblearn/over_sampling/_smote/cluster.py b/imblearn/over_sampling/_smote/cluster.py index c18e9b7db..871577486 100644 --- a/imblearn/over_sampling/_smote/cluster.py +++ b/imblearn/over_sampling/_smote/cluster.py @@ -45,10 +45,17 @@ class KMeansSMOTE(BaseSMOTE): {random_state} k_neighbors : int or object, default=2 - If ``int``, number of nearest neighbours to used to construct synthetic - samples. If object, an estimator that inherits from - :class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to - find the k_neighbors. + The nearest neighbors used to define the neighborhood of samples to use + to generate the synthetic samples. You can pass: + + - an `int` corresponding to the number of neighbors to use. A + `~sklearn.neighbors.NearestNeighbors` instance will be fitted in this + case. + - an instance of a compatible nearest neighbors algorithm that should + implement both methods `kneighbors` and `kneighbors_graph`. For + instance, it could correspond to a + :class:`~sklearn.neighbors.NearestNeighbors` but could be extended to + any compatible class. {n_jobs} diff --git a/imblearn/over_sampling/_smote/filter.py b/imblearn/over_sampling/_smote/filter.py index 3c49458ef..93aebc235 100644 --- a/imblearn/over_sampling/_smote/filter.py +++ b/imblearn/over_sampling/_smote/filter.py @@ -15,7 +15,6 @@ from sklearn.utils import _safe_indexing from ..base import BaseOverSampler -from ...exceptions import raise_isinstance_error from ...utils import check_neighbors_object from ...utils import Substitution from ...utils._docstring import _n_jobs_docstring @@ -48,18 +47,32 @@ class BorderlineSMOTE(BaseSMOTE): {random_state} k_neighbors : int or object, default=5 - If ``int``, number of nearest neighbours to used to construct synthetic - samples. If object, an estimator that inherits from - :class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to - find the k_neighbors. + The nearest neighbors used to define the neighborhood of samples to use + to generate the synthetic samples. You can pass: + + - an `int` corresponding to the number of neighbors to use. A + `~sklearn.neighbors.NearestNeighbors` instance will be fitted in this + case. + - an instance of a compatible nearest neighbors algorithm that should + implement both methods `kneighbors` and `kneighbors_graph`. For + instance, it could correspond to a + :class:`~sklearn.neighbors.NearestNeighbors` but could be extended to + any compatible class. {n_jobs} m_neighbors : int or object, default=10 - If int, number of nearest neighbours to use to determine if a minority - sample is in danger. If object, an estimator that inherits - from :class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used - to find the m_neighbors. + The nearest neighbors used to determine if a minority sample is in + "danger". You can pass: + + - an `int` corresponding to the number of neighbors to use. A + `~sklearn.neighbors.NearestNeighbors` instance will be fitted in this + case. + - an instance of a compatible nearest neighbors algorithm that should + implement both methods `kneighbors` and `kneighbors_graph`. For + instance, it could correspond to a + :class:`~sklearn.neighbors.NearestNeighbors` but could be extended to + any compatible class. kind : {{"borderline-1", "borderline-2"}}, default='borderline-1' The type of SMOTE algorithm to use one of the following options: @@ -155,7 +168,6 @@ def _validate_estimator(self): self.nn_m_ = check_neighbors_object( "m_neighbors", self.m_neighbors, additional_neighbor=1 ) - self.nn_m_.set_params(**{"n_jobs": self.n_jobs}) if self.kind not in ("borderline-1", "borderline-2"): raise ValueError( f'The possible "kind" of algorithm are ' @@ -263,21 +275,37 @@ class SVMSMOTE(BaseSMOTE): {random_state} k_neighbors : int or object, default=5 - If ``int``, number of nearest neighbours to used to construct synthetic - samples. If object, an estimator that inherits from - :class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to - find the k_neighbors. + The nearest neighbors used to define the neighborhood of samples to use + to generate the synthetic samples. You can pass: + + - an `int` corresponding to the number of neighbors to use. A + `~sklearn.neighbors.NearestNeighbors` instance will be fitted in this + case. + - an instance of a compatible nearest neighbors algorithm that should + implement both methods `kneighbors` and `kneighbors_graph`. For + instance, it could correspond to a + :class:`~sklearn.neighbors.NearestNeighbors` but could be extended to + any compatible class. {n_jobs} m_neighbors : int or object, default=10 - If int, number of nearest neighbours to use to determine if a minority - sample is in danger. If object, an estimator that inherits from - :class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to - find the m_neighbors. + The nearest neighbors used to determine if a minority sample is in + "danger". You can pass: + + - an `int` corresponding to the number of neighbors to use. A + `~sklearn.neighbors.NearestNeighbors` instance will be fitted in this + case. + - an instance of a compatible nearest neighbors algorithm that should + implement both methods `kneighbors` and `kneighbors_graph`. For + instance, it could correspond to a + :class:`~sklearn.neighbors.NearestNeighbors` but could be extended to + any compatible class. svm_estimator : estimator object, default=SVC() A parametrized :class:`~sklearn.svm.SVC` classifier can be passed. + A scikit-learn compatible estimator can be passed but it is required + to expose a `support_` fitted attribute. out_step : float, default=0.5 Step size when extrapolating. @@ -381,14 +409,11 @@ def _validate_estimator(self): self.nn_m_ = check_neighbors_object( "m_neighbors", self.m_neighbors, additional_neighbor=1 ) - self.nn_m_.set_params(**{"n_jobs": self.n_jobs}) if self.svm_estimator is None: self.svm_estimator_ = SVC(gamma="scale", random_state=self.random_state) - elif isinstance(self.svm_estimator, SVC): - self.svm_estimator_ = clone(self.svm_estimator) else: - raise_isinstance_error("svm_estimator", [SVC], self.svm_estimator) + self.svm_estimator_ = clone(self.svm_estimator) def _fit_resample(self, X, y): self._validate_estimator() @@ -403,6 +428,12 @@ def _fit_resample(self, X, y): X_class = _safe_indexing(X, target_class_indices) self.svm_estimator_.fit(X, y) + if not hasattr(self.svm_estimator_, "support_"): + raise RuntimeError( + "`svm_estimator` is required to exposed a `support_` fitted " + "attribute. Such estimator belongs to the familly of Support " + "Vector Machine." + ) support_index = self.svm_estimator_.support_[ y[self.svm_estimator_.support_] == class_sample ] diff --git a/imblearn/over_sampling/_smote/tests/test_smote.py b/imblearn/over_sampling/_smote/tests/test_smote.py index 3f9ed0b40..27e8e8b79 100644 --- a/imblearn/over_sampling/_smote/tests/test_smote.py +++ b/imblearn/over_sampling/_smote/tests/test_smote.py @@ -4,15 +4,12 @@ # License: MIT import numpy as np -import pytest from sklearn.utils._testing import assert_allclose from sklearn.utils._testing import assert_array_equal from sklearn.neighbors import NearestNeighbors from imblearn.over_sampling import SMOTE -from imblearn.over_sampling import SVMSMOTE -from imblearn.over_sampling import BorderlineSMOTE RND_SEED = 0 @@ -153,14 +150,3 @@ def test_sample_regular_with_nn(): ) assert_allclose(X_resampled, X_gt, rtol=R_TOL) assert_array_equal(y_resampled, y_gt) - - -@pytest.mark.parametrize( - "smote", [BorderlineSMOTE(), SVMSMOTE()], ids=["borderline", "svm"] -) -def test_smote_m_neighbors(smote): - # check that m_neighbors is properly set. Regression test for: - # https://github.com/scikit-learn-contrib/imbalanced-learn/issues/568 - _ = smote.fit_resample(X, Y) - assert smote.nn_k_.n_neighbors == 6 - assert smote.nn_m_.n_neighbors == 11 diff --git a/imblearn/over_sampling/_smote/tests/test_svm_smote.py b/imblearn/over_sampling/_smote/tests/test_svm_smote.py index 578ceccde..c6ae61d0c 100644 --- a/imblearn/over_sampling/_smote/tests/test_svm_smote.py +++ b/imblearn/over_sampling/_smote/tests/test_svm_smote.py @@ -1,6 +1,7 @@ import pytest import numpy as np +from sklearn.linear_model import LogisticRegression from sklearn.neighbors import NearestNeighbors from sklearn.svm import SVC @@ -54,3 +55,12 @@ def test_svm_smote(data): assert_allclose(X_res_1, X_res_2) assert_array_equal(y_res_1, y_res_2) + + +def test_svm_smote_not_svm(data): + """Check that we raise a proper error if passing an estimator that does not + expose a `support_` fitted attribute.""" + + err_msg = "`svm_estimator` is required to exposed a `support_` fitted attribute." + with pytest.raises(RuntimeError, match=err_msg): + SVMSMOTE(svm_estimator=LogisticRegression()).fit_resample(*data) diff --git a/imblearn/over_sampling/tests/test_adasyn.py b/imblearn/over_sampling/tests/test_adasyn.py index 173547b7f..65ce69182 100644 --- a/imblearn/over_sampling/tests/test_adasyn.py +++ b/imblearn/over_sampling/tests/test_adasyn.py @@ -131,7 +131,11 @@ def test_ada_fit_resample_nn_obj(): {"sampling_strategy": {0: 9, 1: 12}}, "No samples will be generated.", ), - ({"n_neighbors": "rnd"}, "has to be one of"), + ( + {"n_neighbors": "rnd"}, + "n_neighbors must be an interger or an object compatible with the " + "KNeighborsMixin API of scikit-learn", + ), ], ) def test_adasyn_error(adasyn_params, err_msg): diff --git a/imblearn/over_sampling/tests/test_common.py b/imblearn/over_sampling/tests/test_common.py new file mode 100644 index 000000000..804382c08 --- /dev/null +++ b/imblearn/over_sampling/tests/test_common.py @@ -0,0 +1,118 @@ +from collections import Counter + +import pytest +import numpy as np + +from imblearn.over_sampling import ( + ADASYN, + BorderlineSMOTE, + KMeansSMOTE, + SMOTE, + SMOTEN, + SMOTENC, + SVMSMOTE, +) +from imblearn.utils.testing import _CustomNearestNeighbors + + +@pytest.fixture +def numerical_data(): + rng = np.random.RandomState(0) + X = rng.randn(100, 2) + y = np.repeat([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0], 5) + + return X, y + + +@pytest.fixture +def categorical_data(): + rng = np.random.RandomState(0) + + feature_1 = ["A"] * 10 + ["B"] * 20 + ["C"] * 30 + feature_2 = ["A"] * 40 + ["B"] * 20 + feature_3 = ["A"] * 20 + ["B"] * 20 + ["C"] * 10 + ["D"] * 10 + X = np.array([feature_1, feature_2, feature_3], dtype=object).T + rng.shuffle(X) + y = np.array([0] * 20 + [1] * 40, dtype=np.int32) + y_labels = np.array(["not apple", "apple"], dtype=object) + y = y_labels[y] + return X, y + + +@pytest.fixture +def heterogeneous_data(): + rng = np.random.RandomState(42) + X = np.empty((30, 4), dtype=object) + X[:, :2] = rng.randn(30, 2) + X[:, 2] = rng.choice(["a", "b", "c"], size=30).astype(object) + X[:, 3] = rng.randint(3, size=30) + y = np.array([0] * 10 + [1] * 20) + return X, y, [2, 3] + + +@pytest.mark.parametrize( + "smote", [BorderlineSMOTE(), SVMSMOTE()], ids=["borderline", "svm"] +) +def test_smote_m_neighbors(numerical_data, smote): + # check that m_neighbors is properly set. Regression test for: + # https://github.com/scikit-learn-contrib/imbalanced-learn/issues/568 + X, y = numerical_data + _ = smote.fit_resample(X, y) + assert smote.nn_k_.n_neighbors == 6 + assert smote.nn_m_.n_neighbors == 11 + + +@pytest.mark.parametrize( + "smote, neighbor_estimator_name", + [ + (ADASYN(random_state=0), "n_neighbors"), + (BorderlineSMOTE(random_state=0), "k_neighbors"), + (KMeansSMOTE(random_state=1), "k_neighbors"), + (SMOTE(random_state=0), "k_neighbors"), + (SVMSMOTE(random_state=0), "k_neighbors"), + ], + ids=["adasyn", "borderline", "kmeans", "smote", "svm"], +) +def test_numerical_smote_custom_nn(numerical_data, smote, neighbor_estimator_name): + X, y = numerical_data + params = { + neighbor_estimator_name: _CustomNearestNeighbors(n_neighbors=5), + } + smote.set_params(**params) + X_res, _ = smote.fit_resample(X, y) + + assert X_res.shape[0] >= 120 + + +def test_categorical_smote_k_custom_nn(categorical_data): + X, y = categorical_data + smote = SMOTEN(k_neighbors=_CustomNearestNeighbors(n_neighbors=5)) + X_res, y_res = smote.fit_resample(X, y) + + assert X_res.shape == (80, 3) + assert Counter(y_res) == {"apple": 40, "not apple": 40} + + +def test_heterogeneous_smote_k_custom_nn(heterogeneous_data): + X, y, categorical_features = heterogeneous_data + smote = SMOTENC( + categorical_features, k_neighbors=_CustomNearestNeighbors(n_neighbors=5) + ) + X_res, y_res = smote.fit_resample(X, y) + + assert X_res.shape == (40, 4) + assert Counter(y_res) == {0: 20, 1: 20} + + +@pytest.mark.parametrize( + "smote", + [BorderlineSMOTE(random_state=0), SVMSMOTE(random_state=0)], + ids=["borderline", "svm"], +) +def test_numerical_smote_extra_custom_nn(numerical_data, smote): + X, y = numerical_data + smote.set_params(m_neighbors=_CustomNearestNeighbors(n_neighbors=5)) + X_res, y_res = smote.fit_resample(X, y) + + assert X_res.shape == (120, 2) + assert Counter(y_res) == {0: 60, 1: 60} diff --git a/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py b/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py index 4ffe7f76b..b7f56274e 100644 --- a/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py +++ b/imblearn/under_sampling/_prototype_generation/_cluster_centroids.py @@ -49,7 +49,8 @@ class ClusterCentroids(BaseUnderSampler): {random_state} estimator : estimator object, default=None - Pass a :class:`~sklearn.cluster.KMeans` estimator. By default, it will + A scikit-learn compatible clustering method that exposes a `n_clusters` + parameter and a `cluster_centers_` fitted attribute. By default, it will be a default :class:`~sklearn.cluster.KMeans` estimator. voting : {{"hard", "soft", "auto"}}, default='auto' @@ -141,13 +142,13 @@ def _validate_estimator(self): ) if self.estimator is None: self.estimator_ = KMeans(random_state=self.random_state) - elif isinstance(self.estimator, KMeans): - self.estimator_ = clone(self.estimator) else: - raise ValueError( - f"`estimator` has to be a KMeans clustering." - f" Got {type(self.estimator)} instead." - ) + self.estimator_ = clone(self.estimator) + if "n_clusters" not in self.estimator_.get_params(): + raise ValueError( + "`estimator` should be a clustering estimator exposing a parameter" + " `n_clusters` and a fitted parameter `cluster_centers_`." + ) def _generate_sample(self, X, y, centroids, target_class): if self.voting_ == "hard": @@ -188,6 +189,11 @@ def _fit_resample(self, X, y): n_samples = self.sampling_strategy_[target_class] self.estimator_.set_params(**{"n_clusters": n_samples}) self.estimator_.fit(_safe_indexing(X, target_class_indices)) + if not hasattr(self.estimator_, "cluster_centers_"): + raise RuntimeError( + "`estimator` should be a clustering estimator exposing a " + "fitted parameter `cluster_centers_`." + ) X_new, y_new = self._generate_sample( _safe_indexing(X, target_class_indices), _safe_indexing(y, target_class_indices), diff --git a/imblearn/under_sampling/_prototype_generation/tests/test_cluster_centroids.py b/imblearn/under_sampling/_prototype_generation/tests/test_cluster_centroids.py index 03d1970e5..d7399dffd 100644 --- a/imblearn/under_sampling/_prototype_generation/tests/test_cluster_centroids.py +++ b/imblearn/under_sampling/_prototype_generation/tests/test_cluster_centroids.py @@ -5,10 +5,12 @@ import numpy as np from scipy import sparse +from sklearn.linear_model import LogisticRegression from sklearn.cluster import KMeans from sklearn.datasets import make_classification from imblearn.under_sampling import ClusterCentroids +from imblearn.utils.testing import _CustomClusterer RND_SEED = 0 X = np.array( @@ -101,7 +103,6 @@ def test_fit_hard_voting(): @pytest.mark.parametrize( "cluster_centroids_params, err_msg", [ - ({"estimator": "rnd"}, "has to be a KMeans clustering"), ({"voting": "unknown"}, "needs to be one of"), ], ) @@ -152,3 +153,21 @@ def test_cluster_centroids_hard_target_class(): for minority_sample in X_minority_class ] assert sum(sample_from_minority_in_majority) == 0 + + +def test_cluster_centroids_error_estimator(): + """Check that an error is raised when estimator does not have a cluster API.""" + + err_msg = ( + "`estimator` should be a clustering estimator exposing a parameter " + "`n_clusters` and a fitted parameter `cluster_centers_`." + ) + with pytest.raises(ValueError, match=err_msg): + ClusterCentroids(estimator=LogisticRegression()).fit_resample(X, Y) + + err_msg = ( + "`estimator` should be a clustering estimator exposing a fitted parameter " + "`cluster_centers_`." + ) + with pytest.raises(RuntimeError, match=err_msg): + ClusterCentroids(estimator=_CustomClusterer()).fit_resample(X, Y) diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_edited_nearest_neighbours.py b/imblearn/under_sampling/_prototype_selection/tests/test_edited_nearest_neighbours.py index 4c1856525..50680f632 100644 --- a/imblearn/under_sampling/_prototype_selection/tests/test_edited_nearest_neighbours.py +++ b/imblearn/under_sampling/_prototype_selection/tests/test_edited_nearest_neighbours.py @@ -126,7 +126,11 @@ def test_enn_fit_resample_with_nn_object(): def test_enn_not_good_object(): nn = "rnd" enn = EditedNearestNeighbours(n_neighbors=nn, kind_sel="mode") - with pytest.raises(ValueError, match="has to be one of"): + err_msg = ( + "n_neighbors must be an interger or an object compatible with the " + "KNeighborsMixin API of scikit-learn" + ) + with pytest.raises(ValueError, match=err_msg): enn.fit_resample(X, Y) diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_nearmiss.py b/imblearn/under_sampling/_prototype_selection/tests/test_nearmiss.py index 3e2e8686c..6a3fc04fb 100644 --- a/imblearn/under_sampling/_prototype_selection/tests/test_nearmiss.py +++ b/imblearn/under_sampling/_prototype_selection/tests/test_nearmiss.py @@ -39,14 +39,17 @@ "nearmiss_params, err_msg", [ ({"version": 1000}, "must be 1, 2 or 3"), - ({"version": 1, "n_neighbors": "rnd"}, "has to be one of"), + ( + {"version": 1, "n_neighbors": "rnd"}, + "n_neighbors must be an interger or an object compatible", + ), ( { "version": 3, "n_neighbors": NearestNeighbors(n_neighbors=3), "n_neighbors_ver3": "rnd", }, - "has to be one of", + "n_neighbors_ver3 must be an interger or an object compatible", ), ], ) diff --git a/imblearn/under_sampling/_prototype_selection/tests/test_neighbourhood_cleaning_rule.py b/imblearn/under_sampling/_prototype_selection/tests/test_neighbourhood_cleaning_rule.py index fc84cb017..78119dde7 100644 --- a/imblearn/under_sampling/_prototype_selection/tests/test_neighbourhood_cleaning_rule.py +++ b/imblearn/under_sampling/_prototype_selection/tests/test_neighbourhood_cleaning_rule.py @@ -37,7 +37,10 @@ [ ({"threshold_cleaning": -10}, "value between 0 and 1"), ({"threshold_cleaning": 10}, "value between 0 and 1"), - ({"n_neighbors": "rnd"}, "has to be one of"), + ( + {"n_neighbors": "rnd"}, + "n_neighbors must be an interger or an object compatible", + ), ], ) def test_ncr_error(ncr_params, err_msg): diff --git a/imblearn/utils/_validation.py b/imblearn/utils/_validation.py index 23dc3b3ab..c3fa53357 100644 --- a/imblearn/utils/_validation.py +++ b/imblearn/utils/_validation.py @@ -12,13 +12,10 @@ import numpy as np from sklearn.base import clone -from sklearn.neighbors._base import KNeighborsMixin from sklearn.neighbors import NearestNeighbors from sklearn.utils import column_or_1d from sklearn.utils.multiclass import type_of_target -from ..exceptions import raise_isinstance_error - SAMPLING_KIND = ( "over-sampling", "under-sampling", @@ -67,20 +64,40 @@ def _transfrom_one(self, array, props): return ret +def _is_neighbors_object(estimator): + """Check that the estimator exposes a KNeighborsMixin-like API. + + A KNeighborsMixin-like API exposes the following methods: (i) `kneighbors`, + (ii) `kneighbors_graph`. + + Parameters + ---------- + estimator : object + A scikit-learn compatible estimator. + + Returns + ------- + is_neighbors_object : bool + True if the estimator exposes a KNeighborsMixin-like API. + """ + neighbors_attributes = ["kneighbors", "kneighbors_graph"] + return all(hasattr(estimator, attr) for attr in neighbors_attributes) + + def check_neighbors_object(nn_name, nn_object, additional_neighbor=0): - """Check the objects is consistent to be a NN. + """Check the objects is consistent to be a k nearest neighbors. - Several methods in imblearn relies on NN. Until version 0.4, these - objects can be passed at initialisation as an integer or a - KNeighborsMixin. After only KNeighborsMixin will be accepted. This - utility allows for type checking and raise if the type is wrong. + Several methods in `imblearn` relies on k nearest neighbors. These objects + can be passed at initialisation as an integer or as an object that has + KNeighborsMixin-like attributes. This utility will create or clone said + object, ensuring it is KNeighbors-like. Parameters ---------- nn_name : str The name associated to the object to raise an error if needed. - nn_object : int or KNeighborsMixin, + nn_object : int or KNeighborsMixin The object to be checked. additional_neighbor : int, default=0 @@ -93,10 +110,14 @@ def check_neighbors_object(nn_name, nn_object, additional_neighbor=0): """ if isinstance(nn_object, Integral): return NearestNeighbors(n_neighbors=nn_object + additional_neighbor) - elif isinstance(nn_object, KNeighborsMixin): + elif _is_neighbors_object(nn_object): return clone(nn_object) else: - raise_isinstance_error(nn_name, [int, KNeighborsMixin], nn_object) + raise ValueError( + f"{nn_name} must be an interger or an object compatible with the " + "KNeighborsMixin API of scikit-learn (i.e. implementing `kneighbors` " + "method)." + ) def _count_class_sample(y): diff --git a/imblearn/utils/testing.py b/imblearn/utils/testing.py index eaad65efb..88abebd1b 100644 --- a/imblearn/utils/testing.py +++ b/imblearn/utils/testing.py @@ -9,13 +9,15 @@ import warnings from contextlib import contextmanager from importlib import import_module -from re import compile +from operator import itemgetter from pathlib import Path +from re import compile -from operator import itemgetter +from scipy import sparse from pytest import warns as _warns from sklearn.base import BaseEstimator +from sklearn.neighbors import KDTree from sklearn.utils._testing import ignore_warnings @@ -164,3 +166,42 @@ def warns(expected_warning, match=None): assert False, msg else: pass + + +class _CustomNearestNeighbors(BaseEstimator): + """Basic implementation of nearest neighbors not relying on scikit-learn. + + `kneighbors_graph` is ignored and `metric` does not have any impact. + """ + + def __init__(self, n_neighbors=1, metric="euclidean"): + self.n_neighbors = n_neighbors + self.metric = metric + + def fit(self, X, y=None): + X = X.toarray() if sparse.issparse(X) else X + self._kd_tree = KDTree(X) + return self + + def kneighbors(self, X, n_neighbors=None, return_distance=True): + n_neighbors = n_neighbors if n_neighbors is not None else self.n_neighbors + X = X.toarray() if sparse.issparse(X) else X + distances, indices = self._kd_tree.query(X, k=n_neighbors) + if return_distance: + return distances, indices + return indices + + def kneighbors_graph(X=None, n_neighbors=None, mode="connectivity"): + """This method is not used within imblearn but it is required for + duck-typing.""" + pass + + +class _CustomClusterer(BaseEstimator): + """Class that mimics a cluster that does not expose `cluster_centers_`.""" + + def __init__(self, n_clusters=1): + self.n_clusters = n_clusters + + def fit(self, X, y=None): + return self diff --git a/imblearn/utils/tests/test_testing.py b/imblearn/utils/tests/test_testing.py index 7f5e302c3..5d4ce2bde 100644 --- a/imblearn/utils/tests/test_testing.py +++ b/imblearn/utils/tests/test_testing.py @@ -5,8 +5,12 @@ import pytest +import numpy as np + +from sklearn.neighbors._base import KNeighborsMixin + from imblearn.base import SamplerMixin -from imblearn.utils.testing import all_estimators +from imblearn.utils.testing import all_estimators, _CustomNearestNeighbors from imblearn.utils.testing import warns @@ -59,3 +63,25 @@ def test_warns_deprecation(): with warns(UserWarning): warnings.warn("value must be 42") assert "The warns function is deprecated" in str(record[0].message) + + +def test_custom_nearest_neighbors(): + """Check that our custom nearest neighbors can be used for our internal + duck-typing.""" + + neareat_neighbors = _CustomNearestNeighbors(n_neighbors=3) + + assert not isinstance(neareat_neighbors, KNeighborsMixin) + assert hasattr(neareat_neighbors, "kneighbors") + assert hasattr(neareat_neighbors, "kneighbors_graph") + + rng = np.random.RandomState(42) + X = rng.randn(150, 3) + y = rng.randint(0, 2, 150) + neareat_neighbors.fit(X, y) + + distances, indices = neareat_neighbors.kneighbors(X) + assert distances.shape == (150, 3) + assert indices.shape == (150, 3) + np.testing.assert_allclose(distances[:, 0], 0.0) + np.testing.assert_allclose(indices[:, 0], np.arange(150)) diff --git a/imblearn/utils/tests/test_validation.py b/imblearn/utils/tests/test_validation.py index c356d4881..81db35a37 100644 --- a/imblearn/utils/tests/test_validation.py +++ b/imblearn/utils/tests/test_validation.py @@ -13,10 +13,10 @@ from sklearn.neighbors import NearestNeighbors from sklearn.utils._testing import assert_array_equal -from imblearn.utils.testing import warns from imblearn.utils import check_neighbors_object from imblearn.utils import check_sampling_strategy from imblearn.utils import check_target_type +from imblearn.utils.testing import warns, _CustomNearestNeighbors from imblearn.utils._validation import ArraysTransformer from imblearn.utils._validation import _deprecate_positional_args @@ -36,8 +36,15 @@ def test_check_neighbors_object(): estimator = NearestNeighbors(n_neighbors=n_neighbors) estimator_cloned = check_neighbors_object(name, estimator) assert estimator.n_neighbors == estimator_cloned.n_neighbors + estimator = _CustomNearestNeighbors() + estimator_cloned = check_neighbors_object(name, estimator) + assert isinstance(estimator_cloned, _CustomNearestNeighbors) n_neighbors = "rnd" - with pytest.raises(ValueError, match="has to be one of"): + err_msg = ( + "n_neighbors must be an interger or an object compatible with the " + "KNeighborsMixin API of scikit-learn" + ) + with pytest.raises(ValueError, match=err_msg): check_neighbors_object(name, n_neighbors)