Skip to content

Commit 0e7d276

Browse files
authored
COSMIT use set_random_state in commmon test (#383)
1 parent c25edb6 commit 0e7d276

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

imblearn/utils/estimator_checks.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from sklearn.exceptions import NotFittedError
2525
from sklearn.utils.testing import assert_allclose
2626
from sklearn.utils.testing import set_random_state
27-
from sklearn.externals.funcsigs import signature
2827

2928
from imblearn.base import SamplerMixin
3029
from imblearn.over_sampling.base import BaseOverSampler
@@ -211,12 +210,10 @@ def check_samplers_sparse(name, Sampler):
211210
estimator=KMeans(random_state=1,
212211
algorithm='full'))]
213212
else:
214-
sampler_attr = signature(Sampler.__init__).parameters.keys()
215-
if 'random_state' in sampler_attr:
216-
samplers = [Sampler(random_state=0)]
217-
else:
218-
samplers = [Sampler()]
213+
samplers = [Sampler()]
214+
219215
for sampler in samplers:
216+
set_random_state(sampler)
220217
X_res_sparse, y_res_sparse = sampler.fit_sample(X_sparse, y)
221218
X_res, y_res = sampler.fit_sample(X, y)
222219
if not isinstance(sampler, BaseEnsembleSampler):
@@ -243,16 +240,16 @@ def check_samplers_pandas(name, Sampler):
243240
samplers = [Sampler(random_state=0, kind=kind)
244241
for kind in ('regular', 'borderline1',
245242
'borderline2', 'svm')]
243+
246244
elif isinstance(Sampler(), NearMiss):
247-
samplers = [Sampler(version=version)
248-
for version in (1, 2, 3)]
245+
samplers = [Sampler(version=version)
246+
for version in (1, 2, 3)]
247+
249248
else:
250-
sampler_attr = signature(Sampler.__init__).parameters.keys()
251-
if 'random_state' in sampler_attr:
252-
samplers = [Sampler(random_state=0)]
253-
else:
254-
samplers = [Sampler()]
249+
samplers = [Sampler()]
250+
255251
for sampler in samplers:
252+
set_random_state(sampler)
256253
X_res_pd, y_res_pd = sampler.fit_sample(X_pd, y_pd)
257254
X_res, y_res = sampler.fit_sample(X, y)
258255
assert_allclose(X_res_pd, X_res)

0 commit comments

Comments
 (0)