Skip to content

Commit 5b6a502

Browse files
authored
Merge pull request #351 from dice-group/general_adjustments
EvoLearner bug fixed
2 parents 28343ea + f58195a commit 5b6a502

File tree

2 files changed

+33
-12
lines changed

2 files changed

+33
-12
lines changed

ontolearn/concept_learner.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ class EvoLearner(BaseConceptLearner[EvoLearnerNode]):
689689
__slots__ = 'fitness_func', 'init_method', 'algorithm', 'value_splitter', 'tournament_size', \
690690
'population_size', 'num_generations', 'height_limit', 'use_data_properties', 'pset', 'toolbox', \
691691
'_learning_problem', '_result_population', 'mut_uniform_gen', '_dp_to_prim_type', '_dp_splits', \
692-
'_split_properties', '_cache', 'use_card_restrictions', 'card_limit', 'use_inverse'
692+
'_split_properties', '_cache', 'use_card_restrictions', 'card_limit', 'use_inverse', 'total_fits'
693693

694694
name = 'evolearner'
695695

@@ -788,11 +788,12 @@ def __init__(self,
788788
self.population_size = population_size
789789
self.num_generations = num_generations
790790
self.height_limit = height_limit
791+
self.total_fits = 0
791792
self.__setup()
792793

793794
def __setup(self):
795+
self.clean(partial=True)
794796
self._cache = dict()
795-
self.clean()
796797
if self.fitness_func is None:
797798
self.fitness_func = LinearPressureFitness()
798799

@@ -971,7 +972,11 @@ def fit(self, *args, **kwargs) -> 'EvoLearner':
971972
"""
972973
Find hypotheses that explain pos and neg.
973974
"""
974-
self.clean()
975+
# Don't reset everything if the user is just using this model for 1 learning problem, since he may use the
976+
# register_op method, else-wise we need to `clean` before fitting to get a fresh fit.
977+
if self.total_fits > 0:
978+
self.clean()
979+
self.total_fits += 1
975980
learning_problem = self.construct_learning_problem(PosNegLPStandard, args, kwargs)
976981
self._learning_problem = learning_problem.encode_kb(self.kb)
977982

@@ -1049,18 +1054,30 @@ def _fitness_func(self, individual: Tree):
10491054
self._cache[ind_str] = (e.q, individual.fitness.values[0])
10501055
self._number_of_tested_concepts += 1
10511056

1052-
def clean(self):
1053-
self._result_population = None
1054-
1057+
def clean(self, partial: bool = False):
10551058
# Resets classes if they already exist, names must match the ones that were created in the toolbox
10561059
try:
10571060
del creator.Fitness
10581061
del creator.Individual
10591062
del creator.Quality
10601063
except AttributeError:
10611064
pass
1062-
self._cache.clear()
10631065
super().clean()
1066+
if not partial:
1067+
# Reset everything if fitting more than one lp. Tests have shown that this is necessary to get the
1068+
# best performance of EvoLearner.
1069+
self._result_population = None
1070+
self._cache.clear()
1071+
self.fitness_func = LinearPressureFitness()
1072+
self.init_method = EARandomWalkInitialization()
1073+
self.algorithm = EASimple()
1074+
self.mut_uniform_gen = EARandomInitialization(min_height=1, max_height=3)
1075+
self.value_splitter = EntropyValueSplitter()
1076+
self._dp_to_prim_type = dict()
1077+
self._dp_splits = dict()
1078+
self._split_properties = []
1079+
self.pset = self.__build_primitive_set()
1080+
self.toolbox = self.__build_toolbox()
10641081

10651082

10661083
class NCES(BaseNCES):

tests/test_evolearner.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ def test_regression_family(self):
1818
kb = KnowledgeBase(path=settings['data_path'][3:])
1919
model = EvoLearner(knowledge_base=kb, max_runtime=10)
2020

21-
regression_test_evolearner = {'Aunt': 0.9, 'Brother': 1.0,
22-
'Cousin': 0.9, 'Granddaughter': 1.0,
23-
'Uncle': 0.9, 'Grandgrandfather': 0.94}
21+
regression_test_evolearner = {'Aunt': 1.0, 'Brother': 1.0,
22+
'Cousin': 1.0, 'Granddaughter': 1.0,
23+
'Uncle': 1.0, 'Grandgrandfather': 1.0}
2424
for str_target_concept, examples in settings['problems'].items():
2525
pos = set(map(OWLNamedIndividual, map(IRI.create, set(examples['positive_examples']))))
2626
neg = set(map(OWLNamedIndividual, map(IRI.create, set(examples['negative_examples']))))
@@ -31,8 +31,12 @@ def test_regression_family(self):
3131
self.assertEqual(returned_model, model)
3232
hypotheses = list(returned_model.best_hypotheses(n=3))
3333
self.assertGreaterEqual(hypotheses[0].quality, regression_test_evolearner[str_target_concept])
34-
self.assertGreaterEqual(hypotheses[0].quality, hypotheses[1].quality)
35-
self.assertGreaterEqual(hypotheses[1].quality, hypotheses[2].quality)
34+
# best_hypotheses returns distinct hypotheses and sometimes the model will not find 'n' distinct hypothesis,
35+
# hence the checks
36+
if len(hypotheses) == 2:
37+
self.assertGreaterEqual(hypotheses[0].quality, hypotheses[1].quality)
38+
if len(hypotheses) == 3:
39+
self.assertGreaterEqual(hypotheses[1].quality, hypotheses[2].quality)
3640

3741
def test_regression_mutagenesis_multiple_fits(self):
3842
kb = KnowledgeBase(path='KGs/Mutagenesis/mutagenesis.owl')

0 commit comments

Comments
 (0)