25
25
from sklearn .utils .testing import assert_allclose
26
26
from sklearn .utils .testing import set_random_state
27
27
28
- from imblearn .base import SamplerMixin
29
28
from imblearn .over_sampling .base import BaseOverSampler
30
29
from imblearn .under_sampling .base import BaseCleaningSampler , BaseUnderSampler
31
30
from imblearn .ensemble .base import BaseEnsembleSampler
34
33
35
34
from imblearn .utils .testing import warns
36
35
37
- NOT_TESTED_SAMPLERS = ('FunctionSampler' ) # This is not a normal sampler
38
-
39
36
40
37
def _yield_sampler_checks (name , Estimator ):
41
38
yield check_target_type
@@ -49,10 +46,10 @@ def _yield_sampler_checks(name, Estimator):
49
46
yield check_samplers_pandas
50
47
51
48
52
- def _yield_all_checks (name , Estimator ):
49
+ def _yield_all_checks (name , estimator ):
53
50
# trigger our checks if this is a SamplerMixin
54
- if issubclass ( Estimator , SamplerMixin ):
55
- for check in _yield_sampler_checks (name , Estimator ):
51
+ if hasattr ( estimator , 'sample' ):
52
+ for check in _yield_sampler_checks (name , estimator ):
56
53
yield check
57
54
58
55
@@ -75,7 +72,6 @@ def check_estimator(Estimator):
75
72
sklearn_check_estimator (Estimator )
76
73
check_parameters_default_constructible (name , Estimator )
77
74
for check in _yield_all_checks (name , Estimator ):
78
- # if name not in NOT_TESTED_SAMPLERS:
79
75
check (name , Estimator )
80
76
81
77
0 commit comments