-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy pathtesting.py
156 lines (128 loc) · 4.86 KB
/
testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""Test utilities."""
# Adapted from scikit-learn
# Authors: Guillaume Lemaitre <[email protected]>
# License: MIT
import inspect
import pkgutil
from contextlib import contextmanager
from importlib import import_module
from re import compile
from pathlib import Path
from operator import itemgetter
from pytest import warns as _warns
from sklearn.base import BaseEstimator
from sklearn.utils._testing import ignore_warnings
def all_estimators(type_filter=None,):
"""Get a list of all estimators from imblearn.
This function crawls the module and gets all classes that inherit
from BaseEstimator. Classes that are defined in test-modules are not
included.
By default meta_estimators are also not included.
This function is adapted from sklearn.
Parameters
----------
type_filter : string, list of string, or None, default=None
Which kind of estimators should be returned. If None, no
filter is applied and all estimators are returned. Possible
values are 'sampler' to get estimators only of these specific
types, or a list of these to get the estimators that fit at
least one of the types.
Returns
-------
estimators : list of tuples
List of (name, class), where ``name`` is the class name as string
and ``class`` is the actual type of the class.
"""
from ..base import SamplerMixin
def is_abstract(c):
if not (hasattr(c, "__abstractmethods__")):
return False
if not len(c.__abstractmethods__):
return False
return True
all_classes = []
modules_to_ignore = {"tests", "dask"}
root = str(Path(__file__).parent.parent)
# Ignore deprecation warnings triggered at import time and from walking
# packages
with ignore_warnings(category=FutureWarning):
for importer, modname, ispkg in pkgutil.walk_packages(
path=[root], prefix='imblearn.'):
mod_parts = modname.split(".")
if (any(part in modules_to_ignore for part in mod_parts)
or '._' in modname):
continue
module = import_module(modname)
classes = inspect.getmembers(module, inspect.isclass)
classes = [(name, est_cls) for name, est_cls in classes
if not name.startswith("_")]
all_classes.extend(classes)
all_classes = set(all_classes)
estimators = [
c
for c in all_classes
if (issubclass(c[1], BaseEstimator) and c[0] != "BaseEstimator")
]
# get rid of abstract base classes
estimators = [c for c in estimators if not is_abstract(c[1])]
# get rid of sklearn estimators which have been imported in some classes
estimators = [c for c in estimators if "sklearn" not in c[1].__module__]
if type_filter is not None:
if not isinstance(type_filter, list):
type_filter = [type_filter]
else:
type_filter = list(type_filter) # copy
filtered_estimators = []
filters = {"sampler": SamplerMixin}
for name, mixin in filters.items():
if name in type_filter:
type_filter.remove(name)
filtered_estimators.extend(
[est for est in estimators if issubclass(est[1], mixin)]
)
estimators = filtered_estimators
if type_filter:
raise ValueError(
"Parameter type_filter must be 'sampler' or "
"None, got"
" %s." % repr(type_filter)
)
# drop duplicates, sort for reproducibility
# itemgetter is used to ensure the sort does not extend to the 2nd item of
# the tuple
return sorted(set(estimators), key=itemgetter(0))
@contextmanager
def warns(expected_warning, match=None):
r"""Assert that a warning is raised with an optional matching pattern
Assert that a code block/function call warns ``expected_warning``
and raise a failure exception otherwise. It can be used within a context
manager ``with``.
Parameters
----------
expected_warning : Warning
Warning type.
match : regex str or None, optional
The pattern to be matched. By default, no check is done.
Returns
-------
None
Examples
--------
>>> import warnings
>>> from imblearn.utils.testing import warns
>>> with warns(UserWarning, match=r'must be \d+$'):
... warnings.warn("value must be 42", UserWarning)
"""
with _warns(expected_warning) as record:
yield
if match is not None:
for each in record:
if compile(match).search(str(each.message)) is not None:
break
else:
msg = "'{}' pattern not found in {}".format(
match, "{}".format([str(r.message) for r in record])
)
assert False, msg
else:
pass