Skip to content

Commit 427486a

Browse files
committed
move the sampler to the base module
1 parent 30be926 commit 427486a

File tree

7 files changed

+139
-135
lines changed

7 files changed

+139
-135
lines changed

doc/api.rst

+14-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ Prototype selection
9292
combine.SMOTEENN
9393
combine.SMOTETomek
9494

95-
9695
.. _ensemble_ref:
9796

9897
:mod:`imblearn.ensemble`: Ensemble methods
@@ -112,6 +111,20 @@ Prototype selection
112111
ensemble.BalancedBaggingClassifier
113112
ensemble.EasyEnsemble
114113

114+
.. _misc_ref:
115+
116+
Miscellaneous
117+
=============
118+
119+
Imbalance-learn provides some fast-prototyping tools.
120+
121+
.. currentmodule:: imblearn
122+
123+
.. autosummary::
124+
:toctree: generated/
125+
:template: class.rst
126+
127+
FunctionSampler
115128

116129
.. _pipeline_ref:
117130

doc/miscellaneous.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Miscellaneous samplers
55
======================
66

7-
.. currentmodule:: imblearn.misc
7+
.. currentmodule:: imblearn
88

99
.. _function_sampler:
1010

@@ -18,7 +18,7 @@ a single function. Additional parameters can be added using the attribute
1818
to retain the 10 first elements of the array ``X`` and ``y``::
1919

