Skip to content

Commit d9f7260

Browse files
author
ichkoar
committed
Refactor and add DaggingRegressor
1 parent a8e81c0 commit d9f7260

File tree

4 files changed

+145
-59
lines changed

4 files changed

+145
-59
lines changed

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Example
2121

2222
.. code-block:: python
2323
24-
from dagging import Dagging
24+
from dagging import DaggingClassifier
2525
from sklearn.datasets import load_iris
2626
2727
# Load Iris from from scikit-learn.

dagging/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ._version import get_versions
2-
from ._dagging import Dagging, DaggingClassifier # noqa
2+
from ._dagging import DaggingClassifier, DaggingRegressor # noqa
33

44
__version__ = get_versions()['version']
55
del get_versions

dagging/_dagging.py

Lines changed: 135 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,22 @@
11
import numpy as np
22

3-
from sklearn.base import ClassifierMixin
3+
from sklearn.base import ClassifierMixin, RegressorMixin, is_classifier, is_regressor
44
from sklearn.ensemble.base import BaseEnsemble
5-
from sklearn.model_selection import StratifiedKFold
5+
from sklearn.model_selection import check_cv
66
from sklearn.preprocessing import LabelEncoder
7-
from sklearn.tree import DecisionTreeClassifier
7+
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
88
from sklearn.utils import check_X_y
99
from sklearn.utils.multiclass import check_classification_targets
1010
from sklearn.utils.validation import check_is_fitted, check_random_state
1111

1212

