|
24 | 24 |
|
25 | 25 |
|
26 | 26 | log = logging.getLogger(__name__) |
| 27 | +yaml = YAML(typ='safe') |
27 | 28 |
|
28 | 29 |
|
29 | | -FeatureGenerationConfig = namedtuple( |
| 30 | +_feature_gen_config = namedtuple( |
30 | 31 | 'FeatureGenerationConfig', |
31 | | - ['needed_columns', 'features'] |
| 32 | + ['needed_columns', 'features'], |
32 | 33 | ) |
33 | 34 |
|
34 | | -yaml = YAML(typ='safe') |
35 | 35 |
|
| 36 | +class FeatureGenerationConfig(_feature_gen_config): |
| 37 | + ''' |
| 38 | + Stores the needed features and the expressions for the |
| 39 | + feature generation |
| 40 | + ''' |
36 | 41 |
|
37 | | -def print_supported_classifiers(): |
38 | | - logging.info('Supported Classifiers:') |
| 42 | + def __new__(cls, needed_columns, features): |
| 43 | + if features is None: |
| 44 | + log.warning('Feature generation config present but no features defined.') |
| 45 | + features = {} |
| 46 | + return super().__new__(cls, needed_columns, features) |
| 47 | + |
| 48 | + |
| 49 | +def print_models(filter_func=is_classifier): |
39 | 50 | for name, module in sklearn_modules.items(): |
40 | 51 | for cls_name in dir(module): |
41 | 52 | cls = getattr(module, cls_name) |
42 | | - if is_classifier(cls): |
| 53 | + if filter_func(cls): |
43 | 54 | logging.info(name + '.' + cls.__name__) |
44 | 55 |
|
45 | 56 |
|
| 57 | +def print_supported_classifiers(): |
| 58 | + logging.info('Supported Classifiers:') |
| 59 | + print_models(is_classifier) |
| 60 | + |
| 61 | + |
46 | 62 | def print_supported_regressors(): |
47 | 63 | logging.info('Supported Regressors:') |
48 | | - for name, module in sklearn_modules.items(): |
49 | | - for cls_name in dir(module): |
50 | | - cls = getattr(module, cls_name) |
51 | | - if is_regressor(cls): |
52 | | - logging.info(name + '.' + cls.__name__) |
| 64 | + print_models(is_regressor) |
53 | 65 |
|
54 | 66 |
|
55 | 67 | def load_regressor(config): |
@@ -162,8 +174,8 @@ def __init__(self, config): |
162 | 174 | raise ValueError('Source dependent features used: {}'.format(source_features)) |
163 | 175 |
|
164 | 176 | if gen_config: |
165 | | - self.features.extend(gen_config['features'].keys()) |
166 | 177 | self.feature_generation = FeatureGenerationConfig(**gen_config) |
| 178 | + self.features.extend(self.feature_generation.features.keys()) |
167 | 179 | else: |
168 | 180 | self.feature_generation = None |
169 | 181 | self.features.sort() |
@@ -230,8 +242,8 @@ def __init__(self, config): |
230 | 242 | if len(source_features): |
231 | 243 | raise ValueError('Source dependent features used: {}'.format(source_features)) |
232 | 244 | if gen_config: |
233 | | - self.features.extend(gen_config['features'].keys()) |
234 | 245 | self.feature_generation = FeatureGenerationConfig(**gen_config) |
| 246 | + self.features.extend(self.feature_generation.features.keys()) |
235 | 247 | else: |
236 | 248 | self.feature_generation = None |
237 | 249 | self.features.sort() |
@@ -275,8 +287,8 @@ def __init__(self, config): |
275 | 287 | if len(source_features): |
276 | 288 | raise ValueError('Source dependent features used: {}'.format(source_features)) |
277 | 289 | if gen_config: |
278 | | - self.features.extend(gen_config['features'].keys()) |
279 | 290 | self.feature_generation = FeatureGenerationConfig(**gen_config) |
| 291 | + self.features.extend(self.feature_generation.features.keys()) |
280 | 292 | else: |
281 | 293 | self.feature_generation = None |
282 | 294 | self.features.sort() |
|
0 commit comments