1
+ import time
2
+ from itertools import product
3
+
4
+ from .algorithm_core import AlgorithmCAInteractive
5
+ from ..ca_environment .active_ca import ActiveCAEnv
6
+ from ..utils import get_relation , get_scope , get_kappa , replace_variables
7
+ from ..problem_instance import ProblemInstance
8
+ from ..answering_queries import Oracle , UserOracle
9
+ from .. import Metrics
10
+
11
+
12
+ class GenAcq (AlgorithmCAInteractive ):
13
+
14
+ """
15
+ GenAcq algorithm, using mine&Ask to detect types of variables and ask genralization queries. From:
16
+
17
+ "Detecting Types of Variables for Generalization in Constraint Acquisition", ICTAI 2015.
18
+ """
19
+
20
+ def __init__ (self , ca_env : ActiveCAEnv = None , types = None , qg_max = 3 ):
21
+ """
22
+ Initialize the PQuAcq algorithm with an optional constraint acquisition environment.
23
+
24
+ :param ca_env: An instance of ActiveCAEnv, default is None.
25
+ : param types: list of types of variables given by the user
26
+ : param qg_max: maximum number of generalization queries
27
+ """
28
+ super ().__init__ (ca_env )
29
+ self ._negativeQ = []
30
+ self ._qg_max = qg_max
31
+ self ._types = types if types is not None else []
32
+
33
+ def learn (self , instance : ProblemInstance , oracle : Oracle = UserOracle (), verbose = 0 , metrics : Metrics = None , X = None ):
34
+ """
35
+ Learn constraints using the QuAcq algorithm by generating queries and analyzing the results.
36
+
37
+ :param instance: the problem instance to acquire the constraints for
38
+ :param oracle: An instance of Oracle, default is to use the user as the oracle.
39
+ :param verbose: Verbosity level, default is 0.
40
+ :param metrics: statistics logger during learning
41
+ :return: the learned instance
42
+ """
43
+ self .env .init_state (instance , oracle , verbose , metrics )
44
+
45
+ if X is None :
46
+ X = list (self .env .instance .variables .flat )
47
+
48
+ if len (self .env .instance .bias ) == 0 :
49
+ self .env .instance .construct_bias (X )
50
+
51
+ while True :
52
+ if self .env .verbose > 0 :
53
+ print ("Size of CL: " , len (self .env .instance .cl ))
54
+ print ("Size of B: " , len (self .env .instance .bias ))
55
+ print ("Number of Queries: " , self .env .metrics .total_queries )
56
+ print ("Number of Generalization Queries: " , self .env .metrics .generalization_queries_count )
57
+ print ("Number of Membership Queries: " , self .env .metrics .membership_queries_count )
58
+
59
+
60
+ gen_start = time .time ()
61
+ Y = self .env .run_query_generation (X )
62
+ gen_end = time .time ()
63
+
64
+ if len (Y ) == 0 :
65
+ # if no query can be generated it means we have (prematurely) converged to the target network -----
66
+ self .env .metrics .finalize_statistics ()
67
+ if self .env .verbose >= 1 :
68
+ print (f"\n Learned { self .env .metrics .cl } constraints in "
69
+ f"{ self .env .metrics .total_queries } queries." )
70
+ self .env .instance .bias = []
71
+ return self .env .instance
72
+
73
+ self .env .metrics .increase_generation_time (gen_end - gen_start )
74
+ self .env .metrics .increase_generated_queries ()
75
+ self .env .metrics .increase_top_queries ()
76
+ kappaB = get_kappa (self .env .instance .bias , Y )
77
+
78
+ answer = self .env .ask_membership_query (Y )
79
+ if answer :
80
+ # it is a solution, so all candidates violated must go
81
+ # B <- B \setminus K_B(e)
82
+ self .env .remove_from_bias (kappaB )
83
+
84
+ else : # user says UNSAT
85
+
86
+ scope = self .env .run_find_scope (Y )
87
+ c = self .env .run_findc (scope )
88
+ self .env .add_to_cl (c )
89
+ self .generalize (get_relation (c , self .env .instance .language ),c )
90
+
91
+
92
+
93
+ def generalize (self , r , c ):
94
+ """
95
+ Generalize function presented in
96
+ "Boosting Constraint Acquisition with Generalization Queries", ECAI 2014.
97
+
98
+
99
+ :param r: The index of a relation in gamma.
100
+ :param c: The constraint to generalize.
101
+ :return: List of learned constraints.
102
+ """
103
+ # Get the scope variables of constraint c
104
+ scope_vars = get_scope (c )
105
+
106
+ # Find all possible type sequences for the variables in the scope
107
+ type_sequences = []
108
+ for var in scope_vars :
109
+ var_types = []
110
+ for type_group in self ._types :
111
+ if var .name in type_group :
112
+ var_types .append (type_group )
113
+ type_sequences .append (var_types )
114
+
115
+ # Generate all possible combinations of type sequences
116
+ all_type_sequences = list (product (* type_sequences ))
117
+
118
+ # Filter out sequences based on NegativeQ and NonTarget
119
+ filtered_sequences = []
120
+ for s in all_type_sequences :
121
+
122
+ # Check if any negative sequence is a subset of current sequence
123
+ if s in self ._negativeQ :
124
+ continue
125
+
126
+ # Check if any non-target constraint has same relation and vars in sequence
127
+ if any (get_relation (c2 , self .env .instance .language ) == r and
128
+ all (any (var in set (type_group ) for type_group in s ) for var in get_scope (c2 ))
129
+ for c2 in set (self .env .instance .excluded_cons )):
130
+ continue
131
+
132
+ filtered_sequences .append (s )
133
+
134
+ all_type_sequences = filtered_sequences
135
+
136
+ gq_counter = 0
137
+
138
+ # Sort sequences by number of distinct elements (ascending)
139
+ all_type_sequences .sort (key = lambda seq : len (set ().union (* seq )))
140
+
141
+ while len (all_type_sequences ) > 0 and gq_counter < self ._qg_max :
142
+ Y = all_type_sequences .pop (0 )
143
+
144
+ # Instead of getting constraints from bias, generate them for this type sequence
145
+ B = []
146
+
147
+ # Generate all possible variable combinations
148
+ var_combinations = list (product (* Y ))
149
+ # Create constraints for each variable combination
150
+ for var_comb in var_combinations :
151
+
152
+ if len (set (var_comb )) != len (var_comb ): # No duplicates
153
+ continue
154
+ # Sort var_comb based on variable names
155
+ var_comb = sorted (var_comb , key = lambda var : var .name )
156
+
157
+ abs_vars = get_scope (self .env .instance .language [r ])
158
+ replace_dict = dict ()
159
+ for i , v in enumerate (var_comb ):
160
+ replace_dict [abs_vars [i ]] = v
161
+ constraint = replace_variables (self .env .instance .language [r ], replace_dict )
162
+
163
+ # Skip already learned or excluded constraints
164
+ if constraint not in set (self .env .instance .cl ) and constraint not in set (self .env .instance .excluded_cons ):
165
+ B .append (constraint )
166
+
167
+ # If generalization query is accepted
168
+ if self .env .ask_generalization_query (self .env .instance .language [r ], B ):
169
+ self .env .add_to_cl (B )
170
+ gq_counter = 0
171
+ else :
172
+ gq_counter += 1
173
+ self ._negativeQ .append (Y )
174
+
175
+
0 commit comments