1
+ import time
2
+ from itertools import product
3
+
4
+ from .algorithm_core import AlgorithmCAInteractive
5
+ from ..ca_environment .active_ca import ActiveCAEnv
6
+ from ..utils import get_relation , get_scope , get_kappa , replace_variables
7
+ from ..problem_instance import ProblemInstance
8
+ from ..answering_queries import Oracle , UserOracle
9
+ from .. import Metrics
10
+
11
+
12
+ class GenAcq (AlgorithmCAInteractive ):
13
+
14
+ """
15
+ GenAcq algorithm, using generalization queries on given types of variables. From:
16
+
17
+ "Boosting Constraint Acquisition with Generalization Queries", ECAI 2014.
18
+ """
19
+
20
+ def __init__ (self , ca_env : ActiveCAEnv = None , types = None , qg_max = 3 ):
21
+ """
22
+ Initialize the GenAcq algorithm with an optional constraint acquisition environment.
23
+
24
+ :param ca_env: An instance of ActiveCAEnv, default is None.
25
+ : param types: list of types of variables given by the user
26
+ : param qg_max: maximum number of generalization queries
27
+ """
28
+ super ().__init__ (ca_env )
29
+ self ._negativeQ = []
30
+ self ._qg_max = qg_max
31
+ self ._types = types if types is not None else []
32
+
33
+ def learn (self , instance : ProblemInstance , oracle : Oracle = UserOracle (), verbose = 0 , metrics : Metrics = None , X = None ):
34
+ """
35
+ Learn constraints using the GenAcq algorithm by generating queries and analyzing the results.
36
+ Using generalization queries on given types of variables.
37
+
38
+ :param instance: the problem instance to acquire the constraints for
39
+ :param oracle: An instance of Oracle, default is to use the user as the oracle.
40
+ :param verbose: Verbosity level, default is 0.
41
+ :param metrics: statistics logger during learning
42
+ :return: the learned instance
43
+ """
44
+ self .env .init_state (instance , oracle , verbose , metrics )
45
+
46
+ if X is None :
47
+ X = list (self .env .instance .variables .flat )
48
+
49
+ if len (self .env .instance .bias ) == 0 :
50
+ self .env .instance .construct_bias (X )
51
+
52
+ while True :
53
+ if self .env .verbose > 0 :
54
+ print ("Size of CL: " , len (self .env .instance .cl ))
55
+ print ("Size of B: " , len (self .env .instance .bias ))
56
+ print ("Number of Queries: " , self .env .metrics .total_queries )
57
+ print ("Number of Generalization Queries: " , self .env .metrics .generalization_queries_count )
58
+ print ("Number of Membership Queries: " , self .env .metrics .membership_queries_count )
59
+
60
+
61
+ gen_start = time .time ()
62
+ Y = self .env .run_query_generation (X )
63
+ gen_end = time .time ()
64
+
65
+ if len (Y ) == 0 :
66
+ # if no query can be generated it means we have (prematurely) converged to the target network -----
67
+ self .env .metrics .finalize_statistics ()
68
+ if self .env .verbose >= 1 :
69
+ print (f"\n Learned { self .env .metrics .cl } constraints in "
70
+ f"{ self .env .metrics .total_queries } queries." )
71
+ self .env .instance .bias = []
72
+ return self .env .instance
73
+
74
+ self .env .metrics .increase_generation_time (gen_end - gen_start )
75
+ self .env .metrics .increase_generated_queries ()
76
+ self .env .metrics .increase_top_queries ()
77
+ kappaB = get_kappa (self .env .instance .bias , Y )
78
+
79
+ answer = self .env .ask_membership_query (Y )
80
+ if answer :
81
+ # it is a solution, so all candidates violated must go
82
+ # B <- B \setminus K_B(e)
83
+ self .env .remove_from_bias (kappaB )
84
+
85
+ else : # user says UNSAT
86
+
87
+ scope = self .env .run_find_scope (Y )
88
+ c = self .env .run_findc (scope )
89
+ self .env .add_to_cl (c )
90
+ self .generalize (get_relation (c , self .env .instance .language ),c )
91
+
92
+
93
+
94
+ def generalize (self , r , c ):
95
+ """
96
+ Generalize function presented in
97
+ "Boosting Constraint Acquisition with Generalization Queries", ECAI 2014.
98
+
99
+
100
+ :param r: The index of a relation in gamma.
101
+ :param c: The constraint to generalize.
102
+ :return: List of learned constraints.
103
+ """
104
+ # Get the scope variables of constraint c
105
+ scope_vars = get_scope (c )
106
+
107
+ # Find all possible type sequences for the variables in the scope
108
+ type_sequences = []
109
+ for var in scope_vars :
110
+ var_types = []
111
+ for type_group in self ._types :
112
+ if var .name in type_group :
113
+ var_types .append (type_group )
114
+ type_sequences .append (var_types )
115
+
116
+ # Generate all possible combinations of type sequences
117
+ all_type_sequences = list (product (* type_sequences ))
118
+
119
+ # Filter out sequences based on NegativeQ and NonTarget
120
+ filtered_sequences = []
121
+ for s in all_type_sequences :
122
+
123
+ # Check if any negative sequence is a subset of current sequence
124
+ if s in self ._negativeQ :
125
+ continue
126
+
127
+ # Check if any non-target constraint has same relation and vars in sequence
128
+ if any (get_relation (c2 , self .env .instance .language ) == r and
129
+ all (any (var in set (type_group ) for type_group in s ) for var in get_scope (c2 ))
130
+ for c2 in set (self .env .instance .excluded_cons )):
131
+ continue
132
+
133
+ filtered_sequences .append (s )
134
+
135
+ all_type_sequences = filtered_sequences
136
+
137
+ gq_counter = 0
138
+
139
+ # Sort sequences by number of distinct elements (ascending)
140
+ all_type_sequences .sort (key = lambda seq : len (set ().union (* seq )))
141
+
142
+ while len (all_type_sequences ) > 0 and gq_counter < self ._qg_max :
143
+ Y = all_type_sequences .pop (0 )
144
+
145
+ # Instead of getting constraints from bias, generate them for this type sequence
146
+ B = []
147
+
148
+ # Generate all possible variable combinations
149
+ var_combinations = list (product (* Y ))
150
+ # Create constraints for each variable combination
151
+ for var_comb in var_combinations :
152
+
153
+ if len (set (var_comb )) != len (var_comb ): # No duplicates
154
+ continue
155
+ # Sort var_comb based on variable names
156
+ var_comb = sorted (var_comb , key = lambda var : var .name )
157
+
158
+ abs_vars = get_scope (self .env .instance .language [r ])
159
+ replace_dict = dict ()
160
+ for i , v in enumerate (var_comb ):
161
+ replace_dict [abs_vars [i ]] = v
162
+ constraint = replace_variables (self .env .instance .language [r ], replace_dict )
163
+
164
+ # Skip already learned or excluded constraints
165
+ if constraint not in set (self .env .instance .cl ) and constraint not in set (self .env .instance .excluded_cons ):
166
+ B .append (constraint )
167
+
168
+ # If generalization query is accepted
169
+ if self .env .ask_generalization_query (self .env .instance .language [r ], B ):
170
+ self .env .add_to_cl (B )
171
+ gq_counter = 0
172
+ else :
173
+ gq_counter += 1
174
+ self ._negativeQ .append (Y )
175
+
176
+
0 commit comments