Skip to content

Commit a7baaa9

Browse files
authored
Merge pull request #112 from fact-project/fix_111
Fix error in case of empty feature generation config, fixes #111
2 parents b2f7ec6 + 7348826 commit a7baaa9

2 files changed

Lines changed: 27 additions & 15 deletions

File tree

aict_tools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.20.3'
1+
__version__ = '0.20.4'

aict_tools/configuration.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,44 @@
2424

2525

2626
log = logging.getLogger(__name__)
27+
yaml = YAML(typ='safe')
2728

2829

29-
FeatureGenerationConfig = namedtuple(
30+
_feature_gen_config = namedtuple(
3031
'FeatureGenerationConfig',
31-
['needed_columns', 'features']
32+
['needed_columns', 'features'],
3233
)
3334

34-
yaml = YAML(typ='safe')
3535

36+
class FeatureGenerationConfig(_feature_gen_config):
37+
'''
38+
Stores the needed features and the expressions for the
39+
feature generation
40+
'''
3641

37-
def print_supported_classifiers():
38-
logging.info('Supported Classifiers:')
42+
def __new__(cls, needed_columns, features):
43+
if features is None:
44+
log.warning('Feature generation config present but no features defined.')
45+
features = {}
46+
return super().__new__(cls, needed_columns, features)
47+
48+
49+
def print_models(filter_func=is_classifier):
3950
for name, module in sklearn_modules.items():
4051
for cls_name in dir(module):
4152
cls = getattr(module, cls_name)
42-
if is_classifier(cls):
53+
if filter_func(cls):
4354
logging.info(name + '.' + cls.__name__)
4455

4556

57+
def print_supported_classifiers():
58+
logging.info('Supported Classifiers:')
59+
print_models(is_classifier)
60+
61+
4662
def print_supported_regressors():
4763
logging.info('Supported Regressors:')
48-
for name, module in sklearn_modules.items():
49-
for cls_name in dir(module):
50-
cls = getattr(module, cls_name)
51-
if is_regressor(cls):
52-
logging.info(name + '.' + cls.__name__)
64+
print_models(is_regressor)
5365

5466

5567
def load_regressor(config):
@@ -162,8 +174,8 @@ def __init__(self, config):
162174
raise ValueError('Source dependent features used: {}'.format(source_features))
163175

164176
if gen_config:
165-
self.features.extend(gen_config['features'].keys())
166177
self.feature_generation = FeatureGenerationConfig(**gen_config)
178+
self.features.extend(self.feature_generation.features.keys())
167179
else:
168180
self.feature_generation = None
169181
self.features.sort()
@@ -230,8 +242,8 @@ def __init__(self, config):
230242
if len(source_features):
231243
raise ValueError('Source dependent features used: {}'.format(source_features))
232244
if gen_config:
233-
self.features.extend(gen_config['features'].keys())
234245
self.feature_generation = FeatureGenerationConfig(**gen_config)
246+
self.features.extend(self.feature_generation.features.keys())
235247
else:
236248
self.feature_generation = None
237249
self.features.sort()
@@ -275,8 +287,8 @@ def __init__(self, config):
275287
if len(source_features):
276288
raise ValueError('Source dependent features used: {}'.format(source_features))
277289
if gen_config:
278-
self.features.extend(gen_config['features'].keys())
279290
self.feature_generation = FeatureGenerationConfig(**gen_config)
291+
self.features.extend(self.feature_generation.features.keys())
280292
else:
281293
self.feature_generation = None
282294
self.features.sort()

0 commit comments

Comments
 (0)