Skip to content

Commit 5e6f23b

Browse files
committed
[CP-SAT] new test; chane shared tree parameters; fix shared_tree crash
1 parent 0c9fd19 commit 5e6f23b

File tree

7 files changed

+33
-8
lines changed

7 files changed

+33
-8
lines changed

ortools/sat/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1873,6 +1873,7 @@ cc_library(
18731873
":sat_base",
18741874
":sat_parameters_cc_proto",
18751875
"//ortools/base",
1876+
"//ortools/base:mathlimits",
18761877
"//ortools/base:mathutil",
18771878
"//ortools/base:stl_util",
18781879
"//ortools/util:random_engine",

ortools/sat/python/cp_model_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,12 @@ def testInterval(self):
10141014
self.assertEqual(size_expr, 2)
10151015
self.assertEqual(str(end_expr), "(x + 2)")
10161016

1017+
def testAbsentInterval(self):
1018+
print("testInterval")
1019+
model = cp_model.CpModel()
1020+
i = model.new_optional_interval_var(1, 0, 1, False, "")
1021+
self.assertEqual(0, i.index)
1022+
10171023
def testOptionalInterval(self):
10181024
print("testOptionalInterval")
10191025
model = cp_model.CpModel()

ortools/sat/samples/code_samples.bzl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
load("@pip_deps//:requirements.bzl", "requirement")
1717
load("@rules_python//python:defs.bzl", "py_binary", "py_test")
1818

19-
2019
def code_sample_cc(name):
2120
native.cc_binary(
2221
name = name + "_cc",

ortools/sat/sat_parameters.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,7 @@ message SatParameters {
10941094
// Minimum number of restarts before a worker will replace a subtree
10951095
// that looks "bad" based on the average LBD of learned clauses.
10961096
optional int32 shared_tree_worker_min_restarts_per_subtree = 282
1097-
[default = 32];
1097+
[default = 1];
10981098

10991099
// If true, workers share more of the information from their local trail.
11001100
// Specifically, literals implied by the shared tree decisions and

ortools/sat/util.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "absl/strings/string_view.h"
3737
#include "absl/types/span.h"
3838
#include "ortools/base/logging.h"
39+
#include "ortools/base/mathlimits.h"
3940
#include "ortools/sat/model.h"
4041
#include "ortools/sat/sat_base.h"
4142
#include "ortools/sat/sat_parameters.pb.h"

ortools/sat/work_assignment.cc

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -626,15 +626,15 @@ bool SharedTreeWorker::AddImplications() {
626626
rev_num_processed_implications_.resize(level + 1, 0);
627627
auto& num_processed_implications = rev_num_processed_implications_[level];
628628
reversible_int_repository_->SaveState(&num_processed_implications);
629-
absl::Span<const ProtoLiteral> implied_literals =
630-
assigned_tree_.Implications(level).subspan(num_processed_implications);
629+
absl::Span<const Literal> implied_literals =
630+
absl::MakeConstSpan(assigned_tree_implications_[level - 1])
631+
.subspan(num_processed_implications);
631632
bool added_clause = false;
632-
for (const ProtoLiteral& impl : implied_literals) {
633-
Literal lit(DecodeDecision(impl));
633+
for (Literal impl : implied_literals) {
634634
++num_processed_implications;
635-
if (sat_solver_->Assignment().LiteralIsTrue(lit)) continue;
635+
if (sat_solver_->Assignment().LiteralIsTrue(impl)) continue;
636636
added_clause = true;
637-
if (!AddDecisionImplication(lit, level)) return true;
637+
if (!AddDecisionImplication(impl, level)) return true;
638638
}
639639
if (objective_ != nullptr &&
640640
objective_->objective_var != kNoIntegerVariable) {
@@ -687,10 +687,19 @@ bool SharedTreeWorker::SyncWithLocalTrail() {
687687
<< " assigned=" << assigned_tree_.MaxLevel();
688688
manager_->CloseTree(assigned_tree_, level + 1);
689689
assigned_tree_literals_.clear();
690+
assigned_tree_implications_.clear();
690691
sat_solver_->Backtrack(0);
691692
} else {
692693
// The next level is implied by the current one.
693694
assigned_tree_.SetLevelImplied(level + 1);
695+
if (level > 0) {
696+
assigned_tree_implications_[level - 1].insert(
697+
assigned_tree_implications_[level - 1].end(),
698+
assigned_tree_implications_[level].begin(),
699+
assigned_tree_implications_[level].end());
700+
}
701+
assigned_tree_implications_.erase(assigned_tree_implications_.begin() +
702+
level);
694703
assigned_tree_literals_.erase(assigned_tree_literals_.begin() + level);
695704
}
696705
}
@@ -760,6 +769,7 @@ void SharedTreeWorker::MaybeProposeSplit() {
760769
manager_->ProposeSplit(assigned_tree_, *encoded);
761770
if (assigned_tree_.MaxLevel() > assigned_tree_literals_.size()) {
762771
assigned_tree_literals_.push_back(split_decision);
772+
assigned_tree_implications_.push_back({});
763773
}
764774
CHECK_EQ(assigned_tree_literals_.size(), assigned_tree_.MaxLevel());
765775
}
@@ -808,9 +818,15 @@ bool SharedTreeWorker::SyncWithSharedTree() {
808818
VLOG(2) << "Assigned level: " << assigned_tree_.MaxLevel() << " "
809819
<< parameters_->name();
810820
assigned_tree_literals_.clear();
821+
assigned_tree_implications_.clear();
811822
for (int i = 1; i <= assigned_tree_.MaxLevel(); ++i) {
812823
assigned_tree_literals_.push_back(
813824
DecodeDecision(assigned_tree_.Decision(i)));
825+
std::vector<Literal> implications;
826+
for (const ProtoLiteral& impl : assigned_tree_.Implications(i)) {
827+
implications.push_back(DecodeDecision(impl));
828+
}
829+
assigned_tree_implications_.push_back(std::move(implications));
814830
}
815831
return true;
816832
}

ortools/sat/work_assignment.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ class ProtoLiteral {
6767
return H::combine(std::move(h), literal.proto_var_, literal.lb_);
6868
}
6969

70+
// Note you should only decode integer literals at the root level.
7071
Literal Decode(CpModelMapping*, IntegerEncoder*) const;
7172
static std::optional<ProtoLiteral> Encode(Literal, CpModelMapping*,
7273
IntegerEncoder*);
@@ -324,6 +325,7 @@ class SharedTreeWorker {
324325

325326
ProtoTrail assigned_tree_;
326327
std::vector<Literal> assigned_tree_literals_;
328+
std::vector<std::vector<Literal>> assigned_tree_implications_;
327329
// How many restarts had happened when the current tree was assigned?
328330
int64_t tree_assignment_restart_ = -1;
329331

0 commit comments

Comments
 (0)