Skip to content

Commit e78846f

Browse files
Adam Cooperrasbt
Adam Cooper
authored andcommitted
add groups parameter to SFS and EFS fit(), for forwarding to sklearn cv (#537)
* add groups parameter to SFS and EFS fit(), for forwarding to sklearn cv * pep8 fix and bumping up the changelog note to the 0.17 release
1 parent b276396 commit e78846f

File tree

5 files changed

+101
-18
lines changed

5 files changed

+101
-18
lines changed

docs/sources/CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ The CHANGELOG for the current development version is available at
1616

1717
##### New Features
1818

19-
- -
19+
- Add optional `groups` parameter to `SequentialFeatureSelector` and `ExhaustiveFeatureSelector` `fit()` methods for forwarding to sklearn CV ([#537](https://github.com/rasbt/mlxtend/pull/537) via [arc12](https://github.com/qiaguhttps://github.com/arc12))
2020

2121
##### Changes
2222

mlxtend/feature_selection/exhaustive_feature_selector.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525
from sklearn.externals.joblib import Parallel, delayed
2626

2727

28-
def _calc_score(selector, X, y, indices, **fit_params):
28+
def _calc_score(selector, X, y, indices, groups=None, **fit_params):
2929
if selector.cv:
3030
scores = cross_val_score(selector.est_,
3131
X[:, indices], y,
32+
groups=groups,
3233
cv=selector.cv,
3334
scoring=selector.scorer,
3435
n_jobs=1,
@@ -175,7 +176,7 @@ def __init__(self, estimator, min_features=1, max_features=1,
175176
# don't mess with this unless testing
176177
self._TESTING_INTERRUPT_MODE = False
177178

178-
def fit(self, X, y, custom_feature_names=None, **fit_params):
179+
def fit(self, X, y, custom_feature_names=None, groups=None, **fit_params):
179180
"""Perform feature selection and learn model from training data.
180181
181182
Parameters
@@ -191,6 +192,9 @@ def fit(self, X, y, custom_feature_names=None, **fit_params):
191192
Custom feature names for `self.k_feature_names` and
192193
`self.subsets_[i]['feature_names']`.
193194
(new in v 0.13.0)
195+
groups : array-like, with shape (n_samples,), optional
196+
Group labels for the samples used while splitting the dataset into
197+
train/test set. Passed to the fit method of the cross-validator.
194198
fit_params : dict of string -> object, optional
195199
Parameters to pass to to the fit method of classifier.
196200
@@ -268,7 +272,7 @@ def ncr(n, r):
268272
n_jobs = min(self.n_jobs, all_comb)
269273
parallel = Parallel(n_jobs=n_jobs, pre_dispatch=self.pre_dispatch)
270274
work = enumerate(parallel(delayed(_calc_score)
271-
(self, X_, y, c, **fit_params)
275+
(self, X_, y, c, groups=groups, **fit_params)
272276
for c in candidates))
273277

274278
try:
@@ -336,7 +340,7 @@ def transform(self, X):
336340
X_ = X
337341
return X_[:, self.best_idx_]
338342

339-
def fit_transform(self, X, y, **fit_params):
343+
def fit_transform(self, X, y, groups=None, **fit_params):
340344
"""Fit to training data and return the best selected features from X.
341345
342346
Parameters
@@ -348,6 +352,9 @@ def fit_transform(self, X, y, **fit_params):
348352
argument for X.
349353
y : array-like, shape = [n_samples]
350354
Target values.
355+
groups : array-like, with shape (n_samples,), optional
356+
Group labels for the samples used while splitting the dataset into
357+
train/test set. Passed to the fit method of the cross-validator.
351358
fit_params : dict of string -> object, optional
352359
Parameters to pass to to the fit method of classifier.
353360
@@ -356,7 +363,7 @@ def fit_transform(self, X, y, **fit_params):
356363
Feature subset of X, shape={n_samples, k_features}
357364
358365
"""
359-
self.fit(X, y, **fit_params)
366+
self.fit(X, y, groups=groups, **fit_params)
360367
return self.transform(X)
361368

362369
def get_metric_dict(self, confidence_interval=0.95):

mlxtend/feature_selection/sequential_feature_selector.py

+26-12
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
from sklearn.externals.joblib import Parallel, delayed
2424

2525

26-
def _calc_score(selector, X, y, indices, **fit_params):
26+
def _calc_score(selector, X, y, indices, groups=None, **fit_params):
2727
if selector.cv:
2828
scores = cross_val_score(selector.est_,
2929
X[:, indices], y,
30+
groups=groups,
3031
cv=selector.cv,
3132
scoring=selector.scorer,
3233
n_jobs=1,
@@ -242,7 +243,7 @@ def set_params(self, **params):
242243
self._set_params('estimator', 'named_estimators', **params)
243244
return self
244245

245-
def fit(self, X, y, custom_feature_names=None, **fit_params):
246+
def fit(self, X, y, custom_feature_names=None, groups=None, **fit_params):
246247
"""Perform feature selection and learn model from training data.
247248
248249
Parameters
@@ -260,6 +261,9 @@ def fit(self, X, y, custom_feature_names=None, **fit_params):
260261
Custom feature names for `self.k_feature_names` and
261262
`self.subsets_[i]['feature_names']`.
262263
(new in v 0.13.0)
264+
groups : array-like, with shape (n_samples,), optional
265+
Group labels for the samples used while splitting the dataset into
266+
train/test set. Passed to the fit method of the cross-validator.
263267
fit_params : dict of string -> object, optional
264268
Parameters to pass to to the fit method of classifier.
265269
@@ -291,8 +295,8 @@ def fit(self, X, y, custom_feature_names=None, **fit_params):
291295
if not isinstance(self.k_features, int) and\
292296
not isinstance(self.k_features, tuple)\
293297
and not isinstance(self.k_features, str):
294-
raise AttributeError('k_features must be a positive integer'
295-
', tuple, or string')
298+
raise AttributeError('k_features must be a positive integer'
299+
', tuple, or string')
296300

297301
if (isinstance(self.k_features, int) and (
298302
self.k_features < 1 or self.k_features > X_.shape[1])):
@@ -351,7 +355,8 @@ def fit(self, X, y, custom_feature_names=None, **fit_params):
351355
k_to_select = min_k
352356
k_idx = tuple(range(X_.shape[1]))
353357
k = len(k_idx)
354-
k_idx, k_score = _calc_score(self, X_, y, k_idx, **fit_params)
358+
k_idx, k_score = _calc_score(self, X_, y, k_idx,
359+
groups=groups, **fit_params)
355360
self.subsets_[k] = {
356361
'feature_idx': k_idx,
357362
'cv_scores': k_score,
@@ -370,6 +375,7 @@ def fit(self, X, y, custom_feature_names=None, **fit_params):
370375
subset=prev_subset,
371376
X=X_,
372377
y=y,
378+
groups=groups,
373379
**fit_params
374380
)
375381
else:
@@ -378,6 +384,7 @@ def fit(self, X, y, custom_feature_names=None, **fit_params):
378384
feature_set=prev_subset,
379385
X=X_,
380386
y=y,
387+
groups=groups,
381388
**fit_params
382389
)
383390

@@ -404,6 +411,7 @@ def fit(self, X, y, custom_feature_names=None, **fit_params):
404411
fixed_feature=new_feature,
405412
X=X_,
406413
y=y,
414+
groups=groups,
407415
**fit_params
408416
)
409417

@@ -413,6 +421,7 @@ def fit(self, X, y, custom_feature_names=None, **fit_params):
413421
subset=set(k_idx),
414422
X=X_,
415423
y=y,
424+
groups=groups,
416425
**fit_params
417426
)
418427

@@ -472,7 +481,7 @@ def fit(self, X, y, custom_feature_names=None, **fit_params):
472481
X)
473482
raise KeyboardInterrupt
474483

475-
except KeyboardInterrupt as e:
484+
except KeyboardInterrupt:
476485
self.interrupted_ = True
477486
sys.stderr.write('\nSTOPPING EARLY DUE TO KEYBOARD INTERRUPT...')
478487

@@ -512,7 +521,7 @@ def fit(self, X, y, custom_feature_names=None, **fit_params):
512521
return self
513522

514523
def _inclusion(self, orig_set, subset, X, y, ignore_feature=None,
515-
**fit_params):
524+
groups=None, **fit_params):
516525
all_avg_scores = []
517526
all_cv_scores = []
518527
all_subsets = []
@@ -526,7 +535,7 @@ def _inclusion(self, orig_set, subset, X, y, ignore_feature=None,
526535
work = parallel(delayed(_calc_score)
527536
(self, X, y,
528537
tuple(subset | {feature}),
529-
**fit_params)
538+
groups=groups, **fit_params)
530539
for feature in remaining
531540
if feature != ignore_feature)
532541

