Skip to content

Commit 00adfb5

Browse files
committed
[CP-SAT] more bugfixes
1 parent 12cb30c commit 00adfb5

17 files changed

+149
-89
lines changed

ortools/sat/all_different.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ bool AllDifferentConstraint::Propagate() {
275275
successor_.AppendToLastVector(value);
276276

277277
// Seed with previous matching.
278-
if (prev_matching_[x] == value) {
278+
if (prev_matching_[x] == value && value_to_variable_[value] == -1) {
279279
variable_to_value_[x] = prev_matching_[x];
280280
value_to_variable_[prev_matching_[x]] = x;
281281
}
@@ -373,6 +373,7 @@ bool AllDifferentConstraint::Propagate() {
373373
if (assignment.LiteralIsFalse(x_lit)) continue;
374374

375375
const int value_node = value + num_variables_;
376+
DCHECK_LT(value_node, component_number_.size());
376377
if (variable_to_value_[x] != value &&
377378
component_number_[x] != component_number_[value_node]) {
378379
// We can deduce that x != value. To explain, force x == value,
@@ -383,6 +384,8 @@ bool AllDifferentConstraint::Propagate() {
383384
variable_visited_.assign(num_variables_, false);
384385
// Undo x -> old_value and old_variable -> value.
385386
const int old_variable = value_to_variable_[value];
387+
DCHECK_GE(old_variable, 0);
388+
DCHECK_LT(old_variable, num_variables_);
386389
variable_to_value_[old_variable] = -1;
387390
const int old_value = variable_to_value_[x];
388391
value_to_variable_[old_value] = -1;

ortools/sat/cp_constraints.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,20 +135,22 @@ bool GreaterThanAtLeastOneOfPropagator::Propagate() {
135135

136136
int first_non_false = 0;
137137
const int size = exprs_.size();
138+
Literal* const selectors = selectors_.data();
139+
AffineExpression* const exprs = exprs_.data();
138140
for (int i = 0; i < size; ++i) {
139-
if (assignment.LiteralIsTrue(selectors_[i])) return true;
141+
if (assignment.LiteralIsTrue(selectors[i])) return true;
140142

141143
// The permutation is needed to have proper lazy reason.
142-
if (assignment.LiteralIsFalse(selectors_[i])) {
144+
if (assignment.LiteralIsFalse(selectors[i])) {
143145
if (i != first_non_false) {
144-
std::swap(selectors_[i], selectors_[first_non_false]);
145-
std::swap(exprs_[i], exprs_[first_non_false]);
146+
std::swap(selectors[i], selectors[first_non_false]);
147+
std::swap(exprs[i], exprs[first_non_false]);
146148
}
147149
++first_non_false;
148150
continue;
149151
}
150152

151-
const IntegerValue min = integer_trail_->LowerBound(exprs_[i]);
153+
const IntegerValue min = integer_trail_->LowerBound(exprs[i]);
152154
if (min < target_min) {
153155
target_min = min;
154156

ortools/sat/cp_model_loader.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,14 +1286,14 @@ void LoadLinearConstraint(const ConstraintProto& ct, Model* m) {
12861286
rhs_min = std::max(rhs_min, min_sum.value());
12871287
rhs_max = std::min(rhs_max, max_sum.value());
12881288

1289-
auto* detector = m->GetOrCreate<GreaterThanAtLeastOneOfDetector>();
1289+
auto* repository = m->GetOrCreate<BinaryRelationRepository>();
12901290
const Literal lit = mapping->Literal(ct.enforcement_literal(0));
12911291
const Domain domain = ReadDomainFromProto(ct.linear());
12921292
if (vars.size() == 1) {
1293-
detector->Add(lit, {vars[0], coeffs[0]}, {}, rhs_min, rhs_max);
1293+
repository->Add(lit, {vars[0], coeffs[0]}, {}, rhs_min, rhs_max);
12941294
} else if (vars.size() == 2) {
1295-
detector->Add(lit, {vars[0], coeffs[0]}, {vars[1], coeffs[1]}, rhs_min,
1296-
rhs_max);
1295+
repository->Add(lit, {vars[0], coeffs[0]}, {vars[1], coeffs[1]}, rhs_min,
1296+
rhs_max);
12971297
}
12981298
}
12991299

ortools/sat/cp_model_solver_helpers.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,6 +1253,7 @@ void LoadCpModel(const CpModelProto& model_proto, Model* model) {
12531253
// Note that we do that before we finish loading the problem (objective and
12541254
// LP relaxation), because propagation will be faster at this point and it
12551255
// should be enough for the purpose of this auto-detection.
1256+
model->GetOrCreate<BinaryRelationRepository>()->Build();
12561257
if (parameters.auto_detect_greater_than_at_least_one_of()) {
12571258
model->GetOrCreate<GreaterThanAtLeastOneOfDetector>()
12581259
->AddGreaterThanAtLeastOneOfConstraints(model);

ortools/sat/cuts.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,8 +1666,7 @@ BoolRLTCutHelper::~BoolRLTCutHelper() {
16661666
shared_stats_->AddStats(stats);
16671667
}
16681668

1669-
void BoolRLTCutHelper::Initialize(
1670-
const absl::flat_hash_map<IntegerVariable, glop::ColIndex>& lp_vars) {
1669+
void BoolRLTCutHelper::Initialize(absl::Span<const IntegerVariable> lp_vars) {
16711670
product_detector_->InitializeBooleanRLTCuts(lp_vars, *lp_values_);
16721671
enabled_ = !product_detector_->BoolRLTCandidates().empty();
16731672
}

ortools/sat/cuts.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -582,8 +582,7 @@ class BoolRLTCutHelper {
582582

583583
// Precompute data according to the current lp relaxation.
584584
// This also restrict any Boolean to be currently appearing in the LP.
585-
void Initialize(
586-
const absl::flat_hash_map<IntegerVariable, glop::ColIndex>& lp_vars);
585+
void Initialize(absl::Span<const IntegerVariable> lp_vars);
587586

588587
// Tries RLT separation of the input constraint. Returns true on success.
589588
bool TrySimpleSeparation(const CutData& input_ct);

ortools/sat/implied_bounds.cc

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,7 @@ void ProductDetector::UpdateRLTMaps(
797797

798798
// TODO(user): limit work if too many ternary.
799799
void ProductDetector::InitializeBooleanRLTCuts(
800-
const absl::flat_hash_map<IntegerVariable, glop::ColIndex>& lp_vars,
800+
absl::Span<const IntegerVariable> lp_vars,
801801
const util_intops::StrongVector<IntegerVariable, double>& lp_values) {
802802
// TODO(user): Maybe we shouldn't reconstruct this every time, but it is hard
803803
// in case of multiple lps to make sure we don't use variables not in the lp
@@ -808,14 +808,19 @@ void ProductDetector::InitializeBooleanRLTCuts(
808808
// We will list all interesting multiplicative candidate for each variable.
809809
bool_rlt_candidates_.clear();
810810
const int size = ternary_clauses_with_view_.size();
811+
if (size == 0) return;
812+
813+
is_in_lp_vars_.resize(integer_trail_->NumIntegerVariables().value());
814+
for (const IntegerVariable var : lp_vars) is_in_lp_vars_.Set(var);
815+
811816
for (int i = 0; i < size; i += 3) {
812817
const IntegerVariable var1 = ternary_clauses_with_view_[i];
813818
const IntegerVariable var2 = ternary_clauses_with_view_[i + 1];
814819
const IntegerVariable var3 = ternary_clauses_with_view_[i + 2];
815820

816-
if (!lp_vars.contains(PositiveVariable(var1))) continue;
817-
if (!lp_vars.contains(PositiveVariable(var2))) continue;
818-
if (!lp_vars.contains(PositiveVariable(var3))) continue;
821+
if (!is_in_lp_vars_[PositiveVariable(var1)]) continue;
822+
if (!is_in_lp_vars_[PositiveVariable(var2)]) continue;
823+
if (!is_in_lp_vars_[PositiveVariable(var3)]) continue;
819824

820825
// If we have l1 + l2 + l3 >= 1, then for all (i, j) pair we have
821826
// !li * !lj <= lk. We are looking for violation like this.
@@ -830,6 +835,10 @@ void ProductDetector::InitializeBooleanRLTCuts(
830835
UpdateRLTMaps(lp_values, NegationOf(var2), 1.0 - lp2, NegationOf(var3),
831836
1.0 - lp3, var1, lp1);
832837
}
838+
839+
// Clear.
840+
// TODO(user): Just switch to memclear() when dense.
841+
for (const IntegerVariable var : lp_vars) is_in_lp_vars_.ClearBucket(var);
833842
}
834843

835844
} // namespace sat

ortools/sat/implied_bounds.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ class ProductDetector {
298298
// Experimental. Find violated inequality of the form l1 * l2 <= l3.
299299
// And set-up data structure to query this efficiently.
300300
void InitializeBooleanRLTCuts(
301-
const absl::flat_hash_map<IntegerVariable, glop::ColIndex>& lp_vars,
301+
absl::Span<const IntegerVariable> lp_vars,
302302
const util_intops::StrongVector<IntegerVariable, double>& lp_values);
303303

304304
// BoolRLTCandidates()[var] contains the list of factor for which we have
@@ -385,6 +385,8 @@ class ProductDetector {
385385
// as NegatedVariable(). This is a flat vector of size multiple of 3.
386386
std::vector<IntegerVariable> ternary_clauses_with_view_;
387387

388+
Bitset64<IntegerVariable> is_in_lp_vars_;
389+
388390
// Stats.
389391
int64_t num_products_ = 0;
390392
int64_t num_int_products_ = 0;

ortools/sat/implied_bounds_test.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,8 +665,7 @@ TEST(ProductDetectorTest, RLT) {
665665
lp_values[x] = 0.7;
666666
lp_values[y] = 0.9;
667667
lp_values[z] = 0.2;
668-
const absl::flat_hash_map<IntegerVariable, glop::ColIndex> lp_vars = {
669-
{x, glop::ColIndex(0)}, {y, glop::ColIndex(1)}, {z, glop::ColIndex(2)}};
668+
std::vector<IntegerVariable> lp_vars = {x, y, z};
670669
detector->InitializeBooleanRLTCuts(lp_vars, lp_values);
671670

672671
// (1 - X) * Y <= Z, 0.3 * 0.9 == 0.27 <= 0.2, interesting!

ortools/sat/linear_constraint_manager.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ struct ModelReducedCosts
5656
ModelReducedCosts() = default;
5757
};
5858

59+
// Stores the mapping integer_variable -> glop::ColIndex.
60+
// This is shared across all LP, which is fine since there are disjoint.
61+
struct ModelLpVariableMapping
62+
: public util_intops::StrongVector<IntegerVariable, glop::ColIndex> {
63+
ModelLpVariableMapping() = default;
64+
};
65+
5966
// Knowing the symmetry of the IP problem should allow us to
6067
// solve the LP faster via "folding" techniques.
6168
//

0 commit comments

Comments
 (0)