Skip to content

Commit 8d051c7

Browse files
committed
add genacq
1 parent 8bfed7d commit 8d051c7

File tree

4 files changed

+178
-2
lines changed

4 files changed

+178
-2
lines changed

pycona/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .find_constraint import FindC, FindC2
2727
from .query_generation import QGen, TQGen, PQGen
2828
from .find_scope import FindScope, FindScope2
29-
from .active_algorithms import QuAcq, PQuAcq, MineAcq, GrowAcq, MQuAcq, MQuAcq2
29+
from .active_algorithms import QuAcq, PQuAcq, MineAcq, GrowAcq, MQuAcq, MQuAcq2, GenAcq
3030
from .problem_instance import ProblemInstance, absvar, langBasic, langDist, langEqNeq
3131
from .predictor import CountsPredictor, FeaturesRelDim, FeaturesSimpleRel
3232

pycona/active_algorithms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
from .growacq import GrowAcq
1414
from .pquacq import PQuAcq
1515
from .gquacq import MineAcq
16+
from .genacq import GenAcq

pycona/active_algorithms/genacq.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import time
2+
from itertools import product
3+
4+
from .algorithm_core import AlgorithmCAInteractive
5+
from ..ca_environment.active_ca import ActiveCAEnv
6+
from ..utils import get_relation, get_scope, get_kappa, replace_variables
7+
from ..problem_instance import ProblemInstance
8+
from ..answering_queries import Oracle, UserOracle
9+
from .. import Metrics
10+
11+
12+
class GenAcq(AlgorithmCAInteractive):
13+
14+
"""
15+
GenAcq algorithm, using mine&Ask to detect types of variables and ask genralization queries. From:
16+
17+
"Detecting Types of Variables for Generalization in Constraint Acquisition", ICTAI 2015.
18+
"""
19+
20+
def __init__(self, ca_env: ActiveCAEnv = None, types=None, qg_max=3):
21+
"""
22+
Initialize the PQuAcq algorithm with an optional constraint acquisition environment.
23+
24+
:param ca_env: An instance of ActiveCAEnv, default is None.
25+
: param types: list of types of variables given by the user
26+
: param qg_max: maximum number of generalization queries
27+
"""
28+
super().__init__(ca_env)
29+
self._negativeQ = []
30+
self._qg_max = qg_max
31+
self._types = types if types is not None else []
32+
33+
def learn(self, instance: ProblemInstance, oracle: Oracle = UserOracle(), verbose=0, metrics: Metrics = None, X=None):
34+
"""
35+
Learn constraints using the QuAcq algorithm by generating queries and analyzing the results.
36+
37+
:param instance: the problem instance to acquire the constraints for
38+
:param oracle: An instance of Oracle, default is to use the user as the oracle.
39+
:param verbose: Verbosity level, default is 0.
40+
:param metrics: statistics logger during learning
41+
:return: the learned instance
42+
"""
43+
self.env.init_state(instance, oracle, verbose, metrics)
44+
45+
if X is None:
46+
X = list(self.env.instance.variables.flat)
47+
48+
if len(self.env.instance.bias) == 0:
49+
self.env.instance.construct_bias(X)
50+
51+
while True:
52+
if self.env.verbose > 0:
53+
print("Size of CL: ", len(self.env.instance.cl))
54+
print("Size of B: ", len(self.env.instance.bias))
55+
print("Number of Queries: ", self.env.metrics.total_queries)
56+
print("Number of Generalization Queries: ", self.env.metrics.generalization_queries_count)
57+
print("Number of Membership Queries: ", self.env.metrics.membership_queries_count)
58+
59+
60+
gen_start = time.time()
61+
Y = self.env.run_query_generation(X)
62+
gen_end = time.time()
63+
64+
if len(Y) == 0:
65+
# if no query can be generated it means we have (prematurely) converged to the target network -----
66+
self.env.metrics.finalize_statistics()
67+
if self.env.verbose >= 1:
68+
print(f"\nLearned {self.env.metrics.cl} constraints in "
69+
f"{self.env.metrics.total_queries} queries.")
70+
self.env.instance.bias = []
71+
return self.env.instance
72+
73+
self.env.metrics.increase_generation_time(gen_end - gen_start)
74+
self.env.metrics.increase_generated_queries()
75+
self.env.metrics.increase_top_queries()
76+
kappaB = get_kappa(self.env.instance.bias, Y)
77+
78+
answer = self.env.ask_membership_query(Y)
79+
if answer:
80+
# it is a solution, so all candidates violated must go
81+
# B <- B \setminus K_B(e)
82+
self.env.remove_from_bias(kappaB)
83+
84+
else: # user says UNSAT
85+
86+
scope = self.env.run_find_scope(Y)
87+
c = self.env.run_findc(scope)
88+
self.env.add_to_cl(c)
89+
self.generalize(get_relation(c, self.env.instance.language),c)
90+
91+
92+
93+
def generalize(self, r, c):
94+
"""
95+
Generalize function presented in
96+
"Boosting Constraint Acquisition with Generalization Queries", ECAI 2014.
97+
98+
99+
:param r: The index of a relation in gamma.
100+
:param c: The constraint to generalize.
101+
:return: List of learned constraints.
102+
"""
103+
# Get the scope variables of constraint c
104+
scope_vars = get_scope(c)
105+
106+
# Find all possible type sequences for the variables in the scope
107+
type_sequences = []
108+
for var in scope_vars:
109+
var_types = []
110+
for type_group in self._types:
111+
if var.name in type_group:
112+
var_types.append(type_group)
113+
type_sequences.append(var_types)
114+
115+
# Generate all possible combinations of type sequences
116+
all_type_sequences = list(product(*type_sequences))
117+
118+
# Filter out sequences based on NegativeQ and NonTarget
119+
filtered_sequences = []
120+
for s in all_type_sequences:
121+
122+
# Check if any negative sequence is a subset of current sequence
123+
if s in self._negativeQ:
124+
continue
125+
126+
# Check if any non-target constraint has same relation and vars in sequence
127+
if any(get_relation(c2, self.env.instance.language) == r and
128+
all(any(var in set(type_group) for type_group in s) for var in get_scope(c2))
129+
for c2 in set(self.env.instance.excluded_cons)):
130+
continue
131+
132+
filtered_sequences.append(s)
133+
134+
all_type_sequences = filtered_sequences
135+
136+
gq_counter = 0
137+
138+
# Sort sequences by number of distinct elements (ascending)
139+
all_type_sequences.sort(key=lambda seq: len(set().union(*seq)))
140+
141+
while len(all_type_sequences) > 0 and gq_counter < self._qg_max:
142+
Y = all_type_sequences.pop(0)
143+
144+
# Instead of getting constraints from bias, generate them for this type sequence
145+
B = []
146+
147+
# Generate all possible variable combinations
148+
var_combinations = list(product(*Y))
149+
# Create constraints for each variable combination
150+
for var_comb in var_combinations:
151+
152+
if len(set(var_comb)) != len(var_comb): # No duplicates
153+
continue
154+
# Sort var_comb based on variable names
155+
var_comb = sorted(var_comb, key=lambda var: var.name)
156+
157+
abs_vars = get_scope(self.env.instance.language[r])
158+
replace_dict = dict()
159+
for i, v in enumerate(var_comb):
160+
replace_dict[abs_vars[i]] = v
161+
constraint = replace_variables(self.env.instance.language[r], replace_dict)
162+
163+
# Skip already learned or excluded constraints
164+
if constraint not in set(self.env.instance.cl) and constraint not in set(self.env.instance.excluded_cons):
165+
B.append(constraint)
166+
167+
# If generalization query is accepted
168+
if self.env.ask_generalization_query(self.env.instance.language[r], B):
169+
self.env.add_to_cl(B)
170+
gq_counter = 0
171+
else:
172+
gq_counter += 1
173+
self._negativeQ.append(Y)
174+
175+

tests/test_algorithms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
problem_generators = [construct_murder_problem(), construct_examtt_simple(), construct_nurse_rostering()]
1515

1616
classifiers = [DecisionTreeClassifier(), RandomForestClassifier()]
17-
algorithms = [ca.QuAcq(), ca.MQuAcq(), ca.MQuAcq2(), ca.MineAcq(), ca.PQuAcq()]
17+
algorithms = [ca.QuAcq(), ca.MQuAcq(), ca.MQuAcq2(), ca.MineAcq(), ca.PQuAcq(), ca.GenAcq()]
1818
fast_tests_algorithms = [ca.QuAcq(), ca.MQuAcq(), ca.MQuAcq2()]
1919

2020
def _generate_fast_benchmarks():

0 commit comments

Comments
 (0)