Skip to content

Commit d482829

Browse files
authored
[MRG] EHN Add a FunctionSampler (#342)
2 parents 3877493 + 792c29d commit d482829

File tree

8 files changed

+327
-10
lines changed

8 files changed

+327
-10
lines changed

doc/api.rst

+14-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ Prototype selection
9292
combine.SMOTEENN
9393
combine.SMOTETomek
9494

95-
9695
.. _ensemble_ref:
9796

9897
:mod:`imblearn.ensemble`: Ensemble methods
@@ -112,6 +111,20 @@ Prototype selection
112111
ensemble.BalancedBaggingClassifier
113112
ensemble.EasyEnsemble
114113

114+
.. _misc_ref:
115+
116+
Miscellaneous
117+
=============
118+
119+
Imbalance-learn provides some fast-prototyping tools.
120+
121+
.. currentmodule:: imblearn
122+
123+
.. autosummary::
124+
:toctree: generated/
125+
:template: class.rst
126+
127+
FunctionSampler
115128

116129
.. _pipeline_ref:
117130

doc/miscellaneous.rst

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
.. _miscellaneous:
2+
3+
======================
4+
Miscellaneous samplers
5+
======================
6+
7+
.. currentmodule:: imblearn
8+
9+
.. _function_sampler:
10+
11+
Custom samplers
12+
---------------
13+
14+
A fully customized sampler, :class:`FunctionSampler`, is available in
15+
imbalanced-learn such that you can fast prototype your own sampler by defining
16+
a single function. Additional parameters can be added using the attribute
17+
``kw_args`` which accepts a dictionary. The following example illustrates how
18+
to retain the 10 first elements of the array ``X`` and ``y``::
19+
20+
>>> import numpy as np
21+
>>> from imblearn import FunctionSampler
22+
>>> from sklearn.datasets import make_classification
23+
>>> X, y = make_classification(n_samples=5000, n_features=2, n_informative=2,
24+
... n_redundant=0, n_repeated=0, n_classes=3,
25+
... n_clusters_per_class=1,
26+
... weights=[0.01, 0.05, 0.94],
27+
... class_sep=0.8, random_state=0)
28+
>>> def func(X, y):
29+
... return X[:10], y[:10]
30+
>>> sampler = FunctionSampler(func=func)
31+
>>> X_res, y_res = sampler.fit_sample(X, y)
32+
>>> np.all(X_res == X[:10])
33+
True
34+
>>> np.all(y_res == y[:10])
35+
True
36+
37+
We illustrate the use of such sampler to implement an outlier rejection
38+
estimator which can be easily used within a
39+
:class:`imblearn.pipeline.Pipeline`:
40+
:ref:`sphx_glr_auto_examples_plot_outlier_rejections.py`

doc/user_guide.rst

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ User Guide
1414
under_sampling.rst
1515
combine.rst
1616
ensemble.rst
17+
miscellaneous.rst
1718
metrics.rst
1819
Dataset loading utilities <datasets/index.rst>
1920
developers_utils.rst

examples/plot_outlier_rejections.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""
2+
===============================================================
3+
Customized sampler to implement an outlier rejections estimator
4+
===============================================================
5+
6+
This example illustrates the use of a custom sampler to implement an outlier
7+
rejections estimator. It can be used easily within a pipeline in which the
8+
number of samples can vary during training, which usually is a limitation of
9+
the current scikit-learn pipeline.
10+
11+
"""
12+
13+
# Authors: Guillaume Lemaitre <[email protected]>
14+
# License: MIT
15+
16+
import numpy as np
17+
import matplotlib.pyplot as plt
18+
19+
from sklearn.datasets import make_moons, make_blobs
20+
from sklearn.ensemble import IsolationForest
21+
from sklearn.linear_model import LogisticRegression
22+
from sklearn.metrics import classification_report
23+
24+
from imblearn import FunctionSampler
25+
from imblearn.pipeline import make_pipeline
26+
27+
print(__doc__)
28+
29+
rng = np.random.RandomState(42)
30+
31+
32+
def plot_scatter(X, y, title):
33+
plt.figure()
34+
plt.scatter(X[y == 1, 0], X[y == 1, 1], label='Class #1')
35+
plt.scatter(X[y == 0, 0], X[y == 0, 1], label='Class #0')
36+
plt.legend()
37+
plt.title(title)
38+
39+
40+
# Generate contaminated training data
41+
moons, _ = make_moons(n_samples=500, noise=0.05)
42+
blobs, _ = make_blobs(n_samples=500, centers=[(-0.75, 2.25),
43+
(1.0, 2.0)],
44+
cluster_std=0.25)
45+
outliers = rng.uniform(low=-3, high=3, size=(500, 2))
46+
X_train = np.vstack([moons, blobs, outliers])
47+
y_train = np.hstack([np.ones(moons.shape[0], dtype=np.int8),
48+
np.zeros(blobs.shape[0], dtype=np.int8),
49+
rng.randint(0, 2, size=outliers.shape[0],
50+
dtype=np.int8)])
51+
52+
plot_scatter(X_train, y_train, 'Training dataset')
53+
54+
# Generate non-contaminated testing data
55+
moons, _ = make_moons(n_samples=50, noise=0.05)
56+
blobs, _ = make_blobs(n_samples=50, centers=[(-0.75, 2.25),
57+
(1.0, 2.0)],
58+
cluster_std=0.25)
59+
X_test = np.vstack([moons, blobs])
60+
y_test = np.hstack([np.ones(moons.shape[0], dtype=np.int8),
61+
np.zeros(blobs.shape[0], dtype=np.int8)])
62+
63+
plot_scatter(X_test, y_test, 'Testing dataset')
64+
65+
66+
def outlier_rejection(X, y):
67+
model = IsolationForest(max_samples=100,
68+
contamination=0.4,
69+
random_state=rng)
70+
model.fit(X)
71+
y_pred = model.predict(X)
72+
return X[y_pred == 1], y[y_pred == 1]
73+
74+
75+
reject_sampler = FunctionSampler(func=outlier_rejection)
76+
X_inliers, y_inliers = reject_sampler.fit_sample(X_train, y_train)
77+
plot_scatter(X_inliers, y_inliers, 'Training data without outliers')
78+
79+
pipe = make_pipeline(FunctionSampler(func=outlier_rejection),
80+
LogisticRegression(random_state=rng))
81+
y_pred = pipe.fit(X_train, y_train).predict(X_test)
82+
print(classification_report(y_test, y_pred))
83+
84+
clf = LogisticRegression(random_state=rng)
85+
y_pred = clf.fit(X_train, y_train).predict(X_test)
86+
print(classification_report(y_test, y_pred))
87+
88+
plt.show()

imblearn/__init__.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@
2626
Module which allowing to create pipeline with scikit-learn estimators.
2727
"""
2828

29+
from .base import FunctionSampler
2930
from ._version import __version__
3031

31-
# list all submodules available in imblearn and version
32-
__all__ = [
33-
'combine', 'ensemble', 'exceptions', 'metrics', 'over_sampling',
34-
'under_sampling', 'utils', 'pipeline', '__version__'
35-
]
32+
__all__ = ['FunctionSampler', '__version__']

imblearn/base.py

+104
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,107 @@ def fit(self, X, y):
159159
self.ratio_ = check_ratio(self.ratio, y, self._sampling_type)
160160

161161
return self
162+
163+
164+
def _identity(X, y):
165+
return X, y
166+
167+
168+
class FunctionSampler(SamplerMixin):
169+
"""Construct a sampler from calling an arbitrary callable.
170+
171+
Read more in the :ref:`User Guide <function_sampler>`.
172+
173+
Parameters
174+
----------
175+
func : callable or None,
176+
The callable to use for the transformation. This will be passed the
177+
same arguments as transform, with args and kwargs forwarded. If func is
178+
None, then func will be the identity function.
179+
180+
accept_sparse : bool, optional (default=True)
181+
Whether sparse input are supported. By default, sparse inputs are
182+
supported.
183+
184+
kw_args : dict, optional (default=None)
185+
The keyword argument expected by ``func``.
186+
187+
Notes
188+
-----
189+
190+
See
191+
:ref:`sphx_glr_auto_examples_plot_outlier_rejections.py`
192+
193+
Examples
194+
--------
195+
>>> import numpy as np
196+
>>> from sklearn.datasets import make_classification
197+
>>> from imblearn import FunctionSampler
198+
>>> X, y = make_classification(n_classes=2, class_sep=2,
199+
... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0,
200+
... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10)
201+
202+
We can create to select only the first ten samples for instance.
203+
204+
>>> def func(X, y):
205+
... return X[:10], y[:10]
206+
>>> sampler = FunctionSampler(func=func)
207+
>>> X_res, y_res = sampler.fit_sample(X, y)
208+
>>> np.all(X_res == X[:10])
209+
True
210+
>>> np.all(y_res == y[:10])
211+
True
212+
213+
We can also create a specific function which take some arguments.
214+
215+
>>> from collections import Counter
216+
>>> from imblearn.under_sampling import RandomUnderSampler
217+
>>> def func(X, y, ratio, random_state):
218+
... return RandomUnderSampler(ratio=ratio,
219+
... random_state=random_state).fit_sample(X, y)
220+
>>> sampler = FunctionSampler(func=func,
221+
... kw_args={'ratio': 'auto', 'random_state': 0})
222+
>>> X_res, y_res = sampler.fit_sample(X, y)
223+
>>> print('Resampled dataset shape {}'.format(
224+
... sorted(Counter(y_res).items())))
225+
Resampled dataset shape [(0, 100), (1, 100)]
226+
227+
"""
228+
229+
def __init__(self, func=None, accept_sparse=True, kw_args=None):
230+
self.func = func
231+
self.accept_sparse = accept_sparse
232+
self.kw_args = kw_args
233+
self.logger = logging.getLogger(__name__)
234+
235+
def _check_X_y(self, X, y):
236+
if self.accept_sparse:
237+
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
238+
else:
239+
X, y = check_X_y(X, y, accept_sparse=False)
240+
y = check_target_type(y)
241+
242+
return X, y
243+
244+
def fit(self, X, y):
245+
X, y = self._check_X_y(X, y)
246+
self.X_hash_, self.y_hash_ = hash_X_y(X, y)
247+
# when using a sampler, ratio_ is supposed to exist after fit
248+
self.ratio_ = 'is_fitted'
249+
250+
return self
251+
252+
def _sample(self, X, y, func=None, kw_args=None):
253+
X, y = self._check_X_y(X, y)
254+
check_is_fitted(self, 'ratio_')
255+
X_hash, y_hash = hash_X_y(X, y)
256+
if self.X_hash_ != X_hash or self.y_hash_ != y_hash:
257+
raise RuntimeError("X and y need to be same array earlier fitted.")
258+
259+
if func is None:
260+
func = _identity
261+
262+
return func(X, y, **(kw_args if self.kw_args else {}))
263+
264+
def sample(self, X, y):
265+
return self._sample(X, y, func=self.func, kw_args=self.kw_args)

imblearn/tests/test_base.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Test for miscellaneous samplers objects."""
2+
3+
# Authors: Guillaume Lemaitre <[email protected]>
4+
# License: MIT
5+
6+
import pytest
7+
8+
from scipy import sparse
9+
10+
from sklearn.datasets import load_iris
11+
from sklearn.utils.testing import assert_array_equal
12+
from sklearn.utils.testing import assert_allclose_dense_sparse
13+
14+
from imblearn.datasets import make_imbalance
15+
from imblearn import FunctionSampler
16+
from imblearn.under_sampling import RandomUnderSampler
17+
18+
iris = load_iris()
19+
X, y = make_imbalance(iris.data, iris.target, ratio={0: 10, 1: 25},
20+
random_state=0)
21+
22+
23+
def test_function_sampler_reject_sparse():
24+
X_sparse = sparse.csr_matrix(X)
25+
sampler = FunctionSampler(accept_sparse=False)
26+
with pytest.raises(TypeError, message="A sparse matrix was passed, "
27+
"but dense data is required"):
28+
sampler.fit(X_sparse, y)
29+
30+
31+
@pytest.mark.parametrize(
32+
"X, y",
33+
[(X, y),
34+
(sparse.csr_matrix(X), y),
35+
(sparse.csc_matrix(X), y)])
36+
def test_function_sampler_identity(X, y):
37+
sampler = FunctionSampler()
38+
X_res, y_res = sampler.fit_sample(X, y)
39+
assert_allclose_dense_sparse(X_res, X)
40+
assert_array_equal(y_res, y)
41+
42+
43+
@pytest.mark.parametrize(
44+
"X, y",
45+
[(X, y),
46+
(sparse.csr_matrix(X), y),
47+
(sparse.csc_matrix(X), y)])
48+
def test_function_sampler_func(X, y):
49+
50+
def func(X, y):
51+
return X[:10], y[:10]
52+
53+
sampler = FunctionSampler(func=func)
54+
X_res, y_res = sampler.fit_sample(X, y)
55+
assert_allclose_dense_sparse(X_res, X[:10])
56+
assert_array_equal(y_res, y[:10])
57+
58+
59+
@pytest.mark.parametrize(
60+
"X, y",
61+
[(X, y),
62+
(sparse.csr_matrix(X), y),
63+
(sparse.csc_matrix(X), y)])
64+
def test_function_sampler_func_kwargs(X, y):
65+
66+
def func(X, y, ratio, random_state):
67+
rus = RandomUnderSampler(ratio=ratio, random_state=random_state)
68+
return rus.fit_sample(X, y)
69+
70+
sampler = FunctionSampler(func=func, kw_args={'ratio': 'auto',
71+
'random_state': 0})
72+
X_res, y_res = sampler.fit_sample(X, y)
73+
X_res_2, y_res_2 = RandomUnderSampler(random_state=0).fit_sample(X, y)
74+
assert_allclose_dense_sparse(X_res, X_res_2)
75+
assert_array_equal(y_res, y_res_2)

imblearn/utils/estimator_checks.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from sklearn.utils.testing import assert_allclose
2626
from sklearn.utils.testing import set_random_state
2727

28-
from imblearn.base import SamplerMixin
2928
from imblearn.over_sampling.base import BaseOverSampler
3029
from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler
3130
from imblearn.ensemble.base import BaseEnsembleSampler
@@ -47,10 +46,10 @@ def _yield_sampler_checks(name, Estimator):
4746
yield check_samplers_pandas
4847

4948

50-
def _yield_all_checks(name, Estimator):
49+
def _yield_all_checks(name, estimator):
5150
# trigger our checks if this is a SamplerMixin
52-
if issubclass(Estimator, SamplerMixin):
53-
for check in _yield_sampler_checks(name, Estimator):
51+
if hasattr(estimator, 'sample'):
52+
for check in _yield_sampler_checks(name, estimator):
5453
yield check
5554

5655

0 commit comments

Comments
 (0)