Skip to content

Commit 6c46dcb

Browse files
authored
Add GenAcq + fix MineAcq + Docstrings of other algorithms (#16)
* gquacq to mineacq + fixes * add genacq * gquacq file to mineacq * Update __init__.py * genacq docstrings * more docstring changes
1 parent 1078448 commit 6c46dcb

File tree

6 files changed

+202
-28
lines changed

6 files changed

+202
-28
lines changed

pycona/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .find_constraint import FindC, FindC2
2727
from .query_generation import QGen, TQGen, PQGen
2828
from .find_scope import FindScope, FindScope2
29-
from .active_algorithms import QuAcq, PQuAcq, GQuAcq, GrowAcq, MQuAcq, MQuAcq2
29+
from .active_algorithms import QuAcq, PQuAcq, MineAcq, GrowAcq, MQuAcq, MQuAcq2, GenAcq
3030
from .problem_instance import ProblemInstance, absvar, langBasic, langDist, langEqNeq
3131
from .predictor import CountsPredictor, FeaturesRelDim, FeaturesSimpleRel
3232

pycona/active_algorithms/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@
1212
from .mquacq import MQuAcq
1313
from .growacq import GrowAcq
1414
from .pquacq import PQuAcq
15-
from .gquacq import GQuAcq
15+
from .mineacq import MineAcq
16+
from .genacq import GenAcq

pycona/active_algorithms/genacq.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import time
2+
from itertools import product
3+
4+
from .algorithm_core import AlgorithmCAInteractive
5+
from ..ca_environment.active_ca import ActiveCAEnv
6+
from ..utils import get_relation, get_scope, get_kappa, replace_variables
7+
from ..problem_instance import ProblemInstance
8+
from ..answering_queries import Oracle, UserOracle
9+
from .. import Metrics
10+
11+
12+
class GenAcq(AlgorithmCAInteractive):
13+
14+
"""
15+
GenAcq algorithm, using generalization queries on given types of variables. From:
16+
17+
"Boosting Constraint Acquisition with Generalization Queries", ECAI 2014.
18+
"""
19+
20+
def __init__(self, ca_env: ActiveCAEnv = None, types=None, qg_max=3):
21+
"""
22+
Initialize the GenAcq algorithm with an optional constraint acquisition environment.
23+
24+
:param ca_env: An instance of ActiveCAEnv, default is None.
25+
: param types: list of types of variables given by the user
26+
: param qg_max: maximum number of generalization queries
27+
"""
28+
super().__init__(ca_env)
29+
self._negativeQ = []
30+
self._qg_max = qg_max
31+
self._types = types if types is not None else []
32+
33+
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, metrics: Metrics = None, X=None):
34+
"""
35+
Learn constraints using the GenAcq algorithm by generating queries and analyzing the results.
36+
Using generalization queries on given types of variables.
37+
38+
:param instance: the problem instance to acquire the constraints for
39+
:param oracle: An instance of Oracle, default is to use the user as the oracle.
40+
:param verbose: Verbosity level, default is 0.
41+
:param metrics: statistics logger during learning
42+
:return: the learned instance
43+
"""
44+
self.env.init_state(instance, oracle, verbose, metrics)
45+
46+
if X is None:
47+
X = list(self.env.instance.variables.flat)
48+
49+
if len(self.env.instance.bias) == 0:
50+
self.env.instance.construct_bias(X)
51+
52+
while True:
53+
if self.env.verbose > 0:
54+
print("Size of CL: ", len(self.env.instance.cl))
55+
print("Size of B: ", len(self.env.instance.bias))
56+
print("Number of Queries: ", self.env.metrics.total_queries)
57+
print("Number of Generalization Queries: ", self.env.metrics.generalization_queries_count)
58+
print("Number of Membership Queries: ", self.env.metrics.membership_queries_count)
59+
60+
61+
gen_start = time.time()
62+
Y = self.env.run_query_generation(X)
63+
gen_end = time.time()
64+
65+
if len(Y) == 0:
66+
# if no query can be generated it means we have (prematurely) converged to the target network -----
67+
self.env.metrics.finalize_statistics()
68+
if self.env.verbose >= 1:
69+
print(f"\nLearned {self.env.metrics.cl} constraints in "
70+
f"{self.env.metrics.total_queries} queries.")
71+
self.env.instance.bias = []
72+
return self.env.instance
73+
74+
self.env.metrics.increase_generation_time(gen_end - gen_start)
75+
self.env.metrics.increase_generated_queries()
76+
self.env.metrics.increase_top_queries()
77+
kappaB = get_kappa(self.env.instance.bias, Y)
78+
79+
answer = self.env.ask_membership_query(Y)
80+
if answer:
81+
# it is a solution, so all candidates violated must go
82+
# B <- B \setminus K_B(e)
83+
self.env.remove_from_bias(kappaB)
84+
85+
else: # user says UNSAT
86+
87+
scope = self.env.run_find_scope(Y)
88+
c = self.env.run_findc(scope)
89+
self.env.add_to_cl(c)
90+
self.generalize(get_relation(c, self.env.instance.language),c)
91+
92+
93+
94+
def generalize(self, r, c):
95+
"""
96+
Generalize function presented in
97+
"Boosting Constraint Acquisition with Generalization Queries", ECAI 2014.
98+
99+
100+
:param r: The index of a relation in gamma.
101+
:param c: The constraint to generalize.
102+
:return: List of learned constraints.
103+
"""
104+
# Get the scope variables of constraint c
105+
scope_vars = get_scope(c)
106+
107+
# Find all possible type sequences for the variables in the scope
108+
type_sequences = []
109+
for var in scope_vars:
110+
var_types = []
111+
for type_group in self._types:
112+
if var.name in type_group:
113+
var_types.append(type_group)
114+
type_sequences.append(var_types)
115+
116+
# Generate all possible combinations of type sequences
117+
all_type_sequences = list(product(*type_sequences))
118+
119+
# Filter out sequences based on NegativeQ and NonTarget
120+
filtered_sequences = []
121+
for s in all_type_sequences:
122+
123+
# Check if any negative sequence is a subset of current sequence
124+
if s in self._negativeQ:
125+
continue
126+
127+
# Check if any non-target constraint has same relation and vars in sequence
128+
if any(get_relation(c2, self.env.instance.language) == r and
129+
all(any(var in set(type_group) for type_group in s) for var in get_scope(c2))
130+
for c2 in set(self.env.instance.excluded_cons)):
131+
continue
132+
133+
filtered_sequences.append(s)
134+
135+
all_type_sequences = filtered_sequences
136+
137+
gq_counter = 0
138+
139+
# Sort sequences by number of distinct elements (ascending)
140+
all_type_sequences.sort(key=lambda seq: len(set().union(*seq)))
141+
142+
while len(all_type_sequences) > 0 and gq_counter < self._qg_max:
143+
Y = all_type_sequences.pop(0)
144+
145+
# Instead of getting constraints from bias, generate them for this type sequence
146+
B = []
147+
148+
# Generate all possible variable combinations
149+
var_combinations = list(product(*Y))
150+
# Create constraints for each variable combination
151+
for var_comb in var_combinations:
152+
153+
if len(set(var_comb)) != len(var_comb): # No duplicates
154+
continue
155+
# Sort var_comb based on variable names
156+
var_comb = sorted(var_comb, key=lambda var: var.name)
157+
158+
abs_vars = get_scope(self.env.instance.language[r])
159+
replace_dict = dict()
160+
for i, v in enumerate(var_comb):
161+
replace_dict[abs_vars[i]] = v
162+
constraint = replace_variables(self.env.instance.language[r], replace_dict)
163+
164+
# Skip already learned or excluded constraints
165+
if constraint not in set(self.env.instance.cl) and constraint not in set(self.env.instance.excluded_cons):
166+
B.append(constraint)
167+
168+
# If generalization query is accepted
169+
if self.env.ask_generalization_query(self.env.instance.language[r], B):
170+
self.env.add_to_cl(B)
171+
gq_counter = 0
172+
else:
173+
gq_counter += 1
174+
self._negativeQ.append(Y)
175+
176+

