forked from scikit-learn-contrib/imbalanced-learn
-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathsmote_enn.py
179 lines (142 loc) · 6.42 KB
/
smote_enn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""Class to perform over-sampling using SMOTE and cleaning using ENN."""
# Authors: Guillaume Lemaitre <[email protected]>
# Christos Aridas
# License: MIT
from __future__ import division
import logging
from sklearn.utils import check_X_y
from ..base import SamplerMixin
from ..over_sampling import SMOTE
from ..under_sampling import EditedNearestNeighbours
from ..utils import check_target_type, hash_X_y
class SMOTEENN(SamplerMixin):
"""Class to perform over-sampling using SMOTE and cleaning using ENN.
Combine over- and under-sampling using SMOTE and Edited Nearest Neighbours.
Read more in the :ref:`User Guide <combine>`.
Parameters
----------
ratio : str, dict, or callable, optional (default='auto')
Ratio to use for resampling the data set.
- If ``str``, has to be one of: (i) ``'minority'``: resample the
minority class; (ii) ``'majority'``: resample the majority class,
(iii) ``'not minority'``: resample all classes apart of the minority
class, (iv) ``'all'``: resample all classes, and (v) ``'auto'``:
correspond to ``'all'`` with for over-sampling methods and ``'not
minority'`` for under-sampling methods. The classes targeted will be
over-sampled or under-sampled to achieve an equal number of sample
with the majority or minority class.
- If ``dict``, the keys correspond to the targeted classes. The values
correspond to the desired number of samples.
- If callable, function taking ``y`` and returns a ``dict``. The keys
correspond to the targeted classes. The values correspond to the
desired number of samples.
random_state : int, RandomState instance or None, optional (default=None)
If int, ``random_state`` is the seed used by the random number
generator; If ``RandomState`` instance, random_state is the random
number generator; If ``None``, the random number generator is the
``RandomState`` instance used by ``np.random``.
smote : object, optional (default=SMOTE())
The :class:`imblearn.over_sampling.SMOTE` object to use. If not given,
a :class:`imblearn.over_sampling.SMOTE` object with default parameters
will be given.
Notes
-----
The method is presented in [1]_.
Supports mutli-class resampling. Refer to SMOTE and ENN regarding the
scheme which used.
See :ref:`sphx_glr_auto_examples_combine_plot_smote_enn.py` and
:ref:`sphx_glr_auto_examples_combine_plot_comparison_combine.py`.
See also
--------
SMOTETomek : Over-sample using SMOTE followed by under-sampling removing
the Tomek's links.
References
----------
.. [1] G. Batista, R. C. Prati, M. C. Monard. "A study of the behavior of
several methods for balancing machine learning training data," ACM
Sigkdd Explorations Newsletter 6 (1), 20-29, 2004.
Examples
--------
>>> from collections import Counter
>>> from sklearn.datasets import make_classification
>>> from imblearn.combine import SMOTEENN # doctest: +NORMALIZE_WHITESPACE
>>> X, y = make_classification(n_classes=2, class_sep=2,
... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0,
... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10)
>>> print('Original dataset shape {}'.format(Counter(y)))
Original dataset shape Counter({1: 900, 0: 100})
>>> sme = SMOTEENN(random_state=42)
>>> X_res, y_res = sme.fit_sample(X, y)
>>> print('Resampled dataset shape {}'.format(Counter(y_res)))
Resampled dataset shape Counter({0: 900, 1: 881})
"""
def __init__(self,
ratio='auto',
random_state=None,
smote=None,
enn=None):
super(SMOTEENN, self).__init__()
self.ratio = ratio
self.random_state = random_state
self.smote = smote
self.enn = enn
self.logger = logging.getLogger(__name__)
def _validate_estimator(self):
"Private function to validate SMOTE and ENN objects"
if self.smote is not None:
if isinstance(self.smote, SMOTE):
self.smote_ = self.smote
else:
raise ValueError('smote needs to be a SMOTE object.'
'Got {} instead.'.format(type(self.smote)))
# Otherwise create a default SMOTE
else:
self.smote_ = SMOTE(
ratio=self.ratio, random_state=self.random_state)
if self.enn is not None:
if isinstance(self.enn, EditedNearestNeighbours):
self.enn_ = self.enn
else:
raise ValueError('enn needs to be an EditedNearestNeighbours.'
' Got {} instead.'.format(type(self.enn)))
# Otherwise create a default EditedNearestNeighbours
else:
self.enn_ = EditedNearestNeighbours(ratio='all',
random_state=self.random_state)
def fit(self, X, y):
"""Find the classes statistics before to perform sampling.
Parameters
----------
X : {array-like, sparse matrix}, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.
y : array-like, shape (n_samples,)
Corresponding label for each sample in X.
Returns
-------
self : object,
Return self.
"""
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'], multi_output=True)
y = check_target_type(y, self)
self.ratio_ = self.ratio
self.X_hash_, self.y_hash_ = hash_X_y(X, y)
return self
def _sample(self, X, y):
"""Resample the dataset.
Parameters
----------
X : {array-like, sparse matrix}, shape (n_samples, n_features)
Matrix containing the data which have to be sampled.
y : array-like, shape (n_samples,)
Corresponding label for each sample in X.
Returns
-------
X_resampled : {ndarray, sparse matrix}, shape \
(n_samples_new, n_features)
The array containing the resampled data.
y_resampled : ndarray, shape (n_samples_new)
The corresponding label of `X_resampled`
"""
self._validate_estimator()
X_res, y_res = self.smote_.fit_sample(X, y)
return self.enn_.fit_sample(X_res, y_res)