@@ -541,7 +550,8 @@ def _inclusion(self, orig_set, subset, X, y, ignore_feature=None,
541550
all_cv_scores[best])
542551
return res
543552

544-
def _exclusion(self, feature_set, X, y, fixed_feature=None, **fit_params):
553+
def _exclusion(self, feature_set, X, y, fixed_feature=None,
554+
groups=None, **fit_params):
545555
n = len(feature_set)
546556
res = (None, None, None)
547557
if n > 1:
@@ -552,7 +562,8 @@ def _exclusion(self, feature_set, X, y, fixed_feature=None, **fit_params):
552562
n_jobs = min(self.n_jobs, features)
553563
parallel = Parallel(n_jobs=n_jobs, verbose=self.verbose,
554564
pre_dispatch=self.pre_dispatch)
555-
work = parallel(delayed(_calc_score)(self, X, y, p, **fit_params)
565+
work = parallel(delayed(_calc_score)(self, X, y, p,
566+
groups=groups, **fit_params)
556567
for p in combinations(feature_set, r=n - 1)
557568
if not fixed_feature or fixed_feature in set(p))
558569

@@ -591,7 +602,7 @@ def transform(self, X):
591602
X_ = X
592603
return X_[:, self.k_feature_idx_]
593604

594-
def fit_transform(self, X, y, **fit_params):
605+
def fit_transform(self, X, y, groups=None, **fit_params):
595606
"""Fit to training data then reduce X to its most important features.
596607
597608
Parameters
@@ -605,6 +616,9 @@ def fit_transform(self, X, y, **fit_params):
605616
Target values.
606617
New in v 0.13.0: a pandas Series are now also accepted as
607618
argument for y.
619+
groups : array-like, with shape (n_samples,), optional
620+
Group labels for the samples used while splitting the dataset into
621+
train/test set. Passed to the fit method of the cross-validator.
608622
fit_params : dict of string -> object, optional
609623
Parameters to pass to to the fit method of classifier.
610624
@@ -613,7 +627,7 @@ def fit_transform(self, X, y, **fit_params):
613627
Reduced feature subset of X, shape={n_samples, k_features}
614628
615629
"""
616-
self.fit(X, y, **fit_params)
630+
self.fit(X, y, groups=groups, **fit_params)
617631
return self.transform(X)
618632

619633
def get_metric_dict(self, confidence_interval=0.95):

mlxtend/feature_selection/tests/test_exhaustive_feature_selector.py

+35
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from sklearn.linear_model import LinearRegression
1717
from sklearn.datasets import load_boston
1818
from mlxtend.utils import assert_raises
19+
from sklearn.model_selection import GroupKFold
1920

2021

2122
def dict_compare_utility(d1, d2):
@@ -183,6 +184,40 @@ def test_knn_cv3():
183184
assert round(efs1.best_score_, 4) == 0.9728
184185

185186

187+
def test_knn_cv3_groups():
188+
iris = load_iris()
189+
X = iris.data
190+
y = iris.target
191+
knn = KNeighborsClassifier(n_neighbors=4)
192+
efs1 = EFS(knn,
193+
min_features=3,
194+
max_features=3,
195+
scoring='accuracy',
196+
cv=GroupKFold(n_splits=3),
197+
print_progress=False)
198+
np.random.seed(1630672634)
199+
groups = np.random.randint(0, 6, size=len(y))
200+
efs1 = efs1.fit(X, y, groups=groups)
201+
# print(efs1.subsets_)
202+
expect = {0: {'cv_scores': np.array([0.97916667, 0.93877551, 0.9245283]),
203+
'feature_idx': (0, 1, 2),
204+
'avg_score': 0.9474901595858469,
205+
'feature_names': ('0', '1', '2')},
206+
1: {'cv_scores': np.array([1., 0.93877551, 0.9245283]),
207+
'feature_idx': (0, 1, 3),
208+
'avg_score': 0.9544346040302915,
209+
'feature_names': ('0', '1', '3')},
210+
2: {'cv_scores': np.array([0.97916667, 0.95918367, 0.9245283]),
211+
'feature_idx': (0, 2, 3),
212+
'avg_score': 0.9542928806742822,
213+
'feature_names': ('0', '2', '3')},
214+
3: {'cv_scores': np.array([0.97916667, 0.95918367, 0.94339623]),
215+
'feature_idx': (1, 2, 3),
216+
'avg_score': 0.9605821888503829,
217+
'feature_names': ('1', '2', '3')}}
218+
dict_compare_utility(d1=expect, d2=efs1.subsets_)
219+
220+
186221
def test_fit_params():
187222
iris = load_iris()
188223
X = iris.data

mlxtend/feature_selection/tests/test_sequential_feature_selector.py

+27
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,33 @@ def test_knn_cv3():
219219
dict_compare_utility(d1=expect, d2=sfs1.subsets_)
220220

221221

222+
def test_knn_cv3_groups():
223+
iris = load_iris()
224+
X = iris.data
225+
y = iris.target
226+
knn = KNeighborsClassifier(n_neighbors=4)
227+
sfs1 = SFS(knn,
228+
k_features=3,
229+
forward=True,
230+
floating=False,
231+
cv=GroupKFold(n_splits=3),
232+
verbose=0)
233+
np.random.seed(1630672634)
234+
groups = np.random.randint(0, 6, size=len(y))
235+
sfs1 = sfs1.fit(X, y, groups=groups)
236+
# print(sfs1.subsets_)
237+
expect = {
238+
1: {'cv_scores': np.array([0.97916667, 0.93877551, 0.96226415]),
239+
'feature_idx': (3,),
240+
'avg_score': 0.9600687759380482},
241+
2: {'cv_scores': np.array([0.95833333, 0.93877551, 0.98113208]),
242+
'feature_idx': (1, 3),
243+
'avg_score': 0.9594136396697044},
244+
3: {'cv_scores': np.array([0.97916667, 0.95918367, 0.94339623]),
245+
'feature_idx': (1, 2, 3),
246+
'avg_score': 0.9605821888503829}}
247+
dict_compare_utility(d1=expect, d2=sfs1.subsets_, decimal=3)
248+
222249
def test_knn_rbf_groupkfold():
223250
nan_roc_auc_scorer = make_scorer(nan_roc_auc_score)
224251
rng = np.random.RandomState(123)

0 commit comments

Comments
 (0)