pycona/active_algorithms/gquacq.py renamed to pycona/active_algorithms/mineacq.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,61 +12,60 @@
1212
from .. import Metrics
1313

1414

15-
class GQuAcq(AlgorithmCAInteractive):
15+
class MineAcq(AlgorithmCAInteractive):
1616

1717
"""
18-
QuAcq variation algorithm, using mine&Ask to detect types of variables and ask genralization queries. From:
18+
QuAcq variation algorithm, using mine&Ask to detect types of variables and ask generalization queries. From:
1919
"Detecting Types of Variables for Generalization in Constraint Acquisition", ICTAI 2015.
2020
"""
2121

2222
def __init__(self, ca_env: ActiveCAEnv = None, qg_max=10):
2323
"""
24-
Initialize the GQuAcq algorithm with an optional constraint acquisition environment.
24+
Initialize the MineAcq algorithm with an optional constraint acquisition environment.
2525
2626
:param ca_env: An instance of ActiveCAEnv, default is None.
27-
: param GQmax: maximum number of generalization queries
27+
:param qg_max: maximum number of generalization queries
2828
"""
2929
super().__init__(ca_env)
3030
self._negativeQ = []
3131
self._qg_max = qg_max
3232

33-
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, X=None, metrics: Metrics = None):
33+
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, metrics: Metrics = None, X=None):
3434
"""
35-
Learn constraints using the GQuAcq algorithm by generating queries and analyzing the results.
35+
Learn constraints using the MineAcq algorithm by generating queries and analyzing the results.
36+
Using mine&ask to detect types of variables and ask generalization queries.
3637
3738
:param instance: the problem instance to acquire the constraints for
3839
:param oracle: An instance of Oracle, default is to use the user as the oracle.
3940
:param verbose: Verbosity level, default is 0.
4041
:param metrics: statistics logger during learning
41-
:param X: The set of variables to consider, default is None.
42+
:param X: List of variables to consider for learning. If None, uses all variables from the instance.
4243
:return: the learned instance
4344
"""
44-
if X is None:
45-
X = instance.X
46-
assert isinstance(X, list), "When using .learn(), set parameter X must be a list of variables. Instead got: {}".format(X)
47-
assert set(X).issubset(set(instance.X)), "When using .learn(), set parameter X must be a subset of the problem instance variables. Instead got: {}".format(X)
48-
4945
self.env.init_state(instance, oracle, verbose, metrics)
5046