2020
>>> import numpy as np
21-
>>> from imblearn.misc import FunctionSampler
21+
>>> from imblearn import FunctionSampler
2222
>>> from sklearn.datasets import make_classification
2323
>>> X, y = make_classification(n_samples=5000, n_features=2, n_informative=2,
2424
... n_redundant=0, n_repeated=0, n_classes=3,

examples/plot_outlier_rejections.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sklearn.linear_model import LogisticRegression
2222
from sklearn.metrics import classification_report
2323

24-
from imblearn.misc import FunctionSampler
24+
from imblearn import FunctionSampler
2525
from imblearn.pipeline import make_pipeline
2626

2727
print(__doc__)

imblearn/__init__.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@
2626
Module which allowing to create pipeline with scikit-learn estimators.
2727
"""
2828

29+
from .base import FunctionSampler
2930
from ._version import __version__
3031

31-
# list all submodules available in imblearn and version
32-
__all__ = [
33-
'combine', 'ensemble', 'exceptions', 'metrics', 'over_sampling',
34-
'under_sampling', 'utils', 'pipeline', '__version__'
35-
]
32+
__all__ = ['FunctionSampler', '__version__']

imblearn/base.py

+104
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,107 @@ def fit(self, X, y):
159159
self.ratio_ = check_ratio(self.ratio, y, self._sampling_type)
160160

161161
return self
162+
163+
164+
def _identity(X, y):
165+
return X, y
166+
167+
168+
class FunctionSampler(SamplerMixin):
169+
"""Construct a sampler from calling an arbitrary callable.
170+
171+
Read more in the :ref:`User Guide <function_sampler>`.
172+
173+
Parameters
174+
----------
175+
func : callable or None,
176+
The callable to use for the transformation. This will be passed the
177+
same arguments as transform, with args and kwargs forwarded. If func is
178+
None, then func will be the identity function.
179+
180+
accept_sparse : bool, optional (default=True)
181+
Whether sparse input are supported. By default, sparse inputs are
182+
supported.
183+
184+
kw_args : dict, optional (default=None)
185+
The keyword argument expected by ``func``.
186+
187+
Notes
188+
-----
189+
190+
See
191+
:ref:`sphx_glr_auto_examples_plot_outlier_rejections.py`
192+
193+
Examples
194+
--------
195+
>>> import numpy as np
196+
>>> from sklearn.datasets import make_classification
197+
>>> from imblearn import FunctionSampler
198+
>>> X, y = make_classification(n_classes=2, class_sep=2,
199+
... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0,
200+
... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10)
201+
202+
We can create to select only the first ten samples for instance.
203+
204+
>>> def func(X, y):
205+
... return X[:10], y[:10]
206+
>>> sampler = FunctionSampler(func=func)
207+
>>> X_res, y_res = sampler.fit_sample(X, y)
208+
>>> np.all(X_res == X[:10])
209+
True
210+
>>> np.all(y_res == y[:10])
211+
True
212+
213+
We can also create a specific function which take some arguments.
214+
215+
>>> from collections import Counter
216+
>>> from imblearn.under_sampling import RandomUnderSampler
217+
>>> def func(X, y, ratio, random_state):
218+
... return RandomUnderSampler(ratio=ratio,
219+
... random_state=random_state).fit_sample(X, y)
220+
>>> sampler = FunctionSampler(func=func,
221+
... kw_args={'ratio': 'auto', 'random_state': 0})
222+
>>> X_res, y_res = sampler.fit_sample(X, y)
223+
>>> print('Resampled dataset shape {}'.format(
224+
... sorted(Counter(y_res).items())))
225+
Resampled dataset shape [(0, 100), (1, 100)]
226+
227+
"""
228+
229+
def __init__(self, func=None, accept_sparse=True, kw_args=None):
230+
self.func = func
231+
self.accept_sparse = accept_sparse
232+
self.kw_args = kw_args
233+
self.logger = logging.getLogger(__name__)
234+
235+
def _check_X_y(self, X, y):
236+
if self.accept_sparse:
237+
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
238+
else:
239+
X, y = check_X_y(X, y, accept_sparse=False)
240+
y = check_target_type(y)
241+
242+
return X, y
243+
244+
def fit(self, X, y):
245+
X, y = self._check_X_y(X, y)
246+
self.X_hash_, self.y_hash_ = hash_X_y(X, y)
247+
# when using a sampler, ratio_ is supposed to exist after fit
248+
self.ratio_ = 'is_fitted'
249+
250+
return self
251+
252+
def _sample(self, X, y, func=None, kw_args=None):
253+
X, y = self._check_X_y(X, y)
254+
check_is_fitted(self, 'ratio_')
255+
X_hash, y_hash = hash_X_y(X, y)
256+
if self.X_hash_ != X_hash or self.y_hash_ != y_hash:
257+
raise RuntimeError("X and y need to be same array earlier fitted.")
258+
259+
if func is None:
260+
func = _identity
261+
262+
return func(X, y, **(kw_args if self.kw_args else {}))
263+
264+
def sample(self, X, y):
265+
return self._sample(X, y, func=self.func, kw_args=self.kw_args)

imblearn/misc.py

-116
This file was deleted.

imblearn/tests/test_misc.py renamed to imblearn/tests/test_base.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sklearn.utils.testing import assert_allclose_dense_sparse
1313

1414
from imblearn.datasets import make_imbalance
15-
from imblearn.misc import FunctionSampler
15+
from imblearn import FunctionSampler
1616
from imblearn.under_sampling import RandomUnderSampler
1717

1818
iris = load_iris()
@@ -28,19 +28,23 @@ def test_function_sampler_reject_sparse():
2828
sampler.fit(X_sparse, y)
2929

3030

31-
@pytest.mark.parametrize("X,y", [(X, y),
32-
(sparse.csr_matrix(X), y),
33-
(sparse.csc_matrix(X), y)])
31+
@pytest.mark.parametrize(
32+
"X, y",
33+
[(X, y),
34+
(sparse.csr_matrix(X), y),
35+
(sparse.csc_matrix(X), y)])
3436
def test_function_sampler_identity(X, y):
3537
sampler = FunctionSampler()
3638
X_res, y_res = sampler.fit_sample(X, y)
3739
assert_allclose_dense_sparse(X_res, X)
3840
assert_array_equal(y_res, y)
3941

4042

41-
@pytest.mark.parametrize("X,y", [(X, y),
42-
(sparse.csr_matrix(X), y),
43-
(sparse.csc_matrix(X), y)])
43+
@pytest.mark.parametrize(
44+
"X, y",
45+
[(X, y),
46+
(sparse.csr_matrix(X), y),
47+
(sparse.csc_matrix(X), y)])
4448
def test_function_sampler_func(X, y):
4549

4650
def func(X, y):
@@ -52,9 +56,11 @@ def func(X, y):
5256
assert_array_equal(y_res, y[:10])
5357

5458

55-
@pytest.mark.parametrize("X,y", [(X, y),
56-
(sparse.csr_matrix(X), y),
57-
(sparse.csc_matrix(X), y)])
59+
@pytest.mark.parametrize(
60+
"X, y",
61+
[(X, y),
62+
(sparse.csr_matrix(X), y),
63+
(sparse.csc_matrix(X), y)])
5864
def test_function_sampler_func_kwargs(X, y):
5965

6066
def func(X, y, ratio, random_state):

0 commit comments

Comments
 (0)