diff --git a/clasp/clause.h b/clasp/clause.h index b458d7a..3e3a4d6 100644 --- a/clasp/clause.h +++ b/clasp/clause.h @@ -397,7 +397,7 @@ class Clause final : public ClauseHead { StrengthenResult strengthen(Solver& s, Literal p, bool allowToShort) override; void detach(Solver&) override; [[nodiscard]] uint32_t size() const override; - void toLits(LitVec& out) const override; + LitView toLits(TempBuffer& tmp) const override; [[nodiscard]] bool contracted() const; [[nodiscard]] bool isSmall() const; [[nodiscard]] bool strengthened() const; @@ -545,7 +545,7 @@ class SharedLitsClause final : public ClauseHead { void destroy(Solver* s, bool detach) override; uint32_t isOpen(const Solver& s, const TypeSet& t, LitVec& freeLits) override; [[nodiscard]] uint32_t size() const override; - void toLits(LitVec& out) const override; + LitView toLits(TempBuffer& tmp) const override; private: SharedLitsClause(Solver& s, SharedLiterals* x, const Literal* lits, const InfoType&, bool addRef); diff --git a/clasp/cli/clasp_app.h b/clasp/cli/clasp_app.h index 6b6c9d5..9ac3a19 100644 --- a/clasp/cli/clasp_app.h +++ b/clasp/cli/clasp_app.h @@ -65,8 +65,7 @@ class WriteCnf { [[nodiscard]] bool unary(Literal, Literal) const; [[nodiscard]] bool binary(Literal, Literal, Literal) const; - FILE* str_; - LitVec lits_; + FILE* str_; }; class LemmaLogger { public: diff --git a/clasp/cli/clasp_cli_options.inl b/clasp/cli/clasp_cli_options.inl index 680161d..ffbea33 100644 --- a/clasp/cli/clasp_cli_options.inl +++ b/clasp/cli/clasp_cli_options.inl @@ -104,9 +104,9 @@ OPTION(share, "!,@1", ARG_EXT(defaultsTo("auto")->state(Value::value_defaulted), " %A: {auto|problem|learnt|all}", FUN(arg) {ContextParams::ShareMode x; return arg>>x && SET(SELF.shareMode, (uint32_t)x);}, GET((ContextParams::ShareMode)SELF.shareMode)) OPTION(learn_explicit, ",@2" , ARG(flag()), "Do not use Short Implication Graph for learning", STORE_FLAG(SELF.shortMode), GET(SELF.shortMode)) OPTION(short_simp_mode, ",@2" , ARG_EXT(arg("")->defaultsTo("no")->state(Value::value_defaulted), DEFINE_ENUM_MAPPING(ContextParams::ShortSimpMode, \ - MAP("no" , ContextParams::simp_no) , MAP("learnt", ContextParams::simp_learnt))),\ + MAP("no" , ContextParams::simp_no), MAP("learnt", ContextParams::simp_learnt), MAP("all", ContextParams::simp_all))),\ "Remove duplicate short constraints [%D]\n"\ - " %A: {no|learnt}", FUN(arg) {ContextParams::ShortSimpMode x; return arg>>x && SET(SELF.shortSimp, (uint32_t)x);}\ + " %A: {no|learnt|all}", FUN(arg) {ContextParams::ShortSimpMode x; return arg>>x && SET(SELF.shortSimp, (uint32_t)x);}\ , GET((ContextParams::ShortSimpMode)SELF.shortSimp))\ OPTION(sat_prepro , "!,@1", ARG(arg("")->implicit("2")), \ "Run SatELite-like preprocessing (Implicit: %I)\n" \ diff --git a/clasp/shared_context.h b/clasp/shared_context.h index a6c8e20..791d9b8 100644 --- a/clasp/shared_context.h +++ b/clasp/shared_context.h @@ -328,6 +328,11 @@ class ShortImplicationsGraph { * \return true iff a new implication was added. */ bool add(LitView lits, bool learnt); + //! Removes the given constraint from the implication graph. + /*! + * \pre The object is currently not shared. + */ + void remove(LitView lits, bool learnt); //! Removes p and its implications. /*! @@ -978,6 +983,7 @@ class SharedContext { [[nodiscard]] MinPtr minimizeNoCreate() const; //@} private: + bool preprocessShort(); bool unfreezeStep(); Literal addStepLit(); using VarVec = PodVector_t; diff --git a/clasp/solver_strategies.h b/clasp/solver_strategies.h index 41e91b4..a752141 100644 --- a/clasp/solver_strategies.h +++ b/clasp/solver_strategies.h @@ -585,6 +585,7 @@ struct ContextParams { enum ShortSimpMode { simp_no = 0, /*!< No additional simplifications. */ simp_learnt = 1, /*!< Drop duplicate learnt short clauses. */ + simp_all = 2, /*!< Drop all duplicate short clauses. */ }; //! How to handle physical sharing of (explicit) constraints. enum ShareMode { diff --git a/clasp/solver_types.h b/clasp/solver_types.h index c310ce0..5bb0bb3 100644 --- a/clasp/solver_types.h +++ b/clasp/solver_types.h @@ -348,6 +348,8 @@ class ClauseHead : public Constraint { static constexpr auto head_lits = 3u; static constexpr auto max_short_len = 5u; + using TempBuffer = std::array; + explicit ClauseHead(const InfoType& init); // base interface //! Propagates the head and calls updateWatch() if necessary. @@ -380,8 +382,8 @@ class ClauseHead : public Constraint { virtual void detach(Solver& s); //! Returns the size of this clause. [[nodiscard]] virtual uint32_t size() const = 0; - //! Returns the literals of this clause in out. - virtual void toLits(LitVec& out) const = 0; + //! Returns the literals of this clause (using the given buffer if needed). + virtual LitView toLits(TempBuffer& tmp) const = 0; //! Returns true if this clause is a valid "reverse antecedent" for p. virtual bool isReverseReason(const Solver& s, Literal p, uint32_t maxL, uint32_t maxN) = 0; struct StrengthenResult { diff --git a/src/clasp_app.cpp b/src/clasp_app.cpp index d97cfc4..1d497ee 100644 --- a/src/clasp_app.cpp +++ b/src/clasp_app.cpp @@ -797,9 +797,8 @@ WriteCnf::WriteCnf(const std::string& outFile) : str_(fopen(outFile.c_str(), "w" WriteCnf::~WriteCnf() { close(); } void WriteCnf::writeHeader(uint32_t numVars, uint32_t numCons) { fprintf(str_, "p cnf %u %u\n", numVars, numCons); } void WriteCnf::write(const ClauseHead* h) { - lits_.clear(); - h->toLits(lits_); - for (auto lit : lits_) { fprintf(str_, "%d ", toInt(lit)); } + ClauseHead::TempBuffer buffer; + for (auto lit : h->toLits(buffer)) { fprintf(str_, "%d ", toInt(lit)); } fprintf(str_, "%d\n", 0); } void WriteCnf::write(Var_t maxVar, const ShortImplicationsGraph& g) { diff --git a/src/clause.cpp b/src/clause.cpp index 3259375..4aa1c41 100644 --- a/src/clause.cpp +++ b/src/clause.cpp @@ -740,13 +740,22 @@ void Clause::undoLevel(Solver& s) { local_.setSize(t); } -void Clause::toLits(LitVec& out) const { - out.insert(out.end(), head_, (head_ + head_lits) - isSentinel(head_[2])); - auto t = const_cast(*this).tail(); - if (contracted()) { - while (not t.e++->flagged()) { ; } +LitView Clause::toLits(TempBuffer& tmp) const { + if (not isSmall()) { + const auto* eoc = const_cast(*this).end(); + if (contracted()) { + while (not eoc++->flagged()) { ; } + } + return {head_, eoc}; + } + auto x = std::copy(head_, (head_ + head_lits) - isSentinel(head_[2]), tmp.data()); + if (const auto* eoc = const_cast(*this).small(); *eoc != lit_false) { + *x++ = *eoc++; + if (*eoc != lit_false) { + *x++ = *eoc; + } } - out.insert(out.end(), t.begin(), t.end()); + return {tmp.data(), x}; } ClauseHead::StrengthenResult Clause::strengthen(Solver& s, Literal p, bool toShort) { @@ -989,7 +998,7 @@ uint32_t SharedLitsClause::isOpen(const Solver& s, const TypeSet& x, LitVec& fre return +ClauseHead::type(); } -void SharedLitsClause::toLits(LitVec& out) const { out.insert(out.end(), shared_->begin(), shared_->end()); } +LitView SharedLitsClause::toLits(TempBuffer&) const { return {shared_->begin(), shared_->end()}; } ClauseHead::StrengthenResult SharedLitsClause::strengthen(Solver&, Literal, bool) { return {}; } diff --git a/src/clingo.cpp b/src/clingo.cpp index 36039c6..2c680ad 100644 --- a/src/clingo.cpp +++ b/src/clingo.cpp @@ -400,12 +400,11 @@ void ClingoPropagator::reason(Solver&, Literal p, LitVec& r) { bool ClingoPropagator::simplify(Solver& s, bool) { if (not s.validVar(aux_.var())) { - LitVec cc; + ClauseHead::TempBuffer buffer; aux_ = lit_true; erase_if(db_, [&](Constraint* con) { if (ClauseHead* clause = con->clause(); clause && clause->aux()) { - cc.clear(); - clause->toLits(cc); + auto cc = clause->toLits(buffer); if (Literal x = *std::ranges::max_element(cc); not s.validVar(x.var())) { clause->destroy(&s, true); return true; diff --git a/src/shared_context.cpp b/src/shared_context.cpp index a4421dc..8136d79 100644 --- a/src/shared_context.cpp +++ b/src/shared_context.cpp @@ -206,7 +206,7 @@ bool ShortImplicationsGraph::add(LitView lits, bool learnt) { Literal p = lits[0], q = lits[1], r = (tern ? lits[2] : lit_false); p.unflag(), q.unflag(), r.unflag(); if (not shared_) { - bool simp = learnt && simp_ == ContextParams::simp_learnt; + bool simp = simp_ == ContextParams::simp_all || (learnt && simp_ == ContextParams::simp_learnt); if (simp && contains(getList(~p).left_view(), q)) { return true; } @@ -241,6 +241,31 @@ bool ShortImplicationsGraph::add(LitView lits, bool learnt) { #endif return false; } +void ShortImplicationsGraph::remove(LitView lits, bool learnt) { + assert(not shared_); + bool tern = lits.size() == 3u; + auto& stats = (tern ? tern_ : bin_)[learnt]; + unsigned i = 0, rem = 0; + for (auto x : lits) { + auto& w = getList(~x); + auto sz = w.left_size() + w.right_size(); + if (not tern) { + w.erase_left_unordered(std::find(w.left_begin(), w.left_end(), lits[1 - i])); + } + else { + Tern t = {lits[(i + 1) % 3], lits[(i + 2) % 3]}; + w.erase_right_unordered(std::find_if(w.right_begin(), w.right_end(), [&t](const Tern& e) { + return contains(t, e[0]) && contains(t, e[1]); + })); + } + rem += sz != (w.left_size() + w.right_size()); + w.try_shrink(); + ++i; + } + if (rem) { + --stats; + } +} void ShortImplicationsGraph::removeBin(Literal other, Literal sat) { --bin_[other.flagged()]; @@ -1022,8 +1047,11 @@ bool SharedContext::endInit(bool attachAll) { initStats(*master()); heuristic.simplify(); SatPrePtr temp = std::move(satPrepro); - bool ok = not master()->hasConflict() && master()->preparePost() && (not temp || temp->preprocess(*this)) && - master()->endInit(); + bool ok = not master()->hasConflict() && master()->preparePost() && (not temp || temp->preprocess(*this)); + if (ok && not temp && btig_.simpMode() == ContextParams::simp_all) { + ok = preprocessShort(); + } + ok = ok && master()->endInit(); satPrepro = std::move(temp); master()->dbIdx_ = size32(master()->constraints_); lastTopLevel_ = master()->assign_.front; @@ -1208,6 +1236,90 @@ uint32_t SharedContext::problemComplexity() const { } return numConstraints(); } +bool SharedContext::preprocessShort() { + auto& s = *master(); + auto& assign = s.assign_; + LitVec lits; + LitVec tern; + for (Var_t v = 1; v < assign.numVars() && not s.hasConflict(); ++v) { + if (assign.value(v) != value_free) { + continue; + } + for (Literal lit : {posLit(v), negLit(v)}) { + if (marked(lit)) { + continue; + } + tern.clear(); + bool ok = true; + auto qFront = assign.assigned(); + assign.assign(lit, 0, lit_true); + do { + ok = btig_.forEach(assign.trail[qFront++], [&](Literal p, Literal q, Literal r = lit_false) { + if (r == lit_false) { + return assign.assign(q, 0, p); + } + auto vq = assign.value(q.var()); + auto vr = assign.value(r.var()); + auto ante = Antecedent(p); + if (vr == trueValue(r) || vq == trueValue(q)) { + if (assign.reason(r.var()).asUint() == ante.asUint() || + assign.reason(q.var()).asUint() == ante.asUint()) { + tern.push_back(~p); + tern.push_back(q); + tern.push_back(r); + } + return true; + } + if (vr == vq) { + return vr == value_free; + } + if (vq) { + if (assign.reason(q.var()).asUint() == ante.asUint()) { + tern.push_back(q.flag()); + tern.push_back(~p); + tern.push_back(r); + } + return assign.assign(r, 0, Antecedent(p, ~q)); + } + if (assign.reason(r.var()).asUint() == ante.asUint()) { + tern.push_back(r.flag()); + tern.push_back(~p); + tern.push_back(q); + } + return assign.assign(q, 0, Antecedent(p, ~r)); + }); + } while (ok && qFront < assign.assigned()); + if (ok) { + for (auto i = 0u; i < size32(tern); i += 3) { + bool sat = not tern[i].flagged(); + bool learnt = tern[i + 1].flagged() || tern[i + 2].flagged(); + tern[i].unflag(); + btig_.remove(std::span(tern.data() + i, 3), learnt); + if (not sat) { + btig_.add(std::span(tern.data() + i + 1, 2), learnt); + } + } + } + while (assign.trail.back() != lit) { + if (not marked(assign.trail.back())) { + mark(assign.trail.back()); + lits.push_back(assign.trail.back()); + } + assign.undoLast(); + } + assign.undoLast(); + if (not ok) { + master()->force(~lit) && master()->propagate(); + break; + } + } + } + while (not lits.empty()) { + unmark(lits.back().var()); + lits.pop_back(); + } + return master()->simplify(); +} ///////////////////////////////////////////////////////////////////////////////////////// // Distributor ///////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/solver.cpp b/src/solver.cpp index ce330a8..2c019a8 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -517,11 +517,10 @@ Literal Solver::popVars(uint32_t num, bool popLearnt, ConstraintDB* popAux) { if (popLearnt) { shared_->report("removing aux constraints", this); ConstraintDB::size_type os = 0; - LitVec cc; + ClauseHead::TempBuffer buffer; for (Constraint* con : learnts_) { if (ClauseHead* clause = con->clause(); clause && clause->aux()) { - cc.clear(); - clause->toLits(cc); + auto cc = clause->toLits(buffer); if (std::ranges::any_of(cc, [&pop](Literal x) { return x >= pop; })) { con->destroy(this, true); continue; @@ -1500,30 +1499,33 @@ uint32_t Solver::simplifyConflictClause(LitVec& cc, ConstraintInfo& info, Clause } // 3. check if final clause subsumes rhs if (rhs) { + ClauseHead::TempBuffer buffer; conflict_.clear(); - rhs->toLits(conflict_); - auto open = size32(cc); markSeen(cc[0].var()); - for (auto it = conflict_.begin(), end = conflict_.end(); it != end && open; ++it) { + auto rhsLits = rhs->toLits(buffer); + auto marked = std::ssize(cc); + for (auto maxMissing = std::ssize(rhsLits) - marked; auto lit : rhsLits) { // NOTE: at this point the DB might not be fully simplified, // e.g. because of mt or lookahead, hence we must explicitly // check for literals assigned on DL 0 - open -= level(it->var()) > 0 && seen(it->var()); + if (not seen(lit.var()) || level(lit.var()) == 0) { + if (--maxMissing < 0) { + break; + } + conflict_.push_back(lit); // potentially redundant literal + } + else if (--marked == 0 && otfsRemove(rhs, &cc) == nullptr) { + rhs = nullptr; // rhs is subsumed by cc and was removed + break; + } } - rhs = open ? nullptr : otfsRemove(rhs, &cc); - if (rhs) { // rhs is subsumed by cc but could not be removed. + if (rhs && marked <= 0) { // rhs is subsumed by cc but could not be removed. // TODO: we could reuse rhs instead of learning cc // but this would complicate the calling code. - if (cc_.size() < conflict_.size()) { - bool litRemoved = true; - // For now, we only try to strengthen rhs. - for (auto it = conflict_.begin(), end = conflict_.end(); it != end && litRemoved; ++it) { - if (not seen(it->var()) || level(it->var()) == 0) { - litRemoved = rhs->strengthen(*this, *it, false).litRemoved; - } - } - if (not litRemoved) { - rhs = nullptr; + // For now, we only try to strengthen rhs. + for (auto lit : conflict_) { + if (not rhs->strengthen(*this, lit, false).litRemoved) { + break; } } } diff --git a/tests/clause_creator_test.cpp b/tests/clause_creator_test.cpp index 3e081b3..22a8f17 100644 --- a/tests/clause_creator_test.cpp +++ b/tests/clause_creator_test.cpp @@ -424,12 +424,12 @@ TEST_CASE("ClauseCreator integrate", "[constraint][core]") { REQUIRE(temp[0] == d); REQUIRE(temp[1] == a); - SharedLiterals* p(SharedLiterals::newShareable(cl, ConstraintType::other)); - ClauseCreator::Result r = ClauseCreator::integrate(s, p, ClauseCreator::clause_no_add); - temp.clear(); - r.local->clause()->toLits(temp); - REQUIRE(temp[0] == d); - REQUIRE(temp[1] == a); + SharedLiterals* p(SharedLiterals::newShareable(cl, ConstraintType::other)); + ClauseCreator::Result r = ClauseCreator::integrate(s, p, ClauseCreator::clause_no_add); + ClauseHead::TempBuffer buffer; + auto lits = r.local->clause()->toLits(buffer); + REQUIRE(lits[0] == d); + REQUIRE(lits[1] == a); r.local->destroy(&s, true); } SECTION("test integrate unsat") { @@ -567,10 +567,10 @@ TEST_CASE("ClauseCreator integrate", "[constraint][core]") { ClauseCreator::Result r = ClauseCreator::integrate(s, p, ClauseCreator::clause_no_add); REQUIRE(r.ok()); REQUIRE(r.local != 0); - cl.clear(); - r.local->toLits(cl); - REQUIRE(cl.size() == 5); - REQUIRE_FALSE(contains(cl, d)); + ClauseHead::TempBuffer buffer; + auto lits = r.local->toLits(buffer); + REQUIRE(lits.size() == 5); + REQUIRE_FALSE(contains(lits, d)); } SECTION("test facts are removed from learnt") { ctx.enableStats(1); diff --git a/tests/clause_test.cpp b/tests/clause_test.cpp index f789e89..f6a4966 100644 --- a/tests/clause_test.cpp +++ b/tests/clause_test.cpp @@ -29,8 +29,11 @@ #include #include +namespace Clasp { +static bool operator==(LitView lhs, LitView rhs) { return std::ranges::equal(lhs, rhs); } +} // namespace Clasp namespace Clasp::Test { -static int countWatches(const Solver& s, ClauseHead* c, const LitVec& lits) { +static int countWatches(const Solver& s, ClauseHead* c, LitView lits) { int w = 0; for (auto lit : lits) { w += s.hasWatch(~lit, c); } return w; @@ -331,8 +334,8 @@ TEST_CASE("Clause", "[core][constraint]") { uint32_t si = cl->size(); REQUIRE(si == 5); cl->strengthen(solver, posLit(4)); - LitVec clause2; - cl->toLits(clause2); + ClauseHead::TempBuffer buffer; + auto clause2 = cl->toLits(buffer); REQUIRE(clause2.size() == 5); for (auto lit : clause) { REQUIRE((contains(clause2, lit) || lit == posLit(4))); } } @@ -506,10 +509,10 @@ TEST_CASE("Clause", "[core][constraint]") { SECTION("testClone") { Solver& solver2 = ctx.pushSolver(); ctx.endInit(true); - cl = createClause(solver, makeLits(clLits, 3, 3)); - auto* clone = cl->cloneAttach(solver2)->clause(); - LitVec lits; - clone->toLits(lits); + cl = createClause(solver, makeLits(clLits, 3, 3)); + auto* clone = cl->cloneAttach(solver2)->clause(); + ClauseHead::TempBuffer buffer; + auto lits = clone->toLits(buffer); REQUIRE(lits == clLits); REQUIRE(countWatches(solver2, clone, lits) == 2); clone->destroy(&solver2, true); @@ -519,8 +522,7 @@ TEST_CASE("Clause", "[core][constraint]") { solver.propagate(); cl->simplify(solver); clone = cl->cloneAttach(solver2)->clause(); - lits.clear(); - clone->toLits(lits); + lits = clone->toLits(buffer); REQUIRE(lits.size() == 4); REQUIRE(countWatches(solver2, clone, lits) == 2); clone->destroy(&solver2, true); @@ -1020,9 +1022,9 @@ TEST_CASE("Shared clause", "[core][constraint]") { ClauseHead* cl = createShared(solver, makeLits(clLits, 3, 2), ClauseInfo()); Solver& solver2 = ctx.pushSolver(); ctx.endInit(true); - auto* clone = cl->cloneAttach(solver2)->clause(); - LitVec lits; - clone->toLits(lits); + auto* clone = cl->cloneAttach(solver2)->clause(); + ClauseHead::TempBuffer buffer; + auto lits = clone->toLits(buffer); REQUIRE(lits == clLits); REQUIRE(countWatches(solver2, clone, clLits) == 2); cl->destroy(ctx.master(), true); diff --git a/tests/cli_test.cpp b/tests/cli_test.cpp index f59be13..73d3e9a 100644 --- a/tests/cli_test.cpp +++ b/tests/cli_test.cpp @@ -457,7 +457,8 @@ TEST_CASE("Cli options", "[cli]") { REQUIRE(config.getValue(key, val) == 2); REQUIRE(val == "no"); - for (auto [x, y] : {std::pair{ContextParams::simp_learnt, "learnt"}, std::pair{ContextParams::simp_no, "no"}}) { + for (auto [x, y] : {std::pair{ContextParams::simp_learnt, "learnt"}, std::pair{ContextParams::simp_all, "all"}, + std::pair{ContextParams::simp_no, "no"}}) { CAPTURE(y); REQUIRE(1 == config.setValue(key, y)); REQUIRE(config.shortSimp == x);