47+
if X is None:
48+
X = list(self.env.instance.variables.flat)
49+
5150
if len(self.env.instance.bias) == 0:
5251
self.env.instance.construct_bias(X)
5352

5453
while True:
5554
if self.env.verbose > 0:
5655
print("Size of CL: ", len(self.env.instance.cl))
5756
print("Size of B: ", len(self.env.instance.bias))
58-
print("Number of Queries: ", self.env.metrics.membership_queries_count)
57+
print("Number of Queries: ", self.env.metrics.total_queries)
5958

6059
gen_start = time.time()
6160
Y = self.env.run_query_generation(X)
62-
gen_end = time.time()
61+
gen_end = time.time()
6362

6463
if len(Y) == 0:
6564
# if no query can be generated it means we have (prematurely) converged to the target network -----
6665
self.env.metrics.finalize_statistics()
6766
if self.env.verbose >= 1:
6867
print(f"\nLearned {self.env.metrics.cl} constraints in "
69-
f"{self.env.metrics.membership_queries_count} queries.")
68+
f"{self.env.metrics.total_queries} queries.")
7069
self.env.instance.bias = []
7170
return self.env.instance
7271

@@ -130,16 +129,13 @@ def mineAsk(self, r):
130129
# potentially generalizing leads to UNSAT
131130
new_CL = self.env.instance.cl.copy()
132131
new_CL += B
133-
if any(Y2.issubset(Y) for Y2 in self._negativeQ) or not can_be_clique(G.subgraph(Y), D) or \
134-
len(B) > 0 or cp.Model(new_CL).solve():
135-
continue
136-
137-
if self.env.ask_generalization_query(self.env.instance.language[r], B):
138-
gen_flag = True
139-
self.env.add_to_cl(B)
140-
else:
141-
gq_counter += 1
142-
self._negativeQ.append(Y)
132+
if not (any(Y2.issubset(Y) for Y2 in self._negativeQ) or not (can_be_clique(G.subgraph(Y), D) and (len(B) > 0) and cp.Model(new_CL).solve())):
133+
if self.env.ask_generalization_query(self.env.instance.language[r], B):
134+
gen_flag = True
135+
self.env.add_to_cl(B)
136+
else:
137+
gq_counter += 1
138+
self._negativeQ.append(Y)
143139

144140
if not gen_flag:
145141
communities = nx.community.greedy_modularity_communities(G.subgraph(Y))

pycona/active_algorithms/mquacq2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(self, ca_env: ActiveCAEnv = None, *, perform_analyzeAndLearn: bool
3434
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, X=None, metrics: Metrics = None):
3535
"""
3636
Learn constraints using the modified QuAcq algorithm by generating queries and analyzing the results.
37+
Learns multiple constraints from each generated query. Uses analyzeAndLearn to focus on the most promising constraints.
3738
3839
:param instance: the problem instance to acquire the constraints for
3940
:param oracle: An instance of Oracle, default is to use the user as the oracle.

tests/test_algorithms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
problem_generators = [construct_murder_problem(), construct_examtt_simple(), construct_nurse_rostering()]
1515

1616
classifiers = [DecisionTreeClassifier(), RandomForestClassifier()]
17-
algorithms = [ca.QuAcq(), ca.MQuAcq(), ca.MQuAcq2(), ca.GQuAcq(), ca.PQuAcq()]
17+
algorithms = [ca.QuAcq(), ca.MQuAcq(), ca.MQuAcq2(), ca.MineAcq(), ca.PQuAcq(), ca.GenAcq()]
1818
fast_tests_algorithms = [ca.QuAcq(), ca.MQuAcq(), ca.MQuAcq2()]
1919

2020
def _generate_fast_benchmarks():

0 commit comments

Comments
 (0)