Skip to content

Commit b3b1b05

Browse files
committed
fix: allow reification of lazy reasons
1 parent 7f7f0cf commit b3b1b05

File tree

9 files changed

+226
-125
lines changed

9 files changed

+226
-125
lines changed

pumpkin-solver/src/engine/conflict_analysis/conflict_analysis_context.rs

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ impl Debug for ConflictAnalysisContext<'_> {
5858
}
5959
}
6060

61-
impl<'a> ConflictAnalysisContext<'a> {
61+
impl ConflictAnalysisContext<'_> {
6262
/// Returns the last decision which was made by the solver.
6363
pub(crate) fn find_last_decision(&mut self) -> Option<Predicate> {
6464
self.assignments.find_last_decision()
@@ -132,15 +132,20 @@ impl<'a> ConflictAnalysisContext<'a> {
132132

133133
/// Returns the reason for a propagation; if it is implied then the reason will be the decision
134134
/// which implied the predicate.
135+
#[allow(
136+
clippy::too_many_arguments,
137+
reason = "borrow checker complains either here or elsewhere"
138+
)]
135139
pub(crate) fn get_propagation_reason(
136140
predicate: Predicate,
137141
assignments: &Assignments,
138142
current_nogood: CurrentNogood<'_>,
139-
reason_store: &'a mut ReasonStore,
140-
propagators: &'a mut PropagatorStore,
141-
proof_log: &'a mut ProofLog,
143+
reason_store: &mut ReasonStore,
144+
propagators: &mut PropagatorStore,
145+
proof_log: &mut ProofLog,
142146
unit_nogood_step_ids: &HashMap<Predicate, StepId>,
143-
) -> &'a [Predicate] {
147+
reason_out: &mut (impl Extend<Predicate> + AsRef<[Predicate]>),
148+
) {
144149
// TODO: this function could be put into the reason store
145150

146151
// Note that this function can only be called with propagations, and never decision
@@ -156,9 +161,8 @@ impl<'a> ConflictAnalysisContext<'a> {
156161
// there would be only one predicate from the current decision level. For this
157162
// reason, it is safe to assume that in the following, that any input predicate is
158163
// indeed a propagated predicate.
159-
reason_store.helper.clear();
160164
if assignments.is_initial_bound(predicate) {
161-
return reason_store.helper.as_slice();
165+
return;
162166
}
163167

164168
let trail_position = assignments
@@ -179,11 +183,17 @@ impl<'a> ConflictAnalysisContext<'a> {
179183

180184
let explanation_context = ExplanationContext::new(assignments, current_nogood);
181185

182-
let reason = reason_store
183-
.get_or_compute(reason_ref, explanation_context, propagators)
184-
.expect("reason reference should not be stale");
186+
let reason_exists = reason_store.get_or_compute(
187+
reason_ref,
188+
explanation_context,
189+
propagators,
190+
reason_out,
191+
);
192+
193+
assert!(reason_exists, "reason reference should not be stale");
194+
185195
if propagator_id == ConstraintSatisfactionSolver::get_nogood_propagator_id()
186-
&& reason.is_empty()
196+
&& reason_out.as_ref().is_empty()
187197
{
188198
// This means that a unit nogood was propagated, we indicate that this nogood step
189199
// was used
@@ -207,12 +217,10 @@ impl<'a> ConflictAnalysisContext<'a> {
207217
// Otherwise we log the inference which was used to derive the nogood
208218
let _ = proof_log.log_inference(
209219
constraint_tag,
210-
reason.iter().copied(),
220+
reason_out.as_ref().iter().copied(),
211221
Some(predicate),
212222
);
213223
}
214-
reason
215-
// The predicate is implicitly due as a result of a decision.
216224
}
217225
// 2) The predicate is true due to a propagation, and not explicitly on the trail.
218226
// It is necessary to further analyse what was the reason for setting the predicate true.
@@ -240,7 +248,7 @@ impl<'a> ConflictAnalysisContext<'a> {
240248
// todo: could consider lifting here, since the trail bound
241249
// might be too strong.
242250
if trail_lower_bound > input_lower_bound {
243-
reason_store.helper.push(trail_entry.predicate);
251+
reason_out.extend(std::iter::once(trail_entry.predicate));
244252
}
245253
// Otherwise, the input bound is strictly greater than the trailed
246254
// bound. This means the reason is due to holes in the domain.
@@ -270,8 +278,8 @@ impl<'a> ConflictAnalysisContext<'a> {
270278
domain_id,
271279
not_equal_constant: input_lower_bound - 1,
272280
};
273-
reason_store.helper.push(one_less_bound_predicate);
274-
reason_store.helper.push(not_equals_predicate);
281+
reason_out.extend(std::iter::once(one_less_bound_predicate));
282+
reason_out.extend(std::iter::once(not_equals_predicate));
275283
}
276284
}
277285
(
@@ -291,7 +299,7 @@ impl<'a> ConflictAnalysisContext<'a> {
291299
// so it safe to take the reason from the trail.
292300
// todo: lifting could be used here
293301
pumpkin_assert_simple!(trail_lower_bound > not_equal_constant);
294-
reason_store.helper.push(trail_entry.predicate);
302+
reason_out.extend(std::iter::once(trail_entry.predicate));
295303
}
296304
(
297305
Predicate::LowerBound {
@@ -323,8 +331,8 @@ impl<'a> ConflictAnalysisContext<'a> {
323331
domain_id,
324332
upper_bound: equality_constant,
325333
};
326-
reason_store.helper.push(predicate_lb);
327-
reason_store.helper.push(predicate_ub);
334+
reason_out.extend(std::iter::once(predicate_lb));
335+
reason_out.extend(std::iter::once(predicate_ub));
328336
}
329337
(
330338
Predicate::UpperBound {
@@ -344,7 +352,7 @@ impl<'a> ConflictAnalysisContext<'a> {
344352
// reason for the input predicate.
345353
// todo: lifting could be applied here.
346354
if trail_upper_bound < input_upper_bound {
347-
reason_store.helper.push(trail_entry.predicate);
355+
reason_out.extend(std::iter::once(trail_entry.predicate));
348356
} else {
349357
// I think it cannot be that the bounds are equal, since otherwise we
350358
// would have found the predicate explicitly on the trail.
@@ -365,8 +373,8 @@ impl<'a> ConflictAnalysisContext<'a> {
365373
domain_id,
366374
not_equal_constant: input_upper_bound + 1,
367375
};
368-
reason_store.helper.push(new_ub_predicate);
369-
reason_store.helper.push(not_equal_predicate);
376+
reason_out.extend(std::iter::once(new_ub_predicate));
377+
reason_out.extend(std::iter::once(not_equal_predicate));
370378
}
371379
}
372380
(
@@ -387,7 +395,7 @@ impl<'a> ConflictAnalysisContext<'a> {
387395

388396
// The bound was set past the not equals, so we can safely returns the trail
389397
// reason. todo: can do lifting here.
390-
reason_store.helper.push(trail_entry.predicate);
398+
reason_out.extend(std::iter::once(trail_entry.predicate));
391399
}
392400
(
393401
Predicate::UpperBound {
@@ -422,8 +430,8 @@ impl<'a> ConflictAnalysisContext<'a> {
422430
domain_id,
423431
upper_bound: equality_constant,
424432
};
425-
reason_store.helper.push(predicate_lb);
426-
reason_store.helper.push(predicate_ub);
433+
reason_out.extend(std::iter::once(predicate_lb));
434+
reason_out.extend(std::iter::once(predicate_ub));
427435
}
428436
(
429437
Predicate::NotEqual {
@@ -457,8 +465,8 @@ impl<'a> ConflictAnalysisContext<'a> {
457465
not_equal_constant: input_lower_bound - 1,
458466
};
459467

460-
reason_store.helper.push(new_lb_predicate);
461-
reason_store.helper.push(new_not_equals_predicate);
468+
reason_out.extend(std::iter::once(new_lb_predicate));
469+
reason_out.extend(std::iter::once(new_not_equals_predicate));
462470
}
463471
(
464472
Predicate::NotEqual {
@@ -492,8 +500,8 @@ impl<'a> ConflictAnalysisContext<'a> {
492500
not_equal_constant: input_upper_bound + 1,
493501
};
494502

495-
reason_store.helper.push(new_ub_predicate);
496-
reason_store.helper.push(new_not_equals_predicate);
503+
reason_out.extend(std::iter::once(new_ub_predicate));
504+
reason_out.extend(std::iter::once(new_not_equals_predicate));
497505
}
498506
(
499507
Predicate::NotEqual {
@@ -522,15 +530,14 @@ impl<'a> ConflictAnalysisContext<'a> {
522530
upper_bound: equality_constant,
523531
};
524532

525-
reason_store.helper.push(predicate_lb);
526-
reason_store.helper.push(predicate_ub);
533+
reason_out.extend(std::iter::once(predicate_lb));
534+
reason_out.extend(std::iter::once(predicate_ub));
527535
}
528536
_ => unreachable!(
529537
"Unreachable combination of {} and {}",
530538
trail_entry.predicate, predicate
531539
),
532540
};
533-
reason_store.helper.as_slice()
534541
}
535542
}
536543
}

pumpkin-solver/src/engine/conflict_analysis/minimisers/recursive_minimiser.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,19 @@ impl RecursiveMinimiser {
117117

118118
// Due to ownership rules, we have to take ownership of the reason.
119119
// TODO: Reuse the allocation if it becomes a bottleneck.
120-
let reason = ConflictAnalysisContext::get_propagation_reason(
120+
let mut reason = vec![];
121+
ConflictAnalysisContext::get_propagation_reason(
121122
input_predicate,
122123
context.assignments,
123124
CurrentNogood::from(current_nogood),
124125
context.reason_store,
125126
context.propagators,
126127
context.proof_log,
127128
context.unit_nogood_step_ids,
128-
)
129-
.to_vec();
129+
&mut reason,
130+
);
130131

131-
for antecedent_predicate in reason {
132+
for antecedent_predicate in reason.iter().copied() {
132133
// Root assignments can be safely ignored.
133134
if context
134135
.assignments

pumpkin-solver/src/engine/conflict_analysis/resolvers/resolution_resolver.rs

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ pub(crate) struct ResolutionResolver {
3939
recursive_minimiser: RecursiveMinimiser,
4040
/// Whether the resolver employs 1-UIP or all-decision learning.
4141
mode: AnalysisMode,
42+
/// Re-usable buffer which reasons are written into.
43+
reason_buffer: Vec<Predicate>,
4244
}
4345

4446
#[derive(Debug, Clone, Copy, Default)]
@@ -138,7 +140,8 @@ impl ConflictResolver for ResolutionResolver {
138140
// However, this can lead to [x <= v] to be processed *before* [x >= v -
139141
// y], meaning that these implied predicates should be replaced with their
140142
// reason
141-
let reason = ConflictAnalysisContext::get_propagation_reason(
143+
self.reason_buffer.clear();
144+
ConflictAnalysisContext::get_propagation_reason(
142145
predicate,
143146
context.assignments,
144147
CurrentNogood::new(
@@ -150,23 +153,24 @@ impl ConflictResolver for ResolutionResolver {
150153
context.propagators,
151154
context.proof_log,
152155
context.unit_nogood_step_ids,
156+
&mut self.reason_buffer,
153157
);
154158

155-
if reason.is_empty() {
159+
if self.reason_buffer.is_empty() {
156160
// In the case when the proof is being completed, it could be the case
157161
// that the reason for a root-level propagation is empty; this
158162
// predicate will be filtered out by the semantic minimisation
159163
pumpkin_assert_simple!(context.is_completing_proof);
160164
predicate
161165
} else {
162-
pumpkin_assert_simple!(predicate.is_lower_bound_predicate() || predicate.is_not_equal_predicate(), "A non-decision predicate in the nogood should be either a lower-bound or a not-equals predicate but it was {predicate} with reason {reason:?}");
166+
pumpkin_assert_simple!(predicate.is_lower_bound_predicate() || predicate.is_not_equal_predicate(), "A non-decision predicate in the nogood should be either a lower-bound or a not-equals predicate but it was {predicate} with reason {:?}", self.reason_buffer);
163167
pumpkin_assert_simple!(
164-
reason.len() == 1 && reason[0].is_lower_bound_predicate(),
168+
self.reason_buffer.len() == 1 && self.reason_buffer[0].is_lower_bound_predicate(),
165169
"The reason for the only propagated predicates left on the trail should be lower-bound predicates, but the reason for {predicate} was {:?}",
166-
reason
170+
self.reason_buffer,
167171
);
168172

169-
reason[0]
173+
self.reason_buffer[0]
170174
}
171175
};
172176

@@ -199,7 +203,9 @@ impl ConflictResolver for ResolutionResolver {
199203
.is_initial_bound(self.peek_predicate_from_conflict_nogood())
200204
{
201205
let predicate = self.peek_predicate_from_conflict_nogood();
202-
let reason = ConflictAnalysisContext::get_propagation_reason(
206+
207+
self.reason_buffer.clear();
208+
ConflictAnalysisContext::get_propagation_reason(
203209
predicate,
204210
context.assignments,
205211
CurrentNogood::new(
@@ -211,13 +217,14 @@ impl ConflictResolver for ResolutionResolver {
211217
context.propagators,
212218
context.proof_log,
213219
context.unit_nogood_step_ids,
220+
&mut self.reason_buffer,
214221
);
215222
pumpkin_assert_simple!(predicate.is_lower_bound_predicate() || predicate.is_not_equal_predicate() , "If the final predicate in the conflict nogood is not a decision predicate then it should be either a lower-bound predicate or a not-equals predicate but was {predicate}");
216223
pumpkin_assert_simple!(
217-
reason.len() == 1 && reason[0].is_lower_bound_predicate(),
218-
"The reason for the decision predicate should be a lower-bound predicate but was {}", reason[0]
224+
self.reason_buffer.len() == 1 && self.reason_buffer[0].is_lower_bound_predicate(),
225+
"The reason for the decision predicate should be a lower-bound predicate but was {}", self.reason_buffer[0]
219226
);
220-
self.replace_predicate_in_conflict_nogood(predicate, reason[0]);
227+
self.replace_predicate_in_conflict_nogood(predicate, self.reason_buffer[0]);
221228
}
222229

223230
// The final predicate in the heap will get pushed in `extract_final_nogood`
@@ -226,7 +233,8 @@ impl ConflictResolver for ResolutionResolver {
226233
}
227234

228235
// 2.b) Standard case, get the reason for the predicate and add it to the nogood.
229-
let reason = ConflictAnalysisContext::get_propagation_reason(
236+
self.reason_buffer.clear();
237+
ConflictAnalysisContext::get_propagation_reason(
230238
next_predicate,
231239
context.assignments,
232240
CurrentNogood::new(
@@ -238,17 +246,22 @@ impl ConflictResolver for ResolutionResolver {
238246
context.propagators,
239247
context.proof_log,
240248
context.unit_nogood_step_ids,
249+
&mut self.reason_buffer,
241250
);
242251

243-
for predicate in reason.iter() {
252+
// We do a little swapping of the ownership of the buffer, so we can call
253+
// `self.add_predicate_to_conflict_nogood`.
254+
let reason = std::mem::take(&mut self.reason_buffer);
255+
for predicate in reason.iter().copied() {
244256
self.add_predicate_to_conflict_nogood(
245-
*predicate,
257+
predicate,
246258
context.assignments,
247259
context.brancher,
248260
self.mode,
249261
context.is_completing_proof,
250262
);
251263
}
264+
self.reason_buffer = reason;
252265
}
253266
Some(self.extract_final_nogood(context))
254267
}

pumpkin-solver/src/engine/constraint_satisfaction_solver.rs

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,20 +1137,18 @@ impl ConstraintSatisfactionSolver {
11371137

11381138
// Look up the reason for the bound that changed.
11391139
// The reason for changing the bound cannot be a decision, so we can safely unwrap.
1140-
let reason_changing_bound = reason_store
1141-
.get_or_compute(
1142-
entry.reason.unwrap(),
1143-
ExplanationContext::from(&*assignments),
1144-
propagators,
1145-
)
1146-
.unwrap();
1147-
11481140
let mut empty_domain_reason: Vec<Predicate> = vec![
11491141
predicate!(conflict_domain >= entry.old_lower_bound),
11501142
predicate!(conflict_domain <= entry.old_upper_bound),
11511143
];
11521144

1153-
empty_domain_reason.append(&mut reason_changing_bound.to_vec());
1145+
let _ = reason_store.get_or_compute(
1146+
entry.reason.unwrap(),
1147+
ExplanationContext::from(&*assignments),
1148+
propagators,
1149+
&mut empty_domain_reason,
1150+
);
1151+
11541152
empty_domain_reason.into()
11551153
}
11561154

@@ -1269,19 +1267,18 @@ impl ConstraintSatisfactionSolver {
12691267
) {
12701268
for trail_idx in start_trail_index..self.assignments.num_trail_entries() {
12711269
let entry = self.assignments.get_trail_entry(trail_idx);
1272-
let reason = entry
1270+
let reason_ref = entry
12731271
.reason
12741272
.expect("Added by a propagator and must therefore have a reason");
12751273

12761274
// Get the conjunction of predicates explaining the propagation.
1277-
let reason = self
1278-
.reason_store
1279-
.get_or_compute(
1280-
reason,
1281-
ExplanationContext::new(&self.assignments, CurrentNogood::empty()),
1282-
&mut self.propagators,
1283-
)
1284-
.expect("Reason ref is valid");
1275+
let mut reason = vec![];
1276+
let _ = self.reason_store.get_or_compute(
1277+
reason_ref,
1278+
ExplanationContext::new(&self.assignments, CurrentNogood::empty()),
1279+
&mut self.propagators,
1280+
&mut reason,
1281+
);
12851282

12861283
let propagated = entry.predicate;
12871284

0 commit comments

Comments
 (0)