Skip to content

Commit dbdb17f

Browse files
committed
correct findc2 query generation
1 parent 078047c commit dbdb17f

File tree

1 file changed

+45
-31
lines changed

1 file changed

+45
-31
lines changed

pycona/find_constraint/findc2.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import cpmpy as cp
2+
import copy
23

34
from ..ca_environment.active_ca import ActiveCAEnv
45
from .utils import get_max_conjunction_size, get_delta_p
56
from .findc_core import FindCBase
67
from .utils import join_con_net
7-
from ..utils import restore_scope_values, get_con_subset, check_value
8+
from ..utils import restore_scope_values, get_con_subset, check_value, get_scope
89

910

1011
class FindC2(FindCBase):
@@ -14,7 +15,6 @@ class FindC2(FindCBase):
1415
1516
This function works also for non-normalised target networks!
1617
"""
17-
# TODO optimize to work better (probably only needs to make better the generate_find_query2)
1818

1919
def __init__(self, ca_env: ActiveCAEnv = None, time_limit=0.2, findscope=None):
2020
"""
@@ -54,15 +54,16 @@ def run(self, scope):
5454
"""
5555
assert self.ca is not None
5656

57+
scope_values = [x.value() for x in scope]
58+
5759
# Initialize delta
5860
delta = get_con_subset(self.ca.instance.bias, scope)
59-
delta = join_con_net(delta, [c for c in delta if check_value(c) is False])
61+
kappaD = [c for c in delta if check_value(c) is False]
62+
delta = join_con_net(delta, kappaD)
6063

6164
# We need to take into account only the constraints in the scope we search on
6265
sub_cl = get_con_subset(self.ca.instance.cl, scope)
6366

64-
scope_values = [x.value() for x in scope]
65-
6667
while True:
6768

6869
# Try to generate a counter example to reduce the candidates
@@ -76,6 +77,8 @@ def run(self, scope):
7677
restore_scope_values(scope, scope_values)
7778

7879
# Return random c in delta otherwise (if more than one, they are equivalent w.r.t. C_l)
80+
# Choose the constraint with the smallest number of conjunctions
81+
delta = sorted(delta, key=lambda x: len(x.args))
7982
return delta[0]
8083

8184
self.ca.metrics.increase_findc_queries()
@@ -90,15 +93,14 @@ def run(self, scope):
9093

9194
kappaD = [c for c in delta if check_value(c) is False]
9295

93-
scope2 = self.ca.run_find_scope(list(scope), kappaD) # TODO: replace with real findscope arguments when done!
96+
#scope2 = self.ca.run_find_scope(list(scope), kappaD) # TODO: replace with real findscope arguments when done!
9497

95-
if len(scope2) < len(scope):
96-
self.run(scope2)
97-
else:
98-
delta = join_con_net(delta, kappaD)
98+
#if len(scope2) < len(scope):
99+
# self.run(scope2)
100+
#else:
101+
delta = join_con_net(delta, kappaD)
99102

100103
def generate_findc_query(self, L, delta):
101-
# TODO: optimize to work better
102104
"""
103105
Changes directly the values of the variables
104106
@@ -107,35 +109,47 @@ def generate_findc_query(self, L, delta):
107109
:return: Boolean value representing a success or failure on the generation
108110
"""
109111

110-
tmp = cp.Model(L)
112+
tmp = cp.Model(L)
113+
114+
satisfied_delta = sum([c for c in delta]) # get the amount of satisfied constraints from B
115+
116+
scope = get_scope(delta[0])
117+
# at least 1 violated and at least 1 satisfied
118+
# we want this to assure that each answer of the user will reduce
119+
# the set of candidates
120+
tmp += satisfied_delta < len(delta)
121+
tmp += satisfied_delta > 0
122+
111123

112124
max_conj_size = get_max_conjunction_size(delta)
113125
delta_p = get_delta_p(delta)
114126

115-
p = cp.intvar(0, max_conj_size)
116-
kappa_delta_p = cp.intvar(0, len(delta), shape=(max_conj_size,))
117-
p_soft_con = cp.boolvar(shape=(max_conj_size,))
127+
for p in range(max_conj_size):
128+
s = cp.SolverLookup.get("ortools", tmp)
118129

119-
for i in range(max_conj_size):
120-
tmp += kappa_delta_p[i] == sum([c for c in delta_p[i]])
121-
p_soft_con[i] = (kappa_delta_p[i] > 0)
130+
kappa_delta_p = sum([c for c in delta_p[p]])
131+
s += kappa_delta_p < len(delta_p[p])
132+
122133

123-
tmp += p == min([i for i in range(max_conj_size) if (kappa_delta_p[i] < len(delta_p[i]))])
134+
if not s.solve(): # if a solution is found
135+
continue
124136

125-
objective = sum([c for c in delta]) # get the amount of satisfied constraints from B
137+
# Next solve will change the values of the variables in lY
138+
# so we need to return them to the original ones to continue if we don't find a solution next
139+
values = [x.value() for x in scope]
126140

127-
# at least 1 violated and at least 1 satisfied
128-
# we want this to assure that each answer of the user will reduce
129-
# the set of candidates
130-
tmp += objective < len(delta)
131-
tmp += objective > 0
132141

133-
# Try first without objective
134-
s = cp.SolverLookup.get("ortools", tmp)
142+
p_soft_con = (kappa_delta_p > 0)
143+
144+
# run with the objective
145+
s.maximize(p_soft_con)
135146

136-
# run with the objective
137-
s.minimize(100 * p - p_soft_con[p])
147+
# So a solution was found, try to find a better one now
148+
s.solution_hint(scope, values)
138149

139-
flag = s.solve(time_limit=self.time_limit)
150+
flag = s.solve(time_limit=self.time_limit, num_workers=8)
151+
if not flag:
152+
restore_scope_values(scope, values)
153+
return True
140154

141-
return flag
155+
return False

0 commit comments

Comments
 (0)