diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs index a4158c8e..e8a8c8b2 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/compiler/post_constraints.rs @@ -451,12 +451,11 @@ fn compile_bool2int( let a = context.resolve_bool_variable(&exprs[0])?; let b = context.resolve_integer_variable(&exprs[1])?; - Ok(constraints::binary_equals( - a, - context.solver.new_literal_for_predicate(predicate!(b == 1)), + Ok( + constraints::binary_equals(a.get_integer_variable(), b.scaled(1)) + .post(context.solver, None) + .is_ok(), ) - .post(context.solver, None) - .is_ok()) } fn compile_bool_or( diff --git a/pumpkin-solver/src/constraints/boolean.rs b/pumpkin-solver/src/constraints/boolean.rs index b9efaf5e..4d0c4014 100644 --- a/pumpkin-solver/src/constraints/boolean.rs +++ b/pumpkin-solver/src/constraints/boolean.rs @@ -3,7 +3,6 @@ use std::num::NonZero; use super::equals; use super::less_than_or_equals; use super::Constraint; -use crate::predicate; use crate::variables::AffineView; use crate::variables::DomainId; use crate::variables::Literal; @@ -49,7 +48,7 @@ impl Constraint for BooleanLessThanOrEqual { solver: &mut Solver, tag: Option>, ) -> Result<(), ConstraintOperationError> { - let domains = self.create_domains(solver); + let domains = self.create_domains(); less_than_or_equals(domains, self.rhs).post(solver, tag) } @@ -60,34 +59,19 @@ impl Constraint for BooleanLessThanOrEqual { reification_literal: Literal, tag: Option>, ) -> Result<(), ConstraintOperationError> { - let domains = self.create_domains(solver); + let domains = self.create_domains(); less_than_or_equals(domains, self.rhs).implied_by(solver, reification_literal, tag) } } impl BooleanLessThanOrEqual { - fn create_domains(&self, solver: &mut Solver) -> Vec> { - let domains = self - .bools + fn create_domains(&self) -> Vec> { + self.bools .iter() .enumerate() - .map(|(index, bool)| { - let corresponding_domain_id = solver.new_bounded_integer(0, 1); - // bool -> [domain = 1] - let _ = solver.add_clause([ - !(*bool).get_true_predicate(), - predicate![corresponding_domain_id >= 1], - ]); - // !bool -> [domain = 0] - let _ = solver.add_clause([ - bool.get_true_predicate(), - predicate![corresponding_domain_id <= 0], - ]); - corresponding_domain_id.scaled(self.weights[index]) - }) - .collect::>(); - domains + .map(|(index, bool)| bool.get_integer_variable().scaled(self.weights[index])) + .collect() } } @@ -103,7 +87,7 @@ impl Constraint for BooleanEqual { solver: &mut Solver, tag: Option>, ) -> Result<(), ConstraintOperationError> { - let domains = self.create_domains(solver); + let domains = self.create_domains(); equals(domains, 0).post(solver, tag) } @@ -114,31 +98,18 @@ impl Constraint for BooleanEqual { reification_literal: Literal, tag: Option>, ) -> Result<(), ConstraintOperationError> { - let domains = self.create_domains(solver); + let domains = self.create_domains(); equals(domains, 0).implied_by(solver, reification_literal, tag) } } impl BooleanEqual { - fn create_domains(&self, solver: &mut Solver) -> Vec> { + fn create_domains(&self) -> Vec> { self.bools .iter() .enumerate() - .map(|(index, bool)| { - let corresponding_domain_id = solver.new_bounded_integer(0, 1); - // bool -> [domain = 1] - let _ = solver.add_clause([ - !(*bool).get_true_predicate(), - predicate![corresponding_domain_id >= 1], - ]); - // !bool -> [domain = 0] - let _ = solver.add_clause([ - (*bool).get_true_predicate(), - predicate![corresponding_domain_id <= 0], - ]); - corresponding_domain_id.scaled(self.weights[index]) - }) + .map(|(index, bool)| bool.get_integer_variable().scaled(self.weights[index])) .chain(std::iter::once(self.rhs.scaled(-1))) .collect() } diff --git a/pumpkin-solver/src/engine/variables/literal.rs b/pumpkin-solver/src/engine/variables/literal.rs index 705370de..ea713807 100644 --- a/pumpkin-solver/src/engine/variables/literal.rs +++ b/pumpkin-solver/src/engine/variables/literal.rs @@ -34,6 +34,10 @@ impl Literal { } } + pub fn get_integer_variable(&self) -> AffineView { + self.integer_variable + } + pub fn get_true_predicate(&self) -> Predicate { self.lower_bound_predicate(1) }