diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 44c5d5ca..6bb509dc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,6 +28,22 @@ jobs: - uses: dtolnay/rust-toolchain@stable - run: cargo test --release --no-fail-fast + pumpkin-py: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.x + - name: Install pytest + run: pip install pytest + - name: Install pumpkin-py + run: pip install -e . + working-directory: pumpkin-py + - name: Run tests + run: pytest + working-directory: pumpkin-py + docs: name: Documentation runs-on: ubuntu-latest diff --git a/Cargo.lock b/Cargo.lock index 88f67e31..5351d7f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -223,6 +223,26 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "enum-map" +version = "2.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6866f3bfdf8207509a033af1a75a7b08abda06bbaaeae6669323fd5a097df2e9" +dependencies = [ + "enum-map-derive", +] + +[[package]] +name = "enum-map-derive" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f282cfdfe92516eb26c2af8589c274c7c17681f5ecc03c18255fe741c6aa64eb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "enumset" version = "1.1.5" @@ -527,6 +547,7 @@ dependencies = [ "convert_case", "downcast-rs", "drcp-format 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)", + "enum-map", "enumset", "env_logger", "fixedbitset", diff --git a/drcp-format/src/writer/mod.rs b/drcp-format/src/writer/mod.rs index 799af567..54d8ba2a 100644 --- a/drcp-format/src/writer/mod.rs +++ b/drcp-format/src/writer/mod.rs @@ -293,9 +293,7 @@ impl WritableProofStep for Deletion { mod tests { use super::*; - // Safety: Unwrapping an option is not stable, so we cannot get a NonZero safely in a const - // context. - const TEST_ID: NonZeroU64 = unsafe { NonZeroU64::new_unchecked(1) }; + const TEST_ID: NonZeroU64 = NonZeroU64::new(1).unwrap(); #[test] fn write_basic_inference() { diff --git a/pumpkin-py/src/lib.rs b/pumpkin-py/src/lib.rs index c33a265b..4957b490 100644 --- a/pumpkin-py/src/lib.rs +++ b/pumpkin-py/src/lib.rs @@ -1,5 +1,6 @@ mod constraints; mod model; +mod optimisation; mod result; mod variables; @@ -27,12 +28,15 @@ macro_rules! submodule { fn pumpkin_py(python: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; - m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; submodule!(constraints, python, m); + submodule!(optimisation, python, m); Ok(()) } diff --git a/pumpkin-py/src/model.rs b/pumpkin-py/src/model.rs index 21fa59ad..e112816b 100644 --- a/pumpkin-py/src/model.rs +++ b/pumpkin-py/src/model.rs @@ -2,6 +2,9 @@ use std::num::NonZero; use std::path::PathBuf; use pumpkin_solver::containers::KeyedVec; +use pumpkin_solver::optimisation::linear_sat_unsat::LinearSatUnsat; +use pumpkin_solver::optimisation::linear_unsat_sat::LinearUnsatSat; +use pumpkin_solver::optimisation::OptimisationDirection; use pumpkin_solver::options::SolverOptions; use pumpkin_solver::predicate; use pumpkin_solver::proof::Format; @@ -14,12 +17,17 @@ use pumpkin_solver::Solver; use pyo3::prelude::*; use crate::constraints::Constraint; +use crate::optimisation::Direction; +use crate::optimisation::OptimisationResult; +use crate::optimisation::Optimiser; use crate::result::SatisfactionResult; +use crate::result::SatisfactionUnderAssumptionsResult; use crate::result::Solution; use crate::variables::BoolExpression; use crate::variables::BoolVariable; use crate::variables::IntExpression; use crate::variables::IntVariable; +use crate::variables::Predicate; use crate::variables::VariableMap; #[pyclass] @@ -30,15 +38,6 @@ pub struct Model { constraints: Vec, } -#[pyclass(eq, eq_int)] -#[derive(Clone, Copy, PartialEq, Eq)] -pub enum Comparator { - NotEqual, - Equal, - LessThanOrEqual, - GreaterThanOrEqual, -} - #[pymethods] impl Model { #[new] @@ -115,23 +114,13 @@ impl Model { } } - #[pyo3(signature = (integer, comparator, value, name=None))] - fn predicate_as_boolean( - &mut self, - integer: IntExpression, - comparator: Comparator, - value: i32, - name: Option<&str>, - ) -> BoolExpression { + #[pyo3(signature = (predicate, name=None))] + fn predicate_as_boolean(&mut self, predicate: Predicate, name: Option<&str>) -> BoolExpression { self.boolean_variables .push(ModelBoolVar { name: name.map(|n| n.to_owned()), integer_equivalent: None, - predicate: Some(Predicate { - integer, - comparator, - value, - }), + predicate: Some(predicate), }) .into() } @@ -163,27 +152,9 @@ impl Model { #[pyo3(signature = (proof=None))] fn satisfy(&self, proof: Option) -> SatisfactionResult { - let proof_log = proof - .map(|path| ProofLog::cp(&path, Format::Text, true, true)) - .transpose() - .map(|proof| proof.unwrap_or_default()) - .expect("failed to create proof file"); + let solver_setup = self.create_solver(proof); - let options = SolverOptions { - proof_log, - ..Default::default() - }; - - let mut solver = Solver::with_options(options); - - let solver_setup = self - .create_variable_map(&mut solver) - .and_then(|variable_map| { - self.post_constraints(&mut solver, &variable_map)?; - Ok(variable_map) - }); - - let Ok(variable_map) = solver_setup else { + let Ok((mut solver, variable_map)) = solver_setup else { return SatisfactionResult::Unsatisfiable(); }; @@ -202,6 +173,126 @@ impl Model { pumpkin_solver::results::SatisfactionResult::Unknown => SatisfactionResult::Unknown(), } } + + #[pyo3(signature = (assumptions))] + fn satisfy_under_assumptions( + &self, + assumptions: Vec, + ) -> SatisfactionUnderAssumptionsResult { + let solver_setup = self.create_solver(None); + + let Ok((mut solver, variable_map)) = solver_setup else { + return SatisfactionUnderAssumptionsResult::Unsatisfiable(); + }; + + let mut brancher = solver.default_brancher(); + + let solver_assumptions = assumptions + .iter() + .map(|pred| pred.to_solver_predicate(&variable_map)) + .collect::>(); + + // Maarten: I do not understand why it is necessary, but we have to create a local variable + // here that is the result of the `match` statement. Otherwise the compiler + // complains that `solver` and `brancher` potentially do not live long enough. + // + // Ideally this would not be necessary, but perhaps it is unavoidable with the setup we + // currently have. Either way, we take the suggestion by the compiler. + let result = match solver.satisfy_under_assumptions(&mut brancher, &mut Indefinite, &solver_assumptions) { + pumpkin_solver::results::SatisfactionResultUnderAssumptions::Satisfiable(solution) => { + SatisfactionUnderAssumptionsResult::Satisfiable(Solution { + solver_solution: solution, + variable_map, + }) + } + pumpkin_solver::results::SatisfactionResultUnderAssumptions::UnsatisfiableUnderAssumptions(mut result) => { + // Maarten: For now we assume that the core _must_ consist of the predicates that + // were the input to the solve call. In general this is not the case, e.g. when + // the assumptions can be semantically minized (the assumptions [y <= 1], + // [y >= 0] and [y != 0] will be compressed to [y == 1] which would end up in + // the core). + // + // In the future, perhaps we should make the distinction between predicates and + // literals in the python wrapper as well. For now, this is the simplest way + // forward. I expect that the situation above almost never happens in practice. + let core = result + .extract_core() + .iter() + .map(|predicate| assumptions + .iter() + .find(|pred| pred.to_solver_predicate(&variable_map) == *predicate) + .copied() + .expect("predicates in core must be part of the assumptions")) + .collect(); + + SatisfactionUnderAssumptionsResult::UnsatisfiableUnderAssumptions(core) + } + pumpkin_solver::results::SatisfactionResultUnderAssumptions::Unsatisfiable => { + SatisfactionUnderAssumptionsResult::Unsatisfiable() + } + pumpkin_solver::results::SatisfactionResultUnderAssumptions::Unknown => { + SatisfactionUnderAssumptionsResult::Unknown() + } + }; + + result + } + + #[pyo3(signature = (objective, optimiser=Optimiser::LinearSatUnsat, direction=Direction::Minimise, proof=None))] + fn optimise( + &self, + objective: IntExpression, + optimiser: Optimiser, + direction: Direction, + proof: Option, + ) -> OptimisationResult { + let solver_setup = self.create_solver(proof); + + let Ok((mut solver, variable_map)) = solver_setup else { + return OptimisationResult::Unsatisfiable(); + }; + + let mut brancher = solver.default_brancher(); + + let direction = match direction { + Direction::Minimise => OptimisationDirection::Minimise, + Direction::Maximise => OptimisationDirection::Maximise, + }; + + let objective = objective.to_affine_view(&variable_map); + + let result = match optimiser { + Optimiser::LinearSatUnsat => solver.optimise( + &mut brancher, + &mut Indefinite, + LinearSatUnsat::new(direction, objective, |_, _| {}), + ), + Optimiser::LinearUnsatSat => solver.optimise( + &mut brancher, + &mut Indefinite, + LinearUnsatSat::new(direction, objective, |_, _| {}), + ), + }; + + match result { + pumpkin_solver::results::OptimisationResult::Satisfiable(solution) => { + OptimisationResult::Satisfiable(Solution { + solver_solution: solution, + variable_map, + }) + } + pumpkin_solver::results::OptimisationResult::Optimal(solution) => { + OptimisationResult::Optimal(Solution { + solver_solution: solution, + variable_map, + }) + } + pumpkin_solver::results::OptimisationResult::Unsatisfiable => { + OptimisationResult::Unsatisfiable() + } + pumpkin_solver::results::OptimisationResult::Unknown => OptimisationResult::Unknown(), + } + } } impl Model { @@ -252,6 +343,33 @@ impl Model { Ok(()) } + + fn create_solver( + &self, + proof: Option, + ) -> Result<(Solver, VariableMap), ConstraintOperationError> { + let proof_log = proof + .map(|path| ProofLog::cp(&path, Format::Text, true, true)) + .transpose() + .map(|proof| proof.unwrap_or_default()) + .expect("failed to create proof file"); + + let options = SolverOptions { + proof_log, + ..Default::default() + }; + + let mut solver = Solver::with_options(options); + + let variable_map = self + .create_variable_map(&mut solver) + .and_then(|variable_map| { + self.post_constraints(&mut solver, &variable_map)?; + Ok(variable_map) + })?; + + Ok((solver, variable_map)) + } } #[derive(Clone)] @@ -340,26 +458,3 @@ impl ModelBoolVar { Ok(literal) } } - -struct Predicate { - integer: IntExpression, - comparator: Comparator, - value: i32, -} - -impl Predicate { - /// Convert the predicate in the model domain to a predicate in the solver domain. - fn to_solver_predicate( - &self, - variable_map: &VariableMap, - ) -> pumpkin_solver::predicates::Predicate { - let affine_view = self.integer.to_affine_view(variable_map); - - match self.comparator { - Comparator::NotEqual => predicate![affine_view != self.value], - Comparator::Equal => predicate![affine_view == self.value], - Comparator::LessThanOrEqual => predicate![affine_view <= self.value], - Comparator::GreaterThanOrEqual => predicate![affine_view >= self.value], - } - } -} diff --git a/pumpkin-py/src/optimisation.rs b/pumpkin-py/src/optimisation.rs new file mode 100644 index 00000000..429caa3d --- /dev/null +++ b/pumpkin-py/src/optimisation.rs @@ -0,0 +1,36 @@ +use pyo3::prelude::*; + +use crate::result::Solution; + +#[pyclass] +pub enum OptimisationResult { + /// The problem was solved to optimality, and the solution is an optimal one. + Optimal(Solution), + /// At least one solution was identified, and the solution is the best one. + Satisfiable(Solution), + /// The problem was unsatisfiable. + Unsatisfiable(), + /// None of the other variants were concluded. + Unknown(), +} + +#[pyclass(eq, eq_int)] +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum Optimiser { + LinearSatUnsat, + LinearUnsatSat, +} + +#[pyclass(eq, eq_int)] +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum Direction { + Minimise, + Maximise, +} + +pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) +} diff --git a/pumpkin-py/src/result.rs b/pumpkin-py/src/result.rs index 99e7dbe1..fc4653ae 100644 --- a/pumpkin-py/src/result.rs +++ b/pumpkin-py/src/result.rs @@ -3,6 +3,7 @@ use pyo3::prelude::*; use crate::variables::BoolExpression; use crate::variables::IntExpression; +use crate::variables::Predicate; use crate::variables::VariableMap; #[pyclass] @@ -13,6 +14,15 @@ pub enum SatisfactionResult { Unknown(), } +#[pyclass] +#[allow(clippy::large_enum_variant)] +pub enum SatisfactionUnderAssumptionsResult { + Satisfiable(Solution), + UnsatisfiableUnderAssumptions(Vec), + Unsatisfiable(), + Unknown(), +} + #[pyclass] #[derive(Clone)] pub struct Solution { @@ -32,3 +42,7 @@ impl Solution { .get_literal_value(variable.to_literal(&self.variable_map)) } } + +#[pyclass] +#[derive(Clone)] +pub struct CoreExtractor {} diff --git a/pumpkin-py/src/variables.rs b/pumpkin-py/src/variables.rs index 6e47d888..1ca25636 100644 --- a/pumpkin-py/src/variables.rs +++ b/pumpkin-py/src/variables.rs @@ -1,12 +1,13 @@ use pumpkin_solver::containers::KeyedVec; use pumpkin_solver::containers::StorageKey; +use pumpkin_solver::predicate; use pumpkin_solver::variables::AffineView; use pumpkin_solver::variables::DomainId; use pumpkin_solver::variables::Literal; use pumpkin_solver::variables::TransformableVariable; use pyo3::prelude::*; -#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)] pub struct IntVariable(usize); impl StorageKey for IntVariable { @@ -19,8 +20,8 @@ impl StorageKey for IntVariable { } } -#[pyclass] -#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] +#[pyclass(eq, hash, frozen)] +#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)] pub struct IntExpression { pub variable: IntVariable, pub offset: i32, @@ -83,6 +84,52 @@ impl IntExpression { } } +#[pyclass(eq, eq_int, hash, frozen)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum Comparator { + NotEqual, + Equal, + LessThanOrEqual, + GreaterThanOrEqual, +} + +#[pyclass(eq, get_all, hash, frozen)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct Predicate { + pub variable: IntExpression, + pub comparator: Comparator, + pub value: i32, +} + +#[pymethods] +impl Predicate { + #[new] + fn new(variable: IntExpression, comparator: Comparator, value: i32) -> Self { + Self { + variable, + comparator, + value, + } + } +} + +impl Predicate { + /// Convert the predicate in the model domain to a predicate in the solver domain. + pub(crate) fn to_solver_predicate( + self, + variable_map: &VariableMap, + ) -> pumpkin_solver::predicates::Predicate { + let affine_view = self.variable.to_affine_view(variable_map); + + match self.comparator { + Comparator::NotEqual => predicate![affine_view != self.value], + Comparator::Equal => predicate![affine_view == self.value], + Comparator::LessThanOrEqual => predicate![affine_view <= self.value], + Comparator::GreaterThanOrEqual => predicate![affine_view >= self.value], + } + } +} + #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] pub struct BoolVariable(usize); @@ -96,7 +143,7 @@ impl StorageKey for BoolVariable { } } -#[pyclass] +#[pyclass(eq)] #[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] pub struct BoolExpression(BoolVariable, bool); diff --git a/pumpkin-py/tests/test_assumptions.py b/pumpkin-py/tests/test_assumptions.py new file mode 100644 index 00000000..37f1b3a3 --- /dev/null +++ b/pumpkin-py/tests/test_assumptions.py @@ -0,0 +1,36 @@ +from pumpkin_py import Comparator, Model, Predicate, SatisfactionUnderAssumptionsResult +from pumpkin_py.constraints import LessThanOrEquals + + +def test_assumptions_are_respected(): + model = Model() + + x = model.new_integer_variable(1, 5, name="x") + + assumption = Predicate(x, Comparator.LessThanOrEqual, 3) + + result = model.satisfy_under_assumptions([assumption]) + assert isinstance(result, SatisfactionUnderAssumptionsResult.Satisfiable) + + solution = result._0 + x_value = solution.int_value(x) + assert x_value <= 3 + + +def test_core_extraction(): + model = Model() + + x = model.new_integer_variable(1, 5, name="x") + y = model.new_integer_variable(1, 5, name="x") + + x_ge_3 = Predicate(x, Comparator.GreaterThanOrEqual, 3) + y_ge_3 = Predicate(y, Comparator.GreaterThanOrEqual, 3) + + model.add_constraint(LessThanOrEquals([x, y], 5)) + + result = model.satisfy_under_assumptions([x_ge_3, y_ge_3]) + assert isinstance(result, SatisfactionUnderAssumptionsResult.UnsatisfiableUnderAssumptions) + + core = set(result._0) + assert set([x_ge_3, y_ge_3]) == core + diff --git a/pumpkin-py/tests/test_constraints.py b/pumpkin-py/tests/test_constraints.py new file mode 100644 index 00000000..213fb52a --- /dev/null +++ b/pumpkin-py/tests/test_constraints.py @@ -0,0 +1,143 @@ +""" + Generate constraints and expressions based on the grammar supported by the API + + Generates linear constraints, special operators and global constraints. + Whenever possible, the script also generates 'boolean as integer' versions of the arguments +""" + +import pytest +from pumpkin_py import constraints +import pumpkin_py + +# generate all linear sum-expressions +def generate_linear(): + for comp in "<=", "==", "!=": + for scaled in (False, True): # to generate a weighted sum + for bool in (False, True): # from bool-view? + model = pumpkin_py.Model() + + if bool: + args = [model.boolean_as_integer(model.new_boolean_variable(name=f"x[{i}]")) for i in range(3)] + else: + args = [model.new_integer_variable(-3, 5, name=f"x[{i}]") for i in range(3)] + if scaled: # do scaling (0, -2, 4,...) + args = [a.scaled(-2 * i + 1) for i, a in enumerate(args)] # TODO: div by zero when scale = 0, fixed with +1 + + rhs = 1 + if comp == "==": + cons = constraints.Equals(args, rhs) + if comp == "!=": + cons = constraints.NotEquals(args, rhs) + if comp == "<=": + cons = constraints.LessThanOrEquals(args, rhs) + + yield model, cons, comp, scaled, bool + +# generate other operators +def generate_operators(): + for name in ['div', 'mul', 'abs', 'min', 'max', 'element']: + for scaled in (False, True): + for bool in (False, True): # from bool-view? + model = pumpkin_py.Model() + + if bool: + args = [model.boolean_as_integer(model.new_boolean_variable(name=f"x[{i}]")) for i in range(3)] + else: + args = [model.new_integer_variable(-3, 5, name=f"x[{i}]") for i in range(3)] + if scaled: # do scaling (0, -2, 4,...) + args = [a.scaled(-2 * i + 1) for i, a in enumerate(args)] # TODO: div by zero when scale = 0, fixed with +1 + + rhs = model.new_integer_variable(-3, 5, name="rhs") + if name == "div": + denom = model.new_integer_variable(1, 3, name="denom") + cons = constraints.Division(args[0], denom, rhs) + if name == "mul": + cons = constraints.Times(*args[:2], rhs) + if name == "abs": + cons = constraints.Absolute(args[0], rhs) + if name == "min": + cons = constraints.Minimum(args, rhs) + if name == "max": + cons = constraints.Maximum(args, rhs) + if name == "element": + idx = model.new_integer_variable(-1, 5, name=f"idx") # sneaky, idx can be out of bounds + cons = constraints.Element(idx, args, rhs) + + yield model, cons, name, scaled, bool + +# generate global constraints, separate functions for readability +def generate_alldiff(): + + for scaled in (False, True): + for bool in (False, True): # from bool-view? Unlikely constraint, but anyway + model = pumpkin_py.Model() + if bool: + args = [model.boolean_as_integer(model.new_boolean_variable(name=f"x[{i}]")) for i in range(3)] + else: + args = [model.new_integer_variable(-3, 5, name=f"x[{i}]") for i in range(3)] + if scaled or bool: # do scaling (0, -2, 4,...) + args = [a.scaled(-2 * i +1) for i, a in enumerate(args)] # TODO: div by zero when scale = 0, fixed with +1 + + cons = constraints.AllDifferent(args) + yield model, cons, "alldifferent", scaled or bool, bool + +def generate_cumulative(): + duration = [2, 3, 4] + demand = [1, 2, 3] + capacity = 4 + + model = pumpkin_py.Model() + start = [model.new_integer_variable(-3, 5, name=f"x[{i}]") for i in range(3)] + cons = constraints.Cumulative(start, duration, demand, capacity) + yield model, cons, "cumulative", False, False + + model = pumpkin_py.Model() + start = [model.new_integer_variable(-3, 5, name=f"x[{i}]") for i in range(3)] + start = [a.scaled(-2 * i) for i, a in enumerate(start)] + cons = constraints.Cumulative(start, duration, demand, capacity) + yield model, cons, "cumulative", True, False + + +def generate_globals(): + + yield from generate_alldiff() + yield from generate_cumulative() + +def label(model, cons, name, scaled, bool): + return " ".join(["Scaled" if scaled else "Unscaled", "Boolean" if bool else "Integer", name]) + +LINEAR = list(generate_operators()) +@pytest.mark.parametrize(("model", "cons", "name", "scaled", "bool"), LINEAR, ids =[label(*a) for a in LINEAR]) +def test_linear(model, cons, name, scaled, bool): + + model.add_constraint(cons) + res = model.satisfy(proof="proof") + assert isinstance(res, pumpkin_py.SatisfactionResult.Satisfiable) + +OPERATORS = list(generate_operators()) +@pytest.mark.parametrize(("model", "cons", "name", "scaled", "bool"), OPERATORS, ids =[label(*a) for a in OPERATORS]) +def test_operators(model, cons, name, scaled, bool): + + model.add_constraint(cons) + res = model.satisfy(proof="proof") + assert isinstance(res, pumpkin_py.SatisfactionResult.Satisfiable) + +GLOBALS = list(generate_globals()) +@pytest.mark.parametrize(("model", "cons", "name", "scaled", "bool"), GLOBALS, ids =[label(*a) for a in GLOBALS]) +def test_global(model, cons, name, scaled, bool): + + model.add_constraint(cons) + res = model.satisfy(proof="proof") + assert isinstance(res, pumpkin_py.SatisfactionResult.Satisfiable) + +ALL_EXPR = list(generate_operators()) + list(generate_linear()) + list(generate_globals()) +@pytest.mark.parametrize(("model", "cons", "name", "scaled", "bool"), ALL_EXPR, ids=["->"+label(*a) for a in ALL_EXPR]) +def test_implication(model, cons, name, scaled, bool): + + if name == 'element': + return # TODO: propagator not yet implemented? + + bv = model.new_boolean_variable("bv") + model.add_implication(cons, bv) + res = model.satisfy(proof="proof") + assert isinstance(res, pumpkin_py.SatisfactionResult.Satisfiable) diff --git a/pumpkin-py/tests/test_optimisation.py b/pumpkin-py/tests/test_optimisation.py new file mode 100644 index 00000000..1cb270d9 --- /dev/null +++ b/pumpkin-py/tests/test_optimisation.py @@ -0,0 +1,30 @@ +from pumpkin_py import Model +from pumpkin_py.optimisation import Direction, OptimisationResult + + +def test_linear_sat_unsat_minimisation(): + model = Model() + + objective = model.new_integer_variable(1, 5, name="objective") + + result = model.optimise(objective, direction=Direction.Minimise) + + assert isinstance(result, OptimisationResult.Optimal) + + solution = result._0 + assert solution.int_value(objective) == 1 + + +def test_linear_sat_unsat_maximisation(): + model = Model() + + objective = model.new_integer_variable(1, 5, name="objective") + + result = model.optimise(objective, direction=Direction.Maximise) + + assert isinstance(result, OptimisationResult.Optimal) + + solution = result._0 + assert solution.int_value(objective) == 5 + + diff --git a/pumpkin-solver/Cargo.toml b/pumpkin-solver/Cargo.toml index b6e2007b..3e28f090 100644 --- a/pumpkin-solver/Cargo.toml +++ b/pumpkin-solver/Cargo.toml @@ -27,6 +27,7 @@ env_logger = "0.10.0" bitfield-struct = "0.9.2" num = "0.4.3" fixedbitset = "0.5.7" +enum-map = "2.7.3" [dev-dependencies] clap = { version = "4.5.17", features = ["derive"] } diff --git a/pumpkin-solver/examples/disjunctive_scheduling.rs b/pumpkin-solver/examples/disjunctive_scheduling.rs index daf045db..885a8466 100644 --- a/pumpkin-solver/examples/disjunctive_scheduling.rs +++ b/pumpkin-solver/examples/disjunctive_scheduling.rs @@ -1,10 +1,16 @@ //! A simple model for disjunctive scheduling using reified constraints //! Given a set of tasks and their processing times, it finds a schedule such that none of the jobs -//! overlap It thus finds a schedule such that either s_i >= s_j + p_j or s_j >= s_i + p_i (i.e. -//! either job i starts after j or job j starts after i) +//! overlap. The optimal schedule is thus all tasks scheduled right after each other. +//! +//! For two tasks x and y, either x ends before y starts, or y ends before x starts. So if s_i is +//! the start time of task i and p_i is then we can express the condition that x ends before y +//! starts as s_x + p_x <= s_y, and that y ends before x starts as s_y + p_y <= s_x. +//! +//! To ensure that one of these occurs, we create two Boolean variables, l_xy and l_yx, to signify +//! the two possibilities, and then post the constraint (l_xy \/ l_yx). use pumpkin_solver::constraints; -use pumpkin_solver::constraints::Constraint; +use pumpkin_solver::constraints::NegatableConstraint; use pumpkin_solver::results::ProblemSolution; use pumpkin_solver::results::SatisfactionResult; use pumpkin_solver::termination::Indefinite; @@ -38,8 +44,8 @@ fn main() { .map(|i| solver.new_bounded_integer(0, (horizon - processing_times[i]) as i32)) .collect::>(); - // Literal which indicates precedence (i.e. if precedence_literals[x][y] => s_y + p_y <= s_x - // which is equal to s_y - s_x <= -p_y) + // Literal which indicates precedence (i.e. precedence_literals[x][y] <=> x ends before y + // starts) let precedence_literals = (0..n_tasks) .map(|_| { (0..n_tasks) @@ -53,17 +59,14 @@ fn main() { if x == y { continue; } + // precedence_literals[x][y] <=> x ends before y starts let literal = precedence_literals[x][y]; - let variables = vec![start_variables[y].scaled(1), start_variables[x].scaled(-1)]; - // literal => s_y - s_x <= -p_y) - let _ = - constraints::less_than_or_equals(variables.clone(), -(processing_times[y] as i32)) - .implied_by(&mut solver, literal, None); - - //-literal => -s_y + s_x <= p_y) + // literal <=> (s_x + p_x <= s_y) + // equivelent to literal <=> (s_x - s_y <= -p_x) + // So the variables are -s_y and s_x, and the rhs is -p_x let variables = vec![start_variables[y].scaled(-1), start_variables[x].scaled(1)]; - let _ = constraints::less_than_or_equals(variables.clone(), processing_times[y] as i32) - .implied_by(&mut solver, literal, None); + let _ = constraints::less_than_or_equals(variables, -(processing_times[x] as i32)) + .reify(&mut solver, literal, None); // Either x starts before y or y start before x let _ = solver.add_clause([ diff --git a/pumpkin-solver/src/api/mod.rs b/pumpkin-solver/src/api/mod.rs index 9ae51742..63ffc5e5 100644 --- a/pumpkin-solver/src/api/mod.rs +++ b/pumpkin-solver/src/api/mod.rs @@ -1,4 +1,5 @@ mod outputs; + pub(crate) mod solver; pub mod results { @@ -14,7 +15,6 @@ pub mod results { //! right state for these operations. For example, //! [`SatisfactionResultUnderAssumptions::UnsatisfiableUnderAssumptions`] allows you to extract //! a core consisting of the assumptions using [`UnsatisfiableUnderAssumptions::extract_core`]. - pub use crate::api::outputs::solution_callback_arguments::SolutionCallbackArguments; pub use crate::api::outputs::solution_iterator; pub use crate::api::outputs::unsatisfiable; pub use crate::api::outputs::OptimisationResult; @@ -94,7 +94,7 @@ pub mod termination { } pub mod predicates { - //! Containts structures which represent certain [predicates](https://en.wikipedia.org/wiki/Predicate_(mathematical_logic)). + //! Contains structures which represent certain [predicates](https://en.wikipedia.org/wiki/Predicate_(mathematical_logic)). //! //! The solver only utilizes the following types of predicates: //! - A [`Predicate::LowerBound`] of the form `[x >= v]` diff --git a/pumpkin-solver/src/api/outputs/mod.rs b/pumpkin-solver/src/api/outputs/mod.rs index a66dbdc8..6ac19907 100644 --- a/pumpkin-solver/src/api/outputs/mod.rs +++ b/pumpkin-solver/src/api/outputs/mod.rs @@ -2,7 +2,6 @@ use self::unsatisfiable::UnsatisfiableUnderAssumptions; pub use crate::basic_types::ProblemSolution; use crate::basic_types::Solution; pub use crate::basic_types::SolutionReference; -pub(crate) mod solution_callback_arguments; pub mod solution_iterator; pub mod unsatisfiable; use crate::branching::Brancher; @@ -47,7 +46,7 @@ pub enum SatisfactionResultUnderAssumptions<'solver, 'brancher, B: Brancher> { Unknown, } -/// The result of a call to [`Solver::maximise`] or [`Solver::minimise`]. +/// The result of a call to [`Solver::optimise`]. #[derive(Debug)] pub enum OptimisationResult { /// Indicates that an optimal solution has been found and proven to be optimal. It provides an diff --git a/pumpkin-solver/src/api/outputs/solution_callback_arguments.rs b/pumpkin-solver/src/api/outputs/solution_callback_arguments.rs deleted file mode 100644 index 805e9790..00000000 --- a/pumpkin-solver/src/api/outputs/solution_callback_arguments.rs +++ /dev/null @@ -1,43 +0,0 @@ -use crate::results::Solution; -use crate::Solver; - -/// The input which is passed to the solution callback (which can be set using -/// [`Solver::with_solution_callback`]). -/// -/// Provides direct access to the solution via [`SolutionCallbackArguments::solution`] and allows -/// logging the statistics of the [`Solver`] using [`SolutionCallbackArguments::log_statistics`]. -#[derive(Debug)] -pub struct SolutionCallbackArguments<'a, 'b> { - /// The solver which found the solution - solver: &'a Solver, - /// The solution which has been found - pub solution: &'b Solution, - /// The (optional) objective value provided to the [`Solver`]. - objective_value: Option, -} - -impl<'a, 'b> SolutionCallbackArguments<'a, 'b> { - pub(crate) fn new( - solver: &'a Solver, - solution: &'b Solution, - objective_value: Option, - ) -> Self { - Self { - solver, - solution, - objective_value, - } - } - - /// Log the statistics of the [`Solver`]. - /// - /// If the solution was found using [`Solver::minimise`] or [`Solver::maximise`] then the - /// objective value of the current solution is included in the statistics. - pub fn log_statistics(&self) { - if let Some(objective_value) = self.objective_value { - self.solver.log_statistics_with_objective(objective_value) - } else { - self.solver.log_statistics() - } - } -} diff --git a/pumpkin-solver/src/api/outputs/solution_iterator.rs b/pumpkin-solver/src/api/outputs/solution_iterator.rs index fa0221b0..2009a4b8 100644 --- a/pumpkin-solver/src/api/outputs/solution_iterator.rs +++ b/pumpkin-solver/src/api/outputs/solution_iterator.rs @@ -53,7 +53,7 @@ impl<'solver, 'brancher, 'termination, B: Brancher, T: TerminationCondition> Satisfiable(solution) => { self.has_solution = true; self.next_blocking_clause = Some(get_blocking_clause(&solution)); - IteratedSolution::Solution(solution) + IteratedSolution::Solution(solution, self.solver) } Unsatisfiable => { if self.has_solution { @@ -84,9 +84,9 @@ fn get_blocking_clause(solution: &Solution) -> Vec { reason = "these will not be stored in bulk, so this is not an issue" )] #[derive(Debug)] -pub enum IteratedSolution { +pub enum IteratedSolution<'a> { /// A new solution was identified. - Solution(Solution), + Solution(Solution, &'a Solver), /// No more solutions exist. Finished, diff --git a/pumpkin-solver/src/api/solver.rs b/pumpkin-solver/src/api/solver.rs index 76cd0499..4a589070 100644 --- a/pumpkin-solver/src/api/solver.rs +++ b/pumpkin-solver/src/api/solver.rs @@ -1,5 +1,6 @@ use std::num::NonZero; +use super::outputs::SolutionReference; use super::results::OptimisationResult; use super::results::SatisfactionResult; use super::results::SatisfactionResultUnderAssumptions; @@ -9,11 +10,10 @@ use crate::basic_types::HashSet; use crate::basic_types::Solution; use crate::branching::branchers::autonomous_search::AutonomousSearch; use crate::branching::branchers::independent_variable_value_brancher::IndependentVariableValueBrancher; -use crate::branching::tie_breaking::InOrderTieBreaker; -use crate::branching::value_selection::InDomainMin; +use crate::branching::value_selection::RandomSplitter; #[cfg(doc)] use crate::branching::value_selection::ValueSelector; -use crate::branching::variable_selection::Smallest; +use crate::branching::variable_selection::RandomSelector; #[cfg(doc)] use crate::branching::variable_selection::VariableSelector; use crate::branching::Brancher; @@ -25,12 +25,16 @@ use crate::engine::variables::DomainId; use crate::engine::variables::IntegerVariable; use crate::engine::variables::Literal; use crate::engine::ConstraintSatisfactionSolver; +#[cfg(doc)] +use crate::optimisation::linear_sat_unsat::LinearSatUnsat; +#[cfg(doc)] +use crate::optimisation::linear_unsat_sat::LinearUnsatSat; +use crate::optimisation::OptimisationProcedure; use crate::options::SolverOptions; -use crate::predicate; -use crate::pumpkin_assert_simple; +#[cfg(doc)] +use crate::predicates; use crate::results::solution_iterator::SolutionIterator; use crate::results::unsatisfiable::UnsatisfiableUnderAssumptions; -use crate::results::SolutionCallbackArguments; use crate::statistics::log_statistic; use crate::statistics::log_statistic_postfix; @@ -83,12 +87,10 @@ use crate::statistics::log_statistic_postfix; /// /// # Using the Solver /// For examples on how to use the solver, see the [root-level crate documentation](crate) or [one of these examples](https://github.com/ConSol-Lab/Pumpkin/tree/master/pumpkin-lib/examples). +#[derive(Debug)] pub struct Solver { /// The internal [`ConstraintSatisfactionSolver`] which is used to solve the problems. - satisfaction_solver: ConstraintSatisfactionSolver, - /// The function is called whenever an optimisation function finds a solution; see - /// [`Solver::with_solution_callback`]. - solution_callback: Box, + pub(crate) satisfaction_solver: ConstraintSatisfactionSolver, true_literal: Literal, } @@ -98,25 +100,11 @@ impl Default for Solver { let true_literal = Literal::new(Predicate::trivially_true().get_domain()); Self { satisfaction_solver, - solution_callback: create_empty_function(), true_literal, } } } -/// Creates a place-holder empty function which does not do anything when a solution is found. -fn create_empty_function() -> Box { - Box::new(|_| {}) -} - -impl std::fmt::Debug for Solver { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Solver") - .field("satisfaction_solver", &self.satisfaction_solver) - .finish() - } -} - impl Solver { /// Creates a solver with the provided [`SolverOptions`]. pub fn with_options(solver_options: SolverOptions) -> Self { @@ -124,24 +112,10 @@ impl Solver { let true_literal = Literal::new(Predicate::trivially_true().get_domain()); Self { satisfaction_solver, - solution_callback: create_empty_function(), true_literal, } } - /// Adds a call-back to the [`Solver`] which is called every time that a solution is found when - /// optimising using [`Solver::maximise`] or [`Solver::minimise`]. - /// - /// Note that this will also - /// perform the call-back on the optimal solution which is returned in - /// [`OptimisationResult::Optimal`]. - pub fn with_solution_callback( - &mut self, - solution_callback: impl Fn(SolutionCallbackArguments) + 'static, - ) { - self.solution_callback = Box::new(solution_callback); - } - /// Logs the statistics currently present in the solver with the provided objective value. pub fn log_statistics_with_objective(&self, objective_value: i64) { log_statistic("objective", objective_value); @@ -157,6 +131,10 @@ impl Solver { pub(crate) fn get_satisfaction_solver_mut(&mut self) -> &mut ConstraintSatisfactionSolver { &mut self.satisfaction_solver } + + pub fn get_solution_reference(&self) -> SolutionReference { + self.satisfaction_solver.get_solution_reference() + } } /// Methods to retrieve information about variables @@ -334,7 +312,9 @@ impl Solver { CSPSolverExecutionFlag::Feasible => { let solution: Solution = self.satisfaction_solver.get_solution_reference().into(); self.satisfaction_solver.restore_state_at_root(brancher); - self.process_solution(&solution, brancher); + + brancher.on_solution(solution.as_reference()); + SatisfactionResult::Satisfiable(solution) } CSPSolverExecutionFlag::Infeasible => { @@ -371,7 +351,7 @@ impl Solver { /// which can be used to obtain the found solution or find other solutions. /// /// This method takes as input a list of [`Predicate`]s which represent so-called assumptions - /// (see \[1\] for a more detailed explanation). See the [`predicate`] documentation for how + /// (see \[1\] for a more detailed explanation). See the [`predicates`] documentation for how /// to construct these predicates. /// /// # Bibliography @@ -419,215 +399,20 @@ impl Solver { } /// Solves the model currently in the [`Solver`] to optimality where the provided - /// `objective_variable` is minimised (or is indicated to terminate by the provided - /// [`TerminationCondition`]). + /// `objective_variable` is optimised as indicated by the `direction` (or is indicated to + /// terminate by the provided [`TerminationCondition`]). Uses a search strategy based on the + /// provided [`OptimisationProcedure`], currently [`LinearSatUnsat`] and + /// [`LinearUnsatSat`] are supported. /// /// It returns an [`OptimisationResult`] which can be used to retrieve the optimal solution if /// it exists. - pub fn minimise( + pub fn optimise( &mut self, brancher: &mut impl Brancher, termination: &mut impl TerminationCondition, - objective_variable: impl IntegerVariable, + mut optimisation_procedure: impl OptimisationProcedure, ) -> OptimisationResult { - self.minimise_internal(brancher, termination, objective_variable, false) - } - - /// Solves the model currently in the [`Solver`] to optimality where the provided - /// `objective_variable` is maximised (or is indicated to terminate by the provided - /// [`TerminationCondition`]). - /// - /// It returns an [`OptimisationResult`] which can be used to retrieve the optimal solution if - /// it exists. - pub fn maximise( - &mut self, - brancher: &mut impl Brancher, - termination: &mut impl TerminationCondition, - objective_variable: impl IntegerVariable, - ) -> OptimisationResult { - self.minimise_internal(brancher, termination, objective_variable.scaled(-1), true) - } - - /// The internal method which optimizes the objective function, this function takes an extra - /// argument (`is_maximising`) as compared to [`Solver::maximise`] and [`Solver::minimise`] - /// which determines whether the logged objective value should be scaled by `-1` or not. - /// - /// This is necessary due to the fact that [`Solver::maximise`] simply calls minimise with - /// the objective variable scaled with `-1` which would lead to incorrect statistic if not - /// scaled back. - fn minimise_internal( - &mut self, - brancher: &mut impl Brancher, - termination: &mut impl TerminationCondition, - objective_variable: impl IntegerVariable, - is_maximising: bool, - ) -> OptimisationResult { - // If we are maximising then when we simply scale the variable by -1, however, this will - // lead to the printed objective value in the statistics to be multiplied by -1; this - // objective_multiplier ensures that the objective is correctly logged. - let objective_multiplier = if is_maximising { -1 } else { 1 }; - - let initial_solve = self.satisfaction_solver.solve(termination, brancher); - match initial_solve { - CSPSolverExecutionFlag::Feasible => {} - CSPSolverExecutionFlag::Infeasible => { - // Reset the state whenever we return a result - self.satisfaction_solver.restore_state_at_root(brancher); - let _ = self.satisfaction_solver.conclude_proof_unsat(); - return OptimisationResult::Unsatisfiable; - } - CSPSolverExecutionFlag::Timeout => { - // Reset the state whenever we return a result - self.satisfaction_solver.restore_state_at_root(brancher); - return OptimisationResult::Unknown; - } - } - let mut best_objective_value = Default::default(); - let mut best_solution = Solution::default(); - - self.update_best_solution_and_process( - objective_multiplier, - &objective_variable, - &mut best_objective_value, - &mut best_solution, - brancher, - ); - - loop { - self.satisfaction_solver.restore_state_at_root(brancher); - - let objective_bound_predicate = if is_maximising { - predicate![objective_variable >= best_objective_value as i32 * objective_multiplier] - } else { - predicate![objective_variable <= best_objective_value as i32 * objective_multiplier] - }; - - if self - .strengthen( - &objective_variable, - best_objective_value * objective_multiplier as i64, - ) - .is_err() - { - // Reset the state whenever we return a result - self.satisfaction_solver.restore_state_at_root(brancher); - let _ = self - .satisfaction_solver - .conclude_proof_optimal(objective_bound_predicate); - return OptimisationResult::Optimal(best_solution); - } - - let solve_result = self.satisfaction_solver.solve(termination, brancher); - match solve_result { - CSPSolverExecutionFlag::Feasible => { - self.debug_bound_change( - &objective_variable, - best_objective_value * objective_multiplier as i64, - ); - self.update_best_solution_and_process( - objective_multiplier, - &objective_variable, - &mut best_objective_value, - &mut best_solution, - brancher, - ); - } - CSPSolverExecutionFlag::Infeasible => { - { - // Reset the state whenever we return a result - self.satisfaction_solver.restore_state_at_root(brancher); - let _ = self - .satisfaction_solver - .conclude_proof_optimal(objective_bound_predicate); - return OptimisationResult::Optimal(best_solution); - } - } - CSPSolverExecutionFlag::Timeout => { - // Reset the state whenever we return a result - self.satisfaction_solver.restore_state_at_root(brancher); - return OptimisationResult::Satisfiable(best_solution); - } - } - } - } - - /// Processes a solution when it is found, it consists of the following procedure: - /// - Assigning `best_objective_value` the value assigned to `objective_variable` (multiplied by - /// `objective_multiplier`). - /// - Storing the new best solution in `best_solution`. - /// - Calling [`Brancher::on_solution`] on the provided `brancher`. - /// - Logging the statistics using [`Solver::log_statistics_with_objective`]. - /// - Calling the solution callback stored in [`Solver::solution_callback`]. - fn update_best_solution_and_process( - &self, - objective_multiplier: i32, - objective_variable: &impl IntegerVariable, - best_objective_value: &mut i64, - best_solution: &mut Solution, - brancher: &mut impl Brancher, - ) { - *best_objective_value = (objective_multiplier - * self - .satisfaction_solver - .get_assigned_integer_value(objective_variable) - .expect("expected variable to be assigned")) as i64; - *best_solution = self.satisfaction_solver.get_solution_reference().into(); - - self.internal_process_solution(best_solution, brancher, Some(*best_objective_value)) - } - - pub(crate) fn process_solution(&self, solution: &Solution, brancher: &mut impl Brancher) { - self.internal_process_solution(solution, brancher, None) - } - - fn internal_process_solution( - &self, - solution: &Solution, - brancher: &mut impl Brancher, - objective_value: Option, - ) { - brancher.on_solution(solution.as_reference()); - - (self.solution_callback)(SolutionCallbackArguments::new( - self, - solution, - objective_value, - )); - } - - /// Given the current objective value `best_objective_value`, it adds a constraint specifying - /// that the objective value should be at most `best_objective_value - 1`. Note that it is - /// assumed that we are always minimising the variable. - fn strengthen( - &mut self, - objective_variable: &impl IntegerVariable, - best_objective_value: i64, - ) -> Result<(), ConstraintOperationError> { - self.satisfaction_solver.add_clause([predicate!( - objective_variable <= (best_objective_value - 1) as i32 - )]) - } - - fn debug_bound_change( - &self, - objective_variable: &impl IntegerVariable, - best_objective_value: i64, - ) { - pumpkin_assert_simple!( - (self - .satisfaction_solver - .get_assigned_integer_value(objective_variable) - .expect("expected variable to be assigned") as i64) - < best_objective_value, - "{}", - format!( - "The current bound {} should be smaller than the previous bound {}", - self.satisfaction_solver - .get_assigned_integer_value(objective_variable) - .expect("expected variable to be assigned"), - best_objective_value - ) - ); + optimisation_procedure.optimise(brancher, termination, self) } } @@ -730,9 +515,9 @@ impl Solver { /// A brancher which makes use of VSIDS \[1\] and solution-based phase saving (both adapted for CP). /// /// If VSIDS does not contain any (unfixed) predicates then it will default to the -/// [`IndependentVariableValueBrancher`] using [`Smallest`] for variable selection -/// (over the variables in the order in which they were defined) and [`InDomainMin`] for value -/// selection. +/// [`IndependentVariableValueBrancher`] using [`RandomSelector`] for variable selection +/// (over the variables in the order in which they were defined) and [`RandomSplitter`] for +/// value selection. /// /// # Bibliography /// \[1\] M. W. Moskewicz, C. F. Madigan, Y. Zhao, L. Zhang, and S. Malik, ‘Chaff: Engineering an @@ -741,10 +526,5 @@ impl Solver { /// \[2\] E. Demirović, G. Chu, and P. J. Stuckey, ‘Solution-based phase saving for CP: A /// value-selection heuristic to simulate local search behavior in complete solvers’, in the /// proceedings of the Principles and Practice of Constraint Programming (CP 2018). -pub type DefaultBrancher = AutonomousSearch< - IndependentVariableValueBrancher< - DomainId, - Smallest>, - InDomainMin, - >, ->; +pub type DefaultBrancher = + AutonomousSearch>; diff --git a/pumpkin-solver/src/basic_types/predicate_id_generator.rs b/pumpkin-solver/src/basic_types/predicate_id_generator.rs index 5363a022..29a77cad 100644 --- a/pumpkin-solver/src/basic_types/predicate_id_generator.rs +++ b/pumpkin-solver/src/basic_types/predicate_id_generator.rs @@ -73,13 +73,6 @@ impl PredicateIdGenerator { self.predicate_to_id.clear(); } - /// Returns an iterator over all active predicate ids. - /// Note that constructing the iterator is not constant time, - /// since the function internally sortes the inactive predicate ids. - pub(crate) fn iter(&self) -> PredicateIdIterator { - PredicateIdIterator::new(self.next_id, self.deleted_ids.clone()) - } - pub(crate) fn replace_predicate(&mut self, predicate: Predicate, replacement: Predicate) { pumpkin_assert_moderate!(self.has_id_for_predicate(predicate)); let predicate_id = self.get_id(predicate); @@ -87,6 +80,7 @@ impl PredicateIdGenerator { } } +#[cfg(test)] #[derive(Debug)] pub(crate) struct PredicateIdIterator { sorted_deleted_ids: Vec, @@ -94,6 +88,7 @@ pub(crate) struct PredicateIdIterator { next_deleted: u32, } +#[cfg(test)] impl PredicateIdIterator { fn new(end_id: u32, mut deleted_ids: Vec) -> PredicateIdIterator { deleted_ids.sort(); @@ -136,6 +131,7 @@ impl PredicateIdIterator { } } +#[cfg(test)] impl Iterator for PredicateIdIterator { type Item = PredicateId; diff --git a/pumpkin-solver/src/basic_types/propositional_conjunction.rs b/pumpkin-solver/src/basic_types/propositional_conjunction.rs index e06080d8..97cfee47 100644 --- a/pumpkin-solver/src/basic_types/propositional_conjunction.rs +++ b/pumpkin-solver/src/basic_types/propositional_conjunction.rs @@ -77,6 +77,12 @@ impl PropositionalConjunction { } } +impl Extend for PropositionalConjunction { + fn extend>(&mut self, iter: T) { + self.predicates_in_conjunction.extend(iter); + } +} + impl IntoIterator for PropositionalConjunction { type Item = Predicate; diff --git a/pumpkin-solver/src/basic_types/random.rs b/pumpkin-solver/src/basic_types/random.rs index 25613b6b..c265b620 100644 --- a/pumpkin-solver/src/basic_types/random.rs +++ b/pumpkin-solver/src/basic_types/random.rs @@ -59,6 +59,10 @@ pub trait Random: Debug { /// ``` fn generate_usize_in_range(&mut self, range: Range) -> usize; + /// Generates a random i32 in the provided range with equal probability; this can be seen as + /// sampling from a uniform distribution in the range `[range.start, range.end)` + fn generate_i32_in_range(&mut self, range: Range) -> i32; + /// Generate a random float in the range 0..1. fn generate_f64(&mut self) -> f64; @@ -87,6 +91,10 @@ where self.gen_range(range) } + fn generate_i32_in_range(&mut self, range: Range) -> i32 { + self.gen_range(range) + } + fn generate_f64(&mut self) -> f64 { self.gen_range(0.0..1.0) } @@ -127,6 +135,7 @@ pub(crate) mod tests { #[derive(Debug)] pub(crate) struct TestRandom { pub(crate) usizes: Vec, + pub(crate) integers: Vec, pub(crate) bools: Vec, pub(crate) weighted_choice: fn(&[f64]) -> Option, } @@ -134,9 +143,10 @@ pub(crate) mod tests { impl Default for TestRandom { fn default() -> Self { TestRandom { + weighted_choice: |_| unimplemented!(), usizes: vec![], + integers: vec![], bools: vec![], - weighted_choice: |_| unimplemented!(), } } } @@ -157,6 +167,15 @@ pub(crate) mod tests { selected } + fn generate_i32_in_range(&mut self, range: Range) -> i32 { + let selected = self.integers.remove(0); + pumpkin_assert_simple!( + range.contains(&selected), + "The selected element by `TestRandom` ({selected}) is not in the provided range ({range:?}) and thus should not be returned, please ensure that your test cases are correctly defined" + ); + selected + } + fn generate_usize_in_range(&mut self, range: Range) -> usize { let selected = self.usizes.remove(0); pumpkin_assert_simple!( diff --git a/pumpkin-solver/src/basic_types/solution.rs b/pumpkin-solver/src/basic_types/solution.rs index e1cd6a49..99e1fb35 100644 --- a/pumpkin-solver/src/basic_types/solution.rs +++ b/pumpkin-solver/src/basic_types/solution.rs @@ -1,4 +1,4 @@ -use crate::engine::propagation::propagation_context::HasAssignments; +use crate::engine::propagation::contexts::HasAssignments; use crate::engine::variables::DomainGeneratorIterator; use crate::engine::variables::DomainId; use crate::engine::variables::Literal; diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/instance.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/instance.rs index 14c78d6a..c785b35c 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/instance.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/instance.rs @@ -15,16 +15,6 @@ pub(crate) enum FlatzincObjective { Minimize(DomainId), } -impl FlatzincObjective { - /// Returns the [DomainId] of the objective function - pub(crate) fn get_domain(&self) -> &DomainId { - match self { - FlatzincObjective::Maximize(domain) => domain, - FlatzincObjective::Minimize(domain) => domain, - } - } -} - #[derive(Default)] pub(crate) struct FlatZincInstance { pub(super) outputs: Vec, diff --git a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/mod.rs b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/mod.rs index 64305a6f..3a43b085 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/mod.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/flatzinc/mod.rs @@ -12,17 +12,24 @@ use std::time::Duration; use pumpkin_solver::branching::branchers::alternating_brancher::AlternatingBrancher; use pumpkin_solver::branching::branchers::alternating_brancher::AlternatingStrategy; use pumpkin_solver::branching::branchers::dynamic_brancher::DynamicBrancher; +use pumpkin_solver::branching::Brancher; #[cfg(doc)] use pumpkin_solver::constraints::cumulative; +use pumpkin_solver::optimisation::linear_sat_unsat::LinearSatUnsat; +use pumpkin_solver::optimisation::linear_unsat_sat::LinearUnsatSat; +use pumpkin_solver::optimisation::OptimisationDirection; +use pumpkin_solver::optimisation::OptimisationStrategy; use pumpkin_solver::options::CumulativeOptions; use pumpkin_solver::results::solution_iterator::IteratedSolution; use pumpkin_solver::results::OptimisationResult; use pumpkin_solver::results::ProblemSolution; use pumpkin_solver::results::SatisfactionResult; -use pumpkin_solver::results::Solution; +use pumpkin_solver::results::SolutionReference; use pumpkin_solver::termination::Combinator; use pumpkin_solver::termination::OsSignal; +use pumpkin_solver::termination::TerminationCondition; use pumpkin_solver::termination::TimeBudget; +use pumpkin_solver::variables::DomainId; use pumpkin_solver::Solver; use self::instance::FlatZincInstance; @@ -44,6 +51,26 @@ pub(crate) struct FlatZincOptions { /// Options used for the cumulative constraint (see [`cumulative`]). pub(crate) cumulative_options: CumulativeOptions, + + /// Determines which type of search is performed by the solver + pub(crate) optimisation_strategy: OptimisationStrategy, +} + +fn solution_callback( + instance_objective_function: Option, + options_all_solutions: bool, + outputs: &[Output], + solver: &Solver, + solution: SolutionReference, +) { + if options_all_solutions || instance_objective_function.is_none() { + if let Some(objective) = instance_objective_function { + solver.log_statistics_with_objective(solution.get_integer_value(objective) as i64); + } else { + solver.log_statistics() + } + print_solution_from_solver(solution, outputs); + } } pub(crate) fn solve( @@ -73,90 +100,121 @@ pub(crate) fn solve( instance.search.expect("Expected a search to be defined") }; - solver.with_solution_callback(move |solution_callback_arguments| { - if options.all_solutions || instance.objective_function.is_none() { - solution_callback_arguments.log_statistics(); - print_solution_from_solver(solution_callback_arguments.solution, &outputs); + let (direction, objective) = match instance.objective_function { + Some(FlatzincObjective::Maximize(domain_id)) => { + (OptimisationDirection::Maximise, domain_id) } - }); + Some(FlatzincObjective::Minimize(domain_id)) => { + (OptimisationDirection::Minimise, domain_id) + } + None => { + satisfy(options, &mut solver, brancher, termination, outputs); + return Ok(()); + } + }; - let value = if let Some(objective_function) = &instance.objective_function { - let result = match objective_function { - FlatzincObjective::Maximize(domain_id) => { - solver.maximise(&mut brancher, &mut termination, *domain_id) - } - FlatzincObjective::Minimize(domain_id) => { - solver.minimise(&mut brancher, &mut termination, *domain_id) + let callback = |solver: &Solver, solution: SolutionReference<'_>| { + solution_callback( + Some(objective), + options.all_solutions, + &outputs, + solver, + solution, + ); + }; + + let result = match options.optimisation_strategy { + OptimisationStrategy::LinearSatUnsat => solver.optimise( + &mut brancher, + &mut termination, + LinearSatUnsat::new(direction, objective, callback), + ), + OptimisationStrategy::LinearUnsatSat => solver.optimise( + &mut brancher, + &mut termination, + LinearUnsatSat::new(direction, objective, callback), + ), + }; + + match result { + OptimisationResult::Optimal(optimal_solution) => { + if !options.all_solutions { + solver.log_statistics(); + print_solution_from_solver(optimal_solution.as_reference(), &instance.outputs) } - }; + println!("=========="); + } + OptimisationResult::Satisfiable(_) => { + // Solutions are printed in the callback. + solver.log_statistics(); + } + OptimisationResult::Unsatisfiable => { + solver.log_statistics(); + println!("{MSG_UNSATISFIABLE}"); + } + OptimisationResult::Unknown => { + solver.log_statistics(); + println!("{MSG_UNKNOWN}"); + } + }; + + Ok(()) +} - match result { - OptimisationResult::Optimal(optimal_solution) => { - let optimal_objective_value = - optimal_solution.get_integer_value(*objective_function.get_domain()); - if !options.all_solutions { +fn satisfy( + options: FlatZincOptions, + solver: &mut Solver, + mut brancher: impl Brancher, + mut termination: impl TerminationCondition, + outputs: Vec, +) { + if options.all_solutions { + let mut solution_iterator = solver.get_solution_iterator(&mut brancher, &mut termination); + loop { + match solution_iterator.next_solution() { + IteratedSolution::Solution(solution, solver) => { + solution_callback( + None, + options.all_solutions, + &outputs, + solver, + solution.as_reference(), + ); + } + IteratedSolution::Finished => { + println!("=========="); + break; + } + IteratedSolution::Unknown => { solver.log_statistics(); - print_solution_from_solver(&optimal_solution, &instance.outputs) + break; + } + IteratedSolution::Unsatisfiable => { + solver.log_statistics(); + println!("{MSG_UNSATISFIABLE}"); + break; } - println!("=========="); - Some(optimal_objective_value) - } - OptimisationResult::Satisfiable(solution) => { - let best_found_objective_value = - solution.get_integer_value(*objective_function.get_domain()); - Some(best_found_objective_value) - } - OptimisationResult::Unsatisfiable => { - println!("{MSG_UNSATISFIABLE}"); - None - } - OptimisationResult::Unknown => { - println!("{MSG_UNKNOWN}"); - None } } } else { - if options.all_solutions { - let mut solution_iterator = - solver.get_solution_iterator(&mut brancher, &mut termination); - loop { - match solution_iterator.next_solution() { - IteratedSolution::Solution(_) => {} - IteratedSolution::Finished => { - println!("=========="); - break; - } - IteratedSolution::Unknown => { - break; - } - IteratedSolution::Unsatisfiable => { - println!("{MSG_UNSATISFIABLE}"); - break; - } - } + match solver.satisfy(&mut brancher, &mut termination) { + SatisfactionResult::Satisfiable(solution) => solution_callback( + None, + options.all_solutions, + &outputs, + &*solver, + solution.as_reference(), + ), + SatisfactionResult::Unsatisfiable => { + solver.log_statistics(); + println!("{MSG_UNSATISFIABLE}"); } - } else { - match solver.satisfy(&mut brancher, &mut termination) { - SatisfactionResult::Satisfiable(_) => {} - SatisfactionResult::Unsatisfiable => { - println!("{MSG_UNSATISFIABLE}"); - } - SatisfactionResult::Unknown => { - println!("{MSG_UNKNOWN}"); - } + SatisfactionResult::Unknown => { + solver.log_statistics(); + println!("{MSG_UNKNOWN}"); } } - - None - }; - - if let Some(value) = value { - solver.log_statistics_with_objective(value as i64) - } else { - solver.log_statistics() } - - Ok(()) } fn parse_and_compile( @@ -169,7 +227,7 @@ fn parse_and_compile( } /// Prints the current solution. -fn print_solution_from_solver(solution: &Solution, outputs: &[Output]) { +fn print_solution_from_solver(solution: SolutionReference, outputs: &[Output]) { for output_specification in outputs { match output_specification { Output::Bool(output) => { diff --git a/pumpkin-solver/src/bin/pumpkin-solver/main.rs b/pumpkin-solver/src/bin/pumpkin-solver/main.rs index 6872742d..1142f582 100644 --- a/pumpkin-solver/src/bin/pumpkin-solver/main.rs +++ b/pumpkin-solver/src/bin/pumpkin-solver/main.rs @@ -26,6 +26,7 @@ use maxsat::PseudoBooleanEncoding; use parsers::dimacs::parse_cnf; use parsers::dimacs::SolverArgs; use parsers::dimacs::SolverDimacsSink; +use pumpkin_solver::optimisation::OptimisationStrategy; use pumpkin_solver::options::*; use pumpkin_solver::proof::Format; use pumpkin_solver::proof::ProofLog; @@ -345,6 +346,10 @@ struct Args { /// Possible values: bool #[arg(long = "cumulative-incremental-backtracking")] cumulative_incremental_backtracking: bool, + + /// Determine what type of optimisation strategy is used by the solver + #[arg(long = "optimisation-strategy", default_value_t)] + optimisation_strategy: OptimisationStrategy, } fn configure_logging( @@ -548,6 +553,7 @@ fn run() -> PumpkinResult<()> { args.cumulative_propagation_method, args.cumulative_incremental_backtracking, ), + optimisation_strategy: args.optimisation_strategy, }, )?, } diff --git a/pumpkin-solver/src/branching/brancher.rs b/pumpkin-solver/src/branching/brancher.rs index ecdeb006..b79240e1 100644 --- a/pumpkin-solver/src/branching/brancher.rs +++ b/pumpkin-solver/src/branching/brancher.rs @@ -1,9 +1,13 @@ +use enum_map::Enum; + #[cfg(doc)] use crate::basic_types::Random; use crate::basic_types::SolutionReference; #[cfg(doc)] use crate::branching; #[cfg(doc)] +use crate::branching::branchers::dynamic_brancher::DynamicBrancher; +#[cfg(doc)] use crate::branching::value_selection::ValueSelector; #[cfg(doc)] use crate::branching::variable_selection::VariableSelector; @@ -39,15 +43,24 @@ pub trait Brancher { /// A function which is called after a conflict has been found and processed but (currently) /// does not provide any additional information. + /// + /// To receive information about this event, use [`BrancherEvent::Conflict`] in + /// [`Self::subscribe_to_events`] fn on_conflict(&mut self) {} /// A function which is called whenever a backtrack occurs in the [`Solver`]. + /// + /// To receive information about this event, use [`BrancherEvent::Backtrack`] in + /// [`Self::subscribe_to_events`] fn on_backtrack(&mut self) {} /// This method is called when a solution is found; this will either be called when a new /// incumbent solution is found (i.e. a solution with a better objective value than previously /// known) or when a new solution is found when iterating over solutions using /// [`SolutionIterator`]. + /// + /// To receive information about this event, use [`BrancherEvent::Solution`] in + /// [`Self::subscribe_to_events`] fn on_solution(&mut self, _solution: SolutionReference) {} /// A function which is called after a [`DomainId`] is unassigned during backtracking (i.e. when @@ -55,17 +68,28 @@ pub trait Brancher { /// [`DomainId`] which has been reset and `value` which is the value to which the variable was /// previously fixed. This method could thus be called multiple times in a single /// backtracking operation by the solver. + /// + /// To receive information about this event, use [`BrancherEvent::UnassignInteger`] in + /// [`Self::subscribe_to_events`] fn on_unassign_integer(&mut self, _variable: DomainId, _value: i32) {} /// A function which is called when a [`Predicate`] appears in a conflict during conflict /// analysis. + /// + /// To receive information about this event, use + /// [`BrancherEvent::AppearanceInConflictPredicate`] in [`Self::subscribe_to_events`] fn on_appearance_in_conflict_predicate(&mut self, _predicate: Predicate) {} /// This method is called whenever a restart is performed. + /// To receive information about this event, use [`BrancherEvent::Restart`] in + /// [`Self::subscribe_to_events`] fn on_restart(&mut self) {} /// Called after backtracking. /// Used to reset internal data structures to account for the backtrack. + /// + /// To receive information about this event, use [`BrancherEvent::Synchronise`] in + /// [`Self::subscribe_to_events`] fn synchronise(&mut self, _assignments: &Assignments) {} /// This method returns whether a restart is *currently* pointless for the [`Brancher`]. @@ -81,4 +105,31 @@ pub trait Brancher { fn is_restart_pointless(&mut self) -> bool { true } + + /// Indicates which [`BrancherEvent`] are relevant for this particular [`Brancher`]. + /// + /// This can be used by [`Brancher::subscribe_to_events`] to determine upon which + /// events which [`VariableSelector`] should be called. + fn subscribe_to_events(&self) -> Vec; +} + +/// The events which can occur for a [`Brancher`]. Used for returning which events are relevant in +/// [`Brancher::subscribe_to_events`], [`VariableSelector::subscribe_to_events`], +/// and [`ValueSelector::subscribe_to_events`]. +#[derive(Debug, Clone, Copy, Enum, Hash, PartialEq, Eq)] +pub enum BrancherEvent { + /// Event for when a conflict is detected + Conflict, + /// Event for when a backtrack is performed + Backtrack, + /// Event for when a solution has been found + Solution, + /// Event for when an integer variable has become unassigned + UnassignInteger, + /// Event for when a predicate appears during conflict analysis + AppearanceInConflictPredicate, + /// Event for when a restart occurs + Restart, + /// Event which is called with the new state after a backtrack has occurred + Synchronise, } diff --git a/pumpkin-solver/src/branching/branchers/alternating_brancher.rs b/pumpkin-solver/src/branching/branchers/alternating_brancher.rs index e42a45d8..b00bddec 100644 --- a/pumpkin-solver/src/branching/branchers/alternating_brancher.rs +++ b/pumpkin-solver/src/branching/branchers/alternating_brancher.rs @@ -2,6 +2,7 @@ //! on the strategy specified in [`AlternatingStrategy`]. use crate::basic_types::SolutionReference; +use crate::branching::brancher::BrancherEvent; use crate::branching::Brancher; use crate::branching::SelectionContext; use crate::engine::predicates::predicate::Predicate; @@ -208,6 +209,16 @@ impl Brancher for AlternatingBrancher { self.other_brancher.synchronise(assignments); } } + + fn subscribe_to_events(&self) -> Vec { + // We require the restart event and on solution event for the alternating brancher itself; + // additionally, it will be interested in the events of its sub-branchers + [BrancherEvent::Restart, BrancherEvent::Solution] + .into_iter() + .chain(self.default_brancher.subscribe_to_events()) + .chain(self.other_brancher.subscribe_to_events()) + .collect() + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/branchers/autonomous_search.rs b/pumpkin-solver/src/branching/branchers/autonomous_search.rs index 0522129a..2002b24d 100644 --- a/pumpkin-solver/src/branching/branchers/autonomous_search.rs +++ b/pumpkin-solver/src/branching/branchers/autonomous_search.rs @@ -2,15 +2,17 @@ use super::independent_variable_value_brancher::IndependentVariableValueBrancher use crate::basic_types::PredicateId; use crate::basic_types::PredicateIdGenerator; use crate::basic_types::SolutionReference; -use crate::branching::value_selection::InDomainMin; -use crate::branching::variable_selection::Smallest; +use crate::branching::value_selection::RandomSplitter; +use crate::branching::variable_selection::RandomSelector; use crate::branching::Brancher; +use crate::branching::BrancherEvent; use crate::branching::SelectionContext; use crate::containers::KeyValueHeap; use crate::containers::StorageKey; use crate::engine::predicates::predicate::Predicate; use crate::engine::Assignments; use crate::results::Solution; +use crate::variables::DomainId; use crate::DefaultBrancher; /// A [`Brancher`] that combines [VSIDS \[1\]](https://dl.acm.org/doi/pdf/10.1145/378239.379017) /// and [Solution-based phase saving \[2\]](https://people.eng.unimelb.edu.au/pstuckey/papers/lns-restarts.pdf). @@ -94,9 +96,8 @@ impl DefaultBrancher { /// `0.95` for the decay factor and `0.0` for the initial VSIDS value). /// /// If there are no more predicates left to select, this [`Brancher`] switches to - /// [`Smallest`] with [`InDomainMin`]. + /// [`RandomSelector`] with [`RandomSplitter`]. pub fn default_over_all_variables(assignments: &Assignments) -> DefaultBrancher { - let variables = assignments.get_domains().collect::>(); AutonomousSearch { predicate_id_info: PredicateIdGenerator::default(), heap: KeyValueHeap::default(), @@ -106,8 +107,8 @@ impl DefaultBrancher { decay_factor: DEFAULT_VSIDS_DECAY_FACTOR, best_known_solution: None, backup_brancher: IndependentVariableValueBrancher::new( - Smallest::new(&variables), - InDomainMin, + RandomSelector::new(assignments.get_domains()), + RandomSplitter, ), } } @@ -132,10 +133,6 @@ impl AutonomousSearch { } } - fn minimum_activity_threshold(&self) -> f64 { - 1_f64 / self.increment - } - /// Resizes the heap to accommodate for the id. /// Recall that the underlying heap uses direct hashing. fn resize_heap(&mut self, id: PredicateId) { @@ -149,6 +146,7 @@ impl AutonomousSearch { fn bump_activity(&mut self, predicate: Predicate) { let id = self.predicate_id_info.get_id(predicate); self.resize_heap(id); + self.heap.restore_key(id); // Scale the activities if the values are too large. // Also remove predicates that have activities close to zero. @@ -157,21 +155,6 @@ impl AutonomousSearch { // Adjust heap values. self.heap.divide_values(self.max_threshold); - // Remove inactive predicates from the heap, - // and stage the ids for removal from the id generator. - self.predicate_id_info.iter().for_each(|predicate_id| { - // If the predicate does not reach the minimum activity threshold then we remove it - // from the heap and we remove its id from the generator - // - // Note that we check whether the current predicate being removed is not the - // predicate is being bumped, this is to prevent multiple IDs from being assigned. - if *self.heap.get_value(predicate_id) <= self.minimum_activity_threshold() - && predicate_id != id - { - self.heap.delete_key(predicate_id); - self.predicate_id_info.delete_id(predicate_id); - } - }); // Adjust increment. It is important to adjust the increment after the above code. self.increment /= self.max_threshold; } @@ -279,31 +262,56 @@ impl Brancher for AutonomousSearch { true } }); + self.backup_brancher.synchronise(assignments); } fn on_conflict(&mut self) { self.decay_activities(); + self.backup_brancher.on_conflict(); } fn on_solution(&mut self, solution: SolutionReference) { // We store the best known solution self.best_known_solution = Some(solution.into()); + self.backup_brancher.on_solution(solution); } fn on_appearance_in_conflict_predicate(&mut self, predicate: Predicate) { - self.bump_activity(predicate) + self.bump_activity(predicate); + self.backup_brancher + .on_appearance_in_conflict_predicate(predicate); + } + + fn on_restart(&mut self) { + self.backup_brancher.on_restart(); + } + + fn on_unassign_integer(&mut self, variable: DomainId, value: i32) { + self.backup_brancher.on_unassign_integer(variable, value) } fn is_restart_pointless(&mut self) -> bool { false } + + fn subscribe_to_events(&self) -> Vec { + [ + BrancherEvent::Solution, + BrancherEvent::Conflict, + BrancherEvent::Backtrack, + BrancherEvent::Synchronise, + BrancherEvent::AppearanceInConflictPredicate, + ] + .into_iter() + .chain(self.backup_brancher.subscribe_to_events()) + .collect() + } } #[cfg(test)] mod tests { use super::AutonomousSearch; use crate::basic_types::tests::TestRandom; - use crate::branching::branchers::autonomous_search::DEFAULT_VSIDS_MAX_THRESHOLD; use crate::branching::Brancher; use crate::branching::SelectionContext; use crate::engine::Assignments; @@ -324,28 +332,6 @@ mod tests { (0..100).for_each(|_| brancher.on_conflict()); } - #[test] - fn value_removed_if_threshold_too_small() { - let mut assignments = Assignments::default(); - let x = assignments.grow(0, 10); - let y = assignments.grow(-10, 0); - - let mut brancher = AutonomousSearch::default_over_all_variables(&assignments); - brancher.on_appearance_in_conflict_predicate(predicate!(x >= 5)); - brancher.on_appearance_in_conflict_predicate(predicate!(y >= -5)); - - brancher.increment = DEFAULT_VSIDS_MAX_THRESHOLD; - - brancher.on_appearance_in_conflict_predicate(predicate!(y >= -5)); - - assert!(!brancher - .predicate_id_info - .has_id_for_predicate(predicate!(x >= 5))); - assert!(brancher - .predicate_id_info - .has_id_for_predicate(predicate!(y >= -5))); - } - #[test] fn dormant_values() { let mut assignments = Assignments::default(); @@ -413,13 +399,14 @@ mod tests { let result = brancher.next_decision(&mut SelectionContext::new( &assignments, &mut TestRandom { - usizes: vec![], + integers: vec![2], + usizes: vec![0], + bools: vec![false], weighted_choice: |_| unreachable!(), - ..Default::default() }, )); - assert_eq!(result, Some(predicate!(x <= 0))); + assert_eq!(result, Some(predicate!(x <= 2))); } #[test] diff --git a/pumpkin-solver/src/branching/branchers/dynamic_brancher.rs b/pumpkin-solver/src/branching/branchers/dynamic_brancher.rs index 22d18e7a..a8fa3d3c 100644 --- a/pumpkin-solver/src/branching/branchers/dynamic_brancher.rs +++ b/pumpkin-solver/src/branching/branchers/dynamic_brancher.rs @@ -6,7 +6,11 @@ use std::cmp::min; use std::fmt::Debug; +use enum_map::EnumMap; + +use crate::basic_types::HashSet; use crate::basic_types::SolutionReference; +use crate::branching::brancher::BrancherEvent; use crate::branching::Brancher; use crate::branching::SelectionContext; use crate::engine::predicates::predicate::Predicate; @@ -28,6 +32,9 @@ use crate::engine::Assignments; pub struct DynamicBrancher { branchers: Vec>, brancher_index: usize, + + relevant_event_to_index: EnumMap>, + relevant_events: Vec, } impl Debug for DynamicBrancher { @@ -40,14 +47,36 @@ impl DynamicBrancher { /// Creates a new [`DynamicBrancher`] with the provided `branchers`. It will attempt to use the /// `branchers` in the order in which they were provided. pub fn new(branchers: Vec>) -> Self { + let mut relevant_event_to_index: EnumMap> = EnumMap::default(); + let mut relevant_events = HashSet::new(); + + // The dynamic brancher will reset the indices upon these events so they should be called + let _ = relevant_events.insert(BrancherEvent::Solution); + let _ = relevant_events.insert(BrancherEvent::Conflict); + + branchers.iter().enumerate().for_each(|(index, brancher)| { + for event in brancher.subscribe_to_events() { + relevant_event_to_index[event].push(index); + let _ = relevant_events.insert(event); + } + }); Self { branchers, brancher_index: 0, + + relevant_event_to_index, + relevant_events: relevant_events.into_iter().collect(), } } pub fn add_brancher(&mut self, brancher: Box) { - self.branchers.push(brancher) + for event in brancher.subscribe_to_events() { + self.relevant_event_to_index[event].push(self.branchers.len()); + if !self.relevant_events.contains(&event) { + self.relevant_events.push(event); + } + } + self.branchers.push(brancher); } } @@ -70,46 +99,50 @@ impl Brancher for DynamicBrancher { // A conflict has occurred, we do not know which brancher now can select a variable, reset // to the first one self.brancher_index = 0; - self.branchers - .iter_mut() - .for_each(|brancher| brancher.on_conflict()); + self.relevant_event_to_index[BrancherEvent::Conflict] + .iter() + .for_each(|&brancher_index| self.branchers[brancher_index].on_conflict()); } fn on_backtrack(&mut self) { - self.branchers - .iter_mut() - .for_each(|brancher| brancher.on_backtrack()); + self.relevant_event_to_index[BrancherEvent::Backtrack] + .iter() + .for_each(|&brancher_index| self.branchers[brancher_index].on_backtrack()); } fn on_unassign_integer(&mut self, variable: DomainId, value: i32) { - self.branchers - .iter_mut() - .for_each(|brancher| brancher.on_unassign_integer(variable, value)); + self.relevant_event_to_index[BrancherEvent::UnassignInteger] + .iter() + .for_each(|&brancher_index| { + self.branchers[brancher_index].on_unassign_integer(variable, value) + }); } fn on_appearance_in_conflict_predicate(&mut self, predicate: Predicate) { - self.branchers - .iter_mut() - .for_each(|brancher| brancher.on_appearance_in_conflict_predicate(predicate)); + self.relevant_event_to_index[BrancherEvent::AppearanceInConflictPredicate] + .iter() + .for_each(|&brancher_index| { + self.branchers[brancher_index].on_appearance_in_conflict_predicate(predicate) + }); } fn on_solution(&mut self, solution: SolutionReference) { self.brancher_index = 0; - self.branchers - .iter_mut() - .for_each(|brancher| brancher.on_solution(solution)); + self.relevant_event_to_index[BrancherEvent::Solution] + .iter() + .for_each(|&brancher_index| self.branchers[brancher_index].on_solution(solution)); } fn on_restart(&mut self) { - self.branchers - .iter_mut() - .for_each(|brancher| brancher.on_restart()); + self.relevant_event_to_index[BrancherEvent::Restart] + .iter() + .for_each(|&brancher_index| self.branchers[brancher_index].on_restart()); } fn synchronise(&mut self, assignments: &Assignments) { - self.branchers - .iter_mut() - .for_each(|brancher| brancher.synchronise(assignments)); + self.relevant_event_to_index[BrancherEvent::Synchronise] + .iter() + .for_each(|&brancher_index| self.branchers[brancher_index].synchronise(assignments)); } fn is_restart_pointless(&mut self) -> bool { @@ -120,4 +153,8 @@ impl Brancher for DynamicBrancher { .iter_mut() .all(|brancher| brancher.is_restart_pointless()) } + + fn subscribe_to_events(&self) -> Vec { + self.relevant_events.clone() + } } diff --git a/pumpkin-solver/src/branching/branchers/independent_variable_value_brancher.rs b/pumpkin-solver/src/branching/branchers/independent_variable_value_brancher.rs index 03062f2a..784523a2 100644 --- a/pumpkin-solver/src/branching/branchers/independent_variable_value_brancher.rs +++ b/pumpkin-solver/src/branching/branchers/independent_variable_value_brancher.rs @@ -4,6 +4,7 @@ use std::marker::PhantomData; use crate::basic_types::SolutionReference; +use crate::branching::brancher::BrancherEvent; use crate::branching::value_selection::ValueSelector; use crate::branching::variable_selection::VariableSelector; use crate::branching::Brancher; @@ -89,4 +90,12 @@ where fn is_restart_pointless(&mut self) -> bool { self.variable_selector.is_restart_pointless() && self.value_selector.is_restart_pointless() } + + fn subscribe_to_events(&self) -> Vec { + self.variable_selector + .subscribe_to_events() + .into_iter() + .chain(self.value_selector.subscribe_to_events()) + .collect() + } } diff --git a/pumpkin-solver/src/branching/mod.rs b/pumpkin-solver/src/branching/mod.rs index be69b65f..6bd55b5e 100644 --- a/pumpkin-solver/src/branching/mod.rs +++ b/pumpkin-solver/src/branching/mod.rs @@ -11,8 +11,7 @@ //! hooks into the solver); the main method of this trait is the [`ValueSelector::select_value`] //! method. //! -//! A [`Brancher`] is expected to be passed to [`Solver::satisfy`], [`Solver::maximise`], and -//! [`Solver::minimise`]: +//! A [`Brancher`] is expected to be passed to [`Solver::satisfy`], and [`Solver::optimise`]: //! ```rust //! # use pumpkin_solver::Solver; //! # use pumpkin_solver::variables::Literal; @@ -72,7 +71,7 @@ pub mod tie_breaking; pub mod value_selection; pub mod variable_selection; -pub use brancher::Brancher; +pub use brancher::*; pub use selection_context::SelectionContext; #[cfg(doc)] diff --git a/pumpkin-solver/src/branching/value_selection/dynamic_value_selector.rs b/pumpkin-solver/src/branching/value_selection/dynamic_value_selector.rs index 5b31cd0b..d4b51834 100644 --- a/pumpkin-solver/src/branching/value_selection/dynamic_value_selector.rs +++ b/pumpkin-solver/src/branching/value_selection/dynamic_value_selector.rs @@ -2,6 +2,7 @@ use std::fmt::Debug; use super::ValueSelector; use crate::basic_types::SolutionReference; +use crate::branching::brancher::BrancherEvent; #[cfg(doc)] use crate::branching::branchers::dynamic_brancher::DynamicBrancher; use crate::branching::SelectionContext; @@ -46,4 +47,8 @@ impl ValueSelector for DynamicValueSelector { fn is_restart_pointless(&mut self) -> bool { self.selector.is_restart_pointless() } + + fn subscribe_to_events(&self) -> Vec { + self.selector.subscribe_to_events() + } } diff --git a/pumpkin-solver/src/branching/value_selection/in_domain_interval.rs b/pumpkin-solver/src/branching/value_selection/in_domain_interval.rs index a7315855..d3ec87e9 100644 --- a/pumpkin-solver/src/branching/value_selection/in_domain_interval.rs +++ b/pumpkin-solver/src/branching/value_selection/in_domain_interval.rs @@ -1,4 +1,5 @@ use super::InDomainSplit; +use crate::branching::brancher::BrancherEvent; use crate::branching::value_selection::ValueSelector; use crate::branching::SelectionContext; use crate::engine::predicates::predicate::Predicate; @@ -38,6 +39,10 @@ impl ValueSelector for InDomainInterval { InDomainSplit::get_predicate_excluding_upper_half(context, decision_variable) } } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/value_selection/in_domain_max.rs b/pumpkin-solver/src/branching/value_selection/in_domain_max.rs index bca36bd1..1addc9bb 100644 --- a/pumpkin-solver/src/branching/value_selection/in_domain_max.rs +++ b/pumpkin-solver/src/branching/value_selection/in_domain_max.rs @@ -1,3 +1,4 @@ +use crate::branching::brancher::BrancherEvent; use crate::branching::value_selection::ValueSelector; use crate::branching::SelectionContext; use crate::engine::predicates::predicate::Predicate; @@ -16,6 +17,10 @@ impl ValueSelector for InDomainMax { ) -> Predicate { predicate!(decision_variable >= context.upper_bound(decision_variable)) } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/value_selection/in_domain_median.rs b/pumpkin-solver/src/branching/value_selection/in_domain_median.rs index 3cd61a36..7bcb61ae 100644 --- a/pumpkin-solver/src/branching/value_selection/in_domain_median.rs +++ b/pumpkin-solver/src/branching/value_selection/in_domain_median.rs @@ -1,3 +1,4 @@ +use crate::branching::brancher::BrancherEvent; use crate::branching::value_selection::ValueSelector; use crate::branching::SelectionContext; use crate::engine::predicates::predicate::Predicate; @@ -21,6 +22,10 @@ impl ValueSelector for InDomainMedian { .collect::>(); predicate!(decision_variable == values_in_domain[values_in_domain.len() / 2]) } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/value_selection/in_domain_middle.rs b/pumpkin-solver/src/branching/value_selection/in_domain_middle.rs index 937e73f7..ff6acd48 100644 --- a/pumpkin-solver/src/branching/value_selection/in_domain_middle.rs +++ b/pumpkin-solver/src/branching/value_selection/in_domain_middle.rs @@ -1,3 +1,4 @@ +use crate::branching::brancher::BrancherEvent; #[cfg(doc)] use crate::branching::value_selection::InDomainMedian; use crate::branching::value_selection::ValueSelector; @@ -45,6 +46,10 @@ impl ValueSelector for InDomainMiddle { } unreachable!("There should be at least 1 selectable variable in the domain"); } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/value_selection/in_domain_min.rs b/pumpkin-solver/src/branching/value_selection/in_domain_min.rs index 3a320e18..77e0b82f 100644 --- a/pumpkin-solver/src/branching/value_selection/in_domain_min.rs +++ b/pumpkin-solver/src/branching/value_selection/in_domain_min.rs @@ -1,4 +1,5 @@ use super::ValueSelector; +use crate::branching::brancher::BrancherEvent; use crate::branching::SelectionContext; use crate::engine::predicates::predicate::Predicate; use crate::engine::variables::IntegerVariable; @@ -16,6 +17,10 @@ impl ValueSelector for InDomainMin { ) -> Predicate { predicate!(decision_variable <= context.lower_bound(decision_variable)) } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/value_selection/in_domain_random.rs b/pumpkin-solver/src/branching/value_selection/in_domain_random.rs index 1172f65b..d0beeec1 100644 --- a/pumpkin-solver/src/branching/value_selection/in_domain_random.rs +++ b/pumpkin-solver/src/branching/value_selection/in_domain_random.rs @@ -1,3 +1,4 @@ +use crate::branching::brancher::BrancherEvent; use crate::branching::value_selection::ValueSelector; use crate::branching::SelectionContext; use crate::engine::predicates::predicate::Predicate; @@ -28,6 +29,10 @@ impl ValueSelector for InDomainRandom { fn is_restart_pointless(&mut self) -> bool { false } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } impl ValueSelector for InDomainRandom { @@ -46,6 +51,10 @@ impl ValueSelector for InDomainRandom { fn is_restart_pointless(&mut self) -> bool { false } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/value_selection/in_domain_split.rs b/pumpkin-solver/src/branching/value_selection/in_domain_split.rs index 4752798f..b25c82e2 100644 --- a/pumpkin-solver/src/branching/value_selection/in_domain_split.rs +++ b/pumpkin-solver/src/branching/value_selection/in_domain_split.rs @@ -1,3 +1,4 @@ +use crate::branching::brancher::BrancherEvent; use crate::branching::value_selection::ValueSelector; use crate::branching::SelectionContext; use crate::engine::predicates::predicate::Predicate; @@ -21,6 +22,10 @@ impl ValueSelector for InDomainSplit { ) -> Predicate { InDomainSplit::get_predicate_excluding_upper_half(context, decision_variable) } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } impl InDomainSplit { diff --git a/pumpkin-solver/src/branching/value_selection/in_domain_split_random.rs b/pumpkin-solver/src/branching/value_selection/in_domain_split_random.rs index e179bcf4..597dede1 100644 --- a/pumpkin-solver/src/branching/value_selection/in_domain_split_random.rs +++ b/pumpkin-solver/src/branching/value_selection/in_domain_split_random.rs @@ -1,3 +1,4 @@ +use crate::branching::brancher::BrancherEvent; use crate::branching::value_selection::ValueSelector; use crate::branching::SelectionContext; use crate::engine::predicates::predicate::Predicate; @@ -29,6 +30,10 @@ impl ValueSelector for InDomainSplitRandom { fn is_restart_pointless(&mut self) -> bool { false } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/value_selection/mod.rs b/pumpkin-solver/src/branching/value_selection/mod.rs index 7ebd2014..8c9c0c6d 100644 --- a/pumpkin-solver/src/branching/value_selection/mod.rs +++ b/pumpkin-solver/src/branching/value_selection/mod.rs @@ -18,6 +18,7 @@ mod out_domain_max; mod out_domain_median; mod out_domain_min; mod out_domain_random; +mod random_splitter; mod reverse_in_domain_split; mod value_selector; @@ -34,5 +35,6 @@ pub use out_domain_max::*; pub use out_domain_median::*; pub use out_domain_min::*; pub use out_domain_random::*; +pub use random_splitter::*; pub use reverse_in_domain_split::*; pub use value_selector::ValueSelector; diff --git a/pumpkin-solver/src/branching/value_selection/out_domain_max.rs b/pumpkin-solver/src/branching/value_selection/out_domain_max.rs index 0f645448..fd7522b0 100644 --- a/pumpkin-solver/src/branching/value_selection/out_domain_max.rs +++ b/pumpkin-solver/src/branching/value_selection/out_domain_max.rs @@ -1,3 +1,4 @@ +use crate::branching::brancher::BrancherEvent; use crate::branching::value_selection::ValueSelector; use crate::branching::SelectionContext; use crate::engine::predicates::predicate::Predicate; @@ -16,6 +17,10 @@ impl ValueSelector for OutDomainMax { ) -> Predicate { predicate!(decision_variable <= context.upper_bound(decision_variable) - 1) } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/value_selection/out_domain_median.rs b/pumpkin-solver/src/branching/value_selection/out_domain_median.rs index 4af50c50..eccb51c0 100644 --- a/pumpkin-solver/src/branching/value_selection/out_domain_median.rs +++ b/pumpkin-solver/src/branching/value_selection/out_domain_median.rs @@ -1,3 +1,4 @@ +use crate::branching::brancher::BrancherEvent; use crate::branching::value_selection::ValueSelector; use crate::branching::SelectionContext; use crate::engine::predicates::predicate::Predicate; @@ -20,6 +21,10 @@ impl ValueSelector for OutDomainMedian { .collect::>(); predicate!(decision_variable != values_in_domain[values_in_domain.len() / 2]) } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/value_selection/out_domain_min.rs b/pumpkin-solver/src/branching/value_selection/out_domain_min.rs index 348f2754..c944781a 100644 --- a/pumpkin-solver/src/branching/value_selection/out_domain_min.rs +++ b/pumpkin-solver/src/branching/value_selection/out_domain_min.rs @@ -1,3 +1,4 @@ +use crate::branching::brancher::BrancherEvent; use crate::branching::value_selection::ValueSelector; use crate::branching::SelectionContext; use crate::engine::predicates::predicate::Predicate; @@ -16,6 +17,10 @@ impl ValueSelector for OutDomainMin { ) -> Predicate { predicate!(decision_variable >= context.lower_bound(decision_variable) + 1) } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/value_selection/out_domain_random.rs b/pumpkin-solver/src/branching/value_selection/out_domain_random.rs index 1f746822..a4e97d4f 100644 --- a/pumpkin-solver/src/branching/value_selection/out_domain_random.rs +++ b/pumpkin-solver/src/branching/value_selection/out_domain_random.rs @@ -1,3 +1,4 @@ +use crate::branching::brancher::BrancherEvent; use crate::branching::value_selection::ValueSelector; use crate::branching::SelectionContext; use crate::engine::predicates::predicate::Predicate; @@ -27,6 +28,10 @@ impl ValueSelector for OutDomainRandom { fn is_restart_pointless(&mut self) -> bool { false } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/value_selection/random_splitter.rs b/pumpkin-solver/src/branching/value_selection/random_splitter.rs new file mode 100644 index 00000000..6c95655c --- /dev/null +++ b/pumpkin-solver/src/branching/value_selection/random_splitter.rs @@ -0,0 +1,80 @@ +use crate::branching::value_selection::ValueSelector; +use crate::branching::BrancherEvent; +use crate::branching::SelectionContext; +use crate::engine::predicates::predicate::Predicate; +use crate::engine::variables::DomainId; +use crate::predicate; + +/// A [`ValueSelector`] which splits the domain in a random manner (between the lower-bound and +/// lower-bound, disregarding holes), randomly selecting whether to exclude the lower-half or the +/// upper-half. +#[derive(Debug, Clone, Copy)] +pub struct RandomSplitter; + +impl ValueSelector for RandomSplitter { + fn select_value( + &mut self, + context: &mut SelectionContext, + decision_variable: DomainId, + ) -> Predicate { + // Randomly generate a value within the lower-bound and upper-bound + let range = + context.lower_bound(decision_variable)..context.upper_bound(decision_variable) + 1; + let bound = context.random().generate_i32_in_range(range); + + // We need to handle two special cases: + // + // 1. If the bound is equal to the lower-bound then we need to assign it to this bound since + // [x >= lb] is currently true + // 2. If the bound is equal to the upper-bound then we need to assign it to this bound since + // [x <= ub] is currentl true + if bound == context.lower_bound(decision_variable) { + return predicate!(decision_variable <= bound); + } else if bound == context.upper_bound(decision_variable) { + return predicate!(decision_variable >= bound); + } + + // Then randomly determine how to split the domain + if context.random().generate_bool(0.5) { + predicate!(decision_variable >= bound) + } else { + predicate!(decision_variable <= bound) + } + } + + fn is_restart_pointless(&mut self) -> bool { + false + } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } +} + +#[cfg(test)] +mod tests { + + use crate::basic_types::tests::TestRandom; + use crate::branching::value_selection::RandomSplitter; + use crate::branching::value_selection::ValueSelector; + use crate::branching::SelectionContext; + use crate::predicate; + + #[test] + fn test_returns_correct_literal() { + let assignments = SelectionContext::create_for_testing(vec![(0, 10)]); + let mut test_random = TestRandom { + integers: vec![2], + bools: vec![true], + ..Default::default() + }; + let mut context = SelectionContext::new(&assignments, &mut test_random); + let domain_ids = context.get_domains().collect::>(); + + let mut selector = RandomSplitter; + + let selected_predicate = selector.select_value(&mut context, domain_ids[0]); + + assert_eq!(selected_predicate, predicate!(domain_ids[0] >= 2)) + } +} diff --git a/pumpkin-solver/src/branching/value_selection/reverse_in_domain_split.rs b/pumpkin-solver/src/branching/value_selection/reverse_in_domain_split.rs index b357233a..6f6cdfcb 100644 --- a/pumpkin-solver/src/branching/value_selection/reverse_in_domain_split.rs +++ b/pumpkin-solver/src/branching/value_selection/reverse_in_domain_split.rs @@ -1,3 +1,4 @@ +use crate::branching::brancher::BrancherEvent; use crate::branching::value_selection::ValueSelector; use crate::branching::SelectionContext; use crate::engine::predicates::predicate::Predicate; @@ -32,6 +33,10 @@ impl ValueSelector for ReverseInDomainSplit { ); predicate!(decision_variable >= bound) } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/value_selection/value_selector.rs b/pumpkin-solver/src/branching/value_selection/value_selector.rs index c423dfc1..6f05650c 100644 --- a/pumpkin-solver/src/branching/value_selection/value_selector.rs +++ b/pumpkin-solver/src/branching/value_selection/value_selector.rs @@ -1,8 +1,13 @@ use crate::basic_types::SolutionReference; +use crate::branching::brancher::BrancherEvent; +#[cfg(doc)] +use crate::branching::branchers::dynamic_brancher::DynamicBrancher; #[cfg(doc)] use crate::branching::value_selection::InDomainMin; #[cfg(doc)] use crate::branching::value_selection::InDomainRandom; +#[cfg(doc)] +use crate::branching::Brancher; use crate::branching::SelectionContext; use crate::engine::predicates::predicate::Predicate; use crate::engine::variables::DomainId; @@ -25,11 +30,17 @@ pub trait ValueSelector { /// [`DomainId`] which has been reset and `value` which is the value to which the variable was /// previously fixed. This method could thus be called multiple times in a single /// backtracking operation by the solver. + /// + /// To receive information about this event, use [`BrancherEvent::UnassignInteger`] in + /// [`Self::subscribe_to_events`] fn on_unassign_integer(&mut self, _variable: DomainId, _value: i32) {} /// This method is called when a solution is found; either when iterating over all solutions in /// the case of a satisfiable problem or on solutions of increasing quality when solving an /// optimisation problem. + /// + /// To receive information about this event, use [`BrancherEvent::Solution`] in + /// [`Self::subscribe_to_events`] fn on_solution(&mut self, _solution: SolutionReference) {} /// This method returns whether a restart is *currently* pointless for the [`ValueSelector`]. @@ -43,4 +54,10 @@ pub trait ValueSelector { fn is_restart_pointless(&mut self) -> bool { true } + + /// Indicates which [`BrancherEvent`] are relevant for this particular [`ValueSelector`]. + /// + /// This can be used by [`Brancher::subscribe_to_events`] to determine upon which + /// events which [`ValueSelector`] should be called. + fn subscribe_to_events(&self) -> Vec; } diff --git a/pumpkin-solver/src/branching/variable_selection/anti_first_fail.rs b/pumpkin-solver/src/branching/variable_selection/anti_first_fail.rs index 98dc0457..b38577c9 100644 --- a/pumpkin-solver/src/branching/variable_selection/anti_first_fail.rs +++ b/pumpkin-solver/src/branching/variable_selection/anti_first_fail.rs @@ -1,5 +1,6 @@ use log::warn; +use crate::branching::brancher::BrancherEvent; use crate::branching::tie_breaking::Direction; use crate::branching::tie_breaking::InOrderTieBreaker; use crate::branching::tie_breaking::TieBreaker; @@ -72,6 +73,10 @@ impl> VariableSelector }); self.tie_breaker.select() } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/variable_selection/dynamic_variable_selector.rs b/pumpkin-solver/src/branching/variable_selection/dynamic_variable_selector.rs index aa98b48f..c4f83dda 100644 --- a/pumpkin-solver/src/branching/variable_selection/dynamic_variable_selector.rs +++ b/pumpkin-solver/src/branching/variable_selection/dynamic_variable_selector.rs @@ -1,6 +1,7 @@ use std::fmt::Debug; use super::VariableSelector; +use crate::branching::brancher::BrancherEvent; #[cfg(doc)] use crate::branching::branchers::dynamic_brancher::DynamicBrancher; use crate::branching::SelectionContext; @@ -45,4 +46,8 @@ impl VariableSelector for DynamicVariableSelector { fn is_restart_pointless(&mut self) -> bool { self.selector.is_restart_pointless() } + + fn subscribe_to_events(&self) -> Vec { + self.selector.subscribe_to_events() + } } diff --git a/pumpkin-solver/src/branching/variable_selection/first_fail.rs b/pumpkin-solver/src/branching/variable_selection/first_fail.rs index 577e05fe..00c61707 100644 --- a/pumpkin-solver/src/branching/variable_selection/first_fail.rs +++ b/pumpkin-solver/src/branching/variable_selection/first_fail.rs @@ -1,5 +1,6 @@ use log::warn; +use crate::branching::brancher::BrancherEvent; use crate::branching::tie_breaking::Direction; use crate::branching::tie_breaking::InOrderTieBreaker; use crate::branching::tie_breaking::TieBreaker; @@ -73,6 +74,10 @@ where }); self.tie_breaker.select() } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/variable_selection/input_order.rs b/pumpkin-solver/src/branching/variable_selection/input_order.rs index 49f8840c..8a60596f 100644 --- a/pumpkin-solver/src/branching/variable_selection/input_order.rs +++ b/pumpkin-solver/src/branching/variable_selection/input_order.rs @@ -1,5 +1,6 @@ use log::warn; +use crate::branching::brancher::BrancherEvent; use crate::branching::variable_selection::VariableSelector; use crate::branching::SelectionContext; use crate::engine::variables::DomainId; @@ -30,6 +31,10 @@ impl VariableSelector for InputOrder { .find(|variable| !context.is_integer_fixed(**variable)) .copied() } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } impl VariableSelector for InputOrder { @@ -39,6 +44,10 @@ impl VariableSelector for InputOrder { .find(|&variable| !context.is_predicate_assigned(variable.get_true_predicate())) .copied() } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/variable_selection/largest.rs b/pumpkin-solver/src/branching/variable_selection/largest.rs index 79f0eb7c..fbe6feca 100644 --- a/pumpkin-solver/src/branching/variable_selection/largest.rs +++ b/pumpkin-solver/src/branching/variable_selection/largest.rs @@ -1,5 +1,6 @@ use log::warn; +use crate::branching::brancher::BrancherEvent; use crate::branching::tie_breaking::Direction; use crate::branching::tie_breaking::InOrderTieBreaker; use crate::branching::tie_breaking::TieBreaker; @@ -76,6 +77,10 @@ where }); self.tie_breaker.select() } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/variable_selection/max_regret.rs b/pumpkin-solver/src/branching/variable_selection/max_regret.rs index 2568a331..02bdffda 100644 --- a/pumpkin-solver/src/branching/variable_selection/max_regret.rs +++ b/pumpkin-solver/src/branching/variable_selection/max_regret.rs @@ -1,5 +1,6 @@ use log::warn; +use crate::branching::brancher::BrancherEvent; use crate::branching::tie_breaking::Direction; use crate::branching::tie_breaking::InOrderTieBreaker; use crate::branching::tie_breaking::TieBreaker; @@ -90,6 +91,10 @@ where }); self.tie_breaker.select() } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/variable_selection/mod.rs b/pumpkin-solver/src/branching/variable_selection/mod.rs index 625627c4..e854253e 100644 --- a/pumpkin-solver/src/branching/variable_selection/mod.rs +++ b/pumpkin-solver/src/branching/variable_selection/mod.rs @@ -14,6 +14,7 @@ mod max_regret; mod most_constrained; mod occurrence; mod proportional_domain_size; +mod random; mod smallest; mod variable_selector; @@ -26,5 +27,6 @@ pub use max_regret::*; pub use most_constrained::*; pub use occurrence::*; pub use proportional_domain_size::*; +pub use random::RandomSelector; pub use smallest::*; pub use variable_selector::VariableSelector; diff --git a/pumpkin-solver/src/branching/variable_selection/most_constrained.rs b/pumpkin-solver/src/branching/variable_selection/most_constrained.rs index 78febf7f..678a0542 100644 --- a/pumpkin-solver/src/branching/variable_selection/most_constrained.rs +++ b/pumpkin-solver/src/branching/variable_selection/most_constrained.rs @@ -2,6 +2,7 @@ use std::cmp::Ordering; use log::warn; +use crate::branching::brancher::BrancherEvent; use crate::branching::tie_breaking::Direction; use crate::branching::tie_breaking::InOrderTieBreaker; use crate::branching::tie_breaking::TieBreaker; @@ -89,6 +90,10 @@ where self.tie_breaker.select() } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/variable_selection/occurrence.rs b/pumpkin-solver/src/branching/variable_selection/occurrence.rs index 72d20bf5..963f5693 100644 --- a/pumpkin-solver/src/branching/variable_selection/occurrence.rs +++ b/pumpkin-solver/src/branching/variable_selection/occurrence.rs @@ -1,5 +1,6 @@ use log::warn; +use crate::branching::brancher::BrancherEvent; use crate::branching::tie_breaking::Direction; use crate::branching::tie_breaking::InOrderTieBreaker; use crate::branching::tie_breaking::TieBreaker; @@ -55,6 +56,10 @@ where }); self.tie_breaker.select() } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/variable_selection/proportional_domain_size.rs b/pumpkin-solver/src/branching/variable_selection/proportional_domain_size.rs index f75cfcfe..346ebfae 100644 --- a/pumpkin-solver/src/branching/variable_selection/proportional_domain_size.rs +++ b/pumpkin-solver/src/branching/variable_selection/proportional_domain_size.rs @@ -1,4 +1,5 @@ use super::VariableSelector; +use crate::branching::brancher::BrancherEvent; use crate::branching::SelectionContext; use crate::pumpkin_assert_extreme; use crate::variables::DomainId; @@ -67,4 +68,8 @@ impl VariableSelector for ProportionalDomainSize { self.weights_idx_to_variables.push(idx); } } + + fn subscribe_to_events(&self) -> Vec { + vec![BrancherEvent::Backtrack] + } } diff --git a/pumpkin-solver/src/branching/variable_selection/random.rs b/pumpkin-solver/src/branching/variable_selection/random.rs new file mode 100644 index 00000000..c6bf12a4 --- /dev/null +++ b/pumpkin-solver/src/branching/variable_selection/random.rs @@ -0,0 +1,157 @@ +use super::VariableSelector; +use crate::branching::BrancherEvent; +use crate::branching::SelectionContext; +use crate::containers::SparseSet; +use crate::containers::StorageKey; +use crate::variables::DomainId; + +/// A [`VariableSelector`] which selects a random unfixed variable. +#[derive(Debug)] +pub struct RandomSelector { + variables: SparseSet, +} + +impl RandomSelector { + pub fn new(variables: impl IntoIterator) -> Self { + // Note the -1 due to the fact that the indices of the domain ids start at 1 + Self { + variables: SparseSet::new(variables.into_iter().collect(), |element| { + element.index() - 1 + }), + } + } +} + +impl VariableSelector for RandomSelector { + fn select_variable(&mut self, context: &mut SelectionContext) -> Option { + if self.variables.is_empty() { + return None; + } + + let mut variable = *self.variables.get( + context + .random() + .generate_usize_in_range(0..self.variables.len()), + ); + + while context.is_integer_fixed(variable) { + self.variables.remove_temporarily(&variable); + if self.variables.is_empty() { + return None; + } + + variable = *self.variables.get( + context + .random() + .generate_usize_in_range(0..self.variables.len()), + ); + } + + Some(variable) + } + + fn on_unassign_integer(&mut self, variable: DomainId, _value: i32) { + self.variables.insert(variable); + } + + fn is_restart_pointless(&mut self) -> bool { + false + } + + fn subscribe_to_events(&self) -> Vec { + vec![BrancherEvent::UnassignInteger] + } +} + +#[cfg(test)] +mod tests { + use crate::basic_types::tests::TestRandom; + use crate::branching::variable_selection::RandomSelector; + use crate::branching::variable_selection::VariableSelector; + use crate::branching::SelectionContext; + + #[test] + fn test_selects_randomly() { + let assignments = SelectionContext::create_for_testing(vec![(0, 10), (5, 20), (1, 3)]); + let mut test_rng = TestRandom { + usizes: vec![1], + ..Default::default() + }; + let integer_variables = assignments.get_domains().collect::>(); + let mut strategy = RandomSelector::new(assignments.get_domains()); + + let mut context = SelectionContext::new(&assignments, &mut test_rng); + + let selected = strategy.select_variable(&mut context); + assert!(selected.is_some()); + assert_eq!(selected.unwrap(), integer_variables[1]); + } + + #[test] + fn test_selects_randomly_not_unfixed() { + let assignments = SelectionContext::create_for_testing(vec![(0, 10), (5, 5), (1, 3)]); + let mut test_rng = TestRandom { + usizes: vec![1, 0], + ..Default::default() + }; + let integer_variables = assignments.get_domains().collect::>(); + let mut strategy = RandomSelector::new(assignments.get_domains()); + + let mut context = SelectionContext::new(&assignments, &mut test_rng); + + let selected = strategy.select_variable(&mut context); + assert!(selected.is_some()); + assert_eq!(selected.unwrap(), integer_variables[0]); + } + + #[test] + fn test_select_nothing_if_all_fixed() { + let assignments = SelectionContext::create_for_testing(vec![(0, 0), (5, 5), (1, 1)]); + let mut test_rng = TestRandom { + usizes: vec![1, 0, 0], + ..Default::default() + }; + let mut strategy = RandomSelector::new(assignments.get_domains()); + + let mut context = SelectionContext::new(&assignments, &mut test_rng); + + let selected = strategy.select_variable(&mut context); + assert!(selected.is_none()); + } + + #[test] + fn test_select_unfixed_variable_after_fixing() { + let mut assignments = SelectionContext::create_for_testing(vec![(0, 0), (5, 7), (1, 1)]); + let mut test_rng = TestRandom { + usizes: vec![2, 0, 0, 0, 0], + ..Default::default() + }; + let integer_variables = assignments.get_domains().collect::>(); + let mut strategy = RandomSelector::new(assignments.get_domains()); + + { + let mut context = SelectionContext::new(&assignments, &mut test_rng); + + let selected = strategy.select_variable(&mut context); + assert!(selected.is_some()); + assert_eq!(selected.unwrap(), integer_variables[1]); + } + + assignments.increase_decision_level(); + let _ = assignments.tighten_lower_bound(integer_variables[1], 7, None); + + { + let mut context = SelectionContext::new(&assignments, &mut test_rng); + + let selected = strategy.select_variable(&mut context); + assert!(selected.is_none()); + } + + let _ = assignments.synchronise(0, 0, false); + strategy.on_unassign_integer(integer_variables[1], 7); + let mut context = SelectionContext::new(&assignments, &mut test_rng); + let selected = strategy.select_variable(&mut context); + assert!(selected.is_some()); + assert_eq!(selected.unwrap(), integer_variables[1]); + } +} diff --git a/pumpkin-solver/src/branching/variable_selection/smallest.rs b/pumpkin-solver/src/branching/variable_selection/smallest.rs index 6807c0c6..5628039e 100644 --- a/pumpkin-solver/src/branching/variable_selection/smallest.rs +++ b/pumpkin-solver/src/branching/variable_selection/smallest.rs @@ -1,6 +1,7 @@ use log::warn; use super::VariableSelector; +use crate::branching::brancher::BrancherEvent; use crate::branching::tie_breaking::Direction; use crate::branching::tie_breaking::InOrderTieBreaker; use crate::branching::tie_breaking::TieBreaker; @@ -72,6 +73,10 @@ where }); self.tie_breaker.select() } + + fn subscribe_to_events(&self) -> Vec { + vec![] + } } #[cfg(test)] diff --git a/pumpkin-solver/src/branching/variable_selection/variable_selector.rs b/pumpkin-solver/src/branching/variable_selection/variable_selector.rs index 0c1ccd70..1b626073 100644 --- a/pumpkin-solver/src/branching/variable_selection/variable_selector.rs +++ b/pumpkin-solver/src/branching/variable_selection/variable_selector.rs @@ -1,5 +1,10 @@ +use crate::branching::brancher::BrancherEvent; +#[cfg(doc)] +use crate::branching::branchers::dynamic_brancher::DynamicBrancher; #[cfg(doc)] use crate::branching::variable_selection::Smallest; +#[cfg(doc)] +use crate::branching::Brancher; use crate::branching::SelectionContext; use crate::engine::predicates::predicate::Predicate; use crate::engine::variables::DomainId; @@ -18,19 +23,31 @@ pub trait VariableSelector { /// A function which is called after a conflict has been found and processed but (currently) /// does not provide any additional information. + /// + /// To receive information about this event, use [`BrancherEvent::Conflict`] in + /// [`Self::subscribe_to_events`] fn on_conflict(&mut self) {} /// A function which is called whenever a backtrack occurs in the solver. + /// + /// To receive information about this event, use [`BrancherEvent::Backtrack`] in + /// [`Self::subscribe_to_events`] fn on_backtrack(&mut self) {} /// A function which is called after a [`DomainId`] is unassigned during backtracking (i.e. when /// it was fixed but is no longer), specifically, it provides `variable` which is the /// [`DomainId`] which has been reset. This method could thus be called multiple times in a /// single backtracking operation by the solver. + /// + /// To receive information about this event, use [`BrancherEvent::UnassignInteger`] in + /// [`Self::subscribe_to_events`] fn on_unassign_integer(&mut self, _variable: DomainId, _value: i32) {} /// A function which is called when a [`Predicate`] appears in a conflict during conflict /// analysis. + /// + /// To receive information about this event, use + /// [`BrancherEvent::AppearanceInConflictPredicate`] in [`Self::subscribe_to_events`] fn on_appearance_in_conflict_predicate(&mut self, _predicate: Predicate) {} /// This method returns whether a restart is *currently* pointless for the [`VariableSelector`]. @@ -44,4 +61,10 @@ pub trait VariableSelector { fn is_restart_pointless(&mut self) -> bool { true } + + /// Indicates which [`BrancherEvent`] are relevant for this particular [`VariableSelector`]. + /// + /// This can be used by [`Brancher::subscribe_to_events`] to determine upon which + /// events which [`VariableSelector`] should be called. + fn subscribe_to_events(&self) -> Vec; } diff --git a/pumpkin-solver/src/containers/sparse_set.rs b/pumpkin-solver/src/containers/sparse_set.rs index eb03da0a..e9d2928e 100644 --- a/pumpkin-solver/src/containers/sparse_set.rs +++ b/pumpkin-solver/src/containers/sparse_set.rs @@ -25,6 +25,9 @@ //! implementation’, in CP workshop on Techniques foR Implementing Constraint programming Systems //! (TRICS), 2013, pp. 1–10. +use crate::pumpkin_assert_moderate; +use crate::pumpkin_assert_simple; + /// A set for keeping track of which values are still part of the original domain based on [\[1\]](https://hal.science/hal-01339250/document). /// See the module level documentation for more information. /// @@ -69,7 +72,7 @@ impl SparseSet { } pub(crate) fn set_to_empty(&mut self) { - self.indices = vec![usize::MAX; self.domain.len()]; + self.indices = vec![usize::MAX; self.indices.len()]; self.domain.clear(); self.size = 0; } @@ -87,6 +90,7 @@ impl SparseSet { /// Returns the `index`th element in the domain; if `index` is larger than or equal to /// [`SparseSet::len`] then this method will panic. pub(crate) fn get(&self, index: usize) -> &T { + pumpkin_assert_simple!(index < self.size); &self.domain[index] } @@ -104,18 +108,24 @@ impl SparseSet { if self.indices[(self.mapping)(to_remove)] < self.size { // The element is part of the domain and should be removed self.size -= 1; + if self.size > 0 { + self.swap(self.indices[(self.mapping)(to_remove)], self.size); + } + self.swap( self.indices[(self.mapping)(to_remove)], self.domain.len() - 1, ); - let _ = self.domain.pop().expect("Has to have something to pop."); + let element = self.domain.pop().expect("Has to have something to pop."); + pumpkin_assert_moderate!((self.mapping)(&element) == (self.mapping)(to_remove)); self.indices[(self.mapping)(to_remove)] = usize::MAX; } else if self.indices[(self.mapping)(to_remove)] < self.domain.len() { self.swap( self.indices[(self.mapping)(to_remove)], self.domain.len() - 1, ); - let _ = self.domain.pop().expect("Has to have something to pop."); + let element = self.domain.pop().expect("Has to have something to pop."); + pumpkin_assert_moderate!((self.mapping)(&element) == (self.mapping)(to_remove)); self.indices[(self.mapping)(to_remove)] = usize::MAX; } } @@ -138,6 +148,7 @@ impl SparseSet { && self.indices[(self.mapping)(element)] < self.size } + /// Accomodates the `element`. pub(crate) fn accommodate(&mut self, element: &T) { let index = (self.mapping)(element); if self.indices.len() <= index { @@ -145,12 +156,14 @@ impl SparseSet { } } + /// Inserts the element if it is not already contained in the sparse set. pub(crate) fn insert(&mut self, element: T) { if !self.contains(&element) { self.accommodate(&element); self.indices[(self.mapping)(&element)] = self.domain.len(); self.domain.push(element); + self.swap(self.size, self.domain.len() - 1); self.size += 1; } } diff --git a/pumpkin-solver/src/engine/conflict_analysis/conflict_analysis_context.rs b/pumpkin-solver/src/engine/conflict_analysis/conflict_analysis_context.rs index 165d35d1..61e75047 100644 --- a/pumpkin-solver/src/engine/conflict_analysis/conflict_analysis_context.rs +++ b/pumpkin-solver/src/engine/conflict_analysis/conflict_analysis_context.rs @@ -18,6 +18,7 @@ use crate::engine::Assignments; use crate::engine::ConstraintSatisfactionSolver; use crate::engine::IntDomainEvent; use crate::engine::PropagatorQueue; +use crate::engine::TrailedAssignments; use crate::engine::WatchListCP; use crate::predicate; use crate::proof::ProofLog; @@ -48,6 +49,7 @@ pub(crate) struct ConflictAnalysisContext<'a> { pub(crate) is_completing_proof: bool, pub(crate) unit_nogood_step_ids: &'a HashMap, + pub(crate) stateful_assignments: &'a mut TrailedAssignments, } impl Debug for ConflictAnalysisContext<'_> { @@ -56,7 +58,7 @@ impl Debug for ConflictAnalysisContext<'_> { } } -impl<'a> ConflictAnalysisContext<'a> { +impl ConflictAnalysisContext<'_> { /// Returns the last decision which was made by the solver. pub(crate) fn find_last_decision(&mut self) -> Option { self.assignments.find_last_decision() @@ -82,6 +84,7 @@ impl<'a> ConflictAnalysisContext<'a> { self.backtrack_event_drain, backtrack_level, self.brancher, + self.stateful_assignments, ) } @@ -127,17 +130,24 @@ impl<'a> ConflictAnalysisContext<'a> { } } - /// Returns the reason for a propagation; if it is implied then the reason will be the decision - /// which implied the predicate. + /// Compute the reason for `predicate` being true. The reason will be stored in + /// `reason_buffer`. + /// + /// If `predicate` is not true, or it is a decision, then this function will panic. + #[allow( + clippy::too_many_arguments, + reason = "borrow checker complains either here or elsewhere" + )] pub(crate) fn get_propagation_reason( predicate: Predicate, assignments: &Assignments, current_nogood: CurrentNogood<'_>, - reason_store: &'a mut ReasonStore, - propagators: &'a mut PropagatorStore, - proof_log: &'a mut ProofLog, + reason_store: &mut ReasonStore, + propagators: &mut PropagatorStore, + proof_log: &mut ProofLog, unit_nogood_step_ids: &HashMap, - ) -> &'a [Predicate] { + reason_buffer: &mut (impl Extend + AsRef<[Predicate]>), + ) { // TODO: this function could be put into the reason store // Note that this function can only be called with propagations, and never decision @@ -153,9 +163,8 @@ impl<'a> ConflictAnalysisContext<'a> { // there would be only one predicate from the current decision level. For this // reason, it is safe to assume that in the following, that any input predicate is // indeed a propagated predicate. - reason_store.helper.clear(); if assignments.is_initial_bound(predicate) { - return reason_store.helper.as_slice(); + return; } let trail_position = assignments @@ -176,11 +185,17 @@ impl<'a> ConflictAnalysisContext<'a> { let explanation_context = ExplanationContext::new(assignments, current_nogood); - let reason = reason_store - .get_or_compute(reason_ref, explanation_context, propagators) - .expect("reason reference should not be stale"); + let reason_exists = reason_store.get_or_compute( + reason_ref, + explanation_context, + propagators, + reason_buffer, + ); + + assert!(reason_exists, "reason reference should not be stale"); + if propagator_id == ConstraintSatisfactionSolver::get_nogood_propagator_id() - && reason.is_empty() + && reason_buffer.as_ref().is_empty() { // This means that a unit nogood was propagated, we indicate that this nogood step // was used @@ -204,12 +219,10 @@ impl<'a> ConflictAnalysisContext<'a> { // Otherwise we log the inference which was used to derive the nogood let _ = proof_log.log_inference( constraint_tag, - reason.iter().copied(), + reason_buffer.as_ref().iter().copied(), Some(predicate), ); } - reason - // The predicate is implicitly due as a result of a decision. } // 2) The predicate is true due to a propagation, and not explicitly on the trail. // It is necessary to further analyse what was the reason for setting the predicate true. @@ -237,7 +250,7 @@ impl<'a> ConflictAnalysisContext<'a> { // todo: could consider lifting here, since the trail bound // might be too strong. if trail_lower_bound > input_lower_bound { - reason_store.helper.push(trail_entry.predicate); + reason_buffer.extend(std::iter::once(trail_entry.predicate)); } // Otherwise, the input bound is strictly greater than the trailed // bound. This means the reason is due to holes in the domain. @@ -267,8 +280,8 @@ impl<'a> ConflictAnalysisContext<'a> { domain_id, not_equal_constant: input_lower_bound - 1, }; - reason_store.helper.push(one_less_bound_predicate); - reason_store.helper.push(not_equals_predicate); + reason_buffer.extend(std::iter::once(one_less_bound_predicate)); + reason_buffer.extend(std::iter::once(not_equals_predicate)); } } ( @@ -288,7 +301,7 @@ impl<'a> ConflictAnalysisContext<'a> { // so it safe to take the reason from the trail. // todo: lifting could be used here pumpkin_assert_simple!(trail_lower_bound > not_equal_constant); - reason_store.helper.push(trail_entry.predicate); + reason_buffer.extend(std::iter::once(trail_entry.predicate)); } ( Predicate::LowerBound { @@ -320,8 +333,8 @@ impl<'a> ConflictAnalysisContext<'a> { domain_id, upper_bound: equality_constant, }; - reason_store.helper.push(predicate_lb); - reason_store.helper.push(predicate_ub); + reason_buffer.extend(std::iter::once(predicate_lb)); + reason_buffer.extend(std::iter::once(predicate_ub)); } ( Predicate::UpperBound { @@ -341,7 +354,7 @@ impl<'a> ConflictAnalysisContext<'a> { // reason for the input predicate. // todo: lifting could be applied here. if trail_upper_bound < input_upper_bound { - reason_store.helper.push(trail_entry.predicate); + reason_buffer.extend(std::iter::once(trail_entry.predicate)); } else { // I think it cannot be that the bounds are equal, since otherwise we // would have found the predicate explicitly on the trail. @@ -362,8 +375,8 @@ impl<'a> ConflictAnalysisContext<'a> { domain_id, not_equal_constant: input_upper_bound + 1, }; - reason_store.helper.push(new_ub_predicate); - reason_store.helper.push(not_equal_predicate); + reason_buffer.extend(std::iter::once(new_ub_predicate)); + reason_buffer.extend(std::iter::once(not_equal_predicate)); } } ( @@ -384,7 +397,7 @@ impl<'a> ConflictAnalysisContext<'a> { // The bound was set past the not equals, so we can safely returns the trail // reason. todo: can do lifting here. - reason_store.helper.push(trail_entry.predicate); + reason_buffer.extend(std::iter::once(trail_entry.predicate)); } ( Predicate::UpperBound { @@ -419,8 +432,8 @@ impl<'a> ConflictAnalysisContext<'a> { domain_id, upper_bound: equality_constant, }; - reason_store.helper.push(predicate_lb); - reason_store.helper.push(predicate_ub); + reason_buffer.extend(std::iter::once(predicate_lb)); + reason_buffer.extend(std::iter::once(predicate_ub)); } ( Predicate::NotEqual { @@ -454,8 +467,8 @@ impl<'a> ConflictAnalysisContext<'a> { not_equal_constant: input_lower_bound - 1, }; - reason_store.helper.push(new_lb_predicate); - reason_store.helper.push(new_not_equals_predicate); + reason_buffer.extend(std::iter::once(new_lb_predicate)); + reason_buffer.extend(std::iter::once(new_not_equals_predicate)); } ( Predicate::NotEqual { @@ -489,8 +502,8 @@ impl<'a> ConflictAnalysisContext<'a> { not_equal_constant: input_upper_bound + 1, }; - reason_store.helper.push(new_ub_predicate); - reason_store.helper.push(new_not_equals_predicate); + reason_buffer.extend(std::iter::once(new_ub_predicate)); + reason_buffer.extend(std::iter::once(new_not_equals_predicate)); } ( Predicate::NotEqual { @@ -519,15 +532,14 @@ impl<'a> ConflictAnalysisContext<'a> { upper_bound: equality_constant, }; - reason_store.helper.push(predicate_lb); - reason_store.helper.push(predicate_ub); + reason_buffer.extend(std::iter::once(predicate_lb)); + reason_buffer.extend(std::iter::once(predicate_ub)); } _ => unreachable!( "Unreachable combination of {} and {}", trail_entry.predicate, predicate ), }; - reason_store.helper.as_slice() } } } diff --git a/pumpkin-solver/src/engine/conflict_analysis/minimisers/recursive_minimiser.rs b/pumpkin-solver/src/engine/conflict_analysis/minimisers/recursive_minimiser.rs index 52b6534b..47b3c174 100644 --- a/pumpkin-solver/src/engine/conflict_analysis/minimisers/recursive_minimiser.rs +++ b/pumpkin-solver/src/engine/conflict_analysis/minimisers/recursive_minimiser.rs @@ -117,7 +117,8 @@ impl RecursiveMinimiser { // Due to ownership rules, we have to take ownership of the reason. // TODO: Reuse the allocation if it becomes a bottleneck. - let reason = ConflictAnalysisContext::get_propagation_reason( + let mut reason = vec![]; + ConflictAnalysisContext::get_propagation_reason( input_predicate, context.assignments, CurrentNogood::from(current_nogood), @@ -125,10 +126,10 @@ impl RecursiveMinimiser { context.propagators, context.proof_log, context.unit_nogood_step_ids, - ) - .to_vec(); + &mut reason, + ); - for antecedent_predicate in reason { + for antecedent_predicate in reason.iter().copied() { // Root assignments can be safely ignored. if context .assignments diff --git a/pumpkin-solver/src/engine/conflict_analysis/resolvers/resolution_resolver.rs b/pumpkin-solver/src/engine/conflict_analysis/resolvers/resolution_resolver.rs index e5f11b59..25f78372 100644 --- a/pumpkin-solver/src/engine/conflict_analysis/resolvers/resolution_resolver.rs +++ b/pumpkin-solver/src/engine/conflict_analysis/resolvers/resolution_resolver.rs @@ -39,6 +39,8 @@ pub(crate) struct ResolutionResolver { recursive_minimiser: RecursiveMinimiser, /// Whether the resolver employs 1-UIP or all-decision learning. mode: AnalysisMode, + /// Re-usable buffer which reasons are written into. + reason_buffer: Vec, } #[derive(Debug, Clone, Copy, Default)] @@ -138,7 +140,8 @@ impl ConflictResolver for ResolutionResolver { // However, this can lead to [x <= v] to be processed *before* [x >= v - // y], meaning that these implied predicates should be replaced with their // reason - let reason = ConflictAnalysisContext::get_propagation_reason( + self.reason_buffer.clear(); + ConflictAnalysisContext::get_propagation_reason( predicate, context.assignments, CurrentNogood::new( @@ -150,23 +153,24 @@ impl ConflictResolver for ResolutionResolver { context.propagators, context.proof_log, context.unit_nogood_step_ids, + &mut self.reason_buffer, ); - if reason.is_empty() { + if self.reason_buffer.is_empty() { // In the case when the proof is being completed, it could be the case // that the reason for a root-level propagation is empty; this // predicate will be filtered out by the semantic minimisation pumpkin_assert_simple!(context.is_completing_proof); predicate } else { - pumpkin_assert_simple!(predicate.is_lower_bound_predicate() || predicate.is_not_equal_predicate(), "A non-decision predicate in the nogood should be either a lower-bound or a not-equals predicate but it was {predicate} with reason {reason:?}"); + pumpkin_assert_simple!(predicate.is_lower_bound_predicate() || predicate.is_not_equal_predicate(), "A non-decision predicate in the nogood should be either a lower-bound or a not-equals predicate but it was {predicate} with reason {:?}", self.reason_buffer); pumpkin_assert_simple!( - reason.len() == 1 && reason[0].is_lower_bound_predicate(), + self.reason_buffer.len() == 1 && self.reason_buffer[0].is_lower_bound_predicate(), "The reason for the only propagated predicates left on the trail should be lower-bound predicates, but the reason for {predicate} was {:?}", - reason + self.reason_buffer, ); - reason[0] + self.reason_buffer[0] } }; @@ -199,7 +203,9 @@ impl ConflictResolver for ResolutionResolver { .is_initial_bound(self.peek_predicate_from_conflict_nogood()) { let predicate = self.peek_predicate_from_conflict_nogood(); - let reason = ConflictAnalysisContext::get_propagation_reason( + + self.reason_buffer.clear(); + ConflictAnalysisContext::get_propagation_reason( predicate, context.assignments, CurrentNogood::new( @@ -211,13 +217,14 @@ impl ConflictResolver for ResolutionResolver { context.propagators, context.proof_log, context.unit_nogood_step_ids, + &mut self.reason_buffer, ); pumpkin_assert_simple!(predicate.is_lower_bound_predicate() || predicate.is_not_equal_predicate() , "If the final predicate in the conflict nogood is not a decision predicate then it should be either a lower-bound predicate or a not-equals predicate but was {predicate}"); pumpkin_assert_simple!( - reason.len() == 1 && reason[0].is_lower_bound_predicate(), - "The reason for the decision predicate should be a lower-bound predicate but was {}", reason[0] + self.reason_buffer.len() == 1 && self.reason_buffer[0].is_lower_bound_predicate(), + "The reason for the decision predicate should be a lower-bound predicate but was {}", self.reason_buffer[0] ); - self.replace_predicate_in_conflict_nogood(predicate, reason[0]); + self.replace_predicate_in_conflict_nogood(predicate, self.reason_buffer[0]); } // The final predicate in the heap will get pushed in `extract_final_nogood` @@ -226,7 +233,8 @@ impl ConflictResolver for ResolutionResolver { } // 2.b) Standard case, get the reason for the predicate and add it to the nogood. - let reason = ConflictAnalysisContext::get_propagation_reason( + self.reason_buffer.clear(); + ConflictAnalysisContext::get_propagation_reason( next_predicate, context.assignments, CurrentNogood::new( @@ -238,11 +246,12 @@ impl ConflictResolver for ResolutionResolver { context.propagators, context.proof_log, context.unit_nogood_step_ids, + &mut self.reason_buffer, ); - for predicate in reason.iter() { + for i in 0..self.reason_buffer.len() { self.add_predicate_to_conflict_nogood( - *predicate, + self.reason_buffer[i], context.assignments, context.brancher, self.mode, @@ -250,6 +259,7 @@ impl ConflictResolver for ResolutionResolver { ); } } + Some(self.extract_final_nogood(context)) } diff --git a/pumpkin-solver/src/engine/constraint_satisfaction_solver.rs b/pumpkin-solver/src/engine/constraint_satisfaction_solver.rs index 1a8b3f0a..4800c41a 100644 --- a/pumpkin-solver/src/engine/constraint_satisfaction_solver.rs +++ b/pumpkin-solver/src/engine/constraint_satisfaction_solver.rs @@ -16,6 +16,7 @@ use super::conflict_analysis::LearnedNogood; use super::conflict_analysis::NoLearningResolver; use super::conflict_analysis::SemanticMinimiser; use super::nogoods::Lbd; +use super::propagation::contexts::StatefulPropagationContext; use super::propagation::store::PropagatorStore; use super::propagation::PropagatorId; use super::solver_statistics::SolverStatistics; @@ -23,6 +24,7 @@ use super::termination::TerminationCondition; use super::variables::IntegerVariable; use super::variables::Literal; use super::ResolutionResolver; +use super::TrailedAssignments; use crate::basic_types::moving_averages::MovingAverage; use crate::basic_types::CSPSolverExecutionFlag; use crate::basic_types::ConstraintOperationError; @@ -33,6 +35,7 @@ use crate::basic_types::Random; use crate::basic_types::SolutionReference; use crate::basic_types::StoredConflictInfo; use crate::branching::Brancher; +use crate::branching::BrancherEvent; use crate::branching::SelectionContext; use crate::engine::conflict_analysis::ConflictResolver as Resolver; use crate::engine::cp::PropagatorQueue; @@ -142,6 +145,8 @@ pub struct ConstraintSatisfactionSolver { unit_nogood_step_ids: HashMap, /// The resolver which is used upon a conflict. conflict_resolver: Box, + + pub(crate) stateful_assignments: TrailedAssignments, } impl Default for ConstraintSatisfactionSolver { @@ -250,6 +255,7 @@ impl ConstraintSatisfactionSolver { propagators: &mut PropagatorStore, propagator_queue: &mut PropagatorQueue, assignments: &mut Assignments, + stateful_assignments: &mut TrailedAssignments, ) { pumpkin_assert_moderate!( propagators[Self::get_nogood_propagator_id()].name() == "NogoodPropagator" @@ -266,6 +272,7 @@ impl ConstraintSatisfactionSolver { propagators, propagator_queue, assignments, + stateful_assignments, ); } @@ -276,8 +283,9 @@ impl ConstraintSatisfactionSolver { propagators: &mut PropagatorStore, propagator_queue: &mut PropagatorQueue, assignments: &mut Assignments, + stateful_assignments: &mut TrailedAssignments, ) { - let context = PropagationContext::new(assignments); + let context = StatefulPropagationContext::new(stateful_assignments, assignments); let enqueue_decision = propagators[propagator_id].notify(context, local_id, event.into()); @@ -306,6 +314,7 @@ impl ConstraintSatisfactionSolver { &mut self.propagators, &mut self.propagator_queue, &mut self.assignments, + &mut self.stateful_assignments, ); // Now notify other propagators subscribed to this event. for propagator_var in self.watch_list_cp.get_affected_propagators(event, domain) { @@ -318,6 +327,7 @@ impl ConstraintSatisfactionSolver { &mut self.propagators, &mut self.propagator_queue, &mut self.assignments, + &mut self.stateful_assignments, ); } } @@ -371,6 +381,7 @@ impl ConstraintSatisfactionSolver { proof_log: &mut self.internal_parameters.proof_log, is_completing_proof: true, unit_nogood_step_ids: &self.unit_nogood_step_ids, + stateful_assignments: &mut self.stateful_assignments, }; let result = self @@ -410,6 +421,7 @@ impl ConstraintSatisfactionSolver { ConflictResolver::UIP => Box::new(ResolutionResolver::default()), }, internal_parameters: solver_options, + stateful_assignments: TrailedAssignments::default(), }; // As a convention, the assignments contain a dummy domain_id=0, which represents a 0-1 @@ -654,6 +666,7 @@ impl ConstraintSatisfactionSolver { proof_log: &mut self.internal_parameters.proof_log, is_completing_proof: false, unit_nogood_step_ids: &self.unit_nogood_step_ids, + stateful_assignments: &mut self.stateful_assignments, }; let mut resolver = ResolutionResolver::with_mode(AnalysisMode::AllDecision); @@ -724,6 +737,7 @@ impl ConstraintSatisfactionSolver { &mut self.backtrack_event_drain, 0, brancher, + &mut self.stateful_assignments, ); self.state.declare_ready(); } @@ -860,6 +874,7 @@ impl ConstraintSatisfactionSolver { pub(crate) fn declare_new_decision_level(&mut self) { self.assignments.increase_decision_level(); + self.stateful_assignments.increase_decision_level(); self.reason_store.increase_decision_level(); } @@ -895,6 +910,7 @@ impl ConstraintSatisfactionSolver { proof_log: &mut self.internal_parameters.proof_log, is_completing_proof: false, unit_nogood_step_ids: &self.unit_nogood_step_ids, + stateful_assignments: &mut self.stateful_assignments, }; let learned_nogood = self @@ -963,6 +979,7 @@ impl ConstraintSatisfactionSolver { fn add_learned_nogood(&mut self, learned_nogood: LearnedNogood) { let mut context = PropagationContextMut::new( + &mut self.stateful_assignments, &mut self.assignments, &mut self.reason_store, &mut self.semantic_minimiser, @@ -1030,6 +1047,7 @@ impl ConstraintSatisfactionSolver { &mut self.backtrack_event_drain, 0, brancher, + &mut self.stateful_assignments, ); self.restart_strategy.notify_restart(); @@ -1050,6 +1068,7 @@ impl ConstraintSatisfactionSolver { backtrack_event_drain: &mut Vec<(IntDomainEvent, DomainId)>, backtrack_level: usize, brancher: &mut BrancherType, + stateful_assignments: &mut TrailedAssignments, ) { pumpkin_assert_simple!(backtrack_level < assignments.get_decision_level()); @@ -1065,6 +1084,9 @@ impl ConstraintSatisfactionSolver { .for_each(|(domain_id, previous_value)| { brancher.on_unassign_integer(*domain_id, *previous_value) }); + + stateful_assignments.synchronise(backtrack_level); + *last_notified_cp_trail_index = assignments.num_trail_entries(); reason_store.synchronise(backtrack_level); @@ -1115,20 +1137,18 @@ impl ConstraintSatisfactionSolver { // Look up the reason for the bound that changed. // The reason for changing the bound cannot be a decision, so we can safely unwrap. - let reason_changing_bound = reason_store - .get_or_compute( - entry.reason.unwrap(), - ExplanationContext::from(&*assignments), - propagators, - ) - .unwrap(); - let mut empty_domain_reason: Vec = vec![ predicate!(conflict_domain >= entry.old_lower_bound), predicate!(conflict_domain <= entry.old_upper_bound), ]; - empty_domain_reason.append(&mut reason_changing_bound.to_vec()); + let _ = reason_store.get_or_compute( + entry.reason.unwrap(), + ExplanationContext::from(&*assignments), + propagators, + &mut empty_domain_reason, + ); + empty_domain_reason.into() } @@ -1146,6 +1166,7 @@ impl ConstraintSatisfactionSolver { let propagation_status = { let propagator = &mut self.propagators[propagator_id]; let context = PropagationContextMut::new( + &mut self.stateful_assignments, &mut self.assignments, &mut self.reason_store, &mut self.semantic_minimiser, @@ -1188,6 +1209,7 @@ impl ConstraintSatisfactionSolver { // A propagator-specific reason for the current conflict. Inconsistency::Conflict(conflict_nogood) => { pumpkin_assert_advanced!(DebugHelper::debug_reported_failure( + &self.stateful_assignments, &self.assignments, &conflict_nogood, &self.propagators[propagator_id], @@ -1207,6 +1229,7 @@ impl ConstraintSatisfactionSolver { DebugHelper::debug_check_propagations( num_trail_entries_before, propagator_id, + &self.stateful_assignments, &self.assignments, &mut self.reason_store, &mut self.propagators @@ -1223,7 +1246,11 @@ impl ConstraintSatisfactionSolver { // since otherwise the state may be inconsistent. pumpkin_assert_extreme!( self.state.is_conflicting() - || DebugHelper::debug_fixed_point_propagation(&self.assignments, &self.propagators,) + || DebugHelper::debug_fixed_point_propagation( + &self.stateful_assignments, + &self.assignments, + &self.propagators, + ) ); } @@ -1240,19 +1267,18 @@ impl ConstraintSatisfactionSolver { ) { for trail_idx in start_trail_index..self.assignments.num_trail_entries() { let entry = self.assignments.get_trail_entry(trail_idx); - let reason = entry + let reason_ref = entry .reason .expect("Added by a propagator and must therefore have a reason"); // Get the conjunction of predicates explaining the propagation. - let reason = self - .reason_store - .get_or_compute( - reason, - ExplanationContext::new(&self.assignments, CurrentNogood::empty()), - &mut self.propagators, - ) - .expect("Reason ref is valid"); + let mut reason = vec![]; + let _ = self.reason_store.get_or_compute( + reason_ref, + ExplanationContext::new(&self.assignments, CurrentNogood::empty()), + &mut self.propagators, + &mut reason, + ); let propagated = entry.predicate; @@ -1349,8 +1375,9 @@ impl ConstraintSatisfactionSolver { let mut initialisation_context = PropagatorInitialisationContext::new( &mut self.watch_list_cp, + &mut self.stateful_assignments, new_propagator_id, - &self.assignments, + &mut self.assignments, ); let initialisation_status = new_propagator.initialise_at_root(&mut initialisation_context); @@ -1398,6 +1425,7 @@ impl ConstraintSatisfactionSolver { pub fn add_nogood(&mut self, nogood: Vec) -> Result<(), ConstraintOperationError> { let mut propagation_context = PropagationContextMut::new( + &mut self.stateful_assignments, &mut self.assignments, &mut self.reason_store, &mut self.semantic_minimiser, @@ -1638,6 +1666,10 @@ impl Brancher for DummyBrancher { fn next_decision(&mut self, _context: &mut SelectionContext) -> Option { todo!() } + + fn subscribe_to_events(&self) -> Vec { + todo!() + } } #[cfg(test)] diff --git a/pumpkin-solver/src/engine/cp/assignments.rs b/pumpkin-solver/src/engine/cp/assignments.rs index 9a597f9a..019556de 100644 --- a/pumpkin-solver/src/engine/cp/assignments.rs +++ b/pumpkin-solver/src/engine/cp/assignments.rs @@ -734,23 +734,17 @@ impl Assignments { } } }); + // Drain does not remove the events from the internal data structure. Elements are removed // lazily, as the iterator gets executed. For this reason we go through the entire iterator. let iter = self.events.drain(); let _ = iter.count(); - // println!("ASSIGN AFTER SYNC PRESENT: {:?}", self.events.present); - // println!("others: {:?}", self.events.events); unfixed_variables } /// todo: This is a temporary hack, not to be used in general. pub(crate) fn remove_last_trail_element(&mut self) { let entry = self.trail.pop().unwrap(); - // println!( - // "\tHacky remova: {} {}", - // entry.predicate, - // entry.reason.is_none() - // ); let domain_id = entry.predicate.get_domain(); self.domains[domain_id].undo_trail_entry(&entry); } diff --git a/pumpkin-solver/src/engine/cp/mod.rs b/pumpkin-solver/src/engine/cp/mod.rs index 41cf84b8..2f3af9a9 100644 --- a/pumpkin-solver/src/engine/cp/mod.rs +++ b/pumpkin-solver/src/engine/cp/mod.rs @@ -6,12 +6,14 @@ pub(crate) mod propagation; mod propagator_queue; pub(crate) mod reason; pub(crate) mod test_solver; +mod trailed; mod watch_list_cp; pub(crate) use assignments::Assignments; pub(crate) use assignments::EmptyDomain; pub(crate) use event_sink::*; pub(crate) use propagator_queue::PropagatorQueue; +pub(crate) use trailed::*; pub(crate) use watch_list_cp::IntDomainEvent; pub(crate) use watch_list_cp::WatchListCP; pub(crate) use watch_list_cp::Watchers; @@ -26,10 +28,12 @@ mod tests { use crate::engine::propagation::PropagationContextMut; use crate::engine::propagation::PropagatorId; use crate::engine::reason::ReasonStore; + use crate::engine::TrailedAssignments; #[test] fn test_no_update_reason_store_if_no_update_lower_bound() { let mut assignments = Assignments::default(); + let mut stateful_assignments = TrailedAssignments::default(); let domain = assignments.grow(5, 10); let mut reason_store = ReasonStore::default(); @@ -37,6 +41,7 @@ mod tests { { let mut semantic_miniser = SemanticMinimiser::default(); let mut context = PropagationContextMut::new( + &mut stateful_assignments, &mut assignments, &mut reason_store, &mut semantic_miniser, @@ -52,6 +57,7 @@ mod tests { #[test] fn test_no_update_reason_store_if_no_update_upper_bound() { let mut assignments = Assignments::default(); + let mut stateful_assignments = TrailedAssignments::default(); let domain = assignments.grow(5, 10); let mut reason_store = ReasonStore::default(); @@ -60,6 +66,7 @@ mod tests { { let mut semantic_miniser = SemanticMinimiser::default(); let mut context = PropagationContextMut::new( + &mut stateful_assignments, &mut assignments, &mut reason_store, &mut semantic_miniser, @@ -75,6 +82,7 @@ mod tests { #[test] fn test_no_update_reason_store_if_no_update_remove() { let mut assignments = Assignments::default(); + let mut stateful_assignments = TrailedAssignments::default(); let domain = assignments.grow(5, 10); let mut reason_store = ReasonStore::default(); @@ -83,6 +91,7 @@ mod tests { { let mut semantic_miniser = SemanticMinimiser::default(); let mut context = PropagationContextMut::new( + &mut stateful_assignments, &mut assignments, &mut reason_store, &mut semantic_miniser, diff --git a/pumpkin-solver/src/engine/cp/propagation/explanation_context.rs b/pumpkin-solver/src/engine/cp/propagation/contexts/explanation_context.rs similarity index 92% rename from pumpkin-solver/src/engine/cp/propagation/explanation_context.rs rename to pumpkin-solver/src/engine/cp/propagation/contexts/explanation_context.rs index 5ddca53d..bdd835b1 100644 --- a/pumpkin-solver/src/engine/cp/propagation/explanation_context.rs +++ b/pumpkin-solver/src/engine/cp/propagation/contexts/explanation_context.rs @@ -1,5 +1,6 @@ use std::sync::LazyLock; +use super::HasAssignments; use crate::basic_types::PredicateId; use crate::basic_types::PredicateIdGenerator; use crate::containers::KeyValueHeap; @@ -31,12 +32,6 @@ impl<'a> ExplanationContext<'a> { } } - /// Get the underlying assignments. - #[deprecated = "using the assignments directly is not ideal, and we should develop this context API further instead"] - pub(crate) fn assignments(&self) -> &'a Assignments { - self.assignments - } - /// Get the current working nogood. /// /// The working nogood does not necessarily contain the predicate that is being explained. @@ -48,6 +43,12 @@ impl<'a> ExplanationContext<'a> { } } +impl HasAssignments for ExplanationContext<'_> { + fn assignments(&self) -> &Assignments { + self.assignments + } +} + static EMPTY_HEAP: KeyValueHeap = KeyValueHeap::new(); static EMPTY_PREDICATE_IDS: LazyLock = diff --git a/pumpkin-solver/src/engine/cp/propagation/contexts/mod.rs b/pumpkin-solver/src/engine/cp/propagation/contexts/mod.rs new file mode 100644 index 00000000..8d7ff0a7 --- /dev/null +++ b/pumpkin-solver/src/engine/cp/propagation/contexts/mod.rs @@ -0,0 +1,5 @@ +pub(crate) mod explanation_context; +pub(crate) mod propagation_context; +pub(crate) mod propagator_initialisation_context; + +pub(crate) use propagation_context::*; diff --git a/pumpkin-solver/src/engine/cp/propagation/propagation_context.rs b/pumpkin-solver/src/engine/cp/propagation/contexts/propagation_context.rs similarity index 70% rename from pumpkin-solver/src/engine/cp/propagation/propagation_context.rs rename to pumpkin-solver/src/engine/cp/propagation/contexts/propagation_context.rs index 845dbeab..702ea1a8 100644 --- a/pumpkin-solver/src/engine/cp/propagation/propagation_context.rs +++ b/pumpkin-solver/src/engine/cp/propagation/contexts/propagation_context.rs @@ -1,14 +1,40 @@ -use super::PropagatorId; use crate::engine::conflict_analysis::SemanticMinimiser; use crate::engine::predicates::predicate::Predicate; +use crate::engine::propagation::PropagatorId; use crate::engine::reason::Reason; use crate::engine::reason::ReasonStore; +use crate::engine::reason::StoredReason; use crate::engine::variables::IntegerVariable; use crate::engine::variables::Literal; use crate::engine::Assignments; use crate::engine::EmptyDomain; +use crate::engine::TrailedAssignments; +use crate::engine::TrailedInt; use crate::pumpkin_assert_simple; +pub(crate) struct StatefulPropagationContext<'a> { + pub(crate) stateful_assignments: &'a mut TrailedAssignments, + pub(crate) assignments: &'a Assignments, +} + +impl<'a> StatefulPropagationContext<'a> { + pub(crate) fn new( + stateful_assignments: &'a mut TrailedAssignments, + assignments: &'a Assignments, + ) -> Self { + Self { + stateful_assignments, + assignments, + } + } + + pub(crate) fn as_readonly(&self) -> PropagationContext<'_> { + PropagationContext { + assignments: self.assignments, + } + } +} + /// [`PropagationContext`] is passed to propagators during propagation. /// It may be queried to retrieve information about the current variable domains such as the /// lower-bound of a particular variable, or used to apply changes to the domain of a variable @@ -30,6 +56,7 @@ impl<'a> PropagationContext<'a> { #[derive(Debug)] pub(crate) struct PropagationContextMut<'a> { + pub(crate) stateful_assignments: &'a mut TrailedAssignments, pub(crate) assignments: &'a mut Assignments, pub(crate) reason_store: &'a mut ReasonStore, pub(crate) propagator_id: PropagatorId, @@ -39,12 +66,14 @@ pub(crate) struct PropagationContextMut<'a> { impl<'a> PropagationContextMut<'a> { pub(crate) fn new( + stateful_assignments: &'a mut TrailedAssignments, assignments: &'a mut Assignments, reason_store: &'a mut ReasonStore, semantic_minimiser: &'a mut SemanticMinimiser, propagator_id: PropagatorId, ) -> Self { PropagationContextMut { + stateful_assignments, assignments, reason_store, propagator_id, @@ -63,17 +92,30 @@ impl<'a> PropagationContextMut<'a> { self.reification_literal = Some(reification_literal); } - fn build_reason(&self, reason: Reason) -> Reason { - if let Some(reification_literal) = self.reification_literal { - match reason { - Reason::Eager(mut conjunction) => { - conjunction.add(reification_literal.get_true_predicate()); - Reason::Eager(conjunction) + fn build_reason(&self, reason: Reason) -> StoredReason { + match reason { + Reason::Eager(mut conjunction) => { + conjunction.extend( + self.reification_literal + .iter() + .map(|lit| lit.get_true_predicate()), + ); + StoredReason::Eager(conjunction) + } + Reason::DynamicLazy(code) => { + if let Some(reification_literal) = self.reification_literal { + StoredReason::ReifiedLazy(reification_literal, code) + } else { + StoredReason::DynamicLazy(code) } - Reason::DynamicLazy(_) => todo!(), } - } else { - reason + } + } + + pub(crate) fn as_stateful_readonly(&mut self) -> StatefulPropagationContext { + StatefulPropagationContext { + stateful_assignments: self.stateful_assignments, + assignments: self.assignments, } } @@ -95,9 +137,34 @@ pub trait HasAssignments { fn assignments(&self) -> &Assignments; } +pub(crate) trait HasStatefulAssignments { + fn stateful_assignments(&self) -> &TrailedAssignments; + fn stateful_assignments_mut(&mut self) -> &mut TrailedAssignments; +} + mod private { use super::*; + impl HasStatefulAssignments for StatefulPropagationContext<'_> { + fn stateful_assignments(&self) -> &TrailedAssignments { + self.stateful_assignments + } + + fn stateful_assignments_mut(&mut self) -> &mut TrailedAssignments { + self.stateful_assignments + } + } + + impl HasStatefulAssignments for PropagationContextMut<'_> { + fn stateful_assignments(&self) -> &TrailedAssignments { + self.stateful_assignments + } + + fn stateful_assignments_mut(&mut self) -> &mut TrailedAssignments { + self.stateful_assignments + } + } + impl HasAssignments for PropagationContext<'_> { fn assignments(&self) -> &Assignments { self.assignments @@ -109,8 +176,36 @@ mod private { self.assignments } } + + impl HasAssignments for StatefulPropagationContext<'_> { + fn assignments(&self) -> &Assignments { + self.assignments + } + } } +pub(crate) trait ManipulateStatefulIntegers: HasStatefulAssignments { + fn new_stateful_integer(&mut self, initial_value: i64) -> TrailedInt { + self.stateful_assignments_mut().grow(initial_value) + } + + fn value(&self, stateful_integer: TrailedInt) -> i64 { + self.stateful_assignments().read(stateful_integer) + } + + fn add_assign(&mut self, stateful_integer: TrailedInt, addition: i64) { + self.stateful_assignments_mut() + .add_assign(stateful_integer, addition); + } + + fn assign(&mut self, stateful_integer: TrailedInt, value: i64) { + self.stateful_assignments_mut() + .assign(stateful_integer, value); + } +} + +impl ManipulateStatefulIntegers for T {} + pub(crate) trait ReadDomains: HasAssignments { fn is_predicate_satisfied(&self, predicate: Predicate) -> bool { self.assignments() @@ -251,7 +346,8 @@ impl PropagationContextMut<'_> { .is_value_in_domain(domain_id, equality_constant) && !self.assignments.is_domain_assigned(&domain_id) { - let reason = self.reason_store.push(self.propagator_id, reason.into()); + let reason = self.build_reason(reason.into()); + let reason = self.reason_store.push(self.propagator_id, reason); self.assignments .make_assignment(domain_id, equality_constant, Some(reason))?; } diff --git a/pumpkin-solver/src/engine/cp/propagation/propagator_initialisation_context.rs b/pumpkin-solver/src/engine/cp/propagation/contexts/propagator_initialisation_context.rs similarity index 78% rename from pumpkin-solver/src/engine/cp/propagation/propagator_initialisation_context.rs rename to pumpkin-solver/src/engine/cp/propagation/contexts/propagator_initialisation_context.rs index b4f4f4fa..28e8fe86 100644 --- a/pumpkin-solver/src/engine/cp/propagation/propagator_initialisation_context.rs +++ b/pumpkin-solver/src/engine/cp/propagation/contexts/propagator_initialisation_context.rs @@ -1,5 +1,6 @@ use super::PropagationContext; use super::ReadDomains; +use super::StatefulPropagationContext; use crate::engine::domain_events::DomainEvents; use crate::engine::propagation::LocalId; #[cfg(doc)] @@ -8,6 +9,7 @@ use crate::engine::propagation::PropagatorId; use crate::engine::propagation::PropagatorVarId; use crate::engine::variables::IntegerVariable; use crate::engine::Assignments; +use crate::engine::TrailedAssignments; use crate::engine::WatchListCP; use crate::engine::Watchers; @@ -19,29 +21,39 @@ use crate::engine::Watchers; #[derive(Debug)] pub(crate) struct PropagatorInitialisationContext<'a> { watch_list: &'a mut WatchListCP, + pub(crate) stateful_assignments: &'a mut TrailedAssignments, propagator_id: PropagatorId, next_local_id: LocalId, - context: PropagationContext<'a>, + pub assignments: &'a mut Assignments, } impl PropagatorInitialisationContext<'_> { pub(crate) fn new<'a>( watch_list: &'a mut WatchListCP, + stateful_assignments: &'a mut TrailedAssignments, propagator_id: PropagatorId, - assignments: &'a Assignments, + assignments: &'a mut Assignments, ) -> PropagatorInitialisationContext<'a> { PropagatorInitialisationContext { watch_list, + stateful_assignments, propagator_id, next_local_id: LocalId::from(0), - context: PropagationContext::new(assignments), + assignments, + } + } + + pub(crate) fn as_stateful_readonly(&mut self) -> StatefulPropagationContext { + StatefulPropagationContext { + stateful_assignments: self.stateful_assignments, + assignments: self.assignments, } } pub(crate) fn as_readonly(&self) -> PropagationContext { - self.context + PropagationContext::new(self.assignments) } /// Subscribes the propagator to the given [`DomainEvents`]. @@ -61,7 +73,7 @@ impl PropagatorInitialisationContext<'_> { domain_events: DomainEvents, local_id: LocalId, ) -> Var { - if self.context.is_fixed(&var) { + if PropagationContext::new(self.assignments).is_fixed(&var) { return var; } let propagator_var = PropagatorVarId { @@ -116,11 +128,22 @@ impl PropagatorInitialisationContext<'_> { mod private { use super::*; - use crate::engine::propagation::propagation_context::HasAssignments; + use crate::engine::propagation::contexts::HasAssignments; + use crate::engine::propagation::contexts::HasStatefulAssignments; impl HasAssignments for PropagatorInitialisationContext<'_> { fn assignments(&self) -> &Assignments { - self.context.assignments + self.assignments + } + } + + impl HasStatefulAssignments for PropagatorInitialisationContext<'_> { + fn stateful_assignments(&self) -> &TrailedAssignments { + self.stateful_assignments + } + + fn stateful_assignments_mut(&mut self) -> &mut TrailedAssignments { + self.stateful_assignments } } } diff --git a/pumpkin-solver/src/engine/cp/propagation/mod.rs b/pumpkin-solver/src/engine/cp/propagation/mod.rs index ccb38ca0..c577446c 100644 --- a/pumpkin-solver/src/engine/cp/propagation/mod.rs +++ b/pumpkin-solver/src/engine/cp/propagation/mod.rs @@ -76,25 +76,23 @@ //! International Workshop on Constraint Solving and Constraint Logic Programming, 2005, pp. //! 118–132. -mod explanation_context; +pub(crate) mod contexts; pub(crate) mod local_id; -pub(crate) mod propagation_context; pub(crate) mod propagator; pub(crate) mod propagator_id; -pub(crate) mod propagator_initialisation_context; pub(crate) mod propagator_var_id; pub(crate) mod store; -pub(crate) use explanation_context::CurrentNogood; -pub(crate) use explanation_context::ExplanationContext; +pub(crate) use contexts::explanation_context::CurrentNogood; +pub(crate) use contexts::explanation_context::ExplanationContext; +pub(crate) use contexts::propagation_context::PropagationContext; +pub(crate) use contexts::propagation_context::PropagationContextMut; +pub(crate) use contexts::propagation_context::ReadDomains; +pub(crate) use contexts::propagator_initialisation_context::PropagatorInitialisationContext; pub(crate) use local_id::LocalId; -pub(crate) use propagation_context::PropagationContext; -pub(crate) use propagation_context::PropagationContextMut; -pub(crate) use propagation_context::ReadDomains; pub(crate) use propagator::EnqueueDecision; pub(crate) use propagator::Propagator; pub(crate) use propagator_id::PropagatorId; -pub(crate) use propagator_initialisation_context::PropagatorInitialisationContext; pub(crate) use propagator_var_id::PropagatorVarId; #[cfg(doc)] diff --git a/pumpkin-solver/src/engine/cp/propagation/propagator.rs b/pumpkin-solver/src/engine/cp/propagation/propagator.rs index 4bd7661c..dcc1724e 100644 --- a/pumpkin-solver/src/engine/cp/propagation/propagator.rs +++ b/pumpkin-solver/src/engine/cp/propagation/propagator.rs @@ -1,8 +1,11 @@ use downcast_rs::impl_downcast; use downcast_rs::Downcast; -use super::explanation_context::ExplanationContext; -use super::propagator_initialisation_context::PropagatorInitialisationContext; +use super::contexts::StatefulPropagationContext; +use super::ExplanationContext; +use super::PropagationContext; +use super::PropagationContextMut; +use super::PropagatorInitialisationContext; #[cfg(doc)] use crate::basic_types::Inconsistency; use crate::basic_types::PropagationStatusCP; @@ -10,8 +13,6 @@ use crate::basic_types::PropagationStatusCP; use crate::create_statistics_struct; use crate::engine::opaque_domain_event::OpaqueDomainEvent; use crate::engine::propagation::local_id::LocalId; -use crate::engine::propagation::propagation_context::PropagationContext; -use crate::engine::propagation::propagation_context::PropagationContextMut; #[cfg(doc)] use crate::engine::ConstraintSatisfactionSolver; use crate::predicates::Predicate; @@ -90,7 +91,7 @@ pub(crate) trait Propagator: Downcast { /// [`PropagatorInitialisationContext::register()`]. fn notify( &mut self, - _context: PropagationContext, + _context: StatefulPropagationContext, _local_id: LocalId, _event: OpaqueDomainEvent, ) -> EnqueueDecision { @@ -157,7 +158,7 @@ pub(crate) trait Propagator: Downcast { /// inconsistency as well. fn detect_inconsistency( &self, - _context: PropagationContext, + _context: StatefulPropagationContext, ) -> Option { None } diff --git a/pumpkin-solver/src/engine/cp/reason.rs b/pumpkin-solver/src/engine/cp/reason.rs index 73ccfe1c..3b384e9d 100644 --- a/pumpkin-solver/src/engine/cp/reason.rs +++ b/pumpkin-solver/src/engine/cp/reason.rs @@ -7,16 +7,16 @@ use crate::basic_types::PropositionalConjunction; use crate::basic_types::Trail; use crate::predicates::Predicate; use crate::pumpkin_assert_simple; +use crate::variables::Literal; /// The reason store holds a reason for each change made by a CP propagator on a trail. #[derive(Default, Debug)] pub(crate) struct ReasonStore { - trail: Trail<(PropagatorId, Reason)>, - pub helper: PropositionalConjunction, + trail: Trail<(PropagatorId, StoredReason)>, } impl ReasonStore { - pub(crate) fn push(&mut self, propagator: PropagatorId, reason: Reason) -> ReasonRef { + pub(crate) fn push(&mut self, propagator: PropagatorId, reason: StoredReason) -> ReasonRef { let index = self.trail.len(); self.trail.push((propagator, reason)); pumpkin_assert_simple!( @@ -27,22 +27,35 @@ impl ReasonStore { ReasonRef(index as u32) } - pub(crate) fn get_or_compute<'this>( - &'this self, + /// Evaluate the reason with the given reference, and write the predicates to + /// `destination_buffer`. + pub(crate) fn get_or_compute( + &self, reference: ReasonRef, context: ExplanationContext<'_>, - propagators: &'this mut PropagatorStore, - ) -> Option<&'this [Predicate]> { - self.trail - .get(reference.0 as usize) - .map(|reason| reason.1.compute(context, reason.0, propagators)) + propagators: &mut PropagatorStore, + destination_buffer: &mut impl Extend, + ) -> bool { + let Some(reason) = self.trail.get(reference.0 as usize) else { + return false; + }; + + reason + .1 + .compute(context, reason.0, propagators, destination_buffer); + + true } pub(crate) fn get_lazy_code(&self, reference: ReasonRef) -> Option<&u64> { match self.trail.get(reference.0 as usize) { Some(reason) => match &reason.1 { - Reason::Eager(_) => None, - Reason::DynamicLazy(code) => Some(code), + StoredReason::Eager(_) => None, + StoredReason::DynamicLazy(code) => Some(code), + StoredReason::ReifiedLazy(_, _) => { + // If this happens, we need to rethink this API. + unimplemented!("cannot get code of reified lazy explanation") + } }, None => None, } @@ -86,21 +99,52 @@ pub(crate) enum Reason { DynamicLazy(u64), } -impl Reason { - pub(crate) fn compute<'a>( - &'a self, +/// A reason for CP propagator to make a change +#[derive(Debug)] +pub(crate) enum StoredReason { + /// An eager reason contains the propositional conjunction with the reason, without the + /// propagated predicate. + Eager(PropositionalConjunction), + /// A lazy reason, which is computed on-demand rather than up-front. This is also referred to + /// as a 'backward' reason. + /// + /// A lazy reason contains a payload that propagators can use to identify what type of + /// propagation the reason is for. The payload should be enough for the propagator to construct + /// an explanation based on its internal state. + DynamicLazy(u64), + /// A lazy explanation that has been reified. + ReifiedLazy(Literal, u64), +} + +impl StoredReason { + /// Evaluate the reason, and write the predicates to the `destination_buffer`. + pub(crate) fn compute( + &self, context: ExplanationContext<'_>, propagator_id: PropagatorId, - propagators: &'a mut PropagatorStore, - ) -> &'a [Predicate] { + propagators: &mut PropagatorStore, + destination_buffer: &mut impl Extend, + ) { match self { // We do not replace the reason with an eager explanation for dynamic lazy explanations. // // Benchmarking will have to show whether this should change or not. - Reason::DynamicLazy(code) => { - propagators[propagator_id].lazy_explanation(*code, context) + StoredReason::DynamicLazy(code) => destination_buffer.extend( + propagators[propagator_id] + .lazy_explanation(*code, context) + .iter() + .copied(), + ), + StoredReason::Eager(result) => destination_buffer.extend(result.iter().copied()), + StoredReason::ReifiedLazy(literal, code) => { + destination_buffer.extend( + propagators[propagator_id] + .lazy_explanation(*code, context) + .iter() + .copied(), + ); + destination_buffer.extend(std::iter::once(literal.get_true_predicate())); } - Reason::Eager(result) => result.as_slice(), } } } @@ -115,8 +159,10 @@ impl From for Reason { mod tests { use super::*; use crate::conjunction; + use crate::engine::propagation::Propagator; use crate::engine::variables::DomainId; use crate::engine::Assignments; + use crate::predicate; #[test] fn computing_an_eager_reason_returns_a_reference_to_the_conjunction() { @@ -126,16 +172,17 @@ mod tests { let y = DomainId::new(1); let conjunction = conjunction!([x == 1] & [y == 2]); - let reason = Reason::Eager(conjunction.clone()); - - assert_eq!( - conjunction.as_slice(), - reason.compute( - ExplanationContext::from(&integers), - PropagatorId(0), - &mut PropagatorStore::default() - ) + let reason = StoredReason::Eager(conjunction.clone()); + + let mut out_reason = vec![]; + reason.compute( + ExplanationContext::from(&integers), + PropagatorId(0), + &mut PropagatorStore::default(), + &mut out_reason, ); + + assert_eq!(conjunction.as_slice(), &out_reason); } #[test] @@ -147,17 +194,73 @@ mod tests { let y = DomainId::new(1); let conjunction = conjunction!([x == 1] & [y == 2]); - let reason_ref = reason_store.push(PropagatorId(0), Reason::Eager(conjunction.clone())); + let reason_ref = + reason_store.push(PropagatorId(0), StoredReason::Eager(conjunction.clone())); assert_eq!(ReasonRef(0), reason_ref); - assert_eq!( - Some(conjunction.as_slice()), - reason_store.get_or_compute( - reason_ref, - ExplanationContext::from(&integers), - &mut PropagatorStore::default() - ) + let mut out_reason = vec![]; + let _ = reason_store.get_or_compute( + reason_ref, + ExplanationContext::from(&integers), + &mut PropagatorStore::default(), + &mut out_reason, ); + + assert_eq!(conjunction.as_slice(), &out_reason); + } + + #[test] + fn reified_lazy_explanation_has_reification_added_after_compute() { + let mut reason_store = ReasonStore::default(); + let mut integers = Assignments::default(); + + let x = integers.grow(1, 5); + let reif = Literal::new(integers.grow(0, 1)); + + struct TestPropagator(Vec); + + impl Propagator for TestPropagator { + fn name(&self) -> &str { + todo!() + } + + fn debug_propagate_from_scratch( + &self, + _: crate::engine::propagation::PropagationContextMut, + ) -> crate::basic_types::PropagationStatusCP { + todo!() + } + + fn initialise_at_root( + &mut self, + _: &mut crate::engine::propagation::PropagatorInitialisationContext, + ) -> Result<(), PropositionalConjunction> { + todo!() + } + + fn lazy_explanation(&mut self, code: u64, _: ExplanationContext) -> &[Predicate] { + assert_eq!(0, code); + + &self.0 + } + } + + let mut propagator_store = PropagatorStore::default(); + let propagator_id = + propagator_store.alloc(Box::new(TestPropagator(vec![predicate![x >= 2]])), None); + let reason_ref = reason_store.push(propagator_id, StoredReason::ReifiedLazy(reif, 0)); + + assert_eq!(ReasonRef(0), reason_ref); + + let mut reason = vec![]; + let _ = reason_store.get_or_compute( + reason_ref, + ExplanationContext::from(&integers), + &mut propagator_store, + &mut reason, + ); + + assert_eq!(vec![predicate![x >= 2], reif.get_true_predicate()], reason); } } diff --git a/pumpkin-solver/src/engine/cp/test_solver.rs b/pumpkin-solver/src/engine/cp/test_solver.rs index 8acc377d..4dcc886a 100644 --- a/pumpkin-solver/src/engine/cp/test_solver.rs +++ b/pumpkin-solver/src/engine/cp/test_solver.rs @@ -7,12 +7,13 @@ use super::propagation::store::PropagatorStore; use super::propagation::EnqueueDecision; use super::propagation::ExplanationContext; use super::propagation::PropagatorInitialisationContext; +use super::TrailedAssignments; use crate::basic_types::Inconsistency; use crate::engine::conflict_analysis::SemanticMinimiser; use crate::engine::opaque_domain_event::OpaqueDomainEvent; use crate::engine::predicates::predicate::Predicate; +use crate::engine::propagation::contexts::StatefulPropagationContext; use crate::engine::propagation::LocalId; -use crate::engine::propagation::PropagationContext; use crate::engine::propagation::PropagationContextMut; use crate::engine::propagation::Propagator; use crate::engine::propagation::PropagatorId; @@ -33,6 +34,7 @@ pub(crate) struct TestSolver { pub propagator_store: PropagatorStore, pub reason_store: ReasonStore, pub semantic_minimiser: SemanticMinimiser, + pub stateful_assignments: TrailedAssignments, watch_list: WatchListCP, } @@ -44,6 +46,7 @@ impl Default for TestSolver { propagator_store: Default::default(), semantic_minimiser: Default::default(), watch_list: Default::default(), + stateful_assignments: Default::default(), }; // We allocate space for the zero-th dummy variable at the root level of the assignments. solver.watch_list.grow(); @@ -71,10 +74,12 @@ impl TestSolver { self.propagator_store[id].initialise_at_root(&mut PropagatorInitialisationContext::new( &mut self.watch_list, + &mut self.stateful_assignments, id, - &self.assignments, + &mut self.assignments, ))?; let context = PropagationContextMut::new( + &mut self.stateful_assignments, &mut self.assignments, &mut self.reason_store, &mut self.semantic_minimiser, @@ -102,7 +107,8 @@ impl TestSolver { ) -> EnqueueDecision { let result = self.assignments.tighten_lower_bound(var, value, None); assert!(result.is_ok(), "The provided value to `increase_lower_bound` caused an empty domain, generally the propagator should not be notified of this change!"); - let context = PropagationContext::new(&self.assignments); + let context = + StatefulPropagationContext::new(&mut self.stateful_assignments, &self.assignments); self.propagator_store[propagator].notify( context, LocalId::from(local_id), @@ -125,7 +131,8 @@ impl TestSolver { ) -> EnqueueDecision { let result = self.assignments.tighten_upper_bound(var, value, None); assert!(result.is_ok(), "The provided value to `increase_lower_bound` caused an empty domain, generally the propagator should not be notified of this change!"); - let context = PropagationContext::new(&self.assignments); + let context = + StatefulPropagationContext::new(&mut self.stateful_assignments, &self.assignments); self.propagator_store[propagator].notify( context, LocalId::from(local_id), @@ -169,6 +176,7 @@ impl TestSolver { pub(crate) fn propagate(&mut self, propagator: PropagatorId) -> Result<(), Inconsistency> { let context = PropagationContextMut::new( + &mut self.stateful_assignments, &mut self.assignments, &mut self.reason_store, &mut self.semantic_minimiser, @@ -187,6 +195,7 @@ impl TestSolver { { // Specify the life-times to be able to retrieve the trail entries let context = PropagationContextMut::new( + &mut self.stateful_assignments, &mut self.assignments, &mut self.reason_store, &mut self.semantic_minimiser, @@ -205,15 +214,22 @@ impl TestSolver { pub(crate) fn notify_propagator(&mut self, propagator: PropagatorId) { let events = self.assignments.drain_domain_events().collect::>(); - let context = PropagationContext::new(&self.assignments); for (event, domain) in events { // The nogood propagator is treated in a special way, since it is not explicitly // subscribed to any domain updates, but implicitly is subscribed to all updates. if self.propagator_store[propagator].name() == "NogoodPropagator" { + let context = StatefulPropagationContext::new( + &mut self.stateful_assignments, + &self.assignments, + ); let local_id = LocalId::from(domain.id); let _ = self.propagator_store[propagator].notify(context, local_id, event.into()); } else { for propagator_var in self.watch_list.get_affected_propagators(event, domain) { + let context = StatefulPropagationContext::new( + &mut self.stateful_assignments, + &self.assignments, + ); let _ = self.propagator_store[propagator].notify( context, propagator_var.variable, @@ -228,14 +244,13 @@ impl TestSolver { let reason_ref = self .assignments .get_reason_for_predicate_brute_force(predicate); - let predicates = self - .reason_store - .get_or_compute( - reason_ref, - ExplanationContext::from(&self.assignments), - &mut self.propagator_store, - ) - .expect("reason_ref should not be stale"); + let mut predicates = vec![]; + let _ = self.reason_store.get_or_compute( + reason_ref, + ExplanationContext::from(&self.assignments), + &mut self.propagator_store, + &mut predicates, + ); PropositionalConjunction::from(predicates) } diff --git a/pumpkin-solver/src/engine/cp/trailed/mod.rs b/pumpkin-solver/src/engine/cp/trailed/mod.rs new file mode 100644 index 00000000..d361fbaa --- /dev/null +++ b/pumpkin-solver/src/engine/cp/trailed/mod.rs @@ -0,0 +1,4 @@ +mod trailed_assignments; +mod trailed_change; +pub(crate) use trailed_assignments::*; +pub(crate) use trailed_change::*; diff --git a/pumpkin-solver/src/engine/cp/trailed/trailed_assignments.rs b/pumpkin-solver/src/engine/cp/trailed/trailed_assignments.rs new file mode 100644 index 00000000..db641273 --- /dev/null +++ b/pumpkin-solver/src/engine/cp/trailed/trailed_assignments.rs @@ -0,0 +1,118 @@ +use super::TrailedChange; +use crate::basic_types::Trail; +use crate::containers::KeyedVec; +use crate::containers::StorageKey; + +#[derive(Debug, Clone, Copy)] +pub(crate) struct TrailedInt { + id: u32, +} + +impl Default for TrailedInt { + fn default() -> Self { + Self { id: u32::MAX } + } +} + +impl StorageKey for TrailedInt { + fn index(&self) -> usize { + self.id as usize + } + + fn create_from_index(index: usize) -> Self { + Self { id: index as u32 } + } +} + +#[derive(Default, Debug, Clone)] +pub(crate) struct TrailedAssignments { + trail: Trail, + values: KeyedVec, +} + +impl TrailedAssignments { + pub(crate) fn grow(&mut self, initial_value: i64) -> TrailedInt { + self.values.push(initial_value) + } + + pub(crate) fn increase_decision_level(&mut self) { + self.trail.increase_decision_level() + } + + pub(crate) fn read(&self, stateful_int: TrailedInt) -> i64 { + self.values[stateful_int] + } + + pub(crate) fn synchronise(&mut self, new_decision_level: usize) { + self.trail + .synchronise(new_decision_level) + .for_each(|state_change| self.values[state_change.reference] = state_change.old_value) + } + + fn write(&mut self, stateful_int: TrailedInt, value: i64) { + let old_value = self.values[stateful_int]; + if old_value == value { + return; + } + let entry = TrailedChange { + old_value, + reference: stateful_int, + }; + self.trail.push(entry); + self.values[stateful_int] = value; + } + + pub(crate) fn add_assign(&mut self, stateful_int: TrailedInt, addition: i64) { + self.write(stateful_int, self.values[stateful_int] + addition); + } + + pub(crate) fn assign(&mut self, stateful_int: TrailedInt, value: i64) { + self.write(stateful_int, value); + } + + pub(crate) fn debug_create_empty_clone(&self) -> Self { + let mut new_trail = self.trail.clone(); + let mut new_values = self.values.clone(); + if new_trail.get_decision_level() > 0 { + new_trail.synchronise(0).for_each(|state_change| { + new_values[state_change.reference] = state_change.old_value + }); + } + Self { + trail: new_trail, + values: new_values, + } + } +} + +#[cfg(test)] +mod tests { + use crate::engine::TrailedAssignments; + + #[test] + fn test_write_resets() { + let mut assignments = TrailedAssignments::default(); + let trailed_int = assignments.grow(0); + + assert_eq!(assignments.read(trailed_int), 0); + + assignments.increase_decision_level(); + assignments.add_assign(trailed_int, 5); + + assert_eq!(assignments.read(trailed_int), 5); + + assignments.add_assign(trailed_int, 5); + assert_eq!(assignments.read(trailed_int), 10); + + assignments.increase_decision_level(); + assignments.add_assign(trailed_int, 1); + + assert_eq!(assignments.read(trailed_int), 11); + + assignments.synchronise(1); + assert_eq!(assignments.read(trailed_int), 10); + + assignments.synchronise(0); + assert_eq!(assignments.read(trailed_int), 0); + } +} diff --git a/pumpkin-solver/src/engine/cp/trailed/trailed_change.rs b/pumpkin-solver/src/engine/cp/trailed/trailed_change.rs new file mode 100644 index 00000000..af4d9389 --- /dev/null +++ b/pumpkin-solver/src/engine/cp/trailed/trailed_change.rs @@ -0,0 +1,7 @@ +use super::TrailedInt; + +#[derive(Debug, Clone)] +pub(crate) struct TrailedChange { + pub(crate) old_value: i64, + pub(crate) reference: TrailedInt, +} diff --git a/pumpkin-solver/src/engine/debug_helper.rs b/pumpkin-solver/src/engine/debug_helper.rs index 3a0c7106..99f0109f 100644 --- a/pumpkin-solver/src/engine/debug_helper.rs +++ b/pumpkin-solver/src/engine/debug_helper.rs @@ -12,6 +12,7 @@ use super::propagation::store::PropagatorStore; use super::propagation::ExplanationContext; use super::reason::ReasonStore; use super::ConstraintSatisfactionSolver; +use super::TrailedAssignments; use crate::basic_types::Inconsistency; use crate::basic_types::PropositionalConjunction; use crate::engine::cp::Assignments; @@ -49,10 +50,12 @@ impl DebugHelper { /// Additionally checks whether the internal data structures of the clausal propagator are okay /// and consistent with the assignments_propositional pub(crate) fn debug_fixed_point_propagation( + stateful_assignments: &TrailedAssignments, assignments: &Assignments, propagators: &PropagatorStore, ) -> bool { let mut assignments_clone = assignments.clone(); + let mut stateful_assignments_clone = stateful_assignments.clone(); // Check whether constraint programming propagators missed anything // // It works by asking each propagator to propagate from scratch, and checking whether any @@ -73,6 +76,7 @@ impl DebugHelper { let mut reason_store = Default::default(); let mut semantic_minimiser = SemanticMinimiser::default(); let context = PropagationContextMut::new( + &mut stateful_assignments_clone, &mut assignments_clone, &mut reason_store, &mut semantic_minimiser, @@ -113,12 +117,14 @@ impl DebugHelper { } pub(crate) fn debug_reported_failure( + stateful_assignments: &TrailedAssignments, assignments: &Assignments, failure_reason: &PropositionalConjunction, propagator: &dyn Propagator, propagator_id: PropagatorId, ) -> bool { DebugHelper::debug_reported_propagations_reproduce_failure( + stateful_assignments, assignments, failure_reason, propagator, @@ -126,6 +132,7 @@ impl DebugHelper { ); DebugHelper::debug_reported_propagations_negate_failure_and_check( + stateful_assignments, assignments, failure_reason, propagator, @@ -143,6 +150,7 @@ impl DebugHelper { pub(crate) fn debug_check_propagations( num_trail_entries_before: usize, propagator_id: PropagatorId, + stateful_assignments: &TrailedAssignments, assignments: &Assignments, reason_store: &mut ReasonStore, propagators: &mut PropagatorStore, @@ -154,20 +162,20 @@ impl DebugHelper { for trail_index in num_trail_entries_before..assignments.num_trail_entries() { let trail_entry = assignments.get_trail_entry(trail_index); - let reason = reason_store - .get_or_compute( - trail_entry - .reason - .expect("Expected checked propagation to have a reason"), - ExplanationContext::from(assignments), - propagators, - ) - .expect("reason should exist for this propagation") - .to_vec(); + let mut reason = vec![]; + let _ = reason_store.get_or_compute( + trail_entry + .reason + .expect("Expected checked propagation to have a reason"), + ExplanationContext::from(assignments), + propagators, + &mut reason, + ); result &= Self::debug_propagator_reason( trail_entry.predicate, &reason, + stateful_assignments, assignments, &propagators[propagator_id], propagator_id, @@ -179,6 +187,7 @@ impl DebugHelper { fn debug_propagator_reason( propagated_predicate: Predicate, reason: &[Predicate], + stateful_assignments: &TrailedAssignments, assignments: &Assignments, propagator: &dyn Propagator, propagator_id: PropagatorId, @@ -209,6 +218,7 @@ impl DebugHelper { // Does setting the predicates from the reason indeed lead to the propagation? { let mut assignments_clone = assignments.debug_create_empty_clone(); + let mut stateful_assignments_clone = stateful_assignments.debug_create_empty_clone(); let reason_predicates: Vec = reason.to_vec(); let adding_predicates_was_successful = DebugHelper::debug_add_predicates_to_assignments( @@ -221,6 +231,7 @@ impl DebugHelper { let mut reason_store = Default::default(); let mut semantic_minimiser = SemanticMinimiser::default(); let context = PropagationContextMut::new( + &mut stateful_assignments_clone, &mut assignments_clone, &mut reason_store, &mut semantic_minimiser, @@ -304,6 +315,7 @@ impl DebugHelper { // related to reverse unit propagation { let mut assignments_clone = assignments.debug_create_empty_clone(); + let mut stateful_assignments_clone = stateful_assignments.debug_create_empty_clone(); let failing_predicates: Vec = once(!propagated_predicate) .chain(reason.iter().copied()) @@ -330,6 +342,7 @@ impl DebugHelper { let num_predicates_before = assignments_clone.num_trail_entries(); let context = PropagationContextMut::new( + &mut stateful_assignments_clone, &mut assignments_clone, &mut reason_store, &mut semantic_minimiser, @@ -369,6 +382,7 @@ impl DebugHelper { } fn debug_reported_propagations_reproduce_failure( + stateful_assignments: &TrailedAssignments, assignments: &Assignments, failure_reason: &PropositionalConjunction, propagator: &dyn Propagator, @@ -378,6 +392,7 @@ impl DebugHelper { return; } let mut assignments_clone = assignments.debug_create_empty_clone(); + let mut stateful_assignments_clone = stateful_assignments.debug_create_empty_clone(); let reason_predicates: Vec = failure_reason.iter().copied().collect(); let adding_predicates_was_successful = DebugHelper::debug_add_predicates_to_assignments( @@ -390,6 +405,7 @@ impl DebugHelper { let mut reason_store = Default::default(); let mut semantic_minimiser = SemanticMinimiser::default(); let context = PropagationContextMut::new( + &mut stateful_assignments_clone, &mut assignments_clone, &mut reason_store, &mut semantic_minimiser, @@ -414,6 +430,7 @@ impl DebugHelper { } fn debug_reported_propagations_negate_failure_and_check( + stateful_assignments: &TrailedAssignments, assignments: &Assignments, failure_reason: &PropositionalConjunction, propagator: &dyn Propagator, @@ -440,6 +457,7 @@ impl DebugHelper { let mut found_nonconflicting_state_at_root = false; for predicate in &reason_predicates { let mut assignments_clone = assignments.debug_create_empty_clone(); + let mut stateful_assignments_clone = stateful_assignments.debug_create_empty_clone(); let negated_predicate = predicate.not(); let outcome = assignments_clone.post_predicate(negated_predicate, None); @@ -448,6 +466,7 @@ impl DebugHelper { let mut reason_store = Default::default(); let mut semantic_minimiser = SemanticMinimiser::default(); let context = PropagationContextMut::new( + &mut stateful_assignments_clone, &mut assignments_clone, &mut reason_store, &mut semantic_minimiser, diff --git a/pumpkin-solver/src/lib.rs b/pumpkin-solver/src/lib.rs index 611d69ee..3c87b175 100644 --- a/pumpkin-solver/src/lib.rs +++ b/pumpkin-solver/src/lib.rs @@ -103,8 +103,9 @@ //! } //! ``` //! -//! **Optimizing an objective** can be done in a similar way using [`Solver::maximise`] or -//! [`Solver::minimise`]; first the objective variable and a constraint over this value are added: +//! **Optimizing an objective** can be done in a similar way using [`Solver::optimise`]; first the +//! objective variable and a constraint over this value are added: +//! //! ```rust //! # use pumpkin_solver::Solver; //! # use pumpkin_solver::constraints; @@ -122,7 +123,7 @@ //! .post(); //! ``` //! -//! Then we can find the optimal solution using [`Solver::minimise`] or [`Solver::maximise`]: +//! Then we can find the optimal solution using [`Solver::optimise`]: //! ```rust //! # use pumpkin_solver::Solver; //! # use pumpkin_solver::results::OptimisationResult; @@ -130,7 +131,10 @@ //! # use pumpkin_solver::results::ProblemSolution; //! # use pumpkin_solver::constraints; //! # use pumpkin_solver::constraints::Constraint; +//! # use pumpkin_solver::optimisation::OptimisationDirection; +//! # use pumpkin_solver::optimisation::linear_sat_unsat::LinearSatUnsat; //! # use std::cmp::max; +//! # use crate::pumpkin_solver::optimisation::OptimisationProcedure; //! # let mut solver = Solver::default(); //! # let x = solver.new_bounded_integer(5, 10); //! # let y = solver.new_bounded_integer(-3, 15); @@ -141,7 +145,11 @@ //! # let mut termination = Indefinite; //! # let mut brancher = solver.default_brancher(); //! // Then we solve to optimality -//! let result = solver.minimise(&mut brancher, &mut termination, objective); +//! let result = solver.optimise( +//! &mut brancher, +//! &mut termination, +//! LinearSatUnsat::new(OptimisationDirection::Minimise, objective, |_, _| {}), +//! ); //! //! if let OptimisationResult::Optimal(optimal_solution) = result { //! let value_x = optimal_solution.get_integer_value(x); @@ -202,7 +210,7 @@ //! //! loop { //! match solution_iterator.next_solution() { -//! IteratedSolution::Solution(solution) => { +//! IteratedSolution::Solution(solution, _) => { //! number_of_solutions += 1; //! // We have found another solution, the same invariant should hold //! let value_x = solution.get_integer_value(x); @@ -287,6 +295,7 @@ pub(crate) mod math; pub(crate) mod propagators; pub(crate) mod pumpkin_asserts; pub(crate) mod variable_names; + #[cfg(doc)] use crate::branching::Brancher; #[cfg(doc)] @@ -294,6 +303,7 @@ use crate::termination::TerminationCondition; pub mod branching; pub mod constraints; +pub mod optimisation; pub mod proof; pub mod statistics; diff --git a/pumpkin-solver/src/optimisation/linear_sat_unsat.rs b/pumpkin-solver/src/optimisation/linear_sat_unsat.rs new file mode 100644 index 00000000..0b8d3d50 --- /dev/null +++ b/pumpkin-solver/src/optimisation/linear_sat_unsat.rs @@ -0,0 +1,195 @@ +use super::OptimisationProcedure; +use crate::basic_types::CSPSolverExecutionFlag; +use crate::branching::Brancher; +use crate::optimisation::OptimisationDirection; +use crate::predicate; +use crate::pumpkin_assert_simple; +use crate::results::OptimisationResult; +use crate::results::Solution; +use crate::results::SolutionReference; +use crate::termination::TerminationCondition; +use crate::variables::IntegerVariable; +use crate::ConstraintOperationError; +use crate::Solver; + +/// Implements the linear SAT-UNSAT (LSU) optimisation procedure. +#[derive(Debug, Clone, Copy)] +pub struct LinearSatUnsat { + direction: OptimisationDirection, + objective: Var, + solution_callback: Callback, +} + +impl LinearSatUnsat +where + // The trait bound here is not common; see + // linear_unsat_sat for more info. + Callback: Fn(&Solver, SolutionReference), +{ + /// Create a new instance of [`LinearSatUnsat`]. + pub fn new( + direction: OptimisationDirection, + objective: Var, + solution_callback: Callback, + ) -> Self { + Self { + direction, + objective, + solution_callback, + } + } +} + +impl LinearSatUnsat { + /// Given the current objective value `best_objective_value`, it adds a constraint specifying + /// that the objective value should be at most `best_objective_value - 1`. Note that it is + /// assumed that we are always minimising the variable. + fn strengthen( + &mut self, + objective_variable: &impl IntegerVariable, + best_objective_value: i64, + solver: &mut Solver, + ) -> Result<(), ConstraintOperationError> { + solver.satisfaction_solver.add_clause([predicate!( + objective_variable <= (best_objective_value - 1) as i32 + )]) + } + + fn debug_bound_change( + &self, + objective_variable: &impl IntegerVariable, + best_objective_value: i64, + solver: &Solver, + ) { + pumpkin_assert_simple!( + (solver + .satisfaction_solver + .get_assigned_integer_value(objective_variable) + .expect("expected variable to be assigned") as i64) + < best_objective_value, + "{}", + format!( + "The current bound {} should be smaller than the previous bound {}", + solver + .satisfaction_solver + .get_assigned_integer_value(objective_variable) + .expect("expected variable to be assigned"), + best_objective_value + ) + ); + } +} + +impl OptimisationProcedure for LinearSatUnsat +where + Var: IntegerVariable, + Callback: Fn(&Solver, SolutionReference), +{ + fn optimise( + &mut self, + brancher: &mut impl Brancher, + termination: &mut impl TerminationCondition, + solver: &mut Solver, + ) -> OptimisationResult { + let is_maximising = matches!(self.direction, OptimisationDirection::Maximise); + let objective = match self.direction { + OptimisationDirection::Maximise => self.objective.scaled(-1), + OptimisationDirection::Minimise => self.objective.scaled(1), + }; + // If we are maximising then when we simply scale the variable by -1, however, this will + // lead to the printed objective value in the statistics to be multiplied by -1; this + // objective_multiplier ensures that the objective is correctly logged. + let objective_multiplier = if is_maximising { -1 } else { 1 }; + + let initial_solve = solver.satisfaction_solver.solve(termination, brancher); + match initial_solve { + CSPSolverExecutionFlag::Feasible => {} + CSPSolverExecutionFlag::Infeasible => { + // Reset the state whenever we return a result + solver.satisfaction_solver.restore_state_at_root(brancher); + let _ = solver.satisfaction_solver.conclude_proof_unsat(); + return OptimisationResult::Unsatisfiable; + } + CSPSolverExecutionFlag::Timeout => { + // Reset the state whenever we return a result + solver.satisfaction_solver.restore_state_at_root(brancher); + return OptimisationResult::Unknown; + } + } + let mut best_objective_value = Default::default(); + let mut best_solution = Solution::default(); + + self.update_best_solution_and_process( + objective_multiplier, + &objective, + &mut best_objective_value, + &mut best_solution, + brancher, + solver, + ); + + loop { + solver.satisfaction_solver.restore_state_at_root(brancher); + + let objective_bound_predicate = if is_maximising { + predicate![objective >= best_objective_value as i32 * objective_multiplier] + } else { + predicate![objective <= best_objective_value as i32 * objective_multiplier] + }; + + if self + .strengthen( + &objective, + best_objective_value * objective_multiplier as i64, + solver, + ) + .is_err() + { + // Reset the state whenever we return a result + solver.satisfaction_solver.restore_state_at_root(brancher); + let _ = solver + .satisfaction_solver + .conclude_proof_optimal(objective_bound_predicate); + return OptimisationResult::Optimal(best_solution); + } + + let solve_result = solver.satisfaction_solver.solve(termination, brancher); + match solve_result { + CSPSolverExecutionFlag::Feasible => { + self.debug_bound_change( + &objective, + best_objective_value * objective_multiplier as i64, + solver, + ); + self.update_best_solution_and_process( + objective_multiplier, + &objective, + &mut best_objective_value, + &mut best_solution, + brancher, + solver, + ); + } + CSPSolverExecutionFlag::Infeasible => { + { + // Reset the state whenever we return a result + solver.satisfaction_solver.restore_state_at_root(brancher); + let _ = solver + .satisfaction_solver + .conclude_proof_optimal(objective_bound_predicate); + return OptimisationResult::Optimal(best_solution); + } + } + CSPSolverExecutionFlag::Timeout => { + // Reset the state whenever we return a result + solver.satisfaction_solver.restore_state_at_root(brancher); + return OptimisationResult::Satisfiable(best_solution); + } + } + } + } + + fn on_solution_callback(&self, solver: &Solver, solution: SolutionReference) { + (self.solution_callback)(solver, solution) + } +} diff --git a/pumpkin-solver/src/optimisation/linear_unsat_sat.rs b/pumpkin-solver/src/optimisation/linear_unsat_sat.rs new file mode 100644 index 00000000..d4012f10 --- /dev/null +++ b/pumpkin-solver/src/optimisation/linear_unsat_sat.rs @@ -0,0 +1,158 @@ +use log::info; + +use super::OptimisationProcedure; +use crate::basic_types::CSPSolverExecutionFlag; +use crate::branching::Brancher; +use crate::optimisation::OptimisationDirection; +use crate::predicate; +use crate::results::OptimisationResult; +use crate::results::Solution; +use crate::results::SolutionReference; +use crate::termination::TerminationCondition; +use crate::variables::IntegerVariable; +use crate::Solver; + +/// Implements the linear UNSAT-SAT (LUS) optimisation procedure. +#[derive(Debug, Clone, Copy)] +pub struct LinearUnsatSat { + direction: OptimisationDirection, + objective: Var, + solution_callback: Callback, +} + +impl LinearUnsatSat +where + // The trait bound here is contrary to common + // practice; typically the bounds are only enforced + // where they are required (in this case, in the + // implementation of OptimisationProcedure). + // + // However, if we don't have the trait bound here, + // the compiler may implement `FnOnce` for the + // empty closure, which causes problems. So, we + // have the hint here. + // + // Similar is also the case in linear SAT-UNSAT. + Callback: Fn(&Solver, SolutionReference), +{ + /// Create a new instance of [`LinearUnsatSat`]. + pub fn new( + direction: OptimisationDirection, + objective: Var, + solution_callback: Callback, + ) -> Self { + Self { + direction, + objective, + solution_callback, + } + } +} + +impl + OptimisationProcedure for LinearUnsatSat +{ + fn optimise( + &mut self, + brancher: &mut impl Brancher, + termination: &mut impl TerminationCondition, + solver: &mut Solver, + ) -> OptimisationResult { + let is_maximising = matches!(self.direction, OptimisationDirection::Maximise); + let objective = match self.direction { + OptimisationDirection::Maximise => self.objective.scaled(-1), + OptimisationDirection::Minimise => self.objective.scaled(1), + }; + // If we are maximising then when we simply scale the variable by -1, however, this will + // lead to the printed objective value in the statistics to be multiplied by -1; this + // objective_multiplier ensures that the objective is correctly logged. + let objective_multiplier = if is_maximising { -1 } else { 1 }; + + // First we do a feasibility check + let feasibility_check = solver.satisfaction_solver.solve(termination, brancher); + match feasibility_check { + CSPSolverExecutionFlag::Feasible => {} + CSPSolverExecutionFlag::Infeasible => { + // Reset the state whenever we return a result + solver.satisfaction_solver.restore_state_at_root(brancher); + let _ = solver.satisfaction_solver.conclude_proof_unsat(); + return OptimisationResult::Unsatisfiable; + } + CSPSolverExecutionFlag::Timeout => { + // Reset the state whenever we return a result + solver.satisfaction_solver.restore_state_at_root(brancher); + return OptimisationResult::Unknown; + } + } + let mut best_objective_value = Default::default(); + let mut best_solution = Solution::default(); + + self.update_best_solution_and_process( + objective_multiplier, + &objective, + &mut best_objective_value, + &mut best_solution, + brancher, + solver, + ); + solver.satisfaction_solver.restore_state_at_root(brancher); + + loop { + let assumption = predicate!(objective <= solver.lower_bound(&objective)); + + info!( + "Lower-Bounding Search - Attempting to find solution with assumption {assumption}" + ); + + // Solve under the assumption that the objective variable is lower than `lower-bound` + let solve_result = solver.satisfaction_solver.solve_under_assumptions( + &[assumption], + termination, + brancher, + ); + match solve_result { + CSPSolverExecutionFlag::Feasible => { + self.update_best_solution_and_process( + objective_multiplier, + &objective, + &mut best_objective_value, + &mut best_solution, + brancher, + solver, + ); + + // Reset the state whenever we return a result + solver.satisfaction_solver.restore_state_at_root(brancher); + + // We create a predicate specifying the best-found solution for the proof + // logging + let objective_bound_predicate = if is_maximising { + predicate![objective >= best_objective_value as i32 * objective_multiplier] + } else { + predicate![objective <= best_objective_value as i32 * objective_multiplier] + }; + let _ = solver + .satisfaction_solver + .conclude_proof_optimal(objective_bound_predicate); + + return OptimisationResult::Optimal(best_solution); + } + CSPSolverExecutionFlag::Infeasible => { + solver.satisfaction_solver.restore_state_at_root(brancher); + // We add the (hard) constraint that the negated assumption should hold (i.e., + // the solution should be at least as large as the found solution) + let _ = solver.add_clause([!assumption]); + } + CSPSolverExecutionFlag::Timeout => { + // Reset the state whenever we return a result + solver.satisfaction_solver.restore_state_at_root(brancher); + return OptimisationResult::Satisfiable(best_solution); + } + } + } + } + + fn on_solution_callback(&self, solver: &Solver, solution: SolutionReference) { + (self.solution_callback)(solver, solution) + } +} diff --git a/pumpkin-solver/src/optimisation/mod.rs b/pumpkin-solver/src/optimisation/mod.rs new file mode 100644 index 00000000..aa5dd145 --- /dev/null +++ b/pumpkin-solver/src/optimisation/mod.rs @@ -0,0 +1,91 @@ +//! Contains structures related to optimissation. +use std::fmt::Display; + +use clap::ValueEnum; + +use crate::branching::Brancher; +use crate::results::OptimisationResult; +use crate::results::Solution; +use crate::results::SolutionReference; +use crate::termination::TerminationCondition; +use crate::variables::IntegerVariable; +use crate::Solver; + +pub mod linear_sat_unsat; +pub mod linear_unsat_sat; + +pub trait OptimisationProcedure { + fn optimise( + &mut self, + brancher: &mut impl Brancher, + termination: &mut impl TerminationCondition, + solver: &mut Solver, + ) -> OptimisationResult; + + fn on_solution_callback(&self, solver: &Solver, solution: SolutionReference); + + /// Processes a solution when it is found, it consists of the following procedure: + /// - Assigning `best_objective_value` the value assigned to `objective_variable` (multiplied by + /// `objective_multiplier`). + /// - Storing the new best solution in `best_solution`. + /// - Calling [`Brancher::on_solution`] on the provided `brancher`. + /// - Logging the statistics using [`Solver::log_statistics_with_objective`]. + /// - Calling the solution callback. + fn update_best_solution_and_process( + &self, + objective_multiplier: i32, + objective_variable: &impl IntegerVariable, + best_objective_value: &mut i64, + best_solution: &mut Solution, + brancher: &mut impl Brancher, + solver: &Solver, + ) { + *best_objective_value = (objective_multiplier + * solver + .satisfaction_solver + .get_assigned_integer_value(objective_variable) + .expect("expected variable to be assigned")) as i64; + *best_solution = solver.satisfaction_solver.get_solution_reference().into(); + + self.internal_process_solution(best_solution, brancher, solver) + } + + fn internal_process_solution( + &self, + solution: &Solution, + brancher: &mut impl Brancher, + solver: &Solver, + ) { + brancher.on_solution(solution.as_reference()); + + self.on_solution_callback(solver, solution.as_reference()) + } +} + +/// The type of search which is performed by the solver. +#[derive(Debug, Clone, Copy, ValueEnum, Default)] +pub enum OptimisationStrategy { + /// Linear SAT-UNSAT - Starts with a satisfiable solution and tightens the bound on the + /// objective variable until an UNSAT result is reached. Can be seen as upper-bounding search. + #[default] + LinearSatUnsat, + /// Linear UNSAT-SAT - Starts with an unsatisfiable solution and tightens the bound on the + /// objective variable until a SAT result is reached. Can be seen as lower-bounding search. + LinearUnsatSat, +} + +impl Display for OptimisationStrategy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + OptimisationStrategy::LinearSatUnsat => write!(f, "linear-sat-unsat"), + OptimisationStrategy::LinearUnsatSat => write!(f, "linear-unsat-sat"), + } + } +} + +/// The direction of the optimisation, either maximising or minimising. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum OptimisationDirection { + Maximise, + Minimise, +} diff --git a/pumpkin-solver/src/proof/mod.rs b/pumpkin-solver/src/proof/mod.rs index d60f44a4..4d112d2f 100644 --- a/pumpkin-solver/src/proof/mod.rs +++ b/pumpkin-solver/src/proof/mod.rs @@ -35,9 +35,7 @@ pub struct ProofLog { } /// A dummy proof step ID. Used when there is proof logging is not enabled. -// Safety: Unwrapping an option is not stable, so we cannot get a NonZero safely in a const -// context. -const DUMMY_STEP_ID: NonZeroU64 = unsafe { NonZeroU64::new_unchecked(1) }; +const DUMMY_STEP_ID: NonZeroU64 = NonZeroU64::new(1).unwrap(); impl ProofLog { /// Create a CP proof logger. diff --git a/pumpkin-solver/src/propagators/arithmetic/division.rs b/pumpkin-solver/src/propagators/arithmetic/division.rs index 04a4f0d5..50b7070f 100644 --- a/pumpkin-solver/src/propagators/arithmetic/division.rs +++ b/pumpkin-solver/src/propagators/arithmetic/division.rs @@ -1,10 +1,10 @@ use crate::basic_types::PropagationStatusCP; use crate::conjunction; -use crate::engine::cp::propagation::propagation_context::ReadDomains; use crate::engine::propagation::LocalId; use crate::engine::propagation::PropagationContextMut; use crate::engine::propagation::Propagator; use crate::engine::propagation::PropagatorInitialisationContext; +use crate::engine::propagation::ReadDomains; use crate::engine::variables::IntegerVariable; use crate::engine::DomainEvents; use crate::predicates::PropositionalConjunction; diff --git a/pumpkin-solver/src/propagators/arithmetic/linear_less_or_equal.rs b/pumpkin-solver/src/propagators/arithmetic/linear_less_or_equal.rs index ea0d1982..ba11e2c0 100644 --- a/pumpkin-solver/src/propagators/arithmetic/linear_less_or_equal.rs +++ b/pumpkin-solver/src/propagators/arithmetic/linear_less_or_equal.rs @@ -1,8 +1,12 @@ +use itertools::Itertools; + use crate::basic_types::PropagationStatusCP; use crate::basic_types::PropositionalConjunction; use crate::engine::cp::propagation::ReadDomains; use crate::engine::domain_events::DomainEvents; use crate::engine::opaque_domain_event::OpaqueDomainEvent; +use crate::engine::propagation::contexts::ManipulateStatefulIntegers; +use crate::engine::propagation::contexts::StatefulPropagationContext; use crate::engine::propagation::EnqueueDecision; use crate::engine::propagation::LocalId; use crate::engine::propagation::PropagationContext; @@ -10,6 +14,7 @@ use crate::engine::propagation::PropagationContextMut; use crate::engine::propagation::Propagator; use crate::engine::propagation::PropagatorInitialisationContext; use crate::engine::variables::IntegerVariable; +use crate::engine::TrailedInt; use crate::predicate; use crate::pumpkin_assert_simple; @@ -20,9 +25,9 @@ pub(crate) struct LinearLessOrEqualPropagator { c: i32, /// The lower bound of the sum of the left-hand side. This is incremental state. - lower_bound_left_hand_side: i64, + lower_bound_left_hand_side: TrailedInt, /// The value at index `i` is the bound for `x[i]`. - current_bounds: Box<[i32]>, + current_bounds: Box<[TrailedInt]>, } impl LinearLessOrEqualPropagator @@ -30,33 +35,20 @@ where Var: IntegerVariable, { pub(crate) fn new(x: Box<[Var]>, c: i32) -> Self { - let current_bounds = vec![0; x.len()].into(); + let current_bounds = (0..x.len()) + .map(|_| TrailedInt::default()) + .collect_vec() + .into(); // incremental state will be properly initialized in `Propagator::initialise_at_root`. LinearLessOrEqualPropagator:: { x, c, - lower_bound_left_hand_side: 0, + lower_bound_left_hand_side: TrailedInt::default(), current_bounds, } } - /// Recalculates the incremental state from scratch. - fn recalculate_incremental_state(&mut self, context: PropagationContext) { - self.lower_bound_left_hand_side = self - .x - .iter() - .map(|var| context.lower_bound(var) as i64) - .sum(); - - self.current_bounds - .iter_mut() - .enumerate() - .for_each(|(index, bound)| { - *bound = context.lower_bound(&self.x[index]); - }); - } - fn create_conflict_reason(&self, context: PropagationContext) -> PropositionalConjunction { self.x .iter() @@ -73,17 +65,19 @@ where &mut self, context: &mut PropagatorInitialisationContext, ) -> Result<(), PropositionalConjunction> { + let mut lower_bound_left_hand_side = 0_i64; self.x.iter().enumerate().for_each(|(i, x_i)| { let _ = context.register( x_i.clone(), DomainEvents::LOWER_BOUND, LocalId::from(i as u32), ); + lower_bound_left_hand_side += context.lower_bound(x_i) as i64; + self.current_bounds[i] = context.new_stateful_integer(context.lower_bound(x_i) as i64); }); + self.lower_bound_left_hand_side = context.new_stateful_integer(lower_bound_left_hand_side); - self.recalculate_incremental_state(context.as_readonly()); - - if let Some(conjunction) = self.detect_inconsistency(context.as_readonly()) { + if let Some(conjunction) = self.detect_inconsistency(context.as_stateful_readonly()) { Err(conjunction) } else { Ok(()) @@ -92,10 +86,10 @@ where fn detect_inconsistency( &self, - context: PropagationContext, + context: StatefulPropagationContext, ) -> Option { - if (self.c as i64) < self.lower_bound_left_hand_side { - Some(self.create_conflict_reason(context)) + if (self.c as i64) < context.value(self.lower_bound_left_hand_side) { + Some(self.create_conflict_reason(context.as_readonly())) } else { None } @@ -103,31 +97,27 @@ where fn notify( &mut self, - context: PropagationContext, + mut context: StatefulPropagationContext, local_id: LocalId, _event: OpaqueDomainEvent, ) -> EnqueueDecision { let index = local_id.unpack() as usize; - let x_i = &self.x[index]; - let old_bound = self.current_bounds[index]; - let new_bound = context.lower_bound(x_i); + + let old_bound = context.value(self.current_bounds[index]); + let new_bound = context.lower_bound(x_i) as i64; pumpkin_assert_simple!( old_bound < new_bound, "propagator should only be triggered when lower bounds are tightened, old_bound={old_bound}, new_bound={new_bound}" ); - self.current_bounds[index] = new_bound; - self.lower_bound_left_hand_side += (new_bound - old_bound) as i64; + context.add_assign(self.lower_bound_left_hand_side, new_bound - old_bound); + context.assign(self.current_bounds[index], new_bound); EnqueueDecision::Enqueue } - fn synchronise(&mut self, context: PropagationContext) { - self.recalculate_incremental_state(context); - } - fn priority(&self) -> u32 { 0 } @@ -137,14 +127,14 @@ where } fn propagate(&mut self, mut context: PropagationContextMut) -> PropagationStatusCP { - if let Some(conjunction) = self.detect_inconsistency(context.as_readonly()) { + if let Some(conjunction) = self.detect_inconsistency(context.as_stateful_readonly()) { return Err(conjunction.into()); } let lower_bound_left_hand_side = - match TryInto::::try_into(self.lower_bound_left_hand_side) { + match TryInto::::try_into(context.value(self.lower_bound_left_hand_side)) { Ok(bound) => bound, - Err(_) if self.lower_bound_left_hand_side.is_positive() => { + Err(_) if context.value(self.lower_bound_left_hand_side).is_positive() => { // We cannot fit the `lower_bound_left_hand_side` into an i32 due to an // overflow (hence the check that the lower-bound on the left-hand side is // positive) @@ -200,7 +190,7 @@ where let lower_bound_left_hand_side = match TryInto::::try_into(lower_bound_left_hand_side) { Ok(bound) => bound, - Err(_) if self.lower_bound_left_hand_side.is_positive() => { + Err(_) if context.value(self.lower_bound_left_hand_side).is_positive() => { // We cannot fit the `lower_bound_left_hand_side` into an i32 due to an // overflow (hence the check that the lower-bound on the left-hand side is // positive) diff --git a/pumpkin-solver/src/propagators/arithmetic/linear_not_equal.rs b/pumpkin-solver/src/propagators/arithmetic/linear_not_equal.rs index 0c4eaf9f..1b04c5aa 100644 --- a/pumpkin-solver/src/propagators/arithmetic/linear_not_equal.rs +++ b/pumpkin-solver/src/propagators/arithmetic/linear_not_equal.rs @@ -7,6 +7,7 @@ use crate::basic_types::PropositionalConjunction; use crate::engine::cp::propagation::ReadDomains; use crate::engine::domain_events::DomainEvents; use crate::engine::opaque_domain_event::OpaqueDomainEvent; +use crate::engine::propagation::contexts::StatefulPropagationContext; use crate::engine::propagation::EnqueueDecision; use crate::engine::propagation::LocalId; use crate::engine::propagation::PropagationContext; @@ -72,7 +73,7 @@ where fn notify( &mut self, - context: PropagationContext, + context: StatefulPropagationContext, local_id: LocalId, _event: OpaqueDomainEvent, ) -> EnqueueDecision { @@ -250,7 +251,7 @@ where .expect("Expected to be able to fit i64 into i32"), reason, )?; - } else if num_fixed == self.terms.len() && lhs == self.rhs.into() { + } else if num_fixed == self.terms.len() && lhs == self.rhs as i64 { let failure_reason: PropositionalConjunction = self .terms .iter() @@ -318,7 +319,7 @@ impl LinearNotEqualPropagator { let number_of_fixed_terms_is_correct = self.number_of_fixed_terms == expected_number_of_fixed_terms; - let expected_fixed_lhs = self + let expected_fixed_lhs: i32 = self .terms .iter() .filter_map(|x_i| { diff --git a/pumpkin-solver/src/propagators/cumulative/time_table/explanations/big_step.rs b/pumpkin-solver/src/propagators/cumulative/time_table/explanations/big_step.rs index 9ad375d5..902e8a6c 100644 --- a/pumpkin-solver/src/propagators/cumulative/time_table/explanations/big_step.rs +++ b/pumpkin-solver/src/propagators/cumulative/time_table/explanations/big_step.rs @@ -1,8 +1,8 @@ use std::cmp::max; use std::rc::Rc; -use crate::engine::cp::propagation::propagation_context::ReadDomains; use crate::engine::propagation::PropagationContext; +use crate::engine::propagation::ReadDomains; use crate::predicate; use crate::predicates::Predicate; use crate::predicates::PropositionalConjunction; diff --git a/pumpkin-solver/src/propagators/cumulative/time_table/explanations/naive.rs b/pumpkin-solver/src/propagators/cumulative/time_table/explanations/naive.rs index 6aac42c4..c9761ce5 100644 --- a/pumpkin-solver/src/propagators/cumulative/time_table/explanations/naive.rs +++ b/pumpkin-solver/src/propagators/cumulative/time_table/explanations/naive.rs @@ -1,7 +1,7 @@ use std::rc::Rc; -use crate::engine::cp::propagation::propagation_context::ReadDomains; use crate::engine::propagation::PropagationContext; +use crate::engine::propagation::ReadDomains; use crate::predicate; use crate::predicates::Predicate; use crate::predicates::PropositionalConjunction; diff --git a/pumpkin-solver/src/propagators/cumulative/time_table/explanations/pointwise.rs b/pumpkin-solver/src/propagators/cumulative/time_table/explanations/pointwise.rs index 78d16e5e..7e035445 100644 --- a/pumpkin-solver/src/propagators/cumulative/time_table/explanations/pointwise.rs +++ b/pumpkin-solver/src/propagators/cumulative/time_table/explanations/pointwise.rs @@ -1,8 +1,8 @@ use std::rc::Rc; -use crate::engine::cp::propagation::propagation_context::ReadDomains; -use crate::engine::propagation::propagation_context::HasAssignments; +use crate::engine::propagation::contexts::propagation_context::HasAssignments; use crate::engine::propagation::PropagationContextMut; +use crate::engine::propagation::ReadDomains; use crate::engine::EmptyDomain; use crate::options::CumulativeExplanationType; use crate::predicate; diff --git a/pumpkin-solver/src/propagators/cumulative/time_table/over_interval_incremental_propagator/checks.rs b/pumpkin-solver/src/propagators/cumulative/time_table/over_interval_incremental_propagator/checks.rs index 6c526e86..59ff17cc 100644 --- a/pumpkin-solver/src/propagators/cumulative/time_table/over_interval_incremental_propagator/checks.rs +++ b/pumpkin-solver/src/propagators/cumulative/time_table/over_interval_incremental_propagator/checks.rs @@ -148,7 +148,7 @@ pub(crate) fn overlap_updated_profile( // A sanity check, there is a new profile to create consisting // of a combination of the previous profile and the updated task - if profile.height + task.resource_usage + task.resource_usage > capacity { + if profile.height + task.resource_usage > capacity { // The addition of the new mandatory part to the profile // caused an overflow of the resource return Err(ResourceProfile { diff --git a/pumpkin-solver/src/propagators/cumulative/time_table/over_interval_incremental_propagator/synchronisation.rs b/pumpkin-solver/src/propagators/cumulative/time_table/over_interval_incremental_propagator/synchronisation.rs index d7b142f0..58f671c9 100644 --- a/pumpkin-solver/src/propagators/cumulative/time_table/over_interval_incremental_propagator/synchronisation.rs +++ b/pumpkin-solver/src/propagators/cumulative/time_table/over_interval_incremental_propagator/synchronisation.rs @@ -4,8 +4,8 @@ use super::debug::are_mergeable; use super::debug::merge_profiles; use crate::basic_types::Inconsistency; use crate::basic_types::PropagationStatusCP; -use crate::engine::cp::propagation::propagation_context::ReadDomains; use crate::engine::propagation::PropagationContext; +use crate::engine::propagation::ReadDomains; use crate::propagators::create_time_table_over_interval_from_scratch; use crate::propagators::cumulative::time_table::propagation_handler::create_conflict_explanation; use crate::propagators::CumulativeParameters; diff --git a/pumpkin-solver/src/propagators/cumulative/time_table/over_interval_incremental_propagator/time_table_over_interval_incremental.rs b/pumpkin-solver/src/propagators/cumulative/time_table/over_interval_incremental_propagator/time_table_over_interval_incremental.rs index 0769a3bd..04712e2c 100644 --- a/pumpkin-solver/src/propagators/cumulative/time_table/over_interval_incremental_propagator/time_table_over_interval_incremental.rs +++ b/pumpkin-solver/src/propagators/cumulative/time_table/over_interval_incremental_propagator/time_table_over_interval_incremental.rs @@ -6,6 +6,7 @@ use super::insertion; use super::removal; use crate::basic_types::PropagationStatusCP; use crate::engine::opaque_domain_event::OpaqueDomainEvent; +use crate::engine::propagation::contexts::StatefulPropagationContext; use crate::engine::propagation::EnqueueDecision; use crate::engine::propagation::LocalId; use crate::engine::propagation::PropagationContext; @@ -363,7 +364,7 @@ impl Propagator fn notify( &mut self, - context: PropagationContext, + context: StatefulPropagationContext, local_id: LocalId, event: OpaqueDomainEvent, ) -> EnqueueDecision { @@ -379,7 +380,7 @@ impl Propagator &self.parameters, &self.updatable_structures, &updated_task, - context, + context.as_readonly(), self.time_table.is_empty(), ); @@ -388,7 +389,7 @@ impl Propagator insert_update(&updated_task, &mut self.updatable_structures, result.update); update_bounds_task( - context, + context.as_readonly(), self.updatable_structures.get_stored_bounds_mut(), &updated_task, ); diff --git a/pumpkin-solver/src/propagators/cumulative/time_table/per_point_incremental_propagator/time_table_per_point_incremental.rs b/pumpkin-solver/src/propagators/cumulative/time_table/per_point_incremental_propagator/time_table_per_point_incremental.rs index 8cf21171..ba63c001 100644 --- a/pumpkin-solver/src/propagators/cumulative/time_table/per_point_incremental_propagator/time_table_per_point_incremental.rs +++ b/pumpkin-solver/src/propagators/cumulative/time_table/per_point_incremental_propagator/time_table_per_point_incremental.rs @@ -5,6 +5,7 @@ use std::rc::Rc; use crate::basic_types::PropagationStatusCP; use crate::engine::opaque_domain_event::OpaqueDomainEvent; +use crate::engine::propagation::contexts::StatefulPropagationContext; use crate::engine::propagation::EnqueueDecision; use crate::engine::propagation::LocalId; use crate::engine::propagation::PropagationContext; @@ -375,7 +376,7 @@ impl Propagator fn notify( &mut self, - context: PropagationContext, + context: StatefulPropagationContext, local_id: LocalId, event: OpaqueDomainEvent, ) -> EnqueueDecision { @@ -391,7 +392,7 @@ impl Propagator &self.parameters, &self.updatable_structures, &updated_task, - context, + context.as_readonly(), self.time_table.is_empty(), ); @@ -400,7 +401,7 @@ impl Propagator insert_update(&updated_task, &mut self.updatable_structures, result.update); update_bounds_task( - context, + context.as_readonly(), self.updatable_structures.get_stored_bounds_mut(), &updated_task, ); diff --git a/pumpkin-solver/src/propagators/cumulative/time_table/propagation_handler.rs b/pumpkin-solver/src/propagators/cumulative/time_table/propagation_handler.rs index e78ff7d5..9dca270f 100644 --- a/pumpkin-solver/src/propagators/cumulative/time_table/propagation_handler.rs +++ b/pumpkin-solver/src/propagators/cumulative/time_table/propagation_handler.rs @@ -14,10 +14,10 @@ use super::explanations::naive::create_naive_propagation_explanation; use super::explanations::pointwise::create_pointwise_conflict_explanation; use super::explanations::pointwise::create_pointwise_propagation_explanation; use super::CumulativeExplanationType; -use crate::engine::cp::propagation::propagation_context::ReadDomains; -use crate::engine::propagation::propagation_context::HasAssignments; +use crate::engine::propagation::contexts::HasAssignments; use crate::engine::propagation::PropagationContext; use crate::engine::propagation::PropagationContextMut; +use crate::engine::propagation::ReadDomains; use crate::engine::EmptyDomain; use crate::predicates::PropositionalConjunction; use crate::propagators::cumulative::time_table::explanations::pointwise; @@ -435,6 +435,7 @@ pub(crate) mod test_propagation_handler { use crate::engine::propagation::PropagatorId; use crate::engine::reason::ReasonStore; use crate::engine::Assignments; + use crate::engine::TrailedAssignments; use crate::predicate; use crate::predicates::Predicate; use crate::predicates::PropositionalConjunction; @@ -446,6 +447,7 @@ pub(crate) mod test_propagation_handler { propagation_handler: CumulativePropagationHandler, reason_store: ReasonStore, assignments: Assignments, + stateful_assignments: TrailedAssignments, } impl TestPropagationHandler { @@ -454,10 +456,12 @@ pub(crate) mod test_propagation_handler { let reason_store = ReasonStore::default(); let assignments = Assignments::default(); + let stateful_assignments = TrailedAssignments::default(); Self { propagation_handler, reason_store, assignments, + stateful_assignments, } } @@ -518,6 +522,7 @@ pub(crate) mod test_propagation_handler { .propagation_handler .propagate_lower_bound_with_explanations( &mut PropagationContextMut::new( + &mut self.stateful_assignments, &mut self.assignments, &mut self.reason_store, &mut SemanticMinimiser::default(), @@ -578,6 +583,7 @@ pub(crate) mod test_propagation_handler { .propagation_handler .propagate_chain_of_lower_bounds_with_explanations( &mut PropagationContextMut::new( + &mut self.stateful_assignments, &mut self.assignments, &mut self.reason_store, &mut SemanticMinimiser::default(), @@ -625,6 +631,7 @@ pub(crate) mod test_propagation_handler { .propagation_handler .propagate_upper_bound_with_explanations( &mut PropagationContextMut::new( + &mut self.stateful_assignments, &mut self.assignments, &mut self.reason_store, &mut SemanticMinimiser::default(), @@ -685,6 +692,7 @@ pub(crate) mod test_propagation_handler { .propagation_handler .propagate_chain_of_upper_bounds_with_explanations( &mut PropagationContextMut::new( + &mut self.stateful_assignments, &mut self.assignments, &mut self.reason_store, &mut SemanticMinimiser::default(), @@ -706,16 +714,15 @@ pub(crate) mod test_propagation_handler { .assignments .get_reason_for_predicate_brute_force(predicate); let mut propagator_store = PropagatorStore::default(); - let reason = self - .reason_store - .get_or_compute( - reason_ref, - ExplanationContext::from(&self.assignments), - &mut propagator_store, - ) - .expect("reason_ref should not be stale"); + let mut reason = vec![]; + let _ = self.reason_store.get_or_compute( + reason_ref, + ExplanationContext::from(&self.assignments), + &mut propagator_store, + &mut reason, + ); - reason.iter().copied().collect() + reason.into() } } } diff --git a/pumpkin-solver/src/propagators/cumulative/time_table/time_table_over_interval.rs b/pumpkin-solver/src/propagators/cumulative/time_table/time_table_over_interval.rs index ac5008ea..f43baa55 100644 --- a/pumpkin-solver/src/propagators/cumulative/time_table/time_table_over_interval.rs +++ b/pumpkin-solver/src/propagators/cumulative/time_table/time_table_over_interval.rs @@ -4,6 +4,7 @@ use super::time_table_util::propagate_based_on_timetable; use super::time_table_util::should_enqueue; use crate::basic_types::PropagationStatusCP; use crate::engine::opaque_domain_event::OpaqueDomainEvent; +use crate::engine::propagation::contexts::StatefulPropagationContext; use crate::engine::propagation::EnqueueDecision; use crate::engine::propagation::LocalId; use crate::engine::propagation::PropagationContext; @@ -109,7 +110,7 @@ impl Propagator for TimeTableOverIntervalPropaga fn notify( &mut self, - context: PropagationContext, + context: StatefulPropagationContext, local_id: LocalId, event: OpaqueDomainEvent, ) -> EnqueueDecision { @@ -123,12 +124,12 @@ impl Propagator for TimeTableOverIntervalPropaga &self.parameters, &self.updatable_structures, &updated_task, - context, + context.as_readonly(), self.is_time_table_empty, ); update_bounds_task( - context, + context.as_readonly(), self.updatable_structures.get_stored_bounds_mut(), &updated_task, ); diff --git a/pumpkin-solver/src/propagators/cumulative/time_table/time_table_per_point.rs b/pumpkin-solver/src/propagators/cumulative/time_table/time_table_per_point.rs index e333c367..0751d680 100644 --- a/pumpkin-solver/src/propagators/cumulative/time_table/time_table_per_point.rs +++ b/pumpkin-solver/src/propagators/cumulative/time_table/time_table_per_point.rs @@ -10,6 +10,7 @@ use super::time_table_util::should_enqueue; use crate::basic_types::PropagationStatusCP; use crate::engine::cp::propagation::ReadDomains; use crate::engine::opaque_domain_event::OpaqueDomainEvent; +use crate::engine::propagation::contexts::StatefulPropagationContext; use crate::engine::propagation::EnqueueDecision; use crate::engine::propagation::LocalId; use crate::engine::propagation::PropagationContext; @@ -102,7 +103,7 @@ impl Propagator for TimeTablePerPointPropagator< fn notify( &mut self, - context: PropagationContext, + context: StatefulPropagationContext, local_id: LocalId, event: OpaqueDomainEvent, ) -> EnqueueDecision { @@ -116,14 +117,14 @@ impl Propagator for TimeTablePerPointPropagator< &self.parameters, &self.updatable_structures, &updated_task, - context, + context.as_readonly(), self.is_time_table_empty, ); // Note that the non-incremental proapgator does not make use of `result.updated` since it // propagates from scratch anyways update_bounds_task( - context, + context.as_readonly(), self.updatable_structures.get_stored_bounds_mut(), &updated_task, ); diff --git a/pumpkin-solver/src/propagators/cumulative/utils/util.rs b/pumpkin-solver/src/propagators/cumulative/utils/util.rs index 3c80599a..c87d06a7 100644 --- a/pumpkin-solver/src/propagators/cumulative/utils/util.rs +++ b/pumpkin-solver/src/propagators/cumulative/utils/util.rs @@ -8,7 +8,7 @@ use enumset::enum_set; use crate::engine::cp::propagation::ReadDomains; use crate::engine::domain_events::DomainEvents; use crate::engine::propagation::local_id::LocalId; -use crate::engine::propagation::propagation_context::PropagationContext; +use crate::engine::propagation::PropagationContext; use crate::engine::propagation::PropagatorInitialisationContext; use crate::engine::variables::IntegerVariable; use crate::engine::IntDomainEvent; diff --git a/pumpkin-solver/src/propagators/element.rs b/pumpkin-solver/src/propagators/element.rs index 62e6a5e3..00c63f5c 100644 --- a/pumpkin-solver/src/propagators/element.rs +++ b/pumpkin-solver/src/propagators/element.rs @@ -93,14 +93,20 @@ where Ok(()) } - fn lazy_explanation(&mut self, code: u64, _: ExplanationContext) -> &[Predicate] { + fn lazy_explanation(&mut self, code: u64, context: ExplanationContext) -> &[Predicate] { let payload = RightHandSideReason::from_bits(code); self.rhs_reason_buffer.clear(); self.rhs_reason_buffer - .extend(self.array.iter().map(|variable| match payload.bound() { - Bound::Lower => predicate![variable >= payload.value()], - Bound::Upper => predicate![variable <= payload.value()], + .extend(self.array.iter().enumerate().map(|(idx, variable)| { + if context.contains(&self.index, idx as i32) { + match payload.bound() { + Bound::Lower => predicate![variable >= payload.value()], + Bound::Upper => predicate![variable <= payload.value()], + } + } else { + predicate![self.index != idx as i32] + } })); &self.rhs_reason_buffer @@ -129,15 +135,17 @@ where &self, context: &mut PropagationContextMut<'_>, ) -> PropagationStatusCP { - let (rhs_lb, rhs_ub) = - self.array - .iter() - .fold((i32::MAX, i32::MIN), |(rhs_lb, rhs_ub), element| { - ( - i32::min(rhs_lb, context.lower_bound(element)), - i32::max(rhs_ub, context.upper_bound(element)), - ) - }); + let (rhs_lb, rhs_ub) = self + .array + .iter() + .enumerate() + .filter(|(idx, _)| context.contains(&self.index, *idx as i32)) + .fold((i32::MAX, i32::MIN), |(rhs_lb, rhs_ub), (_, element)| { + ( + i32::min(rhs_lb, context.lower_bound(element)), + i32::max(rhs_ub, context.upper_bound(element)), + ) + }); context.set_lower_bound( &self.rhs, @@ -353,4 +361,39 @@ mod tests { conjunction!([index == 1] & [rhs <= 9]) ); } + + #[test] + fn index_hole_propagates_bounds_on_rhs() { + let mut solver = TestSolver::default(); + + let x_0 = solver.new_variable(3, 10); + let x_1 = solver.new_variable(0, 15); + let x_2 = solver.new_variable(7, 9); + let x_3 = solver.new_variable(14, 15); + + let index = solver.new_variable(0, 3); + solver.remove(index, 1).expect("Value can be removed"); + + let rhs = solver.new_variable(-10, 30); + + let _ = solver + .new_propagator(ElementPropagator::new( + vec![x_0, x_1, x_2, x_3].into(), + index, + rhs, + )) + .expect("no empty domains"); + + solver.assert_bounds(rhs, 3, 15); + + assert_eq!( + solver.get_reason_int(predicate![rhs >= 3]), + conjunction!([x_0 >= 3] & [x_2 >= 3] & [x_3 >= 3] & [index != 1]) + ); + + assert_eq!( + solver.get_reason_int(predicate![rhs <= 15]), + conjunction!([x_0 <= 15] & [x_2 <= 15] & [x_3 <= 15] & [index != 1]) + ); + } } diff --git a/pumpkin-solver/src/propagators/nogoods/nogood_propagator.rs b/pumpkin-solver/src/propagators/nogoods/nogood_propagator.rs index 4b7e02ab..88b78283 100644 --- a/pumpkin-solver/src/propagators/nogoods/nogood_propagator.rs +++ b/pumpkin-solver/src/propagators/nogoods/nogood_propagator.rs @@ -15,7 +15,8 @@ use crate::engine::conflict_analysis::Mode; use crate::engine::nogoods::Lbd; use crate::engine::opaque_domain_event::OpaqueDomainEvent; use crate::engine::predicates::predicate::Predicate; -use crate::engine::propagation::propagation_context::HasAssignments; +use crate::engine::propagation::contexts::HasAssignments; +use crate::engine::propagation::contexts::StatefulPropagationContext; use crate::engine::propagation::EnqueueDecision; use crate::engine::propagation::ExplanationContext; use crate::engine::propagation::LocalId; @@ -777,7 +778,7 @@ impl Propagator for NogoodPropagator { fn notify( &mut self, - _context: PropagationContext, + _context: StatefulPropagationContext, local_id: LocalId, event: OpaqueDomainEvent, ) -> EnqueueDecision { @@ -1327,6 +1328,13 @@ impl NogoodPropagator { // This is an inefficient implementation for testing purposes let nogood = &self.nogoods[nogood_id]; + if nogood.is_deleted { + // The nogood has already been deleted, meaning that it could be that the call to + // `propagate` would not find any propagations using it due to the watchers being + // deleted + return Ok(()); + } + // First we get the number of falsified predicates let has_falsified_predicate = nogood .predicates @@ -1501,6 +1509,7 @@ mod tests { let nogood = conjunction!([a >= 2] & [b >= 1] & [c >= 10]); { let mut context = PropagationContextMut::new( + &mut solver.stateful_assignments, &mut solver.assignments, &mut solver.reason_store, &mut solver.semantic_minimiser, @@ -1541,6 +1550,7 @@ mod tests { let nogood = conjunction!([a >= 2] & [b >= 1] & [c >= 10]); { let mut context = PropagationContextMut::new( + &mut solver.stateful_assignments, &mut solver.assignments, &mut solver.reason_store, &mut solver.semantic_minimiser, diff --git a/pumpkin-solver/src/propagators/reified_propagator.rs b/pumpkin-solver/src/propagators/reified_propagator.rs index 03e649c5..56aa8e18 100644 --- a/pumpkin-solver/src/propagators/reified_propagator.rs +++ b/pumpkin-solver/src/propagators/reified_propagator.rs @@ -1,6 +1,7 @@ use crate::basic_types::Inconsistency; use crate::basic_types::PropagationStatusCP; use crate::engine::opaque_domain_event::OpaqueDomainEvent; +use crate::engine::propagation::contexts::StatefulPropagationContext; use crate::engine::propagation::EnqueueDecision; use crate::engine::propagation::LocalId; use crate::engine::propagation::PropagationContext; @@ -50,12 +51,16 @@ impl ReifiedPropagator { impl Propagator for ReifiedPropagator { fn notify( &mut self, - context: PropagationContext, + context: StatefulPropagationContext, local_id: LocalId, event: OpaqueDomainEvent, ) -> EnqueueDecision { if local_id < self.reification_literal_id { - let decision = self.propagator.notify(context, local_id, event); + let decision = self.propagator.notify( + StatefulPropagationContext::new(context.stateful_assignments, context.assignments), + local_id, + event, + ); self.filter_enqueue_decision(context, decision) } else { pumpkin_assert_simple!(local_id == self.reification_literal_id); @@ -162,7 +167,10 @@ impl ReifiedPropagator { Prop: Propagator, { if !context.is_literal_fixed(&self.reification_literal) { - if let Some(conjunction) = self.propagator.detect_inconsistency(context.as_readonly()) { + if let Some(conjunction) = self + .propagator + .detect_inconsistency(context.as_stateful_readonly()) + { context.assign_literal(&self.reification_literal, false, conjunction)?; } } @@ -170,7 +178,7 @@ impl ReifiedPropagator { Ok(()) } - fn find_inconsistency(&mut self, context: PropagationContext<'_>) -> bool { + fn find_inconsistency(&mut self, context: StatefulPropagationContext<'_>) -> bool { if self.inconsistency.is_none() { self.inconsistency = self.propagator.detect_inconsistency(context); } @@ -180,7 +188,7 @@ impl ReifiedPropagator { fn filter_enqueue_decision( &mut self, - context: PropagationContext<'_>, + context: StatefulPropagationContext<'_>, decision: EnqueueDecision, ) -> EnqueueDecision { if decision == EnqueueDecision::Skip { @@ -231,7 +239,7 @@ mod tests { .new_propagator(ReifiedPropagator::new( GenericPropagator::new( move |_: PropagationContextMut| Err(t1.clone().into()), - move |_: PropagationContext| Some(t2.clone()), + move |_: StatefulPropagationContext| Some(t2.clone()), |_: &mut PropagatorInitialisationContext| Ok(()), ), reification_literal, @@ -258,7 +266,7 @@ mod tests { ctx.set_lower_bound(&var, 3, conjunction!())?; Ok(()) }, - |_: PropagationContext| None, + |_: StatefulPropagationContext| None, |_: &mut PropagatorInitialisationContext| Ok(()), ), reification_literal, @@ -291,7 +299,7 @@ mod tests { .new_propagator(ReifiedPropagator::new( GenericPropagator::new( move |_: PropagationContextMut| Err(conjunction!([var >= 1]).into()), - |_: PropagationContext| None, + |_: StatefulPropagationContext| None, |_: &mut PropagatorInitialisationContext| Ok(()), ), reification_literal, @@ -324,7 +332,7 @@ mod tests { .new_propagator(ReifiedPropagator::new( GenericPropagator::new( |_: PropagationContextMut| Ok(()), - |_: PropagationContext| None, + |_: StatefulPropagationContext| None, move |_: &mut PropagatorInitialisationContext| Err(conjunction!([var >= 0])), ), reification_literal, @@ -345,7 +353,7 @@ mod tests { .new_propagator(ReifiedPropagator::new( GenericPropagator::new( |_: PropagationContextMut| Ok(()), - move |context: PropagationContext| { + move |context: StatefulPropagationContext| { if context.is_fixed(&var) { Some(conjunction!([var == 5])) } else { @@ -374,7 +382,8 @@ mod tests { for GenericPropagator where Propagation: Fn(PropagationContextMut) -> PropagationStatusCP + 'static, - ConsistencyCheck: Fn(PropagationContext) -> Option + 'static, + ConsistencyCheck: + Fn(StatefulPropagationContext) -> Option + 'static, Init: Fn(&mut PropagatorInitialisationContext) -> Result<(), PropositionalConjunction> + 'static, { @@ -391,7 +400,7 @@ mod tests { fn detect_inconsistency( &self, - context: PropagationContext, + context: StatefulPropagationContext, ) -> Option { (self.consistency_check)(context) } @@ -417,7 +426,7 @@ mod tests { impl GenericPropagator where Propagation: Fn(PropagationContextMut) -> PropagationStatusCP, - ConsistencyCheck: Fn(PropagationContext) -> Option, + ConsistencyCheck: Fn(StatefulPropagationContext) -> Option, Init: Fn(&mut PropagatorInitialisationContext) -> Result<(), PropositionalConjunction>, { pub(crate) fn new( diff --git a/pumpkin-solver/tests/helpers/flatzinc.rs b/pumpkin-solver/tests/helpers/flatzinc.rs index a8eedde5..60712d2c 100644 --- a/pumpkin-solver/tests/helpers/flatzinc.rs +++ b/pumpkin-solver/tests/helpers/flatzinc.rs @@ -14,6 +14,7 @@ pub(crate) enum Value { Int(i32), Bool(bool), IntArray(Vec), + BoolArray(Vec), } impl FromStr for Value { @@ -31,31 +32,49 @@ impl FromStr for Value { } } -struct IntArrayError; -impl Display for IntArrayError { +struct ArrayError; +impl Display for ArrayError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("Could not parse int array") + f.write_str("Could not parse array") } } -fn create_array_from_string(s: &str) -> Result { - let captures = Regex::new(r"array1d\([0-9]+\.\.[0-9]+,\s*\[(\d+(?:,\s\d+)*\d*)\]\)") +fn create_array_from_string(s: &str) -> Result { + let int_captures = Regex::new(r"array1d\([0-9]+\.\.[0-9]+,\s*\[(-?\d+(?:,\s-?\d+)*-?\d*)\]\)") .unwrap() .captures_iter(s) .next(); - if let Some(captures) = captures { - Ok(Value::IntArray( - captures + if let Some(int_captures) = int_captures { + return Ok(Value::IntArray( + int_captures .get(1) .unwrap() .as_str() .split(", ") .map(|integer| integer.parse::().unwrap()) .collect::>(), - )) - } else { - Err(IntArrayError) + )); } + + let bool_captures = Regex::new( + r"array1d\([0-9]+\.\.[0-9]+,\s*\[((true|false)(?:,\s(true|false))*(true|false)*)\]\)", + ) + .unwrap() + .captures_iter(s) + .next(); + if let Some(bool_captures) = bool_captures { + return Ok(Value::BoolArray( + bool_captures + .get(1) + .unwrap() + .as_str() + .split(", ") + .map(|bool| bool.parse::().unwrap()) + .collect::>(), + )); + } + + Err(ArrayError) } #[derive(Debug)] diff --git a/pumpkin-solver/tests/mzn_optimization/unfixed_objective.expected b/pumpkin-solver/tests/mzn_optimization/unfixed_objective.expected index a6be86f6..dcf477d1 100644 --- a/pumpkin-solver/tests/mzn_optimization/unfixed_objective.expected +++ b/pumpkin-solver/tests/mzn_optimization/unfixed_objective.expected @@ -1,6 +1,3 @@ -objective = 2; -other = 1; ----------- objective = 3; other = 1; ----------