Skip to content

Commit efec02e

Browse files
author
chkoar
committed
Add EasyEnsembleGeneralization
1 parent dd941e7 commit efec02e

File tree

3 files changed

+289
-1
lines changed

3 files changed

+289
-1
lines changed

imblearn/ensemble/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
from .easy_ensemble import EasyEnsemble
7+
from .easy_ensemble_generalization import EasyEnsembleGeneralization
78
from .balance_cascade import BalanceCascade
89

9-
__all__ = ['EasyEnsemble', 'BalanceCascade']
10+
__all__ = ['EasyEnsemble', 'EasyEnsembleGeneralization', 'BalanceCascade']
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
"Easy Ensemble Generalization"
2+
3+
# Authors: Christos Aridas
4+
#
5+
# License: MIT
6+
7+
import numpy as np
8+
9+
from sklearn.base import ClassifierMixin, clone
10+
from sklearn.ensemble import BaseEnsemble, VotingClassifier
11+
from sklearn.tree import DecisionTreeClassifier
12+
from sklearn.utils import check_random_state
13+
from sklearn.utils.validation import check_is_fitted
14+
15+
from ..pipeline import Pipeline
16+
from ..under_sampling import RandomUnderSampler as ROS
17+
18+
19+
MAX_INT = np.iinfo(np.int32).max
20+
21+
22+
class EasyEnsembleGeneralization(BaseEnsemble, ClassifierMixin):
23+
"""This classifier generalize the Easy Ensemble algorithm for imbalanced
24+
datasets.
25+
26+
Parameters
27+
----------
28+
estimator : object or None, optional (default=None)
29+
Invoking the ``fit`` method on the ``EasyEnsembleGeneralization`` will fit clones
30+
of those original estimators that will be stored in the class attribute
31+
``self.estimators_``. An estimator can be set to `None` using
32+
``set_params``.
33+
34+
sampler: object or None, optional (default=None)
35+
Invoking the ``fit`` method on the ``EasyEnsembleGeneralization`` will fit clones
36+
of those original samplers.
37+
38+
n_estimators : int, optional (default=10)
39+
The number of base estimators in the ensemble.
40+
41+
voting : str, {'hard', 'soft'} (default='hard')
42+
If 'hard', uses predicted class labels for majority rule voting.
43+
Else if 'soft', predicts the class label based on the argmax of
44+
the sums of the predicted probabilities, which is recommended for
45+
an ensemble of well-calibrated classifiers.
46+
47+
random_state : int, RandomState instance or None, optional (default=None)
48+
If int, random_state is the seed used by the random number generator;
49+
If RandomState instance, random_state is the random number generator;
50+
If None, the random number generator is the RandomState instance used
51+
by `np.random`.
52+
53+
n_jobs : int, optional (default=1)
54+
The number of jobs to run in parallel for ``fit``.
55+
If -1, then the number of jobs is set to the number of cores.
56+
57+
Attributes
58+
----------
59+
estimators_ : list of classifiers
60+
The collection of fitted estimators.
61+
62+
classes_ : array-like, shape = [n_predictions]
63+
The classes labels.
64+
65+
Examples
66+
--------
67+
>>>import numpy as np
68+
>>>from imblearn.ensemble import EasyEnsembleGeneralization as EEG
69+
>>>X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
70+
>>>y = np.array([1, 1, 1, 2, 2, 2])
71+
>>>eeg = EEG(voting='hard', random_state=0)
72+
>>>eeg.fit(X,y)
73+
>>>eeg.predict(X)
74+
[1 1 1 2 2 2]
75+
>>>
76+
"""
77+
78+
def __init__(self,
79+
base_estimator=None,
80+
base_sampler=None,
81+
n_estimators=5,
82+
voting='hard',
83+
random_state=None,
84+
n_jobs=1):
85+
86+
self.base_estimator = base_estimator
87+
self.base_sampler = base_sampler
88+
self.n_estimators = n_estimators
89+
self.voting = voting
90+
self.random_state = random_state
91+
self.n_jobs = n_jobs
92+
93+
def _validate_estimator(self):
94+
"""Check the estimator and set the base_estimator_ attribute."""
95+
super(EasyEnsembleGeneralization, self)._validate_estimator(
96+
default=DecisionTreeClassifier())
97+
98+
def _validate_sampler(self):
99+
"""Check the sampler and set the base_sampler_ attribute."""
100+
101+
if self.base_sampler is not None:
102+
self.base_sampler_ = self.base_sampler
103+
else:
104+
self.base_sampler_ = ROS()
105+
106+
if self.base_sampler_ is None:
107+
raise ValueError("base_sampler cannot be None")
108+
109+
def fit(self, X, y, sample_weight=None):
110+
"""Build an ensemble of estimators from the training set (X, y).
111+
112+
Parameters
113+
----------
114+
X : {array-like, sparse matrix} of shape = [n_samples, n_features]
115+
The training input samples. Sparse matrices are accepted only if
116+
they are supported by the base estimator.
117+
118+
y : array-like, shape = [n_samples]
119+
The target values (class labels in classification, real numbers in
120+
regression).
121+
122+
sample_weight : array-like, shape = [n_samples] or None
123+
Sample weights. If None, then samples are equally weighted.
124+
Note that this is supported only if the base estimator supports
125+
sample weighting.
126+
127+
Returns
128+
-------
129+
self : object
130+
Returns self.
131+
"""
132+
133+
random_state = check_random_state(self.random_state)
134+
135+
self._validate_estimator()
136+
self._validate_sampler()
137+
138+
random_state = check_random_state(self.random_state)
139+
estimator_seeds = random_state.randint(MAX_INT, size=self.n_estimators)
140+
sampler_seeds = random_state.randint(MAX_INT, size=self.n_estimators)
141+
142+
if not hasattr(self.base_sampler, 'random_state'):
143+
ValueError('Base sampler must have a random_state parameter')
144+
145+
pipelines = []
146+
seeds = zip(estimator_seeds, sampler_seeds)
147+
148+
for i, (estimator_seed, sampler_seed) in enumerate(seeds):
149+
150+
sampler = clone(self.base_sampler_)
151+
sampler.set_params(random_state=sampler_seed)
152+
153+
if hasattr(self.base_estimator_, 'random_state'):
154+
estimator = clone(self.base_estimator_)
155+
estimator.set_params(random_state=estimator_seed)
156+
else:
157+
estimator = clone(self.base_estimator_)
158+
steps = [('sampler', sampler), ('estimator', estimator)]
159+
pipeline = Pipeline(steps)
160+
pipelines.append(pipeline)
161+
162+
ensemble_members = [[str(i), pipeline]
163+
for i, pipeline in enumerate(pipelines)]
164+
165+
self._voting = VotingClassifier(ensemble_members,
166+
voting=self.voting,
167+
n_jobs=self.n_jobs)
168+
self._voting.fit(X, y)
169+
170+
self.classes_ = self._voting.classes_
171+
self.estimators_ = [pipeline.named_steps['estimator']
172+
for pipeline in self._voting.estimators_]
173+
174+
return self
175+
176+
def predict(self, X):
177+
""" Predict class labels for X.
178+
179+
Parameters
180+
----------
181+
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
182+
Training vectors, where n_samples is the number of samples and
183+
n_features is the number of features.
184+
185+
Returns
186+
----------
187+
maj : array-like, shape = [n_samples]
188+
Predicted class labels.
189+
"""
190+
check_is_fitted(self, "_voting")
191+
return self._voting.predict(X)
192+
193+
def predict_proba(self, X):
194+
"""Compute probabilities of possible outcomes for all samples in X.
195+
196+
Parameters
197+
----------
198+
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
199+
Training vectors, where n_samples is the number of samples and
200+
n_features is the number of features.
201+
202+
Returns
203+
----------
204+
avg : array-like, shape = [n_samples, n_classes]
205+
Weighted average probability for each class per sample.
206+
"""
207+
check_is_fitted(self, "_voting")
208+
return self._voting.predict_proba(X)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Testing for the VotingClassifier"""
2+
3+
from __future__ import print_function
4+
5+
import numpy as np
6+
from sklearn.exceptions import NotFittedError
7+
from sklearn.model_selection import GridSearchCV, cross_val_score
8+
from sklearn.utils.testing import assert_almost_equal, assert_array_equal
9+
from sklearn.utils.testing import assert_equal, assert_true, assert_false
10+
from sklearn.utils.testing import assert_raise_message
11+
12+
from imblearn.ensemble import EasyEnsembleGeneralization as EEG
13+
14+
RND_SEED = 0
15+
X = np.array([[0.11622591, -0.0317206], [0.77481731, 0.60935141],
16+
[1.25192108, -0.22367336], [0.53366841, -0.30312976],
17+
[1.52091956, -0.49283504], [-0.28162401, -2.10400981],
18+
[0.83680821, 1.72827342], [0.3084254, 0.33299982],
19+
[0.70472253, -0.73309052], [0.28893132, -0.38761769],
20+
[1.15514042, 0.0129463], [0.88407872, 0.35454207],
21+
[1.31301027, -0.92648734], [-1.11515198, -0.93689695],
22+
[-0.18410027, -0.45194484], [0.9281014, 0.53085498],
23+
[-0.14374509, 0.27370049], [-0.41635887, -0.38299653],
24+
[0.08711622, 0.93259929], [1.70580611, -0.11219234]])
25+
y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
26+
27+
28+
def test_estimator_init():
29+
30+
eeg = EEG(n_estimators=0)
31+
msg = "n_estimators must be greater than zero, got 0."
32+
assert_raise_message(ValueError, msg, eeg.fit, X, y)
33+
34+
35+
def test_predict_proba_hardvoting():
36+
eeg = EEG(voting='hard', random_state=RND_SEED).fit(X, y)
37+
msg = "predict_proba is not available when voting='hard'"
38+
assert_raise_message(AttributeError, msg, eeg.predict_proba, X)
39+
40+
41+
def test_notfitted():
42+
eeg = EEG()
43+
msg = ("This EasyEnsembleGeneralization instance is not fitted yet. Call \'fit\'"
44+
" with appropriate arguments before using this method.")
45+
assert_raise_message(NotFittedError, msg, eeg.predict_proba, X)
46+
47+
48+
def test_majority_label():
49+
"""Check classification by majority vote."""
50+
eeg = EEG(voting='soft', random_state=RND_SEED)
51+
scores = cross_val_score(eeg, X, y, cv=5, scoring='roc_auc')
52+
print(scores.mean())
53+
assert_almost_equal(scores.mean(), 0.625, decimal=2)
54+
55+
56+
def test_predict_on_toy_problem():
57+
"""Manually check predicted class labels for the toy dataset."""
58+
eeg = EEG(voting='hard', random_state=RND_SEED)
59+
assert_equal(all(eeg.fit(X, y).predict(X[0:6])), all([0, 1, 0, 0, 0, 1]))
60+
61+
62+
def test_gridsearch():
63+
"""Check GridSearch support."""
64+
eeg = EEG(random_state=RND_SEED)
65+
66+
params = {'voting': ['soft', 'hard'],
67+
'n_estimators': [2, 3, 4]}
68+
69+
grid = GridSearchCV(estimator=eeg, param_grid=params, cv=3)
70+
grid.fit(X, y)
71+
72+
73+
def test_parallel_predict():
74+
"""Check parallel backend of EasyEnsembleGeneralization on the toy dataset."""
75+
eeg1 = EEG(voting='soft', random_state=RND_SEED, n_jobs=1).fit(X, y)
76+
eeg2 = EEG(voting='soft', random_state=RND_SEED, n_jobs=2).fit(X, y)
77+
78+
assert_array_equal(eeg1.predict(X), eeg2.predict(X))
79+
assert_array_equal(eeg1.predict_proba(X), eeg2.predict_proba(X))

0 commit comments

Comments
 (0)