13-
class Dagging(BaseEnsemble, ClassifierMixin):
14-
"""A Dagging classifier.
15-
This meta classifier creates a number of disjoint, stratified folds out of
16-
the data and feeds each chunk of data to a copy of the supplied base
17-
classifier. Predictions are made via hard or soft voting.
18-
Useful for base classifiers that are quadratic or worse in time behavior,
19-
regarding number of instances in the training data.
20-
21-
Parameters
22-
----------
23-
base_estimator : object or None, optional (default=None)
24-
The base estimator to fit on random subsets of the dataset.
25-
If None, then the base estimator is a decision tree.
26-
n_estimators : int, optional (default=3)
27-
The number of base estimators in the ensemble.
28-
voting : str, {'hard', 'soft'} (default='soft')
29-
If 'hard', uses predicted class labels for majority rule voting.
30-
Else if 'soft', predicts the class label based on the argmax of
31-
the sums of the predicted probabilities, which is recommended for
32-
an ensemble of well-calibrated classifiers.
33-
random_state : int, RandomState instance or None, optional (default=None)
34-
If int, random_state is the seed used by the random number generator;
35-
If RandomState instance, random_state is the random number generator;
36-
If None, the random number generator is the RandomState instance used
37-
by `np.random`.
38-
39-
Attributes
40-
----------
41-
base_estimator_ : estimator
42-
The base estimator from which the ensemble is grown.
43-
estimators_ : list of estimators
44-
The collection of fitted base estimators.
45-
References
46-
----------
47-
.. [1] Ting, K. M., Witten, I. H.: Stacking Bagged and Dagged Models.
48-
In: Fourteenth international Conference on Machine Learning,
49-
San Francisco, CA, 367-375, 1997
50-
"""
51-
13+
class BaseDagging(BaseEnsemble):
5214
def __init__(self,
5315
base_estimator=None,
5416
n_estimators=3,
5517
voting='soft',
5618
random_state=None):
57-
super(Dagging, self).__init__(
19+
super(BaseDagging, self).__init__(
5820
base_estimator=base_estimator,
5921
n_estimators=n_estimators)
6022
self.voting = voting
@@ -78,25 +40,29 @@ def fit(self, X, y):
7840
self : object
7941
"""
8042
X, y = check_X_y(X, y)
81-
check_classification_targets(y)
82-
83-
if isinstance(y, np.ndarray) and len(y.shape) > 1 and y.shape[1] > 1:
84-
raise NotImplementedError('Multilabel and multi-output'
85-
' classification is not supported.')
8643

8744
if self.voting not in ('soft', 'hard'):
8845
raise ValueError("Voting must be 'soft' or 'hard'; got (voting=%r)"
8946
% self.voting)
9047

9148
self._validate_estimator()
49+
if is_classifier(self.base_estimator_):
50+
check_classification_targets(y)
51+
if isinstance(y, np.ndarray) and len(y.shape) > 1 and y.shape[1] > 1:
52+
raise NotImplementedError('Multilabel and multi-output'
53+
' classification is not supported.')
54+
self.le_ = LabelEncoder().fit(y)
55+
self.classes_ = self.le_.classes_
56+
transformed_y = self.le_.transform(y)
57+
else:
58+
transformed_y = y
9259

93-
self.le_ = LabelEncoder().fit(y)
94-
self.classes_ = self.le_.classes_
9560
self.estimators_ = []
9661

97-
transformed_y = self.le_.transform(y)
9862
rs = check_random_state(self.random_state)
99-
splitter = StratifiedKFold(self.n_estimators, random_state=rs)
63+
splitter = check_cv(cv=self.n_estimators,
64+
y=transformed_y,
65+
classifier=is_classifier(self.base_estimator_))
10066

10167
for _, index in splitter.split(X, transformed_y):
10268
estimator = self._make_estimator(append=False,
@@ -106,6 +72,57 @@ def fit(self, X, y):
10672

10773
return self
10874

75+
76+
class DaggingClassifier(BaseDagging, ClassifierMixin):
77+
"""A Dagging classifier.
78+
This meta classifier creates a number of disjoint, stratified folds out of
79+
the data and feeds each chunk of data to a copy of the supplied base
80+
classifier. Predictions are made via hard or soft voting.
81+
Useful for base classifiers that are quadratic or worse in time behavior,
82+
regarding number of instances in the training data.
83+
84+
Parameters
85+
----------
86+
base_estimator : object or None, optional (default=None)
87+
The base estimator to fit on random subsets of the dataset.
88+
If None, then the base estimator is a decision tree.
89+
n_estimators : int, optional (default=3)
90+
The number of base estimators in the ensemble.
91+
voting : str, {'hard', 'soft'} (default='soft')
92+
If 'hard', uses predicted class labels for majority rule voting.
93+
Else if 'soft', predicts the class label based on the argmax of
94+
the sums of the predicted probabilities, which is recommended for
95+
an ensemble of well-calibrated classifiers.
96+
random_state : int, RandomState instance or None, optional (default=None)
97+
If int, random_state is the seed used by the random number generator;
98+
If RandomState instance, random_state is the random number generator;
99+
If None, the random number generator is the RandomState instance used
100+
by `np.random`.
101+
102+
Attributes
103+
----------
104+
base_estimator_ : estimator
105+
The base estimator from which the ensemble is grown.
106+
estimators_ : list of estimators
107+
The collection of fitted base estimators.
108+
References
109+
----------
110+
.. [1] Ting, K. M., Witten, I. H.: Stacking Bagged and Dagged Models.
111+
In: Fourteenth international Conference on Machine Learning,
112+
San Francisco, CA, 367-375, 1997
113+
"""
114+
115+
def __init__(self,
116+
base_estimator=None,
117+
n_estimators=3,
118+
voting='soft',
119+
random_state=None):
120+
super(DaggingClassifier, self).__init__(
121+
base_estimator=base_estimator,
122+
n_estimators=n_estimators,
123+
voting=voting,
124+
random_state=random_state)
125+
109126
def predict(self, X):
110127
""" Predict class labels for X.
111128
Parameters
@@ -168,8 +185,72 @@ def predict_proba(self):
168185

169186
def _validate_estimator(self):
170187
"""Check the estimator and set the base_estimator_ attribute."""
171-
super(Dagging, self)._validate_estimator(
188+
super(DaggingClassifier, self)._validate_estimator(
172189
default=DecisionTreeClassifier())
173190

174191

175-
DaggingClassifier = Dagging
192+
class DaggingRegressor(BaseDagging, RegressorMixin):
193+
"""A Dagging regressor.
194+
This meta regressor creates a number of disjoint, stratified folds out of
195+
the data and feeds each chunk of data to a copy of the supplied base
196+
regressor. Predictions are made via hard or soft voting.
197+
Useful for base regressor that are quadratic or worse in time behavior,
198+
regarding number of instances in the training data.
199+
200+
Parameters
201+
----------
202+
base_estimator : object or None, optional (default=None)
203+
The base estimator to fit on random subsets of the dataset.
204+
If None, then the base estimator is a decision tree.
205+
n_estimators : int, optional (default=3)
206+
The number of base estimators in the ensemble.
207+
random_state : int, RandomState instance or None, optional (default=None)
208+
If int, random_state is the seed used by the random number generator;
209+
If RandomState instance, random_state is the random number generator;
210+
If None, the random number generator is the RandomState instance used
211+
by `np.random`.
212+
213+
Attributes
214+
----------
215+
base_estimator_ : estimator
216+
The base estimator from which the ensemble is grown.
217+
estimators_ : list of estimators
218+
The collection of fitted base estimators.
219+
References
220+
----------
221+
.. [1] Ting, K. M., Witten, I. H.: Stacking Bagged and Dagged Models.
222+
In: Fourteenth international Conference on Machine Learning,
223+
San Francisco, CA, 367-375, 1997
224+
"""
225+
226+
def __init__(self,
227+
base_estimator=None,
228+
n_estimators=3,
229+
random_state=None):
230+
super(DaggingRegressor, self).__init__(
231+
base_estimator=base_estimator,
232+
n_estimators=n_estimators,
233+
random_state=random_state)
234+
235+
def predict(self, X):
236+
""" Predict class labels for X.
237+
Parameters
238+
----------
239+
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
240+
Training vectors, where n_samples is the number of samples and
241+
n_features is the number of features.
242+
Returns
243+
----------
244+
maj : array-like, shape = [n_samples]
245+
Predicted class labels.
246+
"""
247+
check_is_fitted(self, 'estimators_')
248+
predictions = []
249+
for estimator in self.estimators_:
250+
predictions.append(estimator.predict(X))
251+
return np.average(predictions, axis=0)
252+
253+
def _validate_estimator(self):
254+
"""Check the estimator and set the base_estimator_ attribute."""
255+
super(DaggingRegressor, self)._validate_estimator(
256+
default=DecisionTreeRegressor())

dagging/tests/test_dagging.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
from sklearn.utils.estimator_checks import check_estimator
22

3-
from dagging import Dagging
3+
from dagging import DaggingClassifier, DaggingRegressor
44

55

6-
def test_check_estimator():
7-
model = Dagging(random_state=0)
6+
def test_check_dagging_classifier():
7+
model = DaggingClassifier(random_state=0, n_estimators=2)
8+
check_estimator(model)
9+
10+
11+
def test_check_dagging_regressor():
12+
model = DaggingRegressor(random_state=0)
813
check_estimator(model)

0 commit comments

Comments
 (0)