Skip to content

Commit

Permalink
fix: allow reification of lazy reasons
Browse files Browse the repository at this point in the history
  • Loading branch information
maartenflippo committed Feb 12, 2025
1 parent 7f7f0cf commit b3b1b05
Show file tree
Hide file tree
Showing 9 changed files with 226 additions and 125 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl Debug for ConflictAnalysisContext<'_> {
}
}

impl<'a> ConflictAnalysisContext<'a> {
impl ConflictAnalysisContext<'_> {
/// Returns the last decision which was made by the solver.
pub(crate) fn find_last_decision(&mut self) -> Option<Predicate> {
self.assignments.find_last_decision()
Expand Down Expand Up @@ -132,15 +132,20 @@ impl<'a> ConflictAnalysisContext<'a> {

/// Returns the reason for a propagation; if it is implied then the reason will be the decision
/// which implied the predicate.
#[allow(
clippy::too_many_arguments,
reason = "borrow checker complains either here or elsewhere"
)]
pub(crate) fn get_propagation_reason(
predicate: Predicate,
assignments: &Assignments,
current_nogood: CurrentNogood<'_>,
reason_store: &'a mut ReasonStore,
propagators: &'a mut PropagatorStore,
proof_log: &'a mut ProofLog,
reason_store: &mut ReasonStore,
propagators: &mut PropagatorStore,
proof_log: &mut ProofLog,
unit_nogood_step_ids: &HashMap<Predicate, StepId>,
) -> &'a [Predicate] {
reason_out: &mut (impl Extend<Predicate> + AsRef<[Predicate]>),
) {
// TODO: this function could be put into the reason store

// Note that this function can only be called with propagations, and never decision
Expand All @@ -156,9 +161,8 @@ impl<'a> ConflictAnalysisContext<'a> {
// there would be only one predicate from the current decision level. For this
// reason, it is safe to assume that in the following, that any input predicate is
// indeed a propagated predicate.
reason_store.helper.clear();
if assignments.is_initial_bound(predicate) {
return reason_store.helper.as_slice();
return;
}

let trail_position = assignments
Expand All @@ -179,11 +183,17 @@ impl<'a> ConflictAnalysisContext<'a> {

let explanation_context = ExplanationContext::new(assignments, current_nogood);

let reason = reason_store
.get_or_compute(reason_ref, explanation_context, propagators)
.expect("reason reference should not be stale");
let reason_exists = reason_store.get_or_compute(
reason_ref,
explanation_context,
propagators,
reason_out,
);

assert!(reason_exists, "reason reference should not be stale");

if propagator_id == ConstraintSatisfactionSolver::get_nogood_propagator_id()
&& reason.is_empty()
&& reason_out.as_ref().is_empty()
{
// This means that a unit nogood was propagated, we indicate that this nogood step
// was used
Expand All @@ -207,12 +217,10 @@ impl<'a> ConflictAnalysisContext<'a> {
// Otherwise we log the inference which was used to derive the nogood
let _ = proof_log.log_inference(
constraint_tag,
reason.iter().copied(),
reason_out.as_ref().iter().copied(),
Some(predicate),
);
}
reason
// The predicate is implicitly due as a result of a decision.
}
// 2) The predicate is true due to a propagation, and not explicitly on the trail.
// It is necessary to further analyse what was the reason for setting the predicate true.
Expand Down Expand Up @@ -240,7 +248,7 @@ impl<'a> ConflictAnalysisContext<'a> {
// todo: could consider lifting here, since the trail bound
// might be too strong.
if trail_lower_bound > input_lower_bound {
reason_store.helper.push(trail_entry.predicate);
reason_out.extend(std::iter::once(trail_entry.predicate));
}
// Otherwise, the input bound is strictly greater than the trailed
// bound. This means the reason is due to holes in the domain.
Expand Down Expand Up @@ -270,8 +278,8 @@ impl<'a> ConflictAnalysisContext<'a> {
domain_id,
not_equal_constant: input_lower_bound - 1,
};
reason_store.helper.push(one_less_bound_predicate);
reason_store.helper.push(not_equals_predicate);
reason_out.extend(std::iter::once(one_less_bound_predicate));
reason_out.extend(std::iter::once(not_equals_predicate));
}
}
(
Expand All @@ -291,7 +299,7 @@ impl<'a> ConflictAnalysisContext<'a> {
// so it safe to take the reason from the trail.
// todo: lifting could be used here
pumpkin_assert_simple!(trail_lower_bound > not_equal_constant);
reason_store.helper.push(trail_entry.predicate);
reason_out.extend(std::iter::once(trail_entry.predicate));
}
(
Predicate::LowerBound {
Expand Down Expand Up @@ -323,8 +331,8 @@ impl<'a> ConflictAnalysisContext<'a> {
domain_id,
upper_bound: equality_constant,
};
reason_store.helper.push(predicate_lb);
reason_store.helper.push(predicate_ub);
reason_out.extend(std::iter::once(predicate_lb));
reason_out.extend(std::iter::once(predicate_ub));
}
(
Predicate::UpperBound {
Expand All @@ -344,7 +352,7 @@ impl<'a> ConflictAnalysisContext<'a> {
// reason for the input predicate.
// todo: lifting could be applied here.
if trail_upper_bound < input_upper_bound {
reason_store.helper.push(trail_entry.predicate);
reason_out.extend(std::iter::once(trail_entry.predicate));
} else {
// I think it cannot be that the bounds are equal, since otherwise we
// would have found the predicate explicitly on the trail.
Expand All @@ -365,8 +373,8 @@ impl<'a> ConflictAnalysisContext<'a> {
domain_id,
not_equal_constant: input_upper_bound + 1,
};
reason_store.helper.push(new_ub_predicate);
reason_store.helper.push(not_equal_predicate);
reason_out.extend(std::iter::once(new_ub_predicate));
reason_out.extend(std::iter::once(not_equal_predicate));
}
}
(
Expand All @@ -387,7 +395,7 @@ impl<'a> ConflictAnalysisContext<'a> {

// The bound was set past the not equals, so we can safely returns the trail
// reason. todo: can do lifting here.
reason_store.helper.push(trail_entry.predicate);
reason_out.extend(std::iter::once(trail_entry.predicate));
}
(
Predicate::UpperBound {
Expand Down Expand Up @@ -422,8 +430,8 @@ impl<'a> ConflictAnalysisContext<'a> {
domain_id,
upper_bound: equality_constant,
};
reason_store.helper.push(predicate_lb);
reason_store.helper.push(predicate_ub);
reason_out.extend(std::iter::once(predicate_lb));
reason_out.extend(std::iter::once(predicate_ub));
}
(
Predicate::NotEqual {
Expand Down Expand Up @@ -457,8 +465,8 @@ impl<'a> ConflictAnalysisContext<'a> {
not_equal_constant: input_lower_bound - 1,
};

reason_store.helper.push(new_lb_predicate);
reason_store.helper.push(new_not_equals_predicate);
reason_out.extend(std::iter::once(new_lb_predicate));
reason_out.extend(std::iter::once(new_not_equals_predicate));
}
(
Predicate::NotEqual {
Expand Down Expand Up @@ -492,8 +500,8 @@ impl<'a> ConflictAnalysisContext<'a> {
not_equal_constant: input_upper_bound + 1,
};

reason_store.helper.push(new_ub_predicate);
reason_store.helper.push(new_not_equals_predicate);
reason_out.extend(std::iter::once(new_ub_predicate));
reason_out.extend(std::iter::once(new_not_equals_predicate));
}
(
Predicate::NotEqual {
Expand Down Expand Up @@ -522,15 +530,14 @@ impl<'a> ConflictAnalysisContext<'a> {
upper_bound: equality_constant,
};

reason_store.helper.push(predicate_lb);
reason_store.helper.push(predicate_ub);
reason_out.extend(std::iter::once(predicate_lb));
reason_out.extend(std::iter::once(predicate_ub));
}
_ => unreachable!(
"Unreachable combination of {} and {}",
trail_entry.predicate, predicate
),
};
reason_store.helper.as_slice()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,19 @@ impl RecursiveMinimiser {

// Due to ownership rules, we have to take ownership of the reason.
// TODO: Reuse the allocation if it becomes a bottleneck.
let reason = ConflictAnalysisContext::get_propagation_reason(
let mut reason = vec![];
ConflictAnalysisContext::get_propagation_reason(
input_predicate,
context.assignments,
CurrentNogood::from(current_nogood),
context.reason_store,
context.propagators,
context.proof_log,
context.unit_nogood_step_ids,
)
.to_vec();
&mut reason,
);

for antecedent_predicate in reason {
for antecedent_predicate in reason.iter().copied() {
// Root assignments can be safely ignored.
if context
.assignments
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ pub(crate) struct ResolutionResolver {
recursive_minimiser: RecursiveMinimiser,
/// Whether the resolver employs 1-UIP or all-decision learning.
mode: AnalysisMode,
/// Re-usable buffer which reasons are written into.
reason_buffer: Vec<Predicate>,
}

#[derive(Debug, Clone, Copy, Default)]
Expand Down Expand Up @@ -138,7 +140,8 @@ impl ConflictResolver for ResolutionResolver {
// However, this can lead to [x <= v] to be processed *before* [x >= v -
// y], meaning that these implied predicates should be replaced with their
// reason
let reason = ConflictAnalysisContext::get_propagation_reason(
self.reason_buffer.clear();
ConflictAnalysisContext::get_propagation_reason(
predicate,
context.assignments,
CurrentNogood::new(
Expand All @@ -150,23 +153,24 @@ impl ConflictResolver for ResolutionResolver {
context.propagators,
context.proof_log,
context.unit_nogood_step_ids,
&mut self.reason_buffer,
);

if reason.is_empty() {
if self.reason_buffer.is_empty() {
// In the case when the proof is being completed, it could be the case
// that the reason for a root-level propagation is empty; this
// predicate will be filtered out by the semantic minimisation
pumpkin_assert_simple!(context.is_completing_proof);
predicate
} else {
pumpkin_assert_simple!(predicate.is_lower_bound_predicate() || predicate.is_not_equal_predicate(), "A non-decision predicate in the nogood should be either a lower-bound or a not-equals predicate but it was {predicate} with reason {reason:?}");
pumpkin_assert_simple!(predicate.is_lower_bound_predicate() || predicate.is_not_equal_predicate(), "A non-decision predicate in the nogood should be either a lower-bound or a not-equals predicate but it was {predicate} with reason {:?}", self.reason_buffer);
pumpkin_assert_simple!(
reason.len() == 1 && reason[0].is_lower_bound_predicate(),
self.reason_buffer.len() == 1 && self.reason_buffer[0].is_lower_bound_predicate(),
"The reason for the only propagated predicates left on the trail should be lower-bound predicates, but the reason for {predicate} was {:?}",
reason
self.reason_buffer,
);

reason[0]
self.reason_buffer[0]
}
};

Expand Down Expand Up @@ -199,7 +203,9 @@ impl ConflictResolver for ResolutionResolver {
.is_initial_bound(self.peek_predicate_from_conflict_nogood())
{
let predicate = self.peek_predicate_from_conflict_nogood();
let reason = ConflictAnalysisContext::get_propagation_reason(

self.reason_buffer.clear();
ConflictAnalysisContext::get_propagation_reason(
predicate,
context.assignments,
CurrentNogood::new(
Expand All @@ -211,13 +217,14 @@ impl ConflictResolver for ResolutionResolver {
context.propagators,
context.proof_log,
context.unit_nogood_step_ids,
&mut self.reason_buffer,
);
pumpkin_assert_simple!(predicate.is_lower_bound_predicate() || predicate.is_not_equal_predicate() , "If the final predicate in the conflict nogood is not a decision predicate then it should be either a lower-bound predicate or a not-equals predicate but was {predicate}");
pumpkin_assert_simple!(
reason.len() == 1 && reason[0].is_lower_bound_predicate(),
"The reason for the decision predicate should be a lower-bound predicate but was {}", reason[0]
self.reason_buffer.len() == 1 && self.reason_buffer[0].is_lower_bound_predicate(),
"The reason for the decision predicate should be a lower-bound predicate but was {}", self.reason_buffer[0]
);
self.replace_predicate_in_conflict_nogood(predicate, reason[0]);
self.replace_predicate_in_conflict_nogood(predicate, self.reason_buffer[0]);
}

// The final predicate in the heap will get pushed in `extract_final_nogood`
Expand All @@ -226,7 +233,8 @@ impl ConflictResolver for ResolutionResolver {
}

// 2.b) Standard case, get the reason for the predicate and add it to the nogood.
let reason = ConflictAnalysisContext::get_propagation_reason(
self.reason_buffer.clear();
ConflictAnalysisContext::get_propagation_reason(
next_predicate,
context.assignments,
CurrentNogood::new(
Expand All @@ -238,17 +246,22 @@ impl ConflictResolver for ResolutionResolver {
context.propagators,
context.proof_log,
context.unit_nogood_step_ids,
&mut self.reason_buffer,
);

for predicate in reason.iter() {
// We do a little swapping of the ownership of the buffer, so we can call
// `self.add_predicate_to_conflict_nogood`.
let reason = std::mem::take(&mut self.reason_buffer);
for predicate in reason.iter().copied() {
self.add_predicate_to_conflict_nogood(
*predicate,
predicate,
context.assignments,
context.brancher,
self.mode,
context.is_completing_proof,
);
}
self.reason_buffer = reason;
}
Some(self.extract_final_nogood(context))
}
Expand Down
33 changes: 15 additions & 18 deletions pumpkin-solver/src/engine/constraint_satisfaction_solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1137,20 +1137,18 @@ impl ConstraintSatisfactionSolver {

// Look up the reason for the bound that changed.
// The reason for changing the bound cannot be a decision, so we can safely unwrap.
let reason_changing_bound = reason_store
.get_or_compute(
entry.reason.unwrap(),
ExplanationContext::from(&*assignments),
propagators,
)
.unwrap();

let mut empty_domain_reason: Vec<Predicate> = vec![
predicate!(conflict_domain >= entry.old_lower_bound),
predicate!(conflict_domain <= entry.old_upper_bound),
];

empty_domain_reason.append(&mut reason_changing_bound.to_vec());
let _ = reason_store.get_or_compute(
entry.reason.unwrap(),
ExplanationContext::from(&*assignments),
propagators,
&mut empty_domain_reason,
);

empty_domain_reason.into()
}

Expand Down Expand Up @@ -1269,19 +1267,18 @@ impl ConstraintSatisfactionSolver {
) {
for trail_idx in start_trail_index..self.assignments.num_trail_entries() {
let entry = self.assignments.get_trail_entry(trail_idx);
let reason = entry
let reason_ref = entry
.reason
.expect("Added by a propagator and must therefore have a reason");

// Get the conjunction of predicates explaining the propagation.
let reason = self
.reason_store
.get_or_compute(
reason,
ExplanationContext::new(&self.assignments, CurrentNogood::empty()),
&mut self.propagators,
)
.expect("Reason ref is valid");
let mut reason = vec![];
let _ = self.reason_store.get_or_compute(
reason_ref,
ExplanationContext::new(&self.assignments, CurrentNogood::empty()),
&mut self.propagators,
&mut reason,
);

let propagated = entry.predicate;

Expand Down
Loading

0 comments on commit b3b1b05

Please sign in to comment.