From b05315de21fed2f67e96e7560d23c9c6d83ce0fd Mon Sep 17 00:00:00 2001 From: Corentin Le Molgat Date: Mon, 22 Sep 2025 17:24:20 +0200 Subject: [PATCH] sat: backport from main --- ortools/sat/BUILD.bazel | 25 +- ortools/sat/README.md | 2 +- ortools/sat/clause.cc | 58 ++- ortools/sat/clause.h | 2 + ortools/sat/cp_model_expand.cc | 339 +++++++-------- ortools/sat/cp_model_lns.cc | 38 +- ortools/sat/cp_model_lns.h | 22 +- ortools/sat/cp_model_lns_test.cc | 4 +- ortools/sat/cp_model_presolve.cc | 276 ++++++++---- ortools/sat/cp_model_presolve.h | 7 +- ortools/sat/cp_model_presolve_test.cc | 404 ++++++++++++++++++ ortools/sat/cp_model_search.cc | 23 +- ortools/sat/cp_model_search.h | 9 +- ortools/sat/cp_model_solver.cc | 26 +- ortools/sat/cp_model_solver_helpers.cc | 104 +++-- ortools/sat/cp_model_solver_test.cc | 44 ++ ortools/sat/cuts_test.cc | 66 +-- ortools/sat/diffn_cuts.cc | 20 +- ortools/sat/drat_checker.cc | 73 ++-- ortools/sat/drat_checker.h | 59 +-- ortools/sat/drat_checker_test.cc | 105 +++-- ortools/sat/drat_proof_handler.cc | 4 +- ortools/sat/feasibility_jump.h | 10 +- ortools/sat/implied_bounds_test.cc | 12 +- ortools/sat/integer.cc | 17 +- ortools/sat/integer.h | 3 +- ortools/sat/integer_base.h | 38 +- ortools/sat/integer_search.cc | 19 +- ortools/sat/linear_constraint_manager.cc | 18 +- ortools/sat/linear_constraint_manager_test.cc | 26 +- ortools/sat/linear_constraint_test.cc | 40 +- ortools/sat/linear_propagation_test.cc | 70 +++ ortools/sat/linear_relaxation.cc | 2 +- ortools/sat/linear_relaxation_test.cc | 128 +++--- ortools/sat/model.h | 29 +- ortools/sat/optimization.cc | 11 + ortools/sat/pb_constraint.cc | 182 ++++---- ortools/sat/pb_constraint.h | 8 +- ortools/sat/pb_constraint_test.cc | 28 +- ortools/sat/presolve_context.cc | 18 +- ortools/sat/presolve_context.h | 7 +- ortools/sat/python/BUILD.bazel | 24 +- ortools/sat/python/cp_model_test.py | 1 + ortools/sat/routing_cuts_test.cc | 12 +- ortools/sat/sat_base.h | 85 +++- ortools/sat/sat_parameters.proto | 17 +- ortools/sat/sat_runner.cc | 6 +- ortools/sat/sat_solver.cc | 116 +++-- ortools/sat/sat_solver.h | 14 +- ortools/sat/scheduling_cuts.cc | 11 +- ortools/sat/scheduling_cuts_test.cc | 28 +- ortools/sat/scheduling_helpers_test.cc | 4 +- ortools/sat/shaving_solver.cc | 28 +- ortools/sat/stat_tables.cc | 16 +- ortools/sat/subsolver.cc | 10 +- ortools/sat/synchronization.cc | 120 +++--- ortools/sat/synchronization.h | 30 +- ortools/sat/work_assignment.cc | 10 +- ortools/sat/work_assignment.h | 2 +- 59 files changed, 1904 insertions(+), 1006 deletions(-) diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index ae05db8de0..1f41559194 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -71,10 +71,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//ortools/base", - "//ortools/base:typeid", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/log:check", - "@abseil-cpp//absl/meta:type_traits", ], ) @@ -715,6 +713,7 @@ cc_library( "//ortools/util:sorted_interval_list", "//ortools/util:strong_integers", "//ortools/util:time_limit", + "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/cleanup", "@abseil-cpp//absl/container:btree", @@ -877,10 +876,17 @@ cc_test( ":cp_model_cc_proto", ":cp_model_checker", ":cp_model_solver", + ":cp_model_solver_helpers", ":cp_model_test_utils", + ":cp_model_utils", + ":drat_checker", + ":drat_proof_handler", ":lp_utils", ":model", + ":sat_base", ":sat_parameters_cc_proto", + ":sat_solver", + ":synchronization", "//ortools/base:gmock_main", "//ortools/base:parse_test_proto", "//ortools/linear_solver:linear_solver_cc_proto", @@ -1339,7 +1345,6 @@ cc_library( "@abseil-cpp//absl/container:flat_hash_set", "@abseil-cpp//absl/container:inlined_vector", "@abseil-cpp//absl/log:check", - "@abseil-cpp//absl/meta:type_traits", "@abseil-cpp//absl/strings", "@abseil-cpp//absl/types:span", "@protobuf", @@ -1374,15 +1379,12 @@ cc_library( name = "sat_base", hdrs = ["sat_base.h"], deps = [ - ":model", "//ortools/base", "//ortools/base:strong_vector", - "//ortools/base:types", "//ortools/util:bitset", "//ortools/util:strong_integers", "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/log:check", - "@abseil-cpp//absl/strings", "@abseil-cpp//absl/strings:str_format", "@abseil-cpp//absl/types:span", ], @@ -1431,6 +1433,7 @@ cc_library( "//ortools/util:stats", "//ortools/util:strong_integers", "//ortools/util:time_limit", + "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/container:btree", "@abseil-cpp//absl/container:flat_hash_map", @@ -1633,6 +1636,7 @@ cc_library( "//ortools/util:stats", "//ortools/util:strong_integers", "//ortools/util:time_limit", + "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:flat_hash_set", @@ -1706,7 +1710,6 @@ cc_library( "//ortools/util:saturated_arithmetic", "//ortools/util:stats", "//ortools/util:strong_integers", - "@abseil-cpp//absl/cleanup", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/hash", "@abseil-cpp//absl/log:check", @@ -1727,7 +1730,6 @@ cc_test( "//ortools/base:gmock_main", "//ortools/base:strong_vector", "//ortools/util:strong_integers", - "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/types:span", ], @@ -2342,7 +2344,10 @@ cc_test( ":sat_solver", "//ortools/base:gmock_main", "//ortools/util:strong_integers", + "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/random", + "@abseil-cpp//absl/random:distributions", "@abseil-cpp//absl/types:span", ], ) @@ -3704,6 +3709,10 @@ cc_test( name = "diffn_util_test", size = "small", srcs = ["diffn_util_test.cc"], + target_compatible_with = select({ + "@platforms//os:linux": [], + "//conditions:default": ["@platforms//:incompatible"], + }), deps = [ ":2d_orthogonal_packing_testing", ":diffn_util", diff --git a/ortools/sat/README.md b/ortools/sat/README.md index 73521c8e9e..90ca2211b8 100644 --- a/ortools/sat/README.md +++ b/ortools/sat/README.md @@ -1,4 +1,4 @@ -# CP/SAT +# CP-SAT This directory contains a next-gen Constraint Programming (CP) solver with clause learning. It is built on top of an efficient SAT/max-SAT solver whose diff --git a/ortools/sat/clause.cc b/ortools/sat/clause.cc index 5e36480b76..8be8fa79ab 100644 --- a/ortools/sat/clause.cc +++ b/ortools/sat/clause.cc @@ -23,6 +23,7 @@ #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" @@ -190,8 +191,20 @@ bool ClauseManager::PropagateOnFalse(Literal false_literal, Trail* trail) { // clause using this convention. literals[0] = other_watched_literal; literals[1] = false_literal; + + int propagation_level = trail->CurrentDecisionLevel(); + if (trail->ChronologicalBacktrackingEnabled()) { + const int size = it->clause->size(); + propagation_level = trail->AssignmentLevel(false_literal); + for (int i = 2; i < size; ++i) { + propagation_level = std::max( + propagation_level, trail->AssignmentLevel(literals[i])); + } + } + reasons_[trail->Index()] = it->clause; - trail->Enqueue(other_watched_literal, propagator_id_); + trail->EnqueueAtLevel(other_watched_literal, propagator_id_, + propagation_level); *new_it++ = *it; } } @@ -215,6 +228,18 @@ absl::Span ClauseManager::Reason(const Trail& /*trail*/, return reasons_[trail_index]->PropagationReason(); } +void ClauseManager::Reimply(Trail* trail, int old_trail_index) { + const Literal literal = (*trail)[old_trail_index]; + const int level = trail->AssignmentLevel(literal); + CHECK_LE(trail->Index(), old_trail_index); + reasons_[trail->Index()] = reasons_[old_trail_index]; + DCHECK(absl::c_all_of( + reasons_[trail->Index()]->PropagationReason(), + [&](Literal l) { return trail->AssignmentLevel(l) <= level; })); + DCHECK_EQ(reasons_[trail->Index()]->FirstLiteral(), literal); + trail->EnqueueAtLevel(literal, propagator_id_, level); +} + SatClause* ClauseManager::ReasonClause(int trail_index) const { return reasons_[trail_index]; } @@ -272,9 +297,9 @@ bool ClauseManager::AttachAndPropagate(SatClause* clause, Trail* trail) { if (num_literal_not_false == 1) { // To maintain the validity of the 2-watcher algorithm, we need to watch // the false literal with the highest decision level. - int max_level = trail->Info(literals[1].Variable()).level; + int max_level = trail->AssignmentLevel(literals[1]); for (int i = 2; i < size; ++i) { - const int level = trail->Info(literals[i].Variable()).level; + const int level = trail->AssignmentLevel(literals[i]); if (level > max_level) { max_level = level; std::swap(literals[1], literals[i]); @@ -283,8 +308,12 @@ bool ClauseManager::AttachAndPropagate(SatClause* clause, Trail* trail) { // Propagates literals[0] if it is unassigned. if (!trail->Assignment().LiteralIsTrue(literals[0])) { + DCHECK(absl::c_all_of(clause->PropagationReason(), [&](Literal l) { + return trail->AssignmentLevel(l) <= max_level && + trail->Assignment().LiteralIsFalse(l); + })); reasons_[trail->Index()] = clause; - trail->Enqueue(literals[0], propagator_id_); + trail->EnqueueAtLevel(literals[0], propagator_id_, max_level); } } @@ -616,12 +645,12 @@ bool BinaryImplicationGraph::AddBinaryClause(Literal a, Literal b) { if (assignment.LiteralIsFalse(b)) return false; } else { reasons_[trail_->Index()] = a; - trail_->Enqueue(b, propagator_id_); + trail_->EnqueueAtLevel(b, propagator_id_, trail_->AssignmentLevel(a)); } } else if (assignment.LiteralIsFalse(b)) { if (!assignment.LiteralIsAssigned(a)) { reasons_[trail_->Index()] = b; - trail_->Enqueue(a, propagator_id_); + trail_->EnqueueAtLevel(a, propagator_id_, trail_->AssignmentLevel(b)); } } } @@ -829,6 +858,8 @@ bool BinaryImplicationGraph::Propagate(Trail* trail) { DCHECK(assignment.LiteralIsTrue(true_literal)); if (!implies_something[true_literal]) continue; + const int level = trail->AssignmentLevel(true_literal); + // Note(user): This update is not exactly correct because in case of // conflict we don't inspect that much clauses. But doing ++num_inspections_ // inside the loop does slow down the code by a few percent. @@ -852,7 +883,7 @@ bool BinaryImplicationGraph::Propagate(Trail* trail) { } else { // Propagation. reasons_[trail->Index()] = true_literal.Negated(); - trail->FastEnqueue(literal); + trail->EnqueueAtLevel(literal, propagator_id_, level); } } @@ -880,7 +911,7 @@ bool BinaryImplicationGraph::Propagate(Trail* trail) { } else { // Propagation. reasons_[trail->Index()] = true_literal.Negated(); - trail->FastEnqueue(literal.Negated()); + trail->EnqueueAtLevel(literal.Negated(), propagator_id_, level); } } } @@ -894,6 +925,13 @@ absl::Span BinaryImplicationGraph::Reason( return {&reasons_[trail_index], 1}; } +void BinaryImplicationGraph::Reimply(Trail* trail, int old_trail_index) { + const Literal literal = (*trail)[old_trail_index]; + const int level = trail->AssignmentLevel(literal); + reasons_[trail->Index()] = reasons_[old_trail_index]; + trail->EnqueueAtLevel(literal, propagator_id_, level); +} + // Here, we remove all the literal whose negation are implied by the negation of // the 1-UIP literal (which always appear first in the given conflict). Note // that this algorithm is "optimal" in the sense that it leads to a minimized @@ -1088,12 +1126,12 @@ void BinaryImplicationGraph::MinimizeConflictExperimental( int index = 1; for (int i = 1; i < conflict->size(); ++i) { const Literal lit = (*conflict)[i]; - const int lit_level = trail.Info(lit.Variable()).level; + const int lit_level = trail.AssignmentLevel(lit); bool keep_literal = true; for (const Literal implied : implications_and_amos_[lit].literals()) { if (is_marked_[implied]) { DCHECK_LE(lit_level, trail.Info(implied.Variable()).level); - if (lit_level == trail.Info(implied.Variable()).level && + if (lit_level == trail.AssignmentLevel(implied) && is_simplified_[implied]) { continue; } diff --git a/ortools/sat/clause.h b/ortools/sat/clause.h index a5f1dbed99..c7a697f419 100644 --- a/ortools/sat/clause.h +++ b/ortools/sat/clause.h @@ -177,6 +177,7 @@ class ClauseManager : public SatPropagator { bool Propagate(Trail* trail) final; absl::Span Reason(const Trail& trail, int trail_index, int64_t conflict_id) const final; + void Reimply(Trail* trail, int old_trail_index) final; // Returns the reason of the variable at given trail_index. This only works // for variable propagated by this class and is almost the same as Reason() @@ -518,6 +519,7 @@ class BinaryImplicationGraph : public SatPropagator { bool Propagate(Trail* trail) final; absl::Span Reason(const Trail& trail, int trail_index, int64_t conflict_id) const final; + void Reimply(Trail* trail, int old_trail_index) final; // Resizes the data structure. void Resize(int num_variables); diff --git a/ortools/sat/cp_model_expand.cc b/ortools/sat/cp_model_expand.cc index 8d345ef85f..c957f679f3 100644 --- a/ortools/sat/cp_model_expand.cc +++ b/ortools/sat/cp_model_expand.cc @@ -165,10 +165,8 @@ class EnforcedDomains { std::sort(vars.begin(), vars.end()); for (const int var : vars) { // enforcement_literal => var in domain - ConstraintProto* const imply = context_->working_model->add_constraints(); - *imply->mutable_enforcement_literal() = - constraint_->enforcement_literal(); - LinearConstraintProto* const lin = imply->mutable_linear(); + LinearConstraintProto* const lin = + context_->AddEnforcedConstraint(constraint_)->mutable_linear(); lin->add_vars(var); lin->add_coeffs(1); FillDomainInProto(domains_.at(var), lin); @@ -248,9 +246,7 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, // Add enforced linear for demand. { - ConstraintProto* new_ct = context->working_model->add_constraints(); - *new_ct->mutable_enforcement_literal() = - reservoir_ct->enforcement_literal(); + ConstraintProto* new_ct = context->AddEnforcedConstraint(reservoir_ct); new_ct->add_enforcement_literal(start_var); LinearConstraintProto* lin = new_ct->mutable_linear(); FillDomainInProto(0, lin); @@ -288,9 +284,7 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, // Add enforced linear for time. { - ConstraintProto* new_ct = context->working_model->add_constraints(); - *new_ct->mutable_enforcement_literal() = - reservoir_ct->enforcement_literal(); + ConstraintProto* new_ct = context->AddEnforcedConstraint(reservoir_ct); new_ct->add_enforcement_literal(arc_i_j); LinearConstraintProto* lin = new_ct->mutable_linear(); FillDomainInProto(0, std::numeric_limits::max(), lin); @@ -301,9 +295,7 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, // Add enforced linear for demand. { - ConstraintProto* new_ct = context->working_model->add_constraints(); - *new_ct->mutable_enforcement_literal() = - reservoir_ct->enforcement_literal(); + ConstraintProto* new_ct = context->AddEnforcedConstraint(reservoir_ct); new_ct->add_enforcement_literal(arc_i_j); LinearConstraintProto* lin = new_ct->mutable_linear(); FillDomainInProto(0, lin); @@ -350,9 +342,7 @@ void ExpandReservoirUsingPrecedences(bool max_level_is_constraining, if (demand_i > 0 && !max_level_is_constraining) continue; if (demand_i < 0 && !min_level_is_constraining) continue; - ConstraintProto* new_cumul = context->working_model->add_constraints(); - *new_cumul->mutable_enforcement_literal() = - reservoir_ct->enforcement_literal(); + ConstraintProto* new_cumul = context->AddEnforcedConstraint(reservoir_ct); LinearConstraintProto* new_linear = new_cumul->mutable_linear(); int64_t offset = 0; @@ -450,9 +440,7 @@ void ExpandReservoir(ConstraintProto* reservoir_ct, PresolveContext* context) { // terms though. if (num_negatives == 0 || num_positives == 0) { const int true_literal = context->GetTrueLiteral(); - ConstraintProto* new_ct = context->working_model->add_constraints(); - *new_ct->mutable_enforcement_literal() = - reservoir_ct->enforcement_literal(); + ConstraintProto* new_ct = context->AddEnforcedConstraint(reservoir_ct); LinearConstraintProto* sum = new_ct->mutable_linear(); FillDomainInProto(reservoir.min_level(), reservoir.max_level(), sum); for (int i = 0; i < num_events; ++i) { @@ -483,9 +471,7 @@ void ExpandReservoir(ConstraintProto* reservoir_ct, PresolveContext* context) { // Active => new_var == demand. { - ConstraintProto* demand_ct = - context->working_model->add_constraints(); - demand_ct->add_enforcement_literal(active); + ConstraintProto* demand_ct = context->AddEnforcedConstraint({active}); LinearConstraintProto* lin = demand_ct->mutable_linear(); FillDomainInProto(0, lin); lin->add_vars(new_var); @@ -594,13 +580,6 @@ void ExpandIntMod(ConstraintProto* ct, PresolveContext* context) { return; } - // Create a new constraint with the same enforcement as ct. - auto new_enforced_constraint = [&]() { - ConstraintProto* new_ct = context->working_model->add_constraints(); - *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); - return new_ct; - }; - // div_expr = expr / mod_expr. const int div_var = context->NewIntVar( context->DomainSuperSetOf(expr).PositiveDivisionBySuperset( @@ -610,7 +589,7 @@ void ExpandIntMod(ConstraintProto* ct, PresolveContext* context) { div_expr.add_coeffs(1); LinearArgumentProto* const div_proto = - new_enforced_constraint()->mutable_int_div(); + context->AddEnforcedConstraint(ct)->mutable_int_div(); *div_proto->mutable_target() = div_expr; *div_proto->add_exprs() = expr; *div_proto->add_exprs() = mod_expr; @@ -631,14 +610,14 @@ void ExpandIntMod(ConstraintProto* ct, PresolveContext* context) { prod_expr.add_coeffs(1); LinearArgumentProto* const int_prod = - new_enforced_constraint()->mutable_int_prod(); + context->AddEnforcedConstraint(ct)->mutable_int_prod(); *int_prod->mutable_target() = prod_expr; *int_prod->add_exprs() = div_expr; *int_prod->add_exprs() = mod_expr; // expr - prod_expr = target_expr. LinearConstraintProto* const lin = - new_enforced_constraint()->mutable_linear(); + context->AddEnforcedConstraint(ct)->mutable_linear(); FillDomainInProto(0, lin); AddLinearExpressionToLinearConstraint(expr, 1, lin); AddLinearExpressionToLinearConstraint(prod_expr, -1, lin); @@ -664,11 +643,10 @@ void ExpandIntProd(ConstraintProto* ct, PresolveContext* context) { context->DomainSuperSetOf(right)); const int new_var = context->NewIntVar(new_domain); new_vars.push_back(new_var); - ConstraintProto* new_ct = context->working_model->add_constraints(); // TODO(user): since we copy the enforcement literals in the final int // prod constraint below, this is not strictly necessary. Is it better with // or without? - *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); + ConstraintProto* new_ct = context->AddEnforcedConstraint(ct); LinearArgumentProto* const int_prod = new_ct->mutable_int_prod(); *int_prod->add_exprs() = left; *int_prod->add_exprs() = right; @@ -678,8 +656,7 @@ void ExpandIntProd(ConstraintProto* ct, PresolveContext* context) { terms.front() = int_prod->target(); } - ConstraintProto* new_ct = context->working_model->add_constraints(); - *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); + ConstraintProto* new_ct = context->AddEnforcedConstraint(ct); LinearArgumentProto* const final_int_prod = new_ct->mutable_int_prod(); *final_int_prod->add_exprs() = terms[0]; *final_int_prod->add_exprs() = terms[1]; @@ -803,14 +780,13 @@ void ExpandInverse(ConstraintProto* ct, PresolveContext* context) { const int f_i_j = context->GetOrCreateVarValueEncoding(f_i, j); const int r_j_i = context->GetOrCreateVarValueEncoding(r_j, i); if (f_i_j != r_j_i) { - ConstraintProto* eq = context->working_model->add_constraints(); - *eq->mutable_enforcement_literal() = ct->enforcement_literal(); - eq->add_enforcement_literal(f_i_j); - eq->mutable_bool_and()->add_literals(r_j_i); - eq = context->working_model->add_constraints(); - *eq->mutable_enforcement_literal() = ct->enforcement_literal(); - eq->add_enforcement_literal(r_j_i); - eq->mutable_bool_and()->add_literals(f_i_j); + ConstraintProto* eq_direct = context->AddEnforcedConstraint(ct); + eq_direct->add_enforcement_literal(f_i_j); + eq_direct->mutable_bool_and()->add_literals(r_j_i); + + ConstraintProto* eq_inverse = context->AddEnforcedConstraint(ct); + eq_inverse->add_enforcement_literal(r_j_i); + eq_inverse->mutable_bool_and()->add_literals(f_i_j); } } } @@ -833,8 +809,7 @@ void ExpandLinMax(ConstraintProto* ct, PresolveContext* context) { // First. // - enforcement literals => target >= ai for (const LinearExpressionProto& expr : ct->lin_max().exprs()) { - ConstraintProto* new_ct = context->working_model->add_constraints(); - *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); + ConstraintProto* new_ct = context->AddEnforcedConstraint(ct); LinearConstraintProto* lin = new_ct->mutable_linear(); FillDomainInProto(0, std::numeric_limits::max(), lin); AddLinearExpressionToLinearConstraint(ct->lin_max().target(), 1, lin); @@ -851,18 +826,18 @@ void ExpandLinMax(ConstraintProto* ct, PresolveContext* context) { enforcement_literals.push_back(new_bool); enforcement_literals.push_back(NegatedRef(new_bool)); } else { - ConstraintProto* exactly_one = context->working_model->add_constraints(); - *exactly_one->mutable_enforcement_literal() = ct->enforcement_literal(); + BoolArgumentProto* exactly_one = + context->AddEnforcedConstraint(ct)->mutable_exactly_one(); for (int i = 0; i < num_exprs; ++i) { const int new_bool = context->NewBoolVar("lin max expansion"); - exactly_one->mutable_exactly_one()->add_literals(new_bool); + exactly_one->add_literals(new_bool); enforcement_literals.push_back(new_bool); } } for (int i = 0; i < num_exprs; ++i) { - ConstraintProto* new_ct = context->working_model->add_constraints(); - new_ct->add_enforcement_literal(enforcement_literals[i]); + ConstraintProto* new_ct = + context->AddEnforcedConstraint({enforcement_literals[i]}); LinearConstraintProto* lin = new_ct->mutable_linear(); FillDomainInProto(std::numeric_limits::min(), 0, lin); AddLinearExpressionToLinearConstraint(ct->lin_max().target(), 1, lin); @@ -893,8 +868,7 @@ void ExpandElementWhenTargetShareVarWithIndex( const int64_t index_value = AffineExpressionValueAt(index, v); const int64_t target_value = AffineExpressionValueAt(target, v); const LinearExpressionProto& expr = element.exprs(index_value); - ConstraintProto* imply = context->working_model->add_constraints(); - *imply->mutable_enforcement_literal() = ct->enforcement_literal(); + ConstraintProto* imply = context->AddEnforcedConstraint(ct); imply->add_enforcement_literal( context->GetOrCreateVarValueEncoding(index_var, v)); FillDomainInProto(target_value, imply->mutable_linear()); @@ -916,10 +890,6 @@ void ExpandConstantArrayElement(ConstraintProto* ct, PresolveContext* context, const int index_var = index.vars(0); const LinearExpressionProto& target = element.linear_target(); - // This BoolOrs implements the deduction that if all index literals pointing - // to the same value in the constant array are false, then this value is no - // no longer valid for the target variable. They are created only for values - // that have multiples literals supporting them. absl::btree_map> supports; for (const int64_t v : reduced_index_var_domain.Values()) { const int64_t index_value = AffineExpressionValueAt(index, v); @@ -927,38 +897,34 @@ void ExpandConstantArrayElement(ConstraintProto* ct, PresolveContext* context, supports[expr_value].push_back(v); } - // While this is not strictly needed since all value in the index will be - // covered, it allows to easily detect this fact in the presolve. - // - // TODO(user): Do we need the support part ? Is this discovered by probing? - ConstraintProto* const exactly_one = - context->working_model->add_constraints(); - *exactly_one->mutable_enforcement_literal() = ct->enforcement_literal(); - for (const auto& [expr_value, support] : supports) { + // This is redundant, but it improves solving. + ConstraintProto* new_ct = context->AddEnforcedConstraint(ct); + BoolArgumentProto* exactly_one = new_ct->mutable_exactly_one(); + + for (const auto& [expr_value, supporting_index_var_values] : supports) { const int target_literal = context->GetOrCreateAffineValueEncoding(target, expr_value); - // enforcement_literal && not(indices supporting value) => target != value - ConstraintProto* const bool_or = context->working_model->add_constraints(); - *bool_or->mutable_enforcement_literal() = ct->enforcement_literal(); - bool_or->mutable_bool_or()->add_literals(NegatedRef(target_literal)); - for (const int64_t v : support) { - const int index_literal = - context->GetOrCreateVarValueEncoding(index_var, v); - bool_or->mutable_bool_or()->add_literals(index_literal); - - // enforcement_literal && index == v => target == value - if (index_literal != target_literal) { - ConstraintProto* const eq = context->working_model->add_constraints(); - *eq->mutable_enforcement_literal() = ct->enforcement_literal(); - eq->add_enforcement_literal(index_literal); - eq->mutable_bool_and()->add_literals(target_literal); + if (supporting_index_var_values.size() == 1 && + ct->enforcement_literal().empty()) { + const int index_literal = context->GetOrCreateVarValueEncoding( + index_var, supporting_index_var_values[0]); + if (!context->StoreBooleanEqualityRelation(target_literal, + index_literal)) { + return; + } + exactly_one->add_literals(index_literal); + } else { + // enforcement => exactly_one(target != expr_value, index in support) + ConstraintProto* link = context->AddEnforcedConstraint(ct); + link->mutable_exactly_one()->add_literals(NegatedRef(target_literal)); + for (const int64_t v : supporting_index_var_values) { + const int index_literal = + context->GetOrCreateVarValueEncoding(index_var, v); + link->mutable_exactly_one()->add_literals(index_literal); + exactly_one->add_literals(index_literal); } - - // Helps presolve. - exactly_one->mutable_exactly_one()->add_literals(index_literal); } } - context->UpdateRuleStats("element: expanded value element"); ct->Clear(); } @@ -972,18 +938,17 @@ void ExpandVariableElement(ConstraintProto* ct, PresolveContext* context, const int index_var = index.vars(0); const LinearExpressionProto& target = element.linear_target(); - ConstraintProto* exactly_one = context->working_model->add_constraints(); - *exactly_one->mutable_enforcement_literal() = ct->enforcement_literal(); + ConstraintProto* new_ct = context->AddEnforcedConstraint(ct); + BoolArgumentProto* exactly_one = new_ct->mutable_exactly_one(); for (const int64_t v : reduced_index_var_domain.Values()) { const int64_t index_value = AffineExpressionValueAt(index, v); DCHECK_GE(index_value, 0); DCHECK_LT(index_value, element.exprs_size()); const int index_lit = context->GetOrCreateVarValueEncoding(index_var, v); - exactly_one->mutable_exactly_one()->add_literals(index_lit); + exactly_one->add_literals(index_lit); - ConstraintProto* const imply = context->working_model->add_constraints(); - *imply->mutable_enforcement_literal() = ct->enforcement_literal(); + ConstraintProto* const imply = context->AddEnforcedConstraint(ct); imply->add_enforcement_literal(index_lit); FillDomainInProto(0, imply->mutable_linear()); AddLinearExpressionToLinearConstraint(target, -1, imply->mutable_linear()); @@ -1033,8 +998,7 @@ void ExpandElement(ConstraintProto* ct, PresolveContext* context) { return; } // enforcement_literal => index in [0, size - 1] - ConstraintProto* const index_ct = context->working_model->add_constraints(); - *index_ct->mutable_enforcement_literal() = ct->enforcement_literal(); + ConstraintProto* const index_ct = context->AddEnforcedConstraint(ct); FillDomainInProto(0, size - 1, index_ct->mutable_linear()); AddLinearExpressionToLinearConstraint(index, 1, index_ct->mutable_linear()); context->CanonicalizeLinearConstraint(index_ct); @@ -1044,8 +1008,7 @@ void ExpandElement(ConstraintProto* ct, PresolveContext* context) { index.vars_size() == 0 || reduced_index_var_domain.IsFixed(); if (reduced_index_domain_is_fixed) { DCHECK(!ct->enforcement_literal().empty() || context->IsFixed(index)); - ConstraintProto* const eq = context->working_model->add_constraints(); - *eq->mutable_enforcement_literal() = ct->enforcement_literal(); + ConstraintProto* const eq = context->AddEnforcedConstraint(ct); FillDomainInProto(0, eq->mutable_linear()); AddLinearExpressionToLinearConstraint(target, 1, eq->mutable_linear()); const int64_t reduced_index_fixed_value = @@ -1091,7 +1054,8 @@ void ExpandElement(ConstraintProto* ct, PresolveContext* context) { // enforcement_literals && literals[i] true => encoding[values[i]] true // enforcement_literals => one of literals[i in I(j)] true || encoding[j] false // where I(j) = {i | values[i] = j}. This also implicitly uses the fact that -// exactly one alternative is true. +// exactly one literals is true. Note that we will use exactly_one in the +// encoding if possible. void LinkLiteralsAndValues(absl::Span enforcement_literals, absl::Span literals, absl::Span values, @@ -1108,6 +1072,9 @@ void LinkLiteralsAndValues(absl::Span enforcement_literals, encoding_lit_to_support[encoding.at(values[i])].push_back(literals[i]); } + // Using an exactly one convey more structure and has a better linear + // relaxation. Even if we could theorically infer it back from the other + // encoding. for (const auto& [encoding_lit, support] : encoding_lit_to_support) { CHECK(!support.empty()); if (support.size() == 1 && enforcement_literals.empty()) { @@ -1115,25 +1082,12 @@ void LinkLiteralsAndValues(absl::Span enforcement_literals, return; } } else { - // The `ct` constraint ensures that if all tuples supporting a value are - // false, then this value must be false (if the automaton constraint is - // enforced). - ConstraintProto* ct = context->working_model->add_constraints(); - *ct->mutable_enforcement_literal() = {enforcement_literals.begin(), - enforcement_literals.end()}; - BoolArgumentProto* bool_or = ct->mutable_bool_or(); - bool_or->add_literals(NegatedRef(encoding_lit)); + BoolArgumentProto* exo = + context->AddEnforcedConstraint(enforcement_literals) + ->mutable_exactly_one(); + exo->add_literals(NegatedRef(encoding_lit)); for (const int lit : support) { - bool_or->add_literals(lit); - // Conversely, if a tuple supporting a value is selected, this value - // must be selected (if the automaton constraint is enforced). - if (lit != encoding_lit) { - ConstraintProto* inv_ct = context->working_model->add_constraints(); - *inv_ct->mutable_enforcement_literal() = { - enforcement_literals.begin(), enforcement_literals.end()}; - inv_ct->add_enforcement_literal(lit); - inv_ct->mutable_bool_and()->add_literals(encoding_lit); - } + exo->add_literals(lit); } } } @@ -1151,9 +1105,7 @@ void AddImplyInReachableValues(absl::Span enforcement_literals, if (reachable_values.size() == encoding.size()) return; // No constraint. if (reachable_values.size() <= encoding.size() / 2) { // Bool or encoding. - ConstraintProto* ct = context->working_model->add_constraints(); - *ct->mutable_enforcement_literal() = {enforcement_literals.begin(), - enforcement_literals.end()}; + ConstraintProto* ct = context->AddEnforcedConstraint(enforcement_literals); ct->add_enforcement_literal(literal); BoolArgumentProto* bool_or = ct->mutable_bool_or(); for (const int64_t v : reachable_values) { @@ -1163,9 +1115,7 @@ void AddImplyInReachableValues(absl::Span enforcement_literals, // Bool and encoding. absl::flat_hash_set set(reachable_values.begin(), reachable_values.end()); - ConstraintProto* ct = context->working_model->add_constraints(); - *ct->mutable_enforcement_literal() = {enforcement_literals.begin(), - enforcement_literals.end()}; + ConstraintProto* ct = context->AddEnforcedConstraint(enforcement_literals); ct->add_enforcement_literal(literal); BoolArgumentProto* bool_and = ct->mutable_bool_and(); for (const auto [value, literal] : encoding) { @@ -1413,9 +1363,7 @@ void ExpandAutomaton(ConstraintProto* ct, PresolveContext* context) { // Part2, add all 3-clauses: enforcement_literal && (in_state, label) => // out_state. for (int i = 0; i < num_tuples; ++i) { - ConstraintProto* bool_or_ct = context->working_model->add_constraints(); - *bool_or_ct->mutable_enforcement_literal() = ct->enforcement_literal(); - auto* bool_or = bool_or_ct->mutable_bool_or(); + auto* bool_or = context->AddEnforcedConstraint(ct)->mutable_bool_or(); bool_or->add_literals(NegatedRef(in_encoding.at(in_states[i]))); bool_or->add_literals(NegatedRef(encoding.at(labels[i]))); bool_or->add_literals(out_encoding.at(out_states[i])); @@ -1440,11 +1388,8 @@ void ExpandAutomaton(ConstraintProto* ct, PresolveContext* context) { // Note that we do not need the ExactlyOneConstraint(tuple_literals) // because it is already implicitly encoded since we have exactly one // transition value. But adding one seems to help. - ConstraintProto* exactly_one_ct = - context->working_model->add_constraints(); - *exactly_one_ct->mutable_enforcement_literal() = - ct->enforcement_literal(); - BoolArgumentProto* exactly_one = exactly_one_ct->mutable_exactly_one(); + BoolArgumentProto* exactly_one = + context->AddEnforcedConstraint(ct)->mutable_exactly_one(); for (int i = 0; i < num_tuples; ++i) { int tuple_literal; if (in_count[in_states[i]] == 1 && !in_encoding.empty()) { @@ -1577,8 +1522,7 @@ void ExpandNegativeTable(ConstraintProto* ct, PresolveContext* context) { } // Note: if the clause is empty, then the model is infeasible. - ConstraintProto* tuple_ct = context->working_model->add_constraints(); - *tuple_ct->mutable_enforcement_literal() = ct->enforcement_literal(); + ConstraintProto* tuple_ct = context->AddEnforcedConstraint(ct); BoolArgumentProto* bool_or = tuple_ct->mutable_bool_or(); for (const int lit : clause) { bool_or->add_literals(lit); @@ -1599,50 +1543,77 @@ void ProcessOneCompressedColumn( std::optional table_is_active_literal, PresolveContext* context) { DCHECK_EQ(tuple_literals.size(), values.size()); + // Some pre-computations. // Collect pairs of value-literal. - // Add the constraint literal => one of values. - // - // TODO(user): If we have n - 1 values, we could add the constraint that - // tuple literal => not(last_value) instead? - std::vector> pairs; + absl::flat_hash_set value_is_multiple; std::vector any_values_literals; + std::vector> pairs; for (int i = 0; i < values.size(); ++i) { if (values[i].empty()) { any_values_literals.push_back(tuple_literals[i]); continue; } + for (const int64_t v : values[i]) { + pairs.emplace_back(v, tuple_literals[i]); + } + if (values[i].size() > 1) { + value_is_multiple.insert(values[i].begin(), values[i].end()); + } + } - ConstraintProto* ct = context->working_model->add_constraints(); - ct->add_enforcement_literal(tuple_literals[i]); + // Try to use exactly one in the encoding if we can. + bool use_exo = true; + if (table_is_active_literal.has_value()) use_exo = false; + if (!any_values_literals.empty()) use_exo = false; - // It is slightly better to use a bool_and if size is 1 instead of - // reconverting it at a later stage. - auto* literals = - values[i].size() == 1 ? ct->mutable_bool_and() : ct->mutable_bool_or(); + // Add the constraint literal => one of values. + for (int i = 0; i < values.size(); ++i) { + if (values[i].empty()) continue; + + if (use_exo && values[i].size() == 1 && + !value_is_multiple.contains(values[i][0])) { + // nothing to do here since the implication is covered by the exactly one. + continue; + } + + ConstraintProto* ct = context->AddEnforcedConstraint({tuple_literals[i]}); + if (values[i].size() == 1) { + // It is slightly better to use a bool_and if size is 1 instead of + // reconverting it at a later stage. + const int v = values[i][0]; + ct->mutable_bool_and()->add_literals( + context->GetOrCreateVarValueEncoding(variable, v)); + continue; + } + + // TODO(user): If we have n - 1 values, we could add the constraint that + // tuple literal => not(last_value) instead? + auto* literals = ct->mutable_bool_or(); for (const int64_t v : values[i]) { DCHECK(context->DomainContains(variable, v)); literals->add_literals(context->GetOrCreateVarValueEncoding(variable, v)); - pairs.emplace_back(v, tuple_literals[i]); } } // Regroup literal with the same value and add for each the clause: If all the // tuples containing a value are false, then this value must be false too. - std::vector selected; std::sort(pairs.begin(), pairs.end()); for (int i = 0; i < pairs.size();) { - selected.clear(); const int64_t value = pairs[i].first; - for (; i < pairs.size() && pairs[i].first == value; ++i) { - selected.push_back(pairs[i].second); - } // A value is supported if one tuple is still active, or a covering 'any' // tuple is still active, or the table can still be inactive. + // + // Note that if a value only appear individually in each tuple, and the + // table is not enforced, then we have an exactly one. This seems to helps a + // bit, especially the linear relaxation. BoolArgumentProto* no_support = - context->working_model->add_constraints()->mutable_bool_or(); - for (const int lit : selected) { - no_support->add_literals(lit); + use_exo && !value_is_multiple.contains(value) + ? context->working_model->add_constraints()->mutable_exactly_one() + : context->working_model->add_constraints()->mutable_bool_or(); + + for (; i < pairs.size() && pairs[i].first == value; ++i) { + no_support->add_literals(pairs[i].second); } for (const int lit : any_values_literals) { no_support->add_literals(lit); @@ -1693,38 +1664,73 @@ void AddSizeTwoTable( int num_implications = 0; int num_clause_added = 0; int num_large_clause_added = 0; + int num_exo_added = 0; + int num_equivalences_added = 0; auto add_support_constraint = - [context, &num_clause_added, &num_large_clause_added, &num_implications]( - int lit, absl::Span support_literals, - int max_support_size) { + [context, &num_clause_added, &num_large_clause_added, &num_implications, + &num_exo_added, &num_equivalences_added]( + int lit, absl::Span support_literals, int max_support_size, + const absl::btree_map>& other_map) { if (support_literals.size() == max_support_size) return; if (support_literals.size() == 1) { - context->AddImplication(lit, support_literals.front()); - num_implications++; - } else { - BoolArgumentProto* bool_or = - context->working_model->add_constraints()->mutable_bool_or(); - for (const int support_literal : support_literals) { - bool_or->add_literals(support_literal); + const int support_literal = support_literals.front(); + const auto& it = other_map.find(support_literal); + CHECK(it != other_map.end()); + if (it->second.size() > 1) { + context->AddImplication(lit, support_literal); + num_implications++; + } else { + if (!context->StoreBooleanEqualityRelation(lit, support_literal)) { + return; + } + ++num_equivalences_added; } - bool_or->add_literals(NegatedRef(lit)); - num_clause_added++; - if (support_literals.size() > max_support_size / 2) { - num_large_clause_added++; + } else { + bool exclusive = true; + for (const int support_literal : support_literals) { + const auto& it = other_map.find(support_literal); + CHECK(it != other_map.end()); + if (it->second.size() > 1) { + exclusive = false; + break; + } + } + if (exclusive) { + BoolArgumentProto* exo = context->working_model->add_constraints() + ->mutable_exactly_one(); + for (const int support_literal : support_literals) { + exo->add_literals(support_literal); + } + exo->add_literals(NegatedRef(lit)); + ++num_exo_added; + } else { + BoolArgumentProto* bool_or = + context->working_model->add_constraints()->mutable_bool_or(); + for (const int support_literal : support_literals) { + bool_or->add_literals(support_literal); + } + bool_or->add_literals(NegatedRef(lit)); + num_clause_added++; + if (support_literals.size() > max_support_size / 2) { + num_large_clause_added++; + } } } }; for (const auto& it : left_to_right) { - add_support_constraint(it.first, it.second, values_per_var[1].size()); + add_support_constraint(it.first, it.second, values_per_var[1].size(), + right_to_left); } for (const auto& it : right_to_left) { - add_support_constraint(it.first, it.second, values_per_var[0].size()); + add_support_constraint(it.first, it.second, values_per_var[0].size(), + left_to_right); } VLOG(2) << "Table: 2 variables, " << tuples.size() << " tuples encoded using " << num_clause_added << " clauses, including " << num_large_clause_added << " large clauses, " << num_implications - << " implications"; + << " implications, " << num_exo_added << " exactly ones, " + << num_equivalences_added << " equivalences."; } // A "WCSP" (weighted constraint programming) problem is usually encoded as @@ -2679,8 +2685,7 @@ void MaybeExpandAllDiff(ConstraintProto* ct, PresolveContext* context, } } - ConstraintProto* const new_ct = context->working_model->add_constraints(); - *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); + ConstraintProto* const new_ct = context->AddEnforcedConstraint(ct); BoolArgumentProto* at_most_or_equal_one = is_a_permutation ? new_ct->mutable_exactly_one() : new_ct->mutable_at_most_one(); diff --git a/ortools/sat/cp_model_lns.cc b/ortools/sat/cp_model_lns.cc index 36990c16a9..f000f6406a 100644 --- a/ortools/sat/cp_model_lns.cc +++ b/ortools/sat/cp_model_lns.cc @@ -115,7 +115,7 @@ void NeighborhoodGeneratorHelper::Synchronize() { bool new_variables_have_been_fixed = false; if (!model_variables.empty()) { - absl::MutexLock domain_lock(&domain_mutex_); + absl::MutexLock domain_lock(domain_mutex_); for (int i = 0; i < model_variables.size(); ++i) { const int var = model_variables[i]; @@ -219,8 +219,8 @@ void NeighborhoodGeneratorHelper::InitializeHelperData() { // Recompute all the data when new variables have been fixed. Note that this // shouldn't be called if there is no change as it is in O(problem size). void NeighborhoodGeneratorHelper::RecomputeHelperData() { - absl::MutexLock graph_lock(&graph_mutex_); - absl::ReaderMutexLock domain_lock(&domain_mutex_); + absl::MutexLock graph_lock(graph_mutex_); + absl::ReaderMutexLock domain_lock(domain_mutex_); // Do basic presolving to have a more precise graph. // Here we just remove trivially true constraints. @@ -421,7 +421,7 @@ Neighborhood NeighborhoodGeneratorHelper::FullNeighborhood() const { neighborhood.is_reduced = false; neighborhood.is_generated = true; { - absl::ReaderMutexLock lock(&domain_mutex_); + absl::ReaderMutexLock lock(domain_mutex_); *neighborhood.delta.mutable_variables() = model_proto_with_only_variables_.variables(); } @@ -464,7 +464,7 @@ std::vector NeighborhoodGeneratorHelper::KeepActiveIntervals( const CpSolverResponse& initial_solution) const { std::vector filtered_intervals; filtered_intervals.reserve(unfiltered_intervals.size()); - absl::ReaderMutexLock lock(&domain_mutex_); + absl::ReaderMutexLock lock(domain_mutex_); for (const int i : unfiltered_intervals) { if (IntervalIsActive(i, initial_solution)) filtered_intervals.push_back(i); } @@ -1082,7 +1082,7 @@ Neighborhood NeighborhoodGeneratorHelper::FixGivenVariables( // Fill in neighborhood.delta all variable domains. int num_fixed = 0; { - absl::ReaderMutexLock domain_lock(&domain_mutex_); + absl::ReaderMutexLock domain_lock(domain_mutex_); for (int i = 0; i < num_variables; ++i) { const IntegerVariableProto& current_var = model_proto_with_only_variables_.variables(i); @@ -1128,7 +1128,7 @@ Neighborhood NeighborhoodGeneratorHelper::FixGivenVariables( // // TODO(user): If there is just one component, we can skip some computation. { - absl::ReaderMutexLock graph_lock(&graph_mutex_); + absl::ReaderMutexLock graph_lock(graph_mutex_); std::vector count(components_.size(), 0); const int num_variables = neighborhood.delta.variables().size(); for (int var = 0; var < num_variables; ++var) { @@ -1205,7 +1205,7 @@ Neighborhood NeighborhoodGeneratorHelper::RelaxGivenVariables( absl::Span relaxed_variables) const { Bitset64 fixed_variables(NumVariables()); { - absl::ReaderMutexLock graph_lock(&graph_mutex_); + absl::ReaderMutexLock graph_lock(graph_mutex_); for (const int i : active_variables_) { fixed_variables.Set(i); } @@ -1218,7 +1218,7 @@ Neighborhood NeighborhoodGeneratorHelper::FixAllVariables( const CpSolverResponse& initial_solution) const { Bitset64 fixed_variables(NumVariables()); { - absl::ReaderMutexLock graph_lock(&graph_mutex_); + absl::ReaderMutexLock graph_lock(graph_mutex_); for (const int i : active_variables_) { fixed_variables.Set(i); } @@ -1229,7 +1229,7 @@ Neighborhood NeighborhoodGeneratorHelper::FixAllVariables( CpModelProto NeighborhoodGeneratorHelper::UpdatedModelProtoCopy() const { CpModelProto updated_model = model_proto_; { - absl::MutexLock domain_lock(&domain_mutex_); + absl::MutexLock domain_lock(domain_mutex_); *updated_model.mutable_variables() = model_proto_with_only_variables_.variables(); } @@ -1241,14 +1241,14 @@ bool NeighborhoodGenerator::ReadyToGenerate() const { } double NeighborhoodGenerator::GetUCBScore(int64_t total_num_calls) const { - absl::ReaderMutexLock mutex_lock(&generator_mutex_); + absl::ReaderMutexLock mutex_lock(generator_mutex_); DCHECK_GE(total_num_calls, num_calls_); if (num_calls_ <= 10) return std::numeric_limits::infinity(); return current_average_ + sqrt((2 * log(total_num_calls)) / num_calls_); } absl::Span NeighborhoodGenerator::Synchronize() { - absl::MutexLock mutex_lock(&generator_mutex_); + absl::MutexLock mutex_lock(generator_mutex_); // To make the whole update process deterministic, we currently sort the // SolveData. @@ -1334,7 +1334,7 @@ std::vector NeighborhoodGeneratorHelper::ImprovableObjectiveVariablesWhileHoldingLock( const CpSolverResponse& initial_solution) const { std::vector result; - absl::ReaderMutexLock lock(&domain_mutex_); + absl::ReaderMutexLock lock(domain_mutex_); for (const int var : active_objective_variables_) { const auto& domain = model_proto_with_only_variables_.variables(var).domain(); @@ -1386,7 +1386,7 @@ Neighborhood RelaxRandomConstraintsGenerator::Generate( std::vector relaxed_variables; { - absl::ReaderMutexLock graph_lock(&helper_.graph_mutex_); + absl::ReaderMutexLock graph_lock(helper_.graph_mutex_); const int num_active_constraints = helper_.ConstraintToVar().size(); std::vector active_constraints(num_active_constraints); for (int c = 0; c < num_active_constraints; ++c) { @@ -1437,7 +1437,7 @@ Neighborhood VariableGraphNeighborhoodGenerator::Generate( std::vector random_variables; { - absl::ReaderMutexLock graph_lock(&helper_.graph_mutex_); + absl::ReaderMutexLock graph_lock(helper_.graph_mutex_); std::vector initial_vars = helper_.ImprovableObjectiveVariablesWhileHoldingLock(initial_solution); @@ -1507,7 +1507,7 @@ Neighborhood ArcGraphNeighborhoodGenerator::Generate( int num_active_vars = 0; std::vector active_objective_vars; { - absl::ReaderMutexLock graph_lock(&helper_.graph_mutex_); + absl::ReaderMutexLock graph_lock(helper_.graph_mutex_); num_active_vars = helper_.ActiveVariablesWhileHoldingLock().size(); active_objective_vars = helper_.ImprovableObjectiveVariablesWhileHoldingLock(initial_solution); @@ -1595,7 +1595,7 @@ Neighborhood ConstraintGraphNeighborhoodGenerator::Generate( std::vector random_variables; { - absl::ReaderMutexLock graph_lock(&helper_.graph_mutex_); + absl::ReaderMutexLock graph_lock(helper_.graph_mutex_); const int num_active_vars = helper_.ActiveVariablesWhileHoldingLock().size(); const int target_size = std::ceil(data.difficulty * num_active_vars); @@ -1658,7 +1658,7 @@ Neighborhood DecompositionGraphNeighborhoodGenerator::Generate( // might not want to lock the graph for so long? it is just a reader lock // though. { - absl::ReaderMutexLock graph_lock(&helper_.graph_mutex_); + absl::ReaderMutexLock graph_lock(helper_.graph_mutex_); const int num_active_vars = helper_.ActiveVariablesWhileHoldingLock().size(); @@ -2808,7 +2808,7 @@ Neighborhood RelaxationInducedNeighborhoodGenerator::Generate( } neighborhood.source_info = reduced_domains.source_info; - absl::ReaderMutexLock graph_lock(&helper_.graph_mutex_); + absl::ReaderMutexLock graph_lock(helper_.graph_mutex_); // Fix the variables in the local model. for (const std::pair& fixed_var : reduced_domains.fixed_vars) { diff --git a/ortools/sat/cp_model_lns.h b/ortools/sat/cp_model_lns.h index 91b823b4e8..de156e9c4f 100644 --- a/ortools/sat/cp_model_lns.h +++ b/ortools/sat/cp_model_lns.h @@ -156,25 +156,25 @@ class NeighborhoodGeneratorHelper : public SubSolver { // Returns the list of "active" variables. std::vector ActiveVariables() const { std::vector result; - absl::ReaderMutexLock lock(&graph_mutex_); + absl::ReaderMutexLock lock(graph_mutex_); result = active_variables_; return result; } int NumActiveVariables() const { - absl::ReaderMutexLock lock(&graph_mutex_); + absl::ReaderMutexLock lock(graph_mutex_); return active_variables_.size(); } std::vector ActiveObjectiveVariables() const { std::vector result; - absl::ReaderMutexLock lock(&graph_mutex_); + absl::ReaderMutexLock lock(graph_mutex_); result = active_objective_variables_; return result; } bool DifficultyMeansFullNeighborhood(double difficulty) const { - absl::ReaderMutexLock lock(&graph_mutex_); + absl::ReaderMutexLock lock(graph_mutex_); const int target_size = static_cast(std::ceil(difficulty * active_variables_.size())); return target_size == active_variables_.size(); @@ -470,7 +470,7 @@ class NeighborhoodGenerator { double GetUCBScore(int64_t total_num_calls) const; void AddSolveData(SolveData data) { - absl::MutexLock mutex_lock(&generator_mutex_); + absl::MutexLock mutex_lock(generator_mutex_); solve_data_.push_back(data); } @@ -484,19 +484,19 @@ class NeighborhoodGenerator { // Number of times this generator was called. int64_t num_calls() const { - absl::MutexLock mutex_lock(&generator_mutex_); + absl::MutexLock mutex_lock(generator_mutex_); return num_calls_; } // Number of time the neighborhood was fully solved (OPTIMAL/INFEASIBLE). int64_t num_fully_solved_calls() const { - absl::MutexLock mutex_lock(&generator_mutex_); + absl::MutexLock mutex_lock(generator_mutex_); return num_fully_solved_calls_; } // Out of num_calls(), how many improved the given solution. int64_t num_improving_calls() const { - absl::MutexLock mutex_lock(&generator_mutex_); + absl::MutexLock mutex_lock(generator_mutex_); return num_improving_calls_; } @@ -504,19 +504,19 @@ class NeighborhoodGenerator { // the best solution. Note that this count improvement to the best known // solution not the base one used to generate one neighborhood. int64_t num_consecutive_non_improving_calls() const { - absl::MutexLock mutex_lock(&generator_mutex_); + absl::MutexLock mutex_lock(generator_mutex_); return num_consecutive_non_improving_calls_; } // The current difficulty of this generator double difficulty() const { - absl::MutexLock mutex_lock(&generator_mutex_); + absl::MutexLock mutex_lock(generator_mutex_); return difficulty_.value(); } // The current time limit that the sub-solve should use on this generator. double deterministic_limit() const { - absl::MutexLock mutex_lock(&generator_mutex_); + absl::MutexLock mutex_lock(generator_mutex_); return deterministic_limit_; } diff --git a/ortools/sat/cp_model_lns_test.cc b/ortools/sat/cp_model_lns_test.cc index 68dc701a14..83d95d6093 100644 --- a/ortools/sat/cp_model_lns_test.cc +++ b/ortools/sat/cp_model_lns_test.cc @@ -568,7 +568,7 @@ TEST(NeighborhoodGeneratorHelperTest, BoundAreUpdatedOnSynchronize) { // No change since not synchronized. { - absl::ReaderMutexLock lock(&helper.graph_mutex_); + absl::ReaderMutexLock lock(helper.graph_mutex_); EXPECT_TRUE(helper.IsActive(0)); } EXPECT_EQ(ReadDomainFromProto(helper.FullNeighborhood().delta.variables(0)), @@ -577,7 +577,7 @@ TEST(NeighborhoodGeneratorHelperTest, BoundAreUpdatedOnSynchronize) { // New bound are properly there. { - absl::ReaderMutexLock lock(&helper.graph_mutex_); + absl::ReaderMutexLock lock(helper.graph_mutex_); EXPECT_FALSE(helper.IsActive(0)); } EXPECT_EQ(ReadDomainFromProto(helper.FullNeighborhood().delta.variables(0)), diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index 611ef32088..cfd5573507 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -419,7 +419,8 @@ bool CpModelPresolver::PresolveBoolOr(ConstraintProto* ct) { // Note this function does not update the constraint graph. It assumes this is // done elsewhere. ABSL_MUST_USE_RESULT bool CpModelPresolver::MarkConstraintAsFalse( - ConstraintProto* ct, const std::string& reason) { + ConstraintProto* ct, std::string_view reason) { + DCHECK(!reason.empty()); if (HasEnforcementLiteral(*ct)) { // Change the constraint to a bool_or. ct->mutable_bool_or()->clear_literals(); @@ -453,8 +454,7 @@ bool CpModelPresolver::PresolveBoolAnd(ConstraintProto* ct) { ct->enforcement_literal().begin(), ct->enforcement_literal().end()); for (const int literal : ct->bool_and().literals()) { if (context_->LiteralIsFalse(literal)) { - context_->UpdateRuleStats("bool_and: always false"); - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, "bool_and: always false"); } if (context_->LiteralIsTrue(literal)) { changed = true; @@ -466,8 +466,7 @@ bool CpModelPresolver::PresolveBoolAnd(ConstraintProto* ct) { continue; } if (enforcement_literals_set.contains(NegatedRef(literal))) { - context_->UpdateRuleStats("bool_and: x => not x"); - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, "bool_and: x => not x"); } if (context_->VariableIsUniqueAndRemovable(literal)) { // This is a "dual" reduction. @@ -479,8 +478,7 @@ bool CpModelPresolver::PresolveBoolAnd(ConstraintProto* ct) { } if (context_->tmp_literal_set.contains(NegatedRef(literal))) { - context_->UpdateRuleStats("bool_and: cannot be enforced"); - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, "bool_and: cannot be enforced"); } const auto [_, inserted] = context_->tmp_literal_set.insert(literal); @@ -921,8 +919,8 @@ bool CpModelPresolver::PropagateAndReduceAffineMax(ConstraintProto* ct) { } if (reachable_target_values.empty() || valid_variable_values.empty()) { - context_->UpdateRuleStats("lin_max: infeasible affine_max constraint"); - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, + "lin_max: infeasible affine_max constraint"); } { @@ -1014,8 +1012,7 @@ bool CpModelPresolver::PropagateAndReduceLinMax(ConstraintProto* ct) { if (target.vars().empty()) { if (!Domain(infered_min, infered_max).Contains(target.offset())) { - context_->UpdateRuleStats("lin_max: infeasible"); - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, "lin_max: infeasible"); } } if (target.vars().size() <= 1) { // Affine @@ -1134,8 +1131,7 @@ bool CpModelPresolver::PresolveLinMax(int c, ConstraintProto* ct) { } if (ct->lin_max().exprs().empty()) { - context_->UpdateRuleStats("lin_max: no exprs"); - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, "lin_max: no exprs"); } // Try to reduce lin_max using known relation. @@ -1350,7 +1346,7 @@ bool CpModelPresolver::PresolveLinMax(int c, ConstraintProto* ct) { } if (all_booleans) { if (literals.empty()) { - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, "lin_max: all boolean and no support"); } // At least one true; @@ -2302,7 +2298,7 @@ bool CpModelPresolver::DivideLinearByGcd(ConstraintProto* ct) { const Domain rhs = ReadDomainFromProto(ct->linear()); FillDomainInProto(rhs.InverseMultiplicationBy(gcd), ct->mutable_linear()); if (ct->linear().domain_size() == 0) { - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, "linear: not satisfied after GCD."); } } return false; @@ -2318,9 +2314,8 @@ bool CpModelPresolver::CanonicalizeLinear(ConstraintProto* ct, bool* changed) { if (context_->ModelIsUnsat()) return false; if (ct->linear().domain().empty()) { - context_->UpdateRuleStats("linear: no domain"); *changed = true; - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, "linear: no domain"); } *changed = context_->CanonicalizeLinearConstraint(ct); @@ -2421,9 +2416,8 @@ bool CpModelPresolver::RemoveSingletonInLinear(ConstraintProto* ct) { context_->UpdateRuleStats( "TODO independent linear: minimize single linear constraint"); } else if (result.infeasible) { - context_->UpdateRuleStats( - "independent linear: no DP solution to simple constraint"); - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse( + ct, "independent linear: no DP solution to simple constraint"); } else { if (ct->enforcement_literal().empty()) { // Just fix everything. @@ -2771,8 +2765,7 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) { .InverseMultiplicationBy(ct->linear().coeffs(0)) .IntersectionWith(var_domain); if (rhs.IsEmpty()) { - context_->UpdateRuleStats("linear1: infeasible"); - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, "linear1: infeasible"); } if (rhs == var_domain) { context_->UpdateRuleStats("linear1: always true"); @@ -2906,8 +2899,7 @@ bool CpModelPresolver::PresolveLinearOfSizeTwo(ConstraintProto* ct) { return RemoveConstraint(ct); } } else if (status == RelationStatus::IS_FALSE) { - context_->UpdateRuleStats("linear2: infeasible relation"); - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, "linear2: infeasible relation"); } else if (ct->enforcement_literal().empty()) { known_linear2_.Add(expr2, lb, ub); } @@ -2951,8 +2943,7 @@ bool CpModelPresolver::PresolveLinearOfSizeTwo(ConstraintProto* ct) { const bool implied_true = context_->DomainOf(var).IntersectionWith(rhs_if_false).IsEmpty(); if (implied_true && implied_false) { - context_->UpdateRuleStats("linear2: infeasible."); - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, "linear2: infeasible."); } else if (implied_true) { context_->UpdateRuleStats("linear2: Boolean with one feasible value."); @@ -3056,9 +3047,8 @@ bool CpModelPresolver::PresolveLinearOfSizeTwo(ConstraintProto* ct) { int64_t x0 = 0; int64_t y0 = 0; if (!SolveDiophantineEquationOfSizeTwo(a, b, cte, x0, y0)) { - context_->UpdateRuleStats( - "linear2: implied ax + by = cte has no solutions"); - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse( + ct, "linear2: implied ax + by = cte has no solutions"); } const Domain reduced_domain = context_->DomainOf(var1) @@ -3069,9 +3059,8 @@ bool CpModelPresolver::PresolveLinearOfSizeTwo(ConstraintProto* ct) { .InverseMultiplicationBy(-a)); if (reduced_domain.IsEmpty()) { // no solution - context_->UpdateRuleStats( - "linear2: implied ax + by = cte has no solutions"); - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse( + ct, "linear2: implied ax + by = cte has no solutions"); } if (reduced_domain.Size() == 1) { @@ -3111,12 +3100,12 @@ bool CpModelPresolver::PresolveSmallLinear(ConstraintProto* ct) { if (context_->ModelIsUnsat()) return false; if (ct->linear().vars().empty()) { - context_->UpdateRuleStats("linear: empty"); const Domain rhs = ReadDomainFromProto(ct->linear()); if (rhs.Contains(0)) { + context_->UpdateRuleStats("linear: empty"); return RemoveConstraint(ct); } else { - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, "linear: empty"); } } else if (ct->linear().vars().size() == 1) { return PresolveLinearOfSizeOne(ct); @@ -3149,8 +3138,7 @@ bool CpModelPresolver::PresolveDiophantine(ConstraintProto* ct) { linear_constraint.coeffs(), linear_constraint.domain(0), lbs, ubs); if (!diophantine_solution.has_solutions) { - context_->UpdateRuleStats("diophantine: equality has no solutions"); - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, "diophantine: equality has no solutions"); } if (diophantine_solution.no_reformulation_needed) return false; // Only first coefficients of kernel_basis elements and special_solution could @@ -3351,7 +3339,7 @@ void CpModelPresolver::TryToReduceCoefficientsOfLinearConstraint( // Mark trivially false constraint as such. This should have been already // done, but we require non-negative quantity below. if (lb_sum > rhs.Max() || rhs.Min() > ub_sum) { - (void)MarkConstraintAsFalse(ct); + (void)MarkConstraintAsFalse(ct, "linear: trivially false"); context_->UpdateConstraintVariableUsage(c); return; } @@ -3360,6 +3348,7 @@ void CpModelPresolver::TryToReduceCoefficientsOfLinearConstraint( const bool use_ub = max_variation > rhs_ub; const bool use_lb = max_variation > rhs_lb; if (!use_ub && !use_lb) { + context_->UpdateRuleStats("linear: trivially true"); (void)RemoveConstraint(ct); context_->UpdateConstraintVariableUsage(c); return; @@ -3479,7 +3468,7 @@ void CpModelPresolver::TryToReduceCoefficientsOfLinearConstraint( const int64_t new_rhs_ub = use_ub ? shift_lb + ub_feasible_.CurrentMax() : shift_ub; if (new_rhs_lb > new_rhs_ub) { - (void)MarkConstraintAsFalse(ct); + (void)MarkConstraintAsFalse(ct, "linear: false after simplification"); context_->UpdateConstraintVariableUsage(c); return; } @@ -3509,7 +3498,7 @@ void CpModelPresolver::TryToReduceCoefficientsOfLinearConstraint( const int64_t new_rhs_ub = use_ub ? lb_sum + ub_feasible_.CurrentMax() : ub_sum; if (new_rhs_lb > new_rhs_ub) { - (void)MarkConstraintAsFalse(ct); + (void)MarkConstraintAsFalse(ct, "linear: reduce rhs with DP"); context_->UpdateConstraintVariableUsage(c); return; } @@ -3560,7 +3549,7 @@ void CpModelPresolver::TryToReduceCoefficientsOfLinearConstraint( mutable_linear->mutable_coeffs()->Truncate(new_size); const Domain new_rhs = Domain(-minus_new_lb, new_ub); if (new_rhs.IsEmpty()) { - (void)MarkConstraintAsFalse(ct); + (void)MarkConstraintAsFalse(ct, "linear: false after approximate gcd"); } else { FillDomainInProto(new_rhs, mutable_linear); } @@ -3778,8 +3767,8 @@ void CpModelPresolver::ProcessOneLinearWithAmo(int ct_index, Domain(min_bool_activity, max_bool_activity)); if (activity.IntersectionWith(rhs).IsEmpty()) { // Note that this covers min_bool_activity > max_bool_activity. - context_->UpdateRuleStats("linear + amo: infeasible linear constraint"); - (void)MarkConstraintAsFalse(ct); + (void)MarkConstraintAsFalse(ct, + "linear + amo: infeasible linear constraint"); context_->UpdateConstraintVariableUsage(ct_index); return; } else if (activity.IsIncludedIn(rhs)) { @@ -3883,9 +3872,8 @@ void CpModelPresolver::ProcessOneLinearWithAmo(int ct_index, if (temp_set_.contains(NegatedRef(lit))) { // A literal must be true but is incompatible with what the enforcement // implies. The constraint must be false! - context_->UpdateRuleStats( - "linear + amo: advanced infeasible linear constraint"); - (void)MarkConstraintAsFalse(ct); + (void)MarkConstraintAsFalse( + ct, "linear + amo: advanced infeasible linear constraint"); context_->UpdateConstraintVariableUsage(ct_index); return; } @@ -4009,8 +3997,7 @@ bool CpModelPresolver::PropagateDomainsInLinear(int ct_index, // Incorporate the implied rhs information. Domain rhs = old_rhs.SimplifyUsingImpliedDomain(implied_rhs); if (rhs.IsEmpty()) { - context_->UpdateRuleStats("linear: infeasible"); - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, "linear: infeasible"); } if (rhs != old_rhs) { if (ct_index != -1) context_->UpdateRuleStats("linear: simplified rhs"); @@ -4812,8 +4799,8 @@ bool CpModelPresolver::PresolveLinearOnBooleans(ConstraintProto* ct) { min_sum + min_coeff > rhs_domain.Max()) || (!rhs_domain.Contains(max_sum) && max_sum - min_coeff < rhs_domain.Min())) { - context_->UpdateRuleStats("linear: all booleans and trivially false"); - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, + "linear: all booleans and trivially false"); } if (Domain(min_sum, max_sum).IsIncludedIn(rhs_domain)) { context_->UpdateRuleStats("linear: all booleans and trivially true"); @@ -4991,8 +4978,8 @@ bool CpModelPresolver::PresolveInterval(int c, ConstraintProto* ct) { // If the size is < 0, then the interval cannot be performed. if (!ct->enforcement_literal().empty() && context_->SizeMax(c) < 0) { - context_->UpdateRuleStats("interval: negative size implies unperformed"); - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, + "interval: negative size implies unperformed"); } if (ct->enforcement_literal().empty()) { @@ -5465,8 +5452,7 @@ bool CpModelPresolver::PresolveTable(ConstraintProto* ct) { context_->UpdateRuleStats("table: negative table without tuples"); return RemoveConstraint(ct); } else { - context_->UpdateRuleStats("table: positive table without tuples"); - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, "table: positive table without tuples"); } } @@ -5481,8 +5467,7 @@ bool CpModelPresolver::PresolveTable(ConstraintProto* ct) { context_->UpdateRuleStats("table: always true"); return RemoveConstraint(ct); } else { - context_->UpdateRuleStats("table: always false"); - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, "table: always false"); } return RemoveConstraint(ct); } @@ -5874,7 +5859,9 @@ bool CpModelPresolver::PresolveNoOverlap(ConstraintProto* ct) { // Case 1: size > 0. Interval must be unperformed. if (context_->SizeMin(interval_index) > 0) { - if (!MarkConstraintAsFalse(interval_ct)) { + if (!MarkConstraintAsFalse( + interval_ct, + "no_overlap: duplicate interval with positive size")) { return false; } context_->UpdateConstraintVariableUsage(interval_index); @@ -7257,7 +7244,7 @@ bool CpModelPresolver::PresolveCircuit(ConstraintProto* ct) { // All the node must have some incoming and outgoing arcs. for (int i = 0; i < num_nodes; ++i) { if (incoming_arcs[i].empty() || outgoing_arcs[i].empty()) { - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, "circuit: node with no arcs"); } } @@ -7375,7 +7362,8 @@ bool CpModelPresolver::PresolveCircuit(ConstraintProto* ct) { for (int n = 0; n < num_nodes; ++n) { if (!visited[n] && !has_self_arc[n]) { // We have a subircuit, but it doesn't cover all the mandatory nodes. - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse( + ct, "circuit: non-covering fixed subcircuit"); } } context_->UpdateRuleStats("circuit: fully specified."); @@ -7449,7 +7437,7 @@ bool CpModelPresolver::PresolveAutomaton(ConstraintProto* ct) { const LinearExpressionProto& expr = proto->exprs(time); if (context_->IsFixed(expr)) { if (!reachable_labels[time].contains(context_->FixedValue(expr))) { - return MarkConstraintAsFalse(ct); + return MarkConstraintAsFalse(ct, "automaton: unsat"); } } else { std::vector unscaled_reachable_labels; @@ -8888,10 +8876,38 @@ void CpModelPresolver::ExpandObjective() { timer.AddCounter("issues", num_issues); } -void CpModelPresolver::MergeNoOverlapConstraints() { +namespace { +bool MaxCliqueHasMadeSomeChanges( + int old_num_constraints, int old_num_entries, + const std::vector>& cliques, bool no_overlaps) { + int new_num_constraints = 0; + int new_num_entries = 0; + for (const std::vector& clique : cliques) { + if (clique.empty()) continue; + new_num_constraints++; + new_num_entries += clique.size(); + } + if (old_num_constraints != new_num_constraints || + old_num_entries != new_num_entries) { + const std::string_view ct_name = + no_overlaps ? "no-overlaps" : "no-overlap_2ds"; + const std::string_view entry_name = + no_overlaps ? "intervals" : "rectangles"; + VLOG(1) << absl::StrCat("Merged ", old_num_constraints, " ", ct_name, " (", + old_num_entries, " ", entry_name, ") into ", + new_num_constraints, " ", ct_name, " (", + new_num_entries, " ", entry_name, ")."); + return true; + } + return false; +} + +} // namespace + +bool CpModelPresolver::MergeNoOverlapConstraints() { PresolveTimer timer("MergeNoOverlap", logger_, time_limit_); - if (context_->ModelIsUnsat()) return; - if (time_limit_->LimitReached()) return; + if (context_->ModelIsUnsat()) return false; + if (time_limit_->LimitReached()) return true; const int num_constraints = context_->working_model->constraints_size(); int old_num_no_overlaps = 0; @@ -8916,7 +8932,7 @@ void CpModelPresolver::MergeNoOverlapConstraints() { old_num_no_overlaps++; old_num_intervals += clique.size(); } - if (old_num_no_overlaps == 0) return; + if (old_num_no_overlaps == 0) return true; // We reuse the max-clique code from sat. Model local_model; @@ -8936,16 +8952,10 @@ void CpModelPresolver::MergeNoOverlapConstraints() { time_limit_->ResetHistory(); - int new_num_no_overlaps = 0; - int new_num_intervals = 0; - for (int i = 0; i < cliques.size(); ++i) { - new_num_no_overlaps++; - new_num_intervals += cliques[i].size(); - } - - if (old_num_intervals == new_num_intervals && - old_num_no_overlaps == new_num_no_overlaps) { - return; + if (!MaxCliqueHasMadeSomeChanges(old_num_no_overlaps, old_num_intervals, + cliques, + /*no_overlaps=*/true)) { + return true; } // Remove previous no_overlap constraints and add the new recomputed ones. @@ -8964,12 +8974,93 @@ void CpModelPresolver::MergeNoOverlapConstraints() { ct->mutable_no_overlap()->add_intervals(l.Variable().value()); } } - VLOG(1) << absl::StrCat("Merged ", old_num_no_overlaps, " no-overlaps (", - old_num_intervals, " intervals) into ", - new_num_no_overlaps, " no-overlaps (", - new_num_intervals, " intervals)."); context_->UpdateRuleStats("no_overlap: merged constraints"); context_->UpdateNewConstraintsVariableUsage(); + return true; +} + +bool CpModelPresolver::MergeNoOverlap2DConstraints() { + PresolveTimer timer("MergeNoOverlap2D", logger_, time_limit_); + if (context_->ModelIsUnsat()) return false; + if (time_limit_->LimitReached()) return true; + + const int num_constraints = context_->working_model->constraints_size(); + int old_num_no_overlap_2ds = 0; + int old_num_rectangles = 0; + + // Extract the no-overlap constraints with no enforcement literals. + // TODO(user): generalize this to merge constraints with the same + // enforcement literals? + std::vector no_overlap2d_index; + std::vector> cliques; + absl::flat_hash_map, int> rectangle_to_index; + std::vector> index_to_rectangle; + for (int c = 0; c < num_constraints; ++c) { + const ConstraintProto& ct = context_->working_model->constraints(c); + if (ct.constraint_case() != ConstraintProto::kNoOverlap2D) continue; + if (HasEnforcementLiteral(ct)) continue; + std::vector clique; + for (int i = 0; i < ct.no_overlap_2d().x_intervals_size(); ++i) { + const std::pair rect = {ct.no_overlap_2d().x_intervals(i), + ct.no_overlap_2d().y_intervals(i)}; + const auto [it, inserted] = + rectangle_to_index.insert({rect, rectangle_to_index.size()}); + if (inserted) index_to_rectangle.push_back(rect); + clique.push_back(Literal(BooleanVariable(it->second), true)); + } + cliques.push_back(clique); + no_overlap2d_index.push_back(c); + + old_num_no_overlap_2ds++; + old_num_rectangles += clique.size(); + } + if (old_num_no_overlap_2ds == 0) return true; + + // We reuse the max-clique code from sat. + Model local_model; + local_model.GetOrCreate()->Resize(num_constraints); + local_model.GetOrCreate()->MergeWithGlobalTimeLimit(time_limit_); + auto* graph = local_model.GetOrCreate(); + graph->Resize(num_constraints); + for (const std::vector& clique : cliques) { + // All variables at false is always a valid solution of the local model, + // so this should never return UNSAT. + CHECK(graph->AddAtMostOne(clique)); + } + CHECK(graph->DetectEquivalences()); + graph->TransformIntoMaxCliques( + &cliques, + SafeDoubleToInt64(context_->params().merge_no_overlap_work_limit())); + + time_limit_->ResetHistory(); + + if (!MaxCliqueHasMadeSomeChanges(old_num_no_overlap_2ds, old_num_rectangles, + cliques, + /*no_overlaps=*/false)) { + return true; + } + + // Remove previous no_overlap constraints and add the new recomputed ones. + for (int i = 0; i < cliques.size(); ++i) { + const int ct_index = no_overlap2d_index[i]; + if (RemoveConstraint( + context_->working_model->mutable_constraints(ct_index))) { + context_->UpdateConstraintVariableUsage(ct_index); + } + } + for (int i = 0; i < cliques.size(); ++i) { + if (cliques[i].empty()) continue; + ConstraintProto* ct = context_->working_model->add_constraints(); + for (const Literal l : cliques[i]) { + CHECK(l.IsPositive()); + const std::pair rect = index_to_rectangle[l.Variable().value()]; + ct->mutable_no_overlap_2d()->add_x_intervals(rect.first); + ct->mutable_no_overlap_2d()->add_y_intervals(rect.second); + } + } + context_->UpdateRuleStats("no_overlap_2d: merged constraints"); + context_->UpdateNewConstraintsVariableUsage(); + return true; } // TODO(user): Should we take into account the exactly_one constraints? note @@ -9349,9 +9440,9 @@ bool CpModelPresolver::ProcessSetPPCSubset(int subset_c, int superset_c, } if (reachable.IntersectionWith(superset_rhs).IsEmpty()) { // TODO(user): constraint might become bool_or. - context_->UpdateRuleStats("setppc: removed infeasible linear constraint"); *stop_processing_superset = true; - return MarkConstraintAsFalse(superset_ct); + return MarkConstraintAsFalse( + superset_ct, "setppc: removed infeasible linear constraint"); } // We reuse the normal linear constraint code to propagate domains of @@ -9623,8 +9714,8 @@ void CpModelPresolver::DetectIncludedEnforcement() { if (context_->tmp_literal_set.contains(ref)) { context_->UpdateRuleStats("bool_and: filtered literal"); } else if (context_->tmp_literal_set.contains(NegatedRef(ref))) { - context_->UpdateRuleStats("bool_and: must be false"); - if (!MarkConstraintAsFalse(superset_ct)) return; + if (!MarkConstraintAsFalse(superset_ct, "bool_and: must be false")) + return; context_->UpdateConstraintVariableUsage(superset_c); detector.StopProcessingCurrentSuperset(); return; @@ -10165,8 +10256,8 @@ void CpModelPresolver::DetectDuplicateConstraints() { const Domain rhs = rep_domain.IntersectionWith(d); if (rhs.IsEmpty()) { if (!MarkConstraintAsFalse( - context_->working_model->mutable_constraints(rep))) { - SOLVER_LOG(logger_, "Unsat after merging two linear constraints"); + context_->working_model->mutable_constraints(rep), + "duplicate: false after merging")) { return; } @@ -10443,9 +10534,12 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements( // IsFixed() do not work on empty domain. if (rhs.IsEmpty()) { - context_->UpdateRuleStats("duplicate: linear1 infeasible"); - if (!MarkConstraintAsFalse(rep_ct)) return; - if (!MarkConstraintAsFalse(dup_ct)) return; + if (!MarkConstraintAsFalse(rep_ct, + "duplicate: linear1 infeasible")) + return; + if (!MarkConstraintAsFalse(dup_ct, + "duplicate: linear1 infeasible")) + return; context_->UpdateConstraintVariableUsage(rep); context_->UpdateConstraintVariableUsage(dup); continue; @@ -12461,7 +12555,7 @@ void CpModelPresolver::MaybeTransferLinear1ToAnotherVariable(int var) { const Domain current = context_->DomainOf(new_var); new_domain = new_domain.IntersectionWith(current); if (new_domain.IsEmpty()) { - if (!MarkConstraintAsFalse(ct)) return; + if (!MarkConstraintAsFalse(ct, "linear1: unsat transfer")) return; } else if (new_domain == current) { ct->Clear(); } else { @@ -12533,7 +12627,8 @@ void CpModelPresolver::ProcessVariableOnlyUsedInEncoding(int var) { .IntersectionWith(context_->DomainOf(var)); if (implied.IsEmpty()) { if (!MarkConstraintAsFalse( - context_->working_model->mutable_constraints(unique_c))) { + context_->working_model->mutable_constraints(unique_c), + "encoding: empty implied domain")) { return; } context_->UpdateConstraintVariableUsage(unique_c); @@ -13831,9 +13926,8 @@ CpSolverStatus CpModelPresolver::Presolve() { } if (context_->ModelIsUnsat()) return InfeasibleStatus(); - // Regroup no-overlaps into max-cliques. - MergeNoOverlapConstraints(); - if (context_->ModelIsUnsat()) return InfeasibleStatus(); + if (!MergeNoOverlapConstraints()) return InfeasibleStatus(); + if (!MergeNoOverlap2DConstraints()) return InfeasibleStatus(); // Tries to spread the objective amongst many variables. // We re-do a canonicalization with the final linear expression. diff --git a/ortools/sat/cp_model_presolve.h b/ortools/sat/cp_model_presolve.h index 725f5f8a82..8191eea89a 100644 --- a/ortools/sat/cp_model_presolve.h +++ b/ortools/sat/cp_model_presolve.h @@ -299,7 +299,8 @@ class CpModelPresolver { void LookAtVariableWithDegreeTwo(int var); void ProcessVariableInTwoAtMostOrExactlyOne(int var); - void MergeNoOverlapConstraints(); + bool MergeNoOverlapConstraints(); + bool MergeNoOverlap2DConstraints(); // Assumes that all [constraint_index, multiple] in block are linear // constraint that contains multiple * common_part and perform the @@ -347,8 +348,8 @@ class CpModelPresolver { bool ExploitEquivalenceRelations(int c, ConstraintProto* ct); ABSL_MUST_USE_RESULT bool RemoveConstraint(ConstraintProto* ct); - ABSL_MUST_USE_RESULT bool MarkConstraintAsFalse( - ConstraintProto* ct, const std::string& reason = ""); + ABSL_MUST_USE_RESULT bool MarkConstraintAsFalse(ConstraintProto* ct, + std::string_view reason); std::vector* postsolve_mapping_; PresolveContext* context_; diff --git a/ortools/sat/cp_model_presolve_test.cc b/ortools/sat/cp_model_presolve_test.cc index f73272ac0e..37ef2c7b2d 100644 --- a/ortools/sat/cp_model_presolve_test.cc +++ b/ortools/sat/cp_model_presolve_test.cc @@ -2504,6 +2504,410 @@ TEST(PresolveCpModelTest, NoOverlap2DSplitSingletonBoxes) { EXPECT_THAT(presolved_model, testing::EqualsProto(expected_presolved_model)); } +TEST(PresolveCpModelTest, NoOverlap2DMerge) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + constraints { + interval { + start { vars: 0 coeffs: 1 } + end { vars: 0 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 1 coeffs: 1 } + end { vars: 1 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 2 coeffs: 1 } + end { vars: 2 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 3 coeffs: 1 } + end { vars: 3 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 4 coeffs: 1 } + end { vars: 4 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 5 coeffs: 1 } + end { vars: 5 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + no_overlap_2d { + x_intervals: [ 0, 1, 2, 3 ] + y_intervals: [ 0, 1, 2, 3 ] + } + } + constraints { + no_overlap_2d { + x_intervals: [ 0, 1, 4, 5 ] + y_intervals: [ 0, 1, 4, 5 ] + } + } + constraints { + no_overlap_2d { + x_intervals: [ 2, 3, 4, 5 ] + y_intervals: [ 2, 3, 4, 5 ] + } + } + )pb"); + const CpModelProto expected_presolved_model = ParseTestProto(R"pb( + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + constraints { + interval { + start { vars: 0 coeffs: 1 } + end { vars: 0 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 1 coeffs: 1 } + end { vars: 1 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 2 coeffs: 1 } + end { vars: 2 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 3 coeffs: 1 } + end { vars: 3 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 4 coeffs: 1 } + end { vars: 4 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 5 coeffs: 1 } + end { vars: 5 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + no_overlap_2d { + x_intervals: [ 0, 1, 2, 3, 5, 4 ] + y_intervals: [ 0, 1, 2, 3, 5, 4 ] + } + } + )pb"); + + SatParameters params; + params.set_keep_all_feasible_solutions_in_presolve(true); + const CpModelProto presolved_model = PresolveForTest(initial_model, params); + EXPECT_THAT(presolved_model, testing::EqualsProto(expected_presolved_model)); +} + +TEST(PresolveCpModelTest, NoOverlap2DMergePartial) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + constraints { + interval { + start { vars: 0 coeffs: 1 } + end { vars: 0 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 1 coeffs: 1 } + end { vars: 1 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 2 coeffs: 1 } + end { vars: 2 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 3 coeffs: 1 } + end { vars: 3 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 4 coeffs: 1 } + end { vars: 4 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 5 coeffs: 1 } + end { vars: 5 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + no_overlap_2d { + x_intervals: [ 0, 1, 2, 3, 4 ] + y_intervals: [ 0, 1, 2, 3, 4 ] + } + } + constraints { + no_overlap_2d { + x_intervals: [ 0, 1, 3, 4, 5 ] + y_intervals: [ 0, 1, 3, 4, 5 ] + } + } + constraints { + no_overlap_2d { + x_intervals: [ 1, 3, 4, 5 ] + y_intervals: [ 1, 3, 4, 5 ] + } + } + )pb"); + + const CpModelProto expected_presolved_model = ParseTestProto(R"pb( + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + constraints { + interval { + start { vars: 0 coeffs: 1 } + end { vars: 0 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 1 coeffs: 1 } + end { vars: 1 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 2 coeffs: 1 } + end { vars: 2 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 3 coeffs: 1 } + end { vars: 3 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 4 coeffs: 1 } + end { vars: 4 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 5 coeffs: 1 } + end { vars: 5 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + no_overlap_2d { + x_intervals: [ 0, 1, 2, 3, 4 ] + y_intervals: [ 0, 1, 2, 3, 4 ] + } + } + constraints { + no_overlap_2d { + x_intervals: [ 0, 1, 3, 4, 5 ] + y_intervals: [ 0, 1, 3, 4, 5 ] + } + } + )pb"); + + SatParameters params; + params.set_keep_all_feasible_solutions_in_presolve(true); + const CpModelProto presolved_model = PresolveForTest(initial_model, params); + EXPECT_THAT(presolved_model, testing::EqualsProto(expected_presolved_model)); +} + +TEST(PresolveCpModelTest, NoOverlap2DMergeWithOverlaps) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + constraints { + interval { + start { vars: 0 coeffs: 1 } + end { vars: 0 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 1 coeffs: 1 } + end { vars: 1 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 2 coeffs: 1 } + end { vars: 2 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 3 coeffs: 1 } + end { vars: 3 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 4 coeffs: 1 } + end { vars: 4 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 5 coeffs: 1 } + end { vars: 5 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + no_overlap_2d { + x_intervals: [ 0, 1, 2, 3, 4 ] + y_intervals: [ 0, 1, 2, 3, 4 ] + } + } + constraints { + no_overlap_2d { + x_intervals: [ 0, 1, 2, 4, 5 ] + y_intervals: [ 0, 1, 2, 4, 5 ] + } + } + constraints { + no_overlap_2d { + x_intervals: [ 1, 3, 4, 5 ] + y_intervals: [ 1, 3, 4, 5 ] + } + } + )pb"); + + const CpModelProto expected_presolved_model = ParseTestProto(R"pb( + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 0, 10 ] } + constraints { + interval { + start { vars: 0 coeffs: 1 } + end { vars: 0 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 1 coeffs: 1 } + end { vars: 1 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 2 coeffs: 1 } + end { vars: 2 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 3 coeffs: 1 } + end { vars: 3 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 4 coeffs: 1 } + end { vars: 4 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + interval { + start { vars: 5 coeffs: 1 } + end { vars: 5 coeffs: 1 offset: 5 } + size { offset: 5 } + } + } + constraints { + no_overlap_2d { + x_intervals: [ 0, 1, 2, 3, 4, 5 ] + y_intervals: [ 0, 1, 2, 3, 4, 5 ] + } + } + )pb"); + + SatParameters params; + params.set_keep_all_feasible_solutions_in_presolve(true); + const CpModelProto presolved_model = PresolveForTest(initial_model, params); + EXPECT_THAT(presolved_model, testing::EqualsProto(expected_presolved_model)); +} + TEST(PresolveCpModelTest, IntProdWithLeftConstant) { const CpModelProto initial_model = ParseTestProto(R"pb( variables { diff --git a/ortools/sat/cp_model_search.cc b/ortools/sat/cp_model_search.cc index 2876df4c9a..c1c3db2b80 100644 --- a/ortools/sat/cp_model_search.cc +++ b/ortools/sat/cp_model_search.cc @@ -377,7 +377,6 @@ std::function ConstructHeuristicSearchStrategy( } heuristics.push_back(SchedulingSearchHeuristic(model)); - CHECK(!heuristics.empty()); return SequentialSearch(std::move(heuristics)); } return nullptr; @@ -428,26 +427,22 @@ std::function ConstructHintSearchStrategy( return FollowHint(vars, values, model); } -std::function ConstructFixedSearchStrategy( - std::function user_search, - std::function heuristic_search, - std::function integer_completion, Model* model) { +void ConstructFixedSearchStrategy(SearchHeuristics* h, Model* model) { // We start by the user specified heuristic. std::vector> heuristics; - if (user_search != nullptr) { - heuristics.push_back(user_search); + if (h->user_search != nullptr) { + heuristics.push_back(h->user_search); } - if (heuristic_search != nullptr) { - heuristics.push_back(heuristic_search); - } - if (heuristics.empty()) { + if (h->heuristic_search != nullptr) { + heuristics.push_back(h->heuristic_search); + } else { heuristics.push_back(PseudoCost(model)); } - if (integer_completion != nullptr) { - heuristics.push_back(integer_completion); + if (h->integer_completion_search != nullptr) { + heuristics.push_back(h->integer_completion_search); } - return SequentialSearch(std::move(heuristics)); + h->fixed_search = SequentialSearch(std::move(heuristics)); } std::function InstrumentSearchStrategy( diff --git a/ortools/sat/cp_model_search.h b/ortools/sat/cp_model_search.h index a14f192619..b714f3f0d7 100644 --- a/ortools/sat/cp_model_search.h +++ b/ortools/sat/cp_model_search.h @@ -89,12 +89,9 @@ std::function ConstructHintSearchStrategy( const CpModelProto& cp_model_proto, CpModelMapping* mapping, Model* model); // Constructs our "fixed" search strategy which start with -// ConstructUserSearchStrategy() but is completed by a couple of automatic -// heuristics. -std::function ConstructFixedSearchStrategy( - std::function user_search, - std::function heuristic_search, - std::function integer_completion, Model* model); +// ConstructUserSearchStrategy() if present, but is completed by a couple of +// automatic heuristics. +void ConstructFixedSearchStrategy(SearchHeuristics* h, Model* model); // For debugging fixed-search: display information about the named variables // domain before taking each decision. Note that we copy the instrumented diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index ea046c6487..e42b8f83b0 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -1064,7 +1064,7 @@ class FullProblemSolver : public SubSolver { // parameter provided by the user). if (shared_->SearchIsDone()) return false; - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); if (previous_task_is_completed_) { if (solving_first_chunk_) return true; if (split_in_chunks_) return true; @@ -1074,7 +1074,7 @@ class FullProblemSolver : public SubSolver { std::function GenerateTask(int64_t /*task_id*/) override { { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); previous_task_is_completed_ = false; } return [this]() { @@ -1132,7 +1132,7 @@ class FullProblemSolver : public SubSolver { solving_first_chunk_ = false; // Make sure we count the loading/hint dtime. - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); dtime_since_last_sync_ += time_limit->GetElapsedDeterministicTime() - init_dtime; @@ -1155,7 +1155,7 @@ class FullProblemSolver : public SubSolver { const double saved_dtime = time_limit->GetElapsedDeterministicTime(); SolveLoadedCpModel(shared_->model_proto, &local_model_); - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); previous_task_is_completed_ = true; dtime_since_last_sync_ += time_limit->GetElapsedDeterministicTime() - saved_dtime; @@ -1166,7 +1166,7 @@ class FullProblemSolver : public SubSolver { // happen here (bound sharing, RINS neighborhood, objective). Fix that so we // can have a deterministic parallel mode. void Synchronize() override { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); AddTaskDeterministicDuration(dtime_since_last_sync_); shared_->time_limit->AdvanceDeterministicTime(dtime_since_last_sync_); dtime_since_last_sync_ = 0.0; @@ -1211,18 +1211,18 @@ class FeasibilityPumpSolver : public SubSolver { bool TaskIsAvailable() override { if (shared_->SearchIsDone()) return false; - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); return previous_task_is_completed_; } std::function GenerateTask(int64_t /*task_id*/) override { { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); previous_task_is_completed_ = false; } return [this]() { { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); if (solving_first_chunk_) { LoadFeasibilityPump(shared_->model_proto, local_model_.get()); // No new task will be scheduled for this worker if there is no @@ -1243,7 +1243,7 @@ class FeasibilityPumpSolver : public SubSolver { } { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); dtime_since_last_sync_ += time_limit->GetElapsedDeterministicTime() - saved_dtime; } @@ -1254,13 +1254,13 @@ class FeasibilityPumpSolver : public SubSolver { return; } - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); previous_task_is_completed_ = true; }; } void Synchronize() override { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); AddTaskDeterministicDuration(dtime_since_last_sync_); shared_->time_limit->AdvanceDeterministicTime(dtime_since_last_sync_); dtime_since_last_sync_ = 0.0; @@ -1393,7 +1393,7 @@ class LnsSolver : public SubSolver { // Presolve and solve the LNS fragment. size_t buffer_size; { - absl::MutexLock l(&next_arena_size_mutex_); + absl::MutexLock l(next_arena_size_mutex_); buffer_size = next_arena_size_; } google::protobuf::Arena arena( @@ -1687,7 +1687,7 @@ class LnsSolver : public SubSolver { ", p:", fully_solved_proportion, "]"); } { - absl::MutexLock l(&next_arena_size_mutex_); + absl::MutexLock l(next_arena_size_mutex_); next_arena_size_ = arena.SpaceUsed(); } }; diff --git a/ortools/sat/cp_model_solver_helpers.cc b/ortools/sat/cp_model_solver_helpers.cc index d2aaaecb67..c8fc9a3413 100644 --- a/ortools/sat/cp_model_solver_helpers.cc +++ b/ortools/sat/cp_model_solver_helpers.cc @@ -31,6 +31,7 @@ #include "ortools/base/helpers.h" #include "ortools/base/options.h" #endif // __PORTABLE_PLATFORM__ +#include "absl/algorithm/container.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_set.h" #include "absl/flags/flag.h" @@ -40,7 +41,6 @@ #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "absl/time/time.h" #include "absl/types/span.h" #include "google/protobuf/arena.h" #include "ortools/algorithms/sparse_permutation.h" @@ -49,6 +49,7 @@ #include "ortools/port/proto_utils.h" #include "ortools/sat/clause.h" #include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/cp_model_checker.h" #include "ortools/sat/cp_model_loader.h" #include "ortools/sat/cp_model_mapping.h" #include "ortools/sat/cp_model_postsolve.h" @@ -126,6 +127,14 @@ void InitializeDebugSolution(const CpModelProto& model_proto, Model* model) { if (shared_response == nullptr) return; if (shared_response->DebugSolution().empty()) return; + if (!SolutionIsFeasible(model_proto, shared_response->DebugSolution())) { + // TODO(user): we should probably CHECK-fail. + SOLVER_LOG(model->GetOrCreate(), + "Debug solution is not feasible."); + return; + } + SOLVER_LOG(model->GetOrCreate(), "Debug solution is feasible."); + // Copy the proto values. DebugSolution& debug_sol = *model->GetOrCreate(); debug_sol.proto_values = shared_response->DebugSolution(); @@ -161,40 +170,64 @@ void InitializeDebugSolution(const CpModelProto& model_proto, Model* model) { // it in the sat solver for debugging too. if (boolean_solution.size() == debug_sol.proto_values.size() && !model_proto.has_objective()) { - LOG(INFO) << "Loaded pure Boolean debugging solution."; + SOLVER_LOG(model->GetOrCreate(), + "Loaded pure Boolean debugging solution."); model->GetOrCreate()->LoadDebugSolution(boolean_solution); } // The objective variable is usually not part of the proto, but it is still // nice to have it, so we recompute it here. auto* objective_def = model->Get(); - if (objective_def != nullptr) { - const IntegerVariable objective_var = objective_def->objective_var; - const int64_t objective_value = - ComputeInnerObjective(model_proto.objective(), debug_sol.proto_values); - debug_sol.ivar_has_value[objective_var] = true; - debug_sol.ivar_has_value[NegationOf(objective_var)] = true; - debug_sol.ivar_values[objective_var] = objective_value; - debug_sol.ivar_values[NegationOf(objective_var)] = -objective_value; + if (objective_def != nullptr && + objective_def->objective_var != kNoIntegerVariable) { + if (absl::c_all_of(objective_def->vars, [&debug_sol](IntegerVariable var) { + return var < debug_sol.ivar_has_value.end_index() && + debug_sol.ivar_has_value[var]; + })) { + const IntegerVariable objective_var = objective_def->objective_var; + if (objective_var + 1 >= debug_sol.ivar_has_value.size()) { + debug_sol.ivar_has_value.resize(objective_var + 2, false); + debug_sol.ivar_values.resize(objective_var + 2, 0); + } + IntegerValue objective_value = 0; + for (int i = 0; i < objective_def->vars.size(); ++i) { + objective_value += objective_def->coeffs[i] * + debug_sol.ivar_values[objective_def->vars[i]]; + } + SOLVER_LOG( + model->GetOrCreate(), + absl::StrCat("Debug solution objective value: ", + objective_def->ScaleIntegerObjective(objective_value))); + debug_sol.ivar_has_value[objective_var] = true; + debug_sol.ivar_has_value[NegationOf(objective_var)] = true; + debug_sol.ivar_values[objective_var] = objective_value; + debug_sol.ivar_values[NegationOf(objective_var)] = -objective_value; + debug_sol.inner_objective_value = objective_value; + } } // We also register a DEBUG callback to check our reasons. auto* encoder = model->GetOrCreate(); - const auto checker = [mapping = mapping, encoder, debug_sol, model]( + const auto checker = [mapping = &mapping, encoder, model]( absl::Span clause, absl::Span integers) { + const DebugSolution* debug_sol = model->Get(); + if (!debug_sol || debug_sol->proto_values.empty()) return true; + bool is_satisfied = false; int num_bools = 0; int num_ints = 0; - std::vector> to_print; + std::vector> to_print; for (const Literal l : clause) { // First case, this Boolean is mapped. { const int proto_var = - mapping.GetProtoVariableFromBooleanVariable(l.Variable()); + mapping->GetProtoVariableFromBooleanVariable(l.Variable()); if (proto_var != -1) { - to_print.push_back({l, IntegerLiteral(), proto_var}); - if (debug_sol.proto_values[proto_var] == (l.IsPositive() ? 1 : 0)) { + CHECK_LT(proto_var, debug_sol->proto_values.size()); + to_print.push_back( + {l, IntegerLiteral(), debug_sol->proto_values[proto_var]}); + if (debug_sol->proto_values[proto_var] == (l.IsPositive() ? 1 : 0)) { is_satisfied = true; break; } @@ -207,13 +240,13 @@ void InitializeDebugSolution(const CpModelProto& model_proto, Model* model) { // We can use any of them, so if one is false, we use this one. bool all_true = true; for (const IntegerLiteral associated : encoder->GetIntegerLiterals(l)) { - const int proto_var = mapping.GetProtoVariableFromIntegerVariable( - PositiveVariable(associated.var)); - if (proto_var == -1) break; - int64_t value = debug_sol.proto_values[proto_var]; - to_print.push_back({l, associated, proto_var}); + if (associated.var >= debug_sol->ivar_has_value.end_index() || + !debug_sol->ivar_has_value[associated.var]) { + break; + } + const IntegerValue value = debug_sol->ivar_values[associated.var]; + to_print.push_back({l, associated, value}); - if (!VariableIsPositive(associated.var)) value = -value; if (value < associated.bound) { ++num_ints; all_true = false; @@ -226,20 +259,18 @@ void InitializeDebugSolution(const CpModelProto& model_proto, Model* model) { } } for (const IntegerLiteral i_lit : integers) { - const int proto_var = mapping.GetProtoVariableFromIntegerVariable( - PositiveVariable(i_lit.var)); - if (proto_var == -1) { + DCHECK(!i_lit.IsAlwaysFalse()); + if (i_lit.IsAlwaysTrue()) continue; + if (i_lit.var >= debug_sol->ivar_has_value.end_index() || + !debug_sol->ivar_has_value[i_lit.var]) { is_satisfied = true; break; } - int64_t value = debug_sol.proto_values[proto_var]; - to_print.push_back({Literal(kNoLiteralIndex), i_lit, proto_var}); + const IntegerValue value = debug_sol->ivar_values[i_lit.var]; + to_print.push_back({Literal(kNoLiteralIndex), i_lit, value}); - if (!VariableIsPositive(i_lit.var)) value = -value; - // Note the sign is inversed, we cannot have all literal false and all - // integer literal true. - if (value >= i_lit.bound) { + if (value < i_lit.bound) { is_satisfied = true; break; } @@ -250,9 +281,12 @@ void InitializeDebugSolution(const CpModelProto& model_proto, Model* model) { << model->GetOrCreate()->CurrentDecisionLevel(); LOG(INFO) << "literals (neg): " << clause; LOG(INFO) << "integer literals: " << integers; - for (const auto [l, i_lit, proto_var] : to_print) { - LOG(INFO) << l << " " << i_lit << " var=" << proto_var - << " value_in_sol=" << debug_sol.proto_values[proto_var]; + for (const auto [l, i_lit, solution_value] : to_print) { + const int proto_var = + mapping->GetProtoVariableFromIntegerVariable(i_lit.var); + LOG(INFO) << l << " " << i_lit << " var=" + << (proto_var == -1 ? "none" : absl::StrCat(proto_var)) + << " value_in_sol=" << solution_value; } } return is_satisfied; @@ -1596,9 +1630,7 @@ void LoadCpModel(const CpModelProto& model_proto, Model* model) { search_heuristics->integer_completion_search = ConstructIntegerCompletionSearchStrategy(mapping->GetVariableMapping(), objective_var, model); - search_heuristics->fixed_search = ConstructFixedSearchStrategy( - search_heuristics->user_search, search_heuristics->heuristic_search, - search_heuristics->integer_completion_search, model); + ConstructFixedSearchStrategy(search_heuristics, model); if (VLOG_IS_ON(3)) { search_heuristics->fixed_search = InstrumentSearchStrategy(model_proto, mapping->GetVariableMapping(), diff --git a/ortools/sat/cp_model_solver_test.cc b/ortools/sat/cp_model_solver_test.cc index 01419321e3..a81ec3d227 100644 --- a/ortools/sat/cp_model_solver_test.cc +++ b/ortools/sat/cp_model_solver_test.cc @@ -14,6 +14,7 @@ #include "ortools/sat/cp_model_solver.h" #include +#include #include #include @@ -27,10 +28,17 @@ #include "ortools/port/os.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_checker.h" +#include "ortools/sat/cp_model_solver_helpers.h" #include "ortools/sat/cp_model_test_utils.h" +#include "ortools/sat/cp_model_utils.h" +#include "ortools/sat/drat_checker.h" +#include "ortools/sat/drat_proof_handler.h" #include "ortools/sat/lp_utils.h" #include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/sat/synchronization.h" #include "ortools/util/logging.h" namespace operations_research { @@ -5451,6 +5459,42 @@ TEST(PresolveCpModelTest, SolutionCrushBug) { EXPECT_EQ(response.status(), CpSolverStatus::INFEASIBLE); } +TEST(CpModelSolverTest, DratProofIsValidForRandom3Sat) { + int num_infeasible = 0; + for (int i = 0; i < 100; ++i) { + Model model; + SatSolver& solver = *model.GetOrCreate(); + auto drat_proof_handler = std::make_unique(); + solver.SetDratProofHandler(drat_proof_handler.get()); + + const int kNumVariables = 100; + CpModelProto model_proto = Random3SatProblem(kNumVariables); + + drat_proof_handler->SetNumVariables(model_proto.variables_size()); + for (const ConstraintProto& ct : model_proto.constraints()) { + if (ct.constraint_case() == ConstraintProto::ConstraintCase::kBoolOr) { + std::vector clause; + for (const int ref : ct.bool_or().literals()) { + clause.push_back( + Literal(BooleanVariable(PositiveRef(ref)), RefIsPositive(ref))); + } + drat_proof_handler->AddProblemClause(clause); + } + } + + LoadCpModel(model_proto, &model); + SolveLoadedCpModel(model_proto, &model); + if (model.GetOrCreate()->GetResponse().status() == + CpSolverStatus::INFEASIBLE) { + ++num_infeasible; + EXPECT_EQ(drat_proof_handler->Check(/*max_time_in_seconds=*/60), + DratChecker::Status::VALID); + } + } + LOG(INFO) << "num_infeasible: " << num_infeasible; + EXPECT_GT(num_infeasible, 0); +} + #endif // ORTOOLS_TARGET_OS_SUPPORTS_THREADS } // namespace diff --git a/ortools/sat/cuts_test.cc b/ortools/sat/cuts_test.cc index 851969a48d..fa2fc29aa9 100644 --- a/ortools/sat/cuts_test.cc +++ b/ortools/sat/cuts_test.cc @@ -167,7 +167,7 @@ TEST(CoverCutHelperTest, SimpleExample) { Model model; CoverCutHelper helper(&model); EXPECT_TRUE(helper.TrySimpleKnapsack(data)); - EXPECT_EQ(GetCutString(helper), "1*X0 1*X1 1*X2 <= 1"); + EXPECT_EQ(GetCutString(helper), "1*I0 1*I1 1*I2 <= 1"); EXPECT_EQ(helper.Info(), "lift=1"); } @@ -193,7 +193,7 @@ TEST(CoverCutHelperTest, WeirdExampleWithViolatedConstraint) { Model model; CoverCutHelper helper(&model); EXPECT_TRUE(helper.TrySimpleKnapsack(data)); - EXPECT_EQ(GetCutString(helper), "1*X0 1*X1 <= 9"); + EXPECT_EQ(GetCutString(helper), "1*I0 1*I1 <= 9"); EXPECT_EQ(helper.Info(), "lift=1"); } @@ -221,7 +221,7 @@ TEST(CoverCutHelperTest, LetchfordSouliLifting) { CoverCutHelper helper(&model); EXPECT_TRUE(helper.TryWithLetchfordSouliLifting(data)); EXPECT_EQ(GetCutString(helper), - "1*X0 1*X1 1*X2 1*X3 3*X4 3*X5 2*X6 1*X7 1*X8 1*X9 <= 3"); + "1*I0 1*I1 1*I2 1*I3 3*I4 3*I5 2*I6 1*I7 1*I8 1*I9 <= 3"); // For now, we only support Booleans in the cover. // Note that we don't care for variable not in the cover though. @@ -270,7 +270,7 @@ TEST(IntegerRoundingCutTest, LetchfordLodiExample1) { options.max_scaling = 2; LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( options, rhs, vars, coeffs, lp_values, &model); - EXPECT_EQ(constraint.DebugString(), "2*X0 1*X1 <= 2"); + EXPECT_EQ(constraint.DebugString(), "2*I0 1*I1 <= 2"); } TEST(IntegerRoundingCutTest, LetchfordLodiExample1Modified) { @@ -290,7 +290,7 @@ TEST(IntegerRoundingCutTest, LetchfordLodiExample1Modified) { // Note that the cut is only valid because the bound of x1 is one here. LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( RoundingOptions(), rhs, vars, coeffs, lp_values, &model); - EXPECT_EQ(constraint.DebugString(), "1*X0 1*X1 <= 1"); + EXPECT_EQ(constraint.DebugString(), "1*I0 1*I1 <= 1"); } TEST(IntegerRoundingCutTest, LetchfordLodiExample2) { @@ -306,7 +306,7 @@ TEST(IntegerRoundingCutTest, LetchfordLodiExample2) { std::vector lp_values{0.0, 2.25}; LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( RoundingOptions(), rhs, vars, coeffs, lp_values, &model); - EXPECT_EQ(constraint.DebugString(), "3*X0 2*X1 <= 4"); + EXPECT_EQ(constraint.DebugString(), "3*I0 2*I1 <= 4"); } TEST(IntegerRoundingCutTest, LetchfordLodiExample2WithNegatedCoeff) { @@ -323,10 +323,10 @@ TEST(IntegerRoundingCutTest, LetchfordLodiExample2WithNegatedCoeff) { LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( RoundingOptions(), rhs, vars, coeffs, lp_values, &model); - // We actually do not return like in the example "3*X0 -2*X1 <= 4" - // But the simpler X0 - X1 <= 2 which has the same violation (0.25) but a + // We actually do not return like in the example "3*I0 -2*I1 <= 4" + // But the simpler I0 - I1 <= 2 which has the same violation (0.25) but a // better norm. - EXPECT_EQ(constraint.DebugString(), "1*X0 -1*X1 <= 2"); + EXPECT_EQ(constraint.DebugString(), "1*I0 -1*I1 <= 2"); } // This used to trigger a failure with a wrong implied bound code path. @@ -349,7 +349,7 @@ TEST(IntegerRoundingCutTest, TestCaseUsedForDebugging) { LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( RoundingOptions(), rhs, vars, coeffs, lp_values, &model); - EXPECT_EQ(constraint.DebugString(), "-2*X0 -1*X1 -2*X2 -2*X3 2*X4 <= -2"); + EXPECT_EQ(constraint.DebugString(), "-2*I0 -1*I1 -2*I2 -2*I3 2*I4 <= -2"); } // The algo should find a "divisor" 2 when it lead to a good cut. @@ -373,7 +373,7 @@ TEST(IntegerRoundingCutTest, ZeroHalfCut) { std::vector lp_values{0.25, 1.25, 0.3125, 0.0}; LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( RoundingOptions(), rhs, vars, coeffs, lp_values, &model); - EXPECT_EQ(constraint.DebugString(), "3*X0 2*X1 4*X2 3*X3 <= 4"); + EXPECT_EQ(constraint.DebugString(), "3*I0 2*I1 4*I2 3*I3 <= 4"); } TEST(IntegerRoundingCutTest, LargeCoeffWithSmallImprecision) { @@ -386,12 +386,12 @@ TEST(IntegerRoundingCutTest, LargeCoeffWithSmallImprecision) { std::vector vars = {x0, x1}; std::vector coeffs = {IntegerValue(1e6), IntegerValue(-1)}; - // Note thate without adjustement, this returns 2 * X0 - X1 <= 2. + // Note thate without adjustement, this returns 2 * I0 - I1 <= 2. // TODO(user): expose parameters so this can be verified other than manually? std::vector lp_values{1.5, 0.1}; LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( RoundingOptions(), rhs, vars, coeffs, lp_values, &model); - EXPECT_EQ(constraint.DebugString(), "1*X0 <= 1"); + EXPECT_EQ(constraint.DebugString(), "1*I0 <= 1"); } TEST(IntegerRoundingCutTest, LargeCoeffWithSmallImprecision2) { @@ -404,12 +404,12 @@ TEST(IntegerRoundingCutTest, LargeCoeffWithSmallImprecision2) { std::vector vars = {x0, x1}; std::vector coeffs = {IntegerValue(1e6), IntegerValue(999999)}; - // Note thate without adjustement, this returns 2 * X0 + X1 <= 2. + // Note thate without adjustement, this returns 2 * I0 + I1 <= 2. // TODO(user): expose parameters so this can be verified other than manually? std::vector lp_values{1.49, 0.1}; LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( RoundingOptions(), rhs, vars, coeffs, lp_values, &model); - EXPECT_EQ(constraint.DebugString(), "1*X0 1*X1 <= 1"); + EXPECT_EQ(constraint.DebugString(), "1*I0 1*I1 <= 1"); } TEST(IntegerRoundingCutTest, MirOnLargerConstraint) { @@ -433,7 +433,7 @@ TEST(IntegerRoundingCutTest, MirOnLargerConstraint) { options.max_scaling = 4; LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( options, rhs, vars, coeffs, lp_values, &model); - EXPECT_EQ(constraint.DebugString(), "1*X6 2*X7 3*X8 4*X9 <= 4"); + EXPECT_EQ(constraint.DebugString(), "1*I6 2*I7 3*I8 4*I9 <= 4"); } TEST(IntegerRoundingCutTest, MirOnLargerConstraint2) { @@ -457,7 +457,7 @@ TEST(IntegerRoundingCutTest, MirOnLargerConstraint2) { LinearConstraint constraint = IntegerRoundingCutWithBoundsFromTrail( options, rhs, vars, coeffs, lp_values, &model); EXPECT_EQ(constraint.DebugString(), - "2*X1 3*X2 4*X3 6*X4 6*X5 8*X6 9*X7 10*X8 12*X9 <= 18"); + "2*I1 3*I2 4*I3 6*I4 6*I5 8*I6 9*I7 10*I8 12*I9 <= 18"); } std::vector ToIntegerValues(const std::vector input) { @@ -550,7 +550,7 @@ TEST(SquareCutGeneratorTest, TestBelowCut) { square.generate_cuts(manager); EXPECT_EQ(1, manager->num_cuts()); EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), - EndsWith("-5*X0 1*X1 <= 0")); + EndsWith("-5*I0 1*I1 <= 0")); } TEST(SquareCutGeneratorTest, TestBelowCutWithOffset) { @@ -564,7 +564,7 @@ TEST(SquareCutGeneratorTest, TestBelowCutWithOffset) { square.generate_cuts(manager); ASSERT_EQ(1, manager->num_cuts()); EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), - EndsWith("-6*X0 1*X1 <= -5")); + EndsWith("-6*I0 1*I1 <= -5")); } TEST(SquareCutGeneratorTest, TestNoBelowCut) { @@ -590,7 +590,7 @@ TEST(SquareCutGeneratorTest, TestAboveCut) { square.generate_cuts(manager); ASSERT_EQ(1, manager->num_cuts()); EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), - StartsWith("-6 <= -5*X0 1*X1")); + StartsWith("-6 <= -5*I0 1*I1")); } TEST(SquareCutGeneratorTest, TestNearlyAboveCut) { @@ -618,7 +618,7 @@ TEST(MultiplicationCutGeneratorTest, TestCut1) { mult.generate_cuts(manager); ASSERT_EQ(1, manager->num_cuts()); EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), - EndsWith("2*X0 1*X1 -1*X2 <= 2")); + EndsWith("2*I0 1*I1 -1*I2 <= 2")); } TEST(MultiplicationCutGeneratorTest, TestCut2) { @@ -634,7 +634,7 @@ TEST(MultiplicationCutGeneratorTest, TestCut2) { mult.generate_cuts(manager); ASSERT_EQ(1, manager->num_cuts()); EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), - EndsWith("3*X0 5*X1 -1*X2 <= 15")); + EndsWith("3*I0 5*I1 -1*I2 <= 15")); } TEST(MultiplicationCutGeneratorTest, TestCut3) { @@ -650,9 +650,9 @@ TEST(MultiplicationCutGeneratorTest, TestCut3) { mult.generate_cuts(manager); ASSERT_EQ(2, manager->num_cuts()); EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), - StartsWith("3 <= 3*X0 1*X1 -1*X2")); + StartsWith("3 <= 3*I0 1*I1 -1*I2")); EXPECT_THAT(manager->AllConstraints().back().constraint.DebugString(), - StartsWith("10 <= 2*X0 5*X1 -1*X2")); + StartsWith("10 <= 2*I0 5*I1 -1*I2")); } TEST(MultiplicationCutGeneratorTest, TestNoCut1) { @@ -698,7 +698,7 @@ TEST(AllDiffCutGeneratorTest, TestCut) { all_diff.generate_cuts(manager); ASSERT_EQ(1, manager->num_cuts()); EXPECT_EQ(manager->AllConstraints().front().constraint.DebugString(), - "50 <= 1*X0 1*X1 1*X2 <= 50"); + "50 <= 1*I0 1*I1 1*I2 <= 50"); } TEST(AllDiffCutGeneratorTest, TestCut2) { @@ -716,9 +716,9 @@ TEST(AllDiffCutGeneratorTest, TestCut2) { all_diff.generate_cuts(manager); ASSERT_EQ(2, manager->num_cuts()); EXPECT_EQ(manager->AllConstraints().front().constraint.DebugString(), - "25 <= 1*X1 1*X2 <= 40"); + "25 <= 1*I1 1*I2 <= 40"); EXPECT_EQ(manager->AllConstraints().back().constraint.DebugString(), - "50 <= 1*X0 1*X1 1*X2 <= 50"); + "50 <= 1*I0 1*I1 1*I2 <= 50"); } // We model the maximum of 3 affine functions: @@ -767,11 +767,11 @@ TEST(LinMaxCutsTest, BasicCuts1) { max_cuts.generate_cuts(manager); ASSERT_EQ(1, manager->num_cuts()); - // x vars are X0,X1 respectively, target is X2, z_vars are X3,X4,X5 + // x vars are I0,I1 respectively, target is I2, z_vars are I3,I4,I5 // respectively. // Most violated inequality is 2. EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), - StartsWith("0 <= -2*X1 -1*X2 3*X3 1*X4 4*X5")); + StartsWith("0 <= -2*I1 -1*I2 3*I3 1*I4 4*I5")); InitializeLpValues({-1.0, -1.0, 2.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0}, &model); @@ -779,7 +779,7 @@ TEST(LinMaxCutsTest, BasicCuts1) { ASSERT_EQ(2, manager->num_cuts()); // Most violated inequality is 3. EXPECT_THAT(manager->AllConstraints().back().constraint.DebugString(), - StartsWith("0 <= 1*X1 -1*X2 2*X3 4*X4 1*X5")); + StartsWith("0 <= 1*I1 -1*I2 2*I3 4*I4 1*I5")); } // We model the maximum of 3 affine functions: @@ -806,7 +806,7 @@ TEST(LinMaxCutsTest, AffineCuts1) { BuildMaxAffineUpConstraint(target_expr, x, affines, &model, &builder)); // Note, the cut is not normalized. - EXPECT_EQ(builder.Build().DebugString(), "20*X1 <= 200"); + EXPECT_EQ(builder.Build().DebugString(), "20*I1 <= 200"); } // We model the maximum of 3 affine functions: @@ -832,7 +832,7 @@ TEST(LinMaxCutsTest, AffineCuts2) { ASSERT_TRUE( BuildMaxAffineUpConstraint(target_expr, x, affines, &model, &builder)); - EXPECT_EQ(builder.Build().DebugString(), "-9*X0 11*X1 <= 20"); + EXPECT_EQ(builder.Build().DebugString(), "-9*I0 11*I1 <= 20"); } // We model the maximum of 3 affine functions: @@ -1080,7 +1080,7 @@ TEST(CutDataTest, SimpleExample) { cut.FillFromParallelVectors(rhs, vars, coeffs, lp_values, lbs, ubs); cut.ComplementForSmallerLpValues(); - // 6 (X0' + 7) - 4 (X1' - 3) <= 9 + // 6 (I0' + 7) - 4 (I1' - 3) <= 9 ASSERT_EQ(cut.terms.size(), 2); EXPECT_EQ(cut.rhs, 9 - 4 * 3 - 6 * 7); EXPECT_EQ(cut.terms[0].coeff, 6); diff --git a/ortools/sat/diffn_cuts.cc b/ortools/sat/diffn_cuts.cc index 91485e84d9..960b03c0ce 100644 --- a/ortools/sat/diffn_cuts.cc +++ b/ortools/sat/diffn_cuts.cc @@ -128,8 +128,7 @@ struct DiffnEnergyEvent : DiffnBaseEvent { ", x_start_max = ", x_start_max.value(), ", x_end_min = ", x_end_min.value(), ", x_end_max = ", x_end_max.value(), ", y_min = ", y_min.value(), - ", y_max = ", y_max.value(), ", y_size = ", y_size.DebugString(), - ", energy = ", + ", y_max = ", y_max.value(), ", y_size = ", y_size, ", energy = ", decomposed_energy.empty() ? "{}" : absl::StrCat(decomposed_energy.size(), " terms"), @@ -373,15 +372,14 @@ DiffnCtEvent::DiffnCtEvent(int t, const SchedulingConstraintHelper* x_helper) : DiffnBaseEvent(t, x_helper) {} std::string DiffnCtEvent::DebugString() const { - return absl::StrCat("DiffnCtEvent(x_end = ", x_end.DebugString(), - ", x_start_min = ", x_start_min.value(), - ", x_start_max = ", x_start_max.value(), - ", x_size_min = ", x_size_min.value(), - ", x_lp_end = ", x_lp_end, ", y_min = ", y_min.value(), - ", y_max = ", y_max.value(), - ", y_size_min = ", y_size_min.value(), - ", energy_min = ", energy_min.value(), - ", use_energy = ", use_energy, ", lifted = ", lifted); + return absl::StrCat( + "DiffnCtEvent(x_end = ", x_end, ", x_start_min = ", x_start_min.value(), + ", x_start_max = ", x_start_max.value(), + ", x_size_min = ", x_size_min.value(), ", x_lp_end = ", x_lp_end, + ", y_min = ", y_min.value(), ", y_max = ", y_max.value(), + ", y_size_min = ", y_size_min.value(), + ", energy_min = ", energy_min.value(), ", use_energy = ", use_energy, + ", lifted = ", lifted); } // We generate the cut from the Smith's rule from: diff --git a/ortools/sat/drat_checker.cc b/ortools/sat/drat_checker.cc index 1189a0ae07..0372dd97cf 100644 --- a/ortools/sat/drat_checker.cc +++ b/ortools/sat/drat_checker.cc @@ -59,7 +59,7 @@ bool DratChecker::ClauseEquiv::operator()( } DratChecker::DratChecker() - : first_infered_clause_index_(kNoClauseIndex), + : first_inferred_clause_index_(kNoClauseIndex), clause_set_(0, ClauseHash(this), ClauseEquiv(this)), num_variables_(0) {} @@ -68,8 +68,9 @@ bool DratChecker::Clause::IsDeleted(ClauseIndex clause_index) const { } void DratChecker::AddProblemClause(absl::Span clause) { - DCHECK_EQ(first_infered_clause_index_, kNoClauseIndex); - const ClauseIndex clause_index = AddClause(clause); + DCHECK_EQ(first_inferred_clause_index_, kNoClauseIndex); + const ClauseIndex clause_index = MaybeAddClause(clause); + if (clause_index == kNoClauseIndex) return; const auto it = clause_set_.find(clause_index); if (it != clause_set_.end()) { @@ -80,27 +81,28 @@ void DratChecker::AddProblemClause(absl::Span clause) { } } -void DratChecker::AddInferedClause(absl::Span clause) { - const ClauseIndex infered_clause_index = AddClause(clause); - if (first_infered_clause_index_ == kNoClauseIndex) { - first_infered_clause_index_ = infered_clause_index; +void DratChecker::AddInferredClause(absl::Span clause) { + const ClauseIndex inferred_clause_index = MaybeAddClause(clause); + CHECK_NE(inferred_clause_index, kNoClauseIndex); + if (first_inferred_clause_index_ == kNoClauseIndex) { + first_inferred_clause_index_ = inferred_clause_index; } - const auto it = clause_set_.find(infered_clause_index); + const auto it = clause_set_.find(inferred_clause_index); if (it != clause_set_.end()) { clauses_[*it].num_copies += 1; - if (*it >= first_infered_clause_index_ && !clause.empty()) { + if (*it >= first_inferred_clause_index_ && !clause.empty()) { CHECK_EQ(clauses_[*it].rat_literal_index, clause[0].Index()); } RemoveLastClause(); } else { - clauses_[infered_clause_index].rat_literal_index = + clauses_[inferred_clause_index].rat_literal_index = clause.empty() ? kNoLiteralIndex : clause[0].Index(); - clause_set_.insert(infered_clause_index); + clause_set_.insert(inferred_clause_index); } } -ClauseIndex DratChecker::AddClause(absl::Span clause) { +ClauseIndex DratChecker::MaybeAddClause(absl::Span clause) { const int first_literal_index = literals_.size(); literals_.insert(literals_.end(), clause.begin(), clause.end()); // Sort the input clause in strictly increasing order (by sorting and then @@ -111,7 +113,10 @@ ClauseIndex DratChecker::AddClause(absl::Span clause) { literals_.end()); for (int i = first_literal_index + 1; i < literals_.size(); ++i) { - CHECK(literals_[i] != literals_[i - 1].Negated()); + if (literals_[i] == literals_[i - 1].Negated()) { + literals_.resize(first_literal_index); + return kNoClauseIndex; + } } clauses_.push_back( Clause(first_literal_index, literals_.size() - first_literal_index)); @@ -124,7 +129,9 @@ ClauseIndex DratChecker::AddClause(absl::Span clause) { void DratChecker::DeleteClause(absl::Span clause) { // Temporarily add 'clause' to find if it has been previously added. - const auto it = clause_set_.find(AddClause(clause)); + const ClauseIndex clause_index = MaybeAddClause(clause); + if (clause_index == kNoClauseIndex) return; + const auto it = clause_set_.find(clause_index); if (it != clause_set_.end()) { Clause& existing_clause = clauses_[*it]; existing_clause.num_copies -= 1; @@ -151,16 +158,16 @@ void DratChecker::RemoveLastClause() { // See Algorithm of Fig. 8 in 'Trimming while Checking Clausal Proofs'. DratChecker::Status DratChecker::Check(double max_time_in_seconds) { - // First check that the last infered clause is empty (this implies there - // should be at least one infered clause), and mark it as needed for the + // First check that the last inferred clause is empty (this implies there + // should be at least one inferred clause), and mark it as needed for the // proof. - if (clauses_.empty() || first_infered_clause_index_ == kNoClauseIndex || + if (clauses_.empty() || first_inferred_clause_index_ == kNoClauseIndex || clauses_.back().num_literals != 0) { return Status::INVALID; } clauses_.back().is_needed_for_proof = true; - // Checks the infered clauses in reversed order. The advantage of this order + // Checks the inferred clauses in reversed order. The advantage of this order // is that when checking a clause, one can mark all the clauses that are used // to check it. In turn, only these marked clauses need to be checked (and so // on recursively). By contrast, a forward iteration needs to check all the @@ -168,7 +175,7 @@ DratChecker::Status DratChecker::Check(double max_time_in_seconds) { const int64_t start_time_nanos = absl::GetCurrentTimeNanos(); TimeLimit time_limit(max_time_in_seconds); Init(); - for (ClauseIndex i(clauses_.size() - 1); i >= first_infered_clause_index_; + for (ClauseIndex i(clauses_.size() - 1); i >= first_inferred_clause_index_; --i) { if (time_limit.LimitReached()) { return Status::UNKNOWN; @@ -220,11 +227,11 @@ DratChecker::Status DratChecker::Check(double max_time_in_seconds) { } std::vector> DratChecker::GetUnsatSubProblem() const { - return GetClausesNeededForProof(ClauseIndex(0), first_infered_clause_index_); + return GetClausesNeededForProof(ClauseIndex(0), first_inferred_clause_index_); } std::vector> DratChecker::GetOptimizedProof() const { - return GetClausesNeededForProof(first_infered_clause_index_, + return GetClausesNeededForProof(first_inferred_clause_index_, ClauseIndex(clauses_.size())); } @@ -451,23 +458,23 @@ void DratChecker::MarkAsNeededForProof(Clause* clause) { void DratChecker::LogStatistics(int64_t duration_nanos) const { int problem_clauses_needed_for_proof = 0; - int infered_clauses_needed_for_proof = 0; + int inferred_clauses_needed_for_proof = 0; for (ClauseIndex i(0); i < clauses_.size(); ++i) { if (clauses_[i].is_needed_for_proof) { - if (i < first_infered_clause_index_) { + if (i < first_inferred_clause_index_) { ++problem_clauses_needed_for_proof; } else { - ++infered_clauses_needed_for_proof; + ++inferred_clauses_needed_for_proof; } } } LOG(INFO) << problem_clauses_needed_for_proof << " problem clauses needed for proof, out of " - << first_infered_clause_index_; - LOG(INFO) << infered_clauses_needed_for_proof - << " infered clauses needed for proof, out of " - << clauses_.size() - first_infered_clause_index_; - LOG(INFO) << num_rat_checks_ << " RAT infered clauses"; + << first_inferred_clause_index_; + LOG(INFO) << inferred_clauses_needed_for_proof + << " inferred clauses needed for proof, out of " + << clauses_.size() - first_inferred_clause_index_; + LOG(INFO) << num_rat_checks_ << " RAT inferred clauses"; LOG(INFO) << "verification time: " << 1e-9 * duration_nanos << " s"; } @@ -561,8 +568,8 @@ bool AddProblemClauses(const std::string& file_path, return result; } -bool AddInferedAndDeletedClauses(const std::string& file_path, - DratChecker* drat_checker) { +bool AddInferredAndDeletedClauses(const std::string& file_path, + DratChecker* drat_checker) { int line_number = 0; bool ends_with_empty_clause = false; std::vector literals; @@ -592,12 +599,12 @@ bool AddInferedAndDeletedClauses(const std::string& file_path, drat_checker->DeleteClause(literals); ends_with_empty_clause = false; } else { - drat_checker->AddInferedClause(literals); + drat_checker->AddInferredClause(literals); ends_with_empty_clause = literals.empty(); } } if (!ends_with_empty_clause) { - drat_checker->AddInferedClause({}); + drat_checker->AddInferredClause({}); } file.close(); return result; diff --git a/ortools/sat/drat_checker.h b/ortools/sat/drat_checker.h index ca1ff1f301..bbebd8aa74 100644 --- a/ortools/sat/drat_checker.h +++ b/ortools/sat/drat_checker.h @@ -47,37 +47,36 @@ class DratChecker { DratChecker(); ~DratChecker() = default; - // Returns the number of Boolean variables used in the problem and infered + // Returns the number of Boolean variables used in the problem and inferred // clauses. int num_variables() const { return num_variables_; } // Adds a clause of the problem that must be checked. The problem clauses must - // be added first, before any infered clause. The given clause must not - // contain a literal and its negation. Must not be called after Check(). + // be added first, before any inferred clause. Must not be called after + // Check(). void AddProblemClause(absl::Span clause); - // Adds a clause which is infered from the problem clauses and the previously - // infered clauses (that are have not been deleted). Infered clauses must be + // Adds a clause which is inferred from the problem clauses and the previously + // inferred clauses (that are have not been deleted). inferred clauses must be // added after the problem clauses. Clauses with the Reverse Asymmetric // Tautology (RAT) property for literal l must start with this literal. The // given clause must not contain a literal and its negation. Must not be // called after Check(). - void AddInferedClause(absl::Span clause); + void AddInferredClause(absl::Span clause); - // Deletes a problem or infered clause. The order of the literals does not + // Deletes a problem or inferred clause. The order of the literals does not // matter. In particular, it can be different from the order that was used // when the clause was added. Must not be called after Check(). void DeleteClause(absl::Span clause); - // Checks that the infered clauses form a DRAT proof that the problem clauses - // are UNSAT. For this the last added infered clause must be the empty clause - // and each infered clause must have either the Reverse Unit Propagation (RUP) - // or the Reverse Asymmetric Tautology (RAT) property with respect to the - // problem clauses and the previously infered clauses which are not deleted. - // Returns VALID if the proof is valid, INVALID if it is not, and UNKNOWN if - // the check timed out. - // WARNING: no new clause must be added or deleted after this method has been - // called. + // Checks that the inferred clauses form a DRAT proof that the problem clauses + // are UNSAT. For this the last added inferred clause must be the empty clause + // and each inferred clause must have either the Reverse Unit Propagation + // (RUP) or the Reverse Asymmetric Tautology (RAT) property with respect to + // the problem clauses and the previously inferred clauses which are not + // deleted. Returns VALID if the proof is valid, INVALID if it is not, and + // UNKNOWN if the check timed out. WARNING: no new clause must be added or + // deleted after this method has been called. enum Status { UNKNOWN, VALID, @@ -95,7 +94,7 @@ class DratChecker { std::vector> GetOptimizedProof() const; private: - // A problem or infered clause. The literals are specified as a subrange of + // A problem or inferred clause. The literals are specified as a subrange of // 'literals_' (namely the subrange from 'first_literal_index' to // 'first_literal_index' + 'num_literals' - 1), and are sorted in increasing // order *before Check() is called*. @@ -106,7 +105,7 @@ class DratChecker { int num_literals; // The clause literal to use to check the RAT property, or kNoLiteralIndex - // for problem clauses and empty infered clauses. + // for problem clauses and empty inferred clauses. LiteralIndex rat_literal_index = kNoLiteralIndex; // The *current* number of copies of this clause. This number is incremented @@ -132,7 +131,7 @@ class DratChecker { // Whether this clause is actually needed to check the DRAT proof. bool is_needed_for_proof = false; // Whether this clause is actually needed to check the current step (i.e. an - // infered clause) of the DRAT proof. This bool is always false, except in + // inferred clause) of the DRAT proof. This bool is always false, except in // MarkAsNeededForProof() that uses it temporarily. bool tmp_is_needed_for_proof_step = false; @@ -173,8 +172,10 @@ class DratChecker { bool operator()(ClauseIndex clause_index1, ClauseIndex clause_index2) const; }; - // Adds a clause and returns its index. - ClauseIndex AddClause(absl::Span clause); + // Adds a clause and returns its index. If the clause is always true (because + // it contains a literal and its negation), it is not added and kNoClauseIndex + // is returned. + ClauseIndex MaybeAddClause(absl::Span clause); // Removes the last clause added to 'clauses_'. void RemoveLastClause(); @@ -222,11 +223,11 @@ class DratChecker { void LogStatistics(int64_t duration_nanos) const; - // The index of the first infered clause in 'clauses_', or kNoClauseIndex if - // there is no infered clause. - ClauseIndex first_infered_clause_index_; + // The index of the first inferred clause in 'clauses_', or kNoClauseIndex if + // there is no inferred clause. + ClauseIndex first_inferred_clause_index_; - // The problem clauses, followed by the infered clauses. + // The problem clauses, followed by the inferred clauses. util_intops::StrongVector clauses_; // A content addressable set of the non-deleted clauses in clauses_. After @@ -295,7 +296,7 @@ class DratChecker { // --------------------------------------------------------------------------- // Statistics - // The number of infered clauses having the RAT property (but not the RUP + // The number of inferred clauses having the RAT property (but not the RUP // property). int num_rat_checks_; }; @@ -321,11 +322,11 @@ bool Resolve(absl::Span clause, // successfully parsed. bool AddProblemClauses(const std::string& file_path, DratChecker* drat_checker); -// Adds to the given drat checker the infered and deleted clauses from the file +// Adds to the given drat checker the inferred and deleted clauses from the file // at the given path, which must be in DRAT format. Returns true iff the file // was successfully parsed. -bool AddInferedAndDeletedClauses(const std::string& file_path, - DratChecker* drat_checker); +bool AddInferredAndDeletedClauses(const std::string& file_path, + DratChecker* drat_checker); // The file formats that can be used to save a list of clauses. enum SatFormat { diff --git a/ortools/sat/drat_checker_test.cc b/ortools/sat/drat_checker_test.cc index 6b7223a159..134a613ecd 100644 --- a/ortools/sat/drat_checker_test.cc +++ b/ortools/sat/drat_checker_test.cc @@ -15,7 +15,6 @@ #include #include -#include #include "absl/types/span.h" #include "gtest/gtest.h" @@ -42,7 +41,7 @@ DratChecker::Status CheckOptimizedProof(const DratChecker& drat_checker) { optimized_proof_checker.AddProblemClause(clause); } for (const auto& clause : drat_checker.GetOptimizedProof()) { - optimized_proof_checker.AddInferedClause(clause); + optimized_proof_checker.AddInferredClause(clause); } return optimized_proof_checker.Check(kMaxTimeInSeconds); } @@ -57,8 +56,8 @@ TEST(DratCheckerTest, CheckBasicSuccess) { checker.AddProblemClause(Literals({+1, -2})); checker.AddProblemClause(Literals({+2, -3})); - checker.AddInferedClause(Literals({-2})); - checker.AddInferedClause(Literals({})); + checker.AddInferredClause(Literals({-2})); + checker.AddInferredClause(Literals({})); EXPECT_EQ(DratChecker::Status::VALID, checker.Check(kMaxTimeInSeconds)); EXPECT_EQ(DratChecker::Status::VALID, CheckOptimizedProof(checker)); @@ -75,10 +74,10 @@ TEST(DratCheckerTest, CheckBasicSuccessWithClauseAddedSeveralTimes) { // Add a clause two times and deletes it on)e time, there should still be one // copy left, which is needed for the rest )of the proof. - checker.AddInferedClause(Literals({-2})); - checker.AddInferedClause(Literals({-2})); + checker.AddInferredClause(Literals({-2})); + checker.AddInferredClause(Literals({-2})); checker.DeleteClause(Literals({-2})); - checker.AddInferedClause(Literals({})); + checker.AddInferredClause(Literals({})); EXPECT_EQ(DratChecker::Status::VALID, checker.Check(kMaxTimeInSeconds)); EXPECT_EQ(DratChecker::Status::VALID, CheckOptimizedProof(checker)); @@ -96,15 +95,15 @@ TEST(DratCheckerTest, CheckSimpleSuccess) { checker.AddProblemClause(Literals({-1, +2, +4})); checker.AddProblemClause(Literals({+1, -2, -4})); - checker.AddInferedClause(Literals({+1, +2})); + checker.AddInferredClause(Literals({+1, +2})); checker.DeleteClause(Literals({+1, +2, -3, +2})); // Duplicate literals. - checker.AddInferedClause(Literals({+1, +1})); // Duplicate literals. + checker.AddInferredClause(Literals({+1, +1})); // Duplicate literals. checker.DeleteClause(Literals({+1, +3, +4})); checker.DeleteClause( Literals({-4, -2, +1})); // Different order from clause #8. - checker.AddInferedClause(Literals({+2})); + checker.AddInferredClause(Literals({+2})); checker.DeleteClause(Literals({+2, +3, -4})); - checker.AddInferedClause(Literals({})); + checker.AddInferredClause(Literals({})); EXPECT_EQ(DratChecker::Status::VALID, checker.Check(kMaxTimeInSeconds)); EXPECT_EQ(DratChecker::Status::VALID, CheckOptimizedProof(checker)); @@ -125,14 +124,14 @@ TEST(DratCheckerTest, CheckComplexSuccessRupProof) { } } - checker.AddInferedClause(Literals({1, 2, 3})); - checker.AddInferedClause(Literals({1, 2})); - checker.AddInferedClause(Literals({1, 3})); - checker.AddInferedClause(Literals({1})); - checker.AddInferedClause(Literals({2, 3})); - checker.AddInferedClause(Literals({2})); - checker.AddInferedClause(Literals({3})); - checker.AddInferedClause(Literals({})); + checker.AddInferredClause(Literals({1, 2, 3})); + checker.AddInferredClause(Literals({1, 2})); + checker.AddInferredClause(Literals({1, 3})); + checker.AddInferredClause(Literals({1})); + checker.AddInferredClause(Literals({2, 3})); + checker.AddInferredClause(Literals({2})); + checker.AddInferredClause(Literals({3})); + checker.AddInferredClause(Literals({})); EXPECT_EQ(DratChecker::Status::VALID, checker.Check(kMaxTimeInSeconds)); EXPECT_EQ(DratChecker::Status::VALID, CheckOptimizedProof(checker)); @@ -153,10 +152,10 @@ TEST(DratCheckerTest, CheckComplexSuccessRapProof) { } } - checker.AddInferedClause(Literals({1})); - checker.AddInferedClause(Literals({2})); - checker.AddInferedClause(Literals({3})); - checker.AddInferedClause(Literals({})); + checker.AddInferredClause(Literals({1})); + checker.AddInferredClause(Literals({2})); + checker.AddInferredClause(Literals({3})); + checker.AddInferredClause(Literals({})); EXPECT_EQ(DratChecker::Status::VALID, checker.Check(kMaxTimeInSeconds)); EXPECT_EQ(DratChecker::Status::VALID, CheckOptimizedProof(checker)); @@ -178,18 +177,18 @@ TEST(DratCheckerTest, CheckComplexSuccessRapProofWithExtendedResolution) { } // Proof using additional variables not used in the problem clauses. - checker.AddInferedClause(Literals({5, 1, 2})); - checker.AddInferedClause(Literals({5, 1, -2})); - checker.AddInferedClause(Literals({5, -1, 2})); - checker.AddInferedClause(Literals({5, -1, -2})); - checker.AddInferedClause(Literals({-5, 3, 4})); - checker.AddInferedClause(Literals({-5, 3, -4})); - checker.AddInferedClause(Literals({-5, -3, 4})); - checker.AddInferedClause(Literals({-5, -3, -4})); - checker.AddInferedClause(Literals({5, 1})); - checker.AddInferedClause(Literals({5})); - checker.AddInferedClause(Literals({3})); - checker.AddInferedClause(Literals({})); + checker.AddInferredClause(Literals({5, 1, 2})); + checker.AddInferredClause(Literals({5, 1, -2})); + checker.AddInferredClause(Literals({5, -1, 2})); + checker.AddInferredClause(Literals({5, -1, -2})); + checker.AddInferredClause(Literals({-5, 3, 4})); + checker.AddInferredClause(Literals({-5, 3, -4})); + checker.AddInferredClause(Literals({-5, -3, 4})); + checker.AddInferredClause(Literals({-5, -3, -4})); + checker.AddInferredClause(Literals({5, 1})); + checker.AddInferredClause(Literals({5})); + checker.AddInferredClause(Literals({3})); + checker.AddInferredClause(Literals({})); EXPECT_EQ(DratChecker::Status::VALID, checker.Check(kMaxTimeInSeconds)); EXPECT_EQ(DratChecker::Status::VALID, CheckOptimizedProof(checker)); @@ -209,10 +208,10 @@ TEST(DratCheckerTest, CheckBasicSuccessWithoutDeletedClauses) { checker.AddProblemClause(Literals({-1, +2, +4})); checker.AddProblemClause(Literals({+1, -2, -4})); - checker.AddInferedClause(Literals({+1, +2})); - checker.AddInferedClause(Literals({+1})); - checker.AddInferedClause(Literals({+2})); - checker.AddInferedClause(Literals({})); + checker.AddInferredClause(Literals({+1, +2})); + checker.AddInferredClause(Literals({+1})); + checker.AddInferredClause(Literals({+2})); + checker.AddInferredClause(Literals({})); EXPECT_EQ(DratChecker::Status::VALID, checker.Check(kMaxTimeInSeconds)); EXPECT_EQ(DratChecker::Status::VALID, CheckOptimizedProof(checker)); @@ -231,8 +230,8 @@ TEST(DratCheckerTest, CheckBasicFailure) { checker.AddProblemClause(Literals({-1, +2, +4})); checker.AddProblemClause(Literals({+1, -2, -4})); - checker.AddInferedClause(Literals({+2})); - checker.AddInferedClause(Literals({})); + checker.AddInferredClause(Literals({+2})); + checker.AddInferredClause(Literals({})); EXPECT_EQ(DratChecker::Status::INVALID, checker.Check(kMaxTimeInSeconds)); } @@ -250,14 +249,14 @@ TEST(DratCheckerTest, CheckFailureClauseNeededForProofDeleted) { checker.AddProblemClause(Literals({-1, +2, +4})); checker.AddProblemClause(Literals({+1, -2, -4})); - checker.AddInferedClause(Literals({+1, +2})); + checker.AddInferredClause(Literals({+1, +2})); checker.DeleteClause(Literals({+1, +2, -3})); - checker.AddInferedClause(Literals({+1})); + checker.AddInferredClause(Literals({+1})); checker.DeleteClause(Literals({+1, +3, +4})); checker.DeleteClause(Literals({+1, -2, -4})); checker.DeleteClause(Literals({+2, +3, -4})); - checker.AddInferedClause(Literals({+2})); - checker.AddInferedClause(Literals({})); + checker.AddInferredClause(Literals({+2})); + checker.AddInferredClause(Literals({})); EXPECT_EQ(DratChecker::Status::INVALID, checker.Check(kMaxTimeInSeconds)); } @@ -275,11 +274,11 @@ TEST(DratCheckerTest, // Add and delete a clause two times, there should still be no copy left, // yielding an invalid proof because this clause is needed for the rest of the // proof. - checker.AddInferedClause(Literals({-2})); + checker.AddInferredClause(Literals({-2})); checker.DeleteClause(Literals({-2})); - checker.AddInferedClause(Literals({-2})); + checker.AddInferredClause(Literals({-2})); checker.DeleteClause(Literals({-2})); - checker.AddInferedClause(Literals({})); + checker.AddInferredClause(Literals({})); EXPECT_EQ(DratChecker::Status::INVALID, checker.Check(kMaxTimeInSeconds)); } @@ -297,11 +296,11 @@ TEST(DratCheckerTest, // Add and delete a clause two times, there should still be no copy left, // yielding an invalid proof because this clause is needed for the rest of the // proof. - checker.AddInferedClause(Literals({-2})); - checker.AddInferedClause(Literals({-2})); + checker.AddInferredClause(Literals({-2})); + checker.AddInferredClause(Literals({-2})); checker.DeleteClause(Literals({-2})); checker.DeleteClause(Literals({-2})); - checker.AddInferedClause(Literals({})); + checker.AddInferredClause(Literals({})); EXPECT_EQ(DratChecker::Status::INVALID, checker.Check(kMaxTimeInSeconds)); } @@ -310,7 +309,7 @@ TEST(DratCheckerTest, CheckBasicFailureTimeOut) { DratChecker checker; checker.AddProblemClause(Literals({+1})); checker.AddProblemClause(Literals({-1})); - checker.AddInferedClause(Literals({})); + checker.AddInferredClause(Literals({})); EXPECT_EQ(DratChecker::Status::UNKNOWN, checker.Check(-1.0)); } @@ -373,7 +372,7 @@ d 2 3 -4 0 file::Defaults())); EXPECT_TRUE(AddProblemClauses(cnf_file_path, &checker)); - EXPECT_TRUE(AddInferedAndDeletedClauses(drat_file_path, &checker)); + EXPECT_TRUE(AddInferredAndDeletedClauses(drat_file_path, &checker)); EXPECT_EQ(DratChecker::Status::VALID, checker.Check(kMaxTimeInSeconds)); } diff --git a/ortools/sat/drat_proof_handler.cc b/ortools/sat/drat_proof_handler.cc index 2739cd6da0..7058d70d30 100644 --- a/ortools/sat/drat_proof_handler.cc +++ b/ortools/sat/drat_proof_handler.cc @@ -81,7 +81,7 @@ void DratProofHandler::AddProblemClause(absl::Span clause) { void DratProofHandler::AddClause(absl::Span clause) { MapClause(clause); if (drat_checker_ != nullptr) { - drat_checker_->AddInferedClause(values_); + drat_checker_->AddInferredClause(values_); } if (drat_writer_ != nullptr) { drat_writer_->AddClause(values_); @@ -101,7 +101,7 @@ void DratProofHandler::DeleteClause(absl::Span clause) { DratChecker::Status DratProofHandler::Check(double max_time_in_seconds) { if (drat_checker_ != nullptr) { // The empty clause is not explicitly added by the solver. - drat_checker_->AddInferedClause({}); + drat_checker_->AddInferredClause({}); return drat_checker_->Check(max_time_in_seconds); } return DratChecker::Status::UNKNOWN; diff --git a/ortools/sat/feasibility_jump.h b/ortools/sat/feasibility_jump.h index e40f5d42b9..5957ea2bcf 100644 --- a/ortools/sat/feasibility_jump.h +++ b/ortools/sat/feasibility_jump.h @@ -375,7 +375,7 @@ class SharedLsStates { // This is thread safe. If we respect the max_parallelism guarantee, then // all states should be independent. LsState* GetNextState() { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); int next = -1; const int num_states = states_.size(); for (int i = 0; i < num_states; ++i) { @@ -400,7 +400,7 @@ class SharedLsStates { } void Release(LsState* state) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); for (int i = 0; i < states_.size(); ++i) { if (state == states_[i].get()) { taken_[i] = false; @@ -410,7 +410,7 @@ class SharedLsStates { } void ResetLubyCounter() { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); luby_counter_ = 0; } @@ -421,7 +421,7 @@ class SharedLsStates { // Also if options.use_restart, then num_batches_before_change is only // modified under lock, so this code should be thread safe. void ConfigureNextLubyRestart(LsState* state) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); const int factor = std::max(1, params_.feasibility_jump_restart_factor()); CHECK(state->options.use_restart); const int64_t next = factor * SUniv(++luby_counter_); @@ -432,7 +432,7 @@ class SharedLsStates { void CollectStatistics(const LsState& state) { if (state.counters.num_batches == 0) return; - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); options_to_stats_[state.options].AddFrom(state.counters); options_to_num_restarts_[state.options]++; } diff --git a/ortools/sat/implied_bounds_test.cc b/ortools/sat/implied_bounds_test.cc index b0092dad9a..2fc30001c4 100644 --- a/ortools/sat/implied_bounds_test.cc +++ b/ortools/sat/implied_bounds_test.cc @@ -236,12 +236,12 @@ TEST(DetectLinearEncodingOfProductsTest, MatchingElementEncodings) { builder.AddConstant(IntegerValue(-1)); // To be cleared. EXPECT_TRUE( model.GetOrCreate()->TryToLinearize(x0, x1, &builder)); - EXPECT_EQ(builder.BuildExpression().DebugString(), "34*X1 34*X2 294*X3 + 6"); + EXPECT_EQ(builder.BuildExpression().DebugString(), "34*I1 34*I2 294*I3 + 6"); builder.Clear(); EXPECT_TRUE( model.GetOrCreate()->TryToLinearize(x1, x0, &builder)); - EXPECT_EQ(builder.BuildExpression().DebugString(), "34*X1 34*X2 294*X3 + 6"); + EXPECT_EQ(builder.BuildExpression().DebugString(), "34*I1 34*I2 294*I3 + 6"); } TEST(DetectLinearEncodingOfProductsTest, MatchingEncodingAndSizeTwoEncoding) { @@ -270,11 +270,11 @@ TEST(DetectLinearEncodingOfProductsTest, MatchingEncodingAndSizeTwoEncoding) { builder.AddConstant(IntegerValue(-1)); // To be cleared. EXPECT_TRUE( model.GetOrCreate()->TryToLinearize(x0, x1, &builder)); - EXPECT_EQ(builder.BuildExpression().DebugString(), "12*X3 2*X4 48*X5 + 12"); + EXPECT_EQ(builder.BuildExpression().DebugString(), "12*I3 2*I4 48*I5 + 12"); EXPECT_TRUE( model.GetOrCreate()->TryToLinearize(x1, x0, &builder)); - EXPECT_EQ(builder.BuildExpression().DebugString(), "12*X3 2*X4 48*X5 + 12"); + EXPECT_EQ(builder.BuildExpression().DebugString(), "12*I3 2*I4 48*I5 + 12"); } TEST(DetectLinearEncodingOfProductsTest, BooleanAffinePosPosProduct) { @@ -397,11 +397,11 @@ TEST(DetectLinearEncodingOfProductsTest, AffineTimesConstant) { LinearConstraintBuilder builder(&model); EXPECT_TRUE(model.GetOrCreate()->TryToLinearize( left, right, &builder)); - EXPECT_EQ(builder.BuildExpression().DebugString(), "6*X0 + -3"); + EXPECT_EQ(builder.BuildExpression().DebugString(), "6*I0 + -3"); EXPECT_TRUE(model.GetOrCreate()->TryToLinearize( right, left, &builder)); - EXPECT_EQ(builder.BuildExpression().DebugString(), "6*X0 + -3"); + EXPECT_EQ(builder.BuildExpression().DebugString(), "6*I0 + -3"); } TEST(DecomposeProductTest, MatchingElementEncodings) { diff --git a/ortools/sat/integer.cc b/ortools/sat/integer.cc index 7691ffa88e..bd7783c6c7 100644 --- a/ortools/sat/integer.cc +++ b/ortools/sat/integer.cc @@ -1388,14 +1388,23 @@ bool IntegerTrail::ReasonIsValid( std::vector clause; clause.assign(literal_reason.begin(), literal_reason.end()); - std::vector lits; - lits.assign(integer_reason.begin(), integer_reason.end()); - MergeReasonInto(lits, &clause); - if (!debug_checker_(clause, {i_lit})) { + std::vector lits = {integer_reason.begin(), + integer_reason.end()}; + const IntegerLiteral negated_i_lit = + i_lit.IsAlwaysFalse() ? IntegerLiteral::TrueLiteral() : i_lit.Negated(); + lits.push_back(negated_i_lit); + if (!debug_checker_(clause, lits)) { LOG(INFO) << "Invalid reason for loaded solution: " << i_lit << " " << literal_reason << " " << integer_reason; return false; } + lits.pop_back(); + MergeReasonInto(lits, &clause); + if (!debug_checker_(clause, {negated_i_lit})) { + LOG(INFO) << "Invalid reason for loaded solution after merging: " << i_lit + << " " << literal_reason << " " << integer_reason; + return false; + } return true; } diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index b7aefdb1d9..7389f6381a 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -731,7 +731,8 @@ class IntegerTrail final : public SatPropagator { // simply do: return integer_trail_->ReportConflict(...); bool ReportConflict(absl::Span literal_reason, absl::Span integer_reason) { - DCHECK(ReasonIsValid(literal_reason, integer_reason)); + DCHECK(ReasonIsValid(IntegerLiteral::FalseLiteral(), literal_reason, + integer_reason)); std::vector* conflict = trail_->MutableConflict(); conflict->assign(literal_reason.begin(), literal_reason.end()); MergeReasonInto(integer_reason, conflict); diff --git a/ortools/sat/integer_base.h b/ortools/sat/integer_base.h index f757b792aa..5884bc7c05 100644 --- a/ortools/sat/integer_base.h +++ b/ortools/sat/integer_base.h @@ -193,7 +193,7 @@ inline PositiveOnlyIndex GetPositiveOnlyIndex(IntegerVariable var) { inline std::string IntegerTermDebugString(IntegerVariable var, IntegerValue coeff) { coeff = VariableIsPositive(var) ? coeff : -coeff; - return absl::StrCat(coeff.value(), "*X", var.value() / 2); + return absl::StrCat(coeff.value(), "*I", GetPositiveOnlyIndex(var)); } // Returns the vector of the negated variables. @@ -325,13 +325,14 @@ struct AffineExpression { bool IsConstant() const { return var == kNoIntegerVariable; } - std::string DebugString() const { - if (var == kNoIntegerVariable) return absl::StrCat(constant.value()); - if (constant == 0) { - return absl::StrCat("(", coeff.value(), " * X", var.value(), ")"); + template + friend void AbslStringify(Sink& sink, const AffineExpression& expr) { + if (expr.constant == 0) { + absl::Format(&sink, "(%v)", IntegerTermDebugString(expr.var, expr.coeff)); } else { - return absl::StrCat("(", coeff.value(), " * X", var.value(), " + ", - constant.value(), ")"); + absl::Format(&sink, "(%v + %d)", + IntegerTermDebugString(expr.var, expr.coeff), + expr.constant.value()); } } @@ -434,9 +435,18 @@ struct LinearExpression2 { template friend void AbslStringify(Sink& sink, const LinearExpression2& expr) { - absl::Format(&sink, "%d X%d + %d X%d", expr.coeffs[0].value(), - expr.vars[0].value(), expr.coeffs[1].value(), - expr.vars[1].value()); + if (expr.coeffs[0] == 0) { + if (expr.coeffs[1] == 0) { + absl::Format(&sink, "0"); + } else { + absl::Format(&sink, "%v", + IntegerTermDebugString(expr.vars[1], expr.coeffs[1])); + } + } else { + absl::Format(&sink, "%v + %v", + IntegerTermDebugString(expr.vars[0], expr.coeffs[0]), + IntegerTermDebugString(expr.vars[1], expr.coeffs[1])); + } } }; @@ -525,11 +535,19 @@ struct IntegerDomains // can check that various derived constraint do not exclude this solution (if it // is a known optimal solution for instance). struct DebugSolution { + void Clear() { + proto_values.clear(); + ivar_has_value.clear(); + ivar_values.clear(); + } + // This is the value of all proto variables. // It should be of the same size of the PRESOLVED model and should correspond // to a solution to the presolved model. std::vector proto_values; + IntegerValue inner_objective_value = kMinIntegerValue; + // This is filled from proto_values at load-time, and using the // cp_model_mapping, we cache the solution of the integer variables that are // mapped. Note that it is possible that not all integer variable are mapped. diff --git a/ortools/sat/integer_search.cc b/ortools/sat/integer_search.cc index 65b0396422..d33467767f 100644 --- a/ortools/sat/integer_search.cc +++ b/ortools/sat/integer_search.cc @@ -297,8 +297,10 @@ UnassignedVarWithLowestMinAtItsMinHeuristic( std::function SequentialSearch( std::vector> heuristics) { - for (const auto& h : heuristics) { - CHECK(h != nullptr); + if (DEBUG_MODE) { + for (const auto& h : heuristics) { + DCHECK(h != nullptr); + } } return [heuristics]() { for (const auto& h : heuristics) { @@ -994,19 +996,16 @@ std::function RandomizeOnRestartHeuristic( weights.push_back(1); } - // Add heuristic search if present. + // Add model based heuristic search if present. if (heuristics.heuristic_search != nullptr) { policies.push_back( SequentialSearch({heuristics.heuristic_search, sat_policy, heuristics.integer_completion_search})); weights.push_back(1); - } - - if (policies.size() == 1) { - CHECK(heuristics.fixed_search != nullptr); - policies.push_back( - SequentialSearch({heuristics.fixed_search, sat_policy, - heuristics.integer_completion_search})); + } else if (heuristics.user_search == nullptr) { + // Add pseudo cost search if nothing else is present. + policies.push_back(SequentialSearch( + {PseudoCost(model), sat_policy, heuristics.integer_completion_search})); weights.push_back(1); } diff --git a/ortools/sat/linear_constraint_manager.cc b/ortools/sat/linear_constraint_manager.cc index 77b5419a97..41300bb3fe 100644 --- a/ortools/sat/linear_constraint_manager.cc +++ b/ortools/sat/linear_constraint_manager.cc @@ -957,17 +957,23 @@ void LinearConstraintManager::AddAllConstraintsToLp() { bool LinearConstraintManager::DebugCheckConstraint( const LinearConstraint& cut) { - if (model_->Get() == nullptr) return true; - const auto& debug_solution = *(model_->Get()); + const DebugSolution* debug_solution = model_->Get(); + if (debug_solution == nullptr || debug_solution->proto_values.empty()) { + return true; + } - IntegerValue activity(0); + absl::int128 activity(0); for (int i = 0; i < cut.num_terms; ++i) { const IntegerVariable var = cut.vars[i]; const IntegerValue coeff = cut.coeffs[i]; - CHECK(debug_solution.ivar_has_value[var]); - activity += coeff * debug_solution.ivar_values[var]; + if (var >= debug_solution->ivar_has_value.size() || + !debug_solution->ivar_has_value[var]) { + return true; + } + activity += + absl::int128(coeff.value()) * debug_solution->ivar_values[var].value(); } - if (activity > cut.ub || activity < cut.lb) { + if (activity > cut.ub.value() || activity < cut.lb.value()) { LOG(INFO) << cut.DebugString(); LOG(INFO) << "activity " << activity << " not in [" << cut.lb << "," << cut.ub << "]"; diff --git a/ortools/sat/linear_constraint_manager_test.cc b/ortools/sat/linear_constraint_manager_test.cc index b98914bcf6..95756c0048 100644 --- a/ortools/sat/linear_constraint_manager_test.cc +++ b/ortools/sat/linear_constraint_manager_test.cc @@ -112,7 +112,7 @@ TEST(LinearConstraintManagerTest, DuplicateDetection) { EXPECT_EQ(manager.AllConstraints().size(), 1); EXPECT_EQ(manager.AllConstraints().front().constraint.DebugString(), - "0 <= 1*X0 <= 3"); + "0 <= 1*I0 <= 3"); } void SetLpValue(IntegerVariable v, double value, Model* model) { @@ -142,7 +142,7 @@ TEST(LinearConstraintManagerTest, DuplicateDetectionCuts) { EXPECT_EQ(manager.AllConstraints().size(), 1); EXPECT_EQ(manager.AllConstraints().front().constraint.DebugString(), - "0 <= 1*X0 <= 3"); + "0 <= 1*I0 <= 3"); } TEST(LinearConstraintManagerTest, DuplicateDetectionCauseLpChange) { @@ -171,7 +171,7 @@ TEST(LinearConstraintManagerTest, DuplicateDetectionCauseLpChange) { EXPECT_EQ(manager.AllConstraints().size(), 1); EXPECT_EQ(manager.AllConstraints().front().constraint.DebugString(), - "0 <= 1*X0 <= 3"); + "0 <= 1*I0 <= 3"); } TEST(LinearConstraintManagerTest, OnlyAddInfeasibleConstraints) { @@ -351,7 +351,7 @@ TEST(LinearConstraintManagerTest, SimplificationRemoveFixedVariable) { } const LinearConstraintManager::ConstraintIndex index(0); - EXPECT_EQ("0 <= 3*X0 -4*X1 7*X2 <= 11", + EXPECT_EQ("0 <= 3*I0 -4*I1 7*I2 <= 11", manager.AllConstraints()[index].constraint.DebugString()); // ChangeLp will trigger the simplification. @@ -360,7 +360,7 @@ TEST(LinearConstraintManagerTest, SimplificationRemoveFixedVariable) { glop::BasisState state; EXPECT_TRUE(manager.ChangeLp(&state)); EXPECT_EQ(1, manager.num_shortened_constraints()); - EXPECT_EQ("20 <= 3*X0 7*X2 <= 31", + EXPECT_EQ("20 <= 3*I0 7*I2 <= 31", manager.AllConstraints()[index].constraint.DebugString()); // We also test that the constraint equivalence work with the change. @@ -372,7 +372,7 @@ TEST(LinearConstraintManagerTest, SimplificationRemoveFixedVariable) { manager.Add(ct.Build()); } EXPECT_EQ(manager.AllConstraints().size(), 1); - EXPECT_EQ("20 <= 3*X0 7*X2 <= 21", + EXPECT_EQ("20 <= 3*I0 7*I2 <= 21", manager.AllConstraints()[index].constraint.DebugString()); } @@ -392,7 +392,7 @@ TEST(LinearConstraintManagerTest, SimplificationStrenghtenUb) { const LinearConstraintManager::ConstraintIndex index(0); EXPECT_EQ(2, manager.num_coeff_strenghtening()); EXPECT_THAT(manager.AllConstraints()[index].constraint.DebugString(), - EndsWith("3*X0 -5*X1 5*X2 <= 75")); + EndsWith("3*I0 -5*I1 5*I2 <= 75")); } TEST(LinearConstraintManagerTest, SimplificationStrenghtenLb) { @@ -411,7 +411,7 @@ TEST(LinearConstraintManagerTest, SimplificationStrenghtenLb) { const LinearConstraintManager::ConstraintIndex index(0); EXPECT_EQ(2, manager.num_coeff_strenghtening()); EXPECT_THAT(manager.AllConstraints()[index].constraint.DebugString(), - StartsWith("-45 <= 3*X0 -5*X1 5*X2")); + StartsWith("-45 <= 3*I0 -5*I1 5*I2")); } TEST(LinearConstraintManagerTest, AdvancedStrenghtening1) { @@ -430,7 +430,7 @@ TEST(LinearConstraintManagerTest, AdvancedStrenghtening1) { const LinearConstraintManager::ConstraintIndex index(0); EXPECT_EQ(3, manager.num_coeff_strenghtening()); EXPECT_THAT(manager.AllConstraints()[index].constraint.DebugString(), - StartsWith("2 <= 1*X0 1*X1 1*X2")); + StartsWith("2 <= 1*I0 1*I1 1*I2")); } TEST(LinearConstraintManagerTest, AdvancedStrenghtening2) { @@ -449,7 +449,7 @@ TEST(LinearConstraintManagerTest, AdvancedStrenghtening2) { const LinearConstraintManager::ConstraintIndex index(0); EXPECT_EQ(2, manager.num_coeff_strenghtening()); EXPECT_THAT(manager.AllConstraints()[index].constraint.DebugString(), - StartsWith("16 <= 9*X0 7*X1 9*X2")); + StartsWith("16 <= 9*I0 7*I1 9*I2")); } TEST(LinearConstraintManagerTest, AdvancedStrenghtening3) { @@ -466,12 +466,12 @@ TEST(LinearConstraintManagerTest, AdvancedStrenghtening3) { manager.Add(ct.Build()); // TODO(user): Technically, because the 5 are "enforcement" the inner - // constraint is 4*X2 >= 5 which can be rewriten and X2 >= 2, and we could - // instead have 2X0 + 2X1 + X2 >= 2 which should be tighter. + // constraint is 4*I2 >= 5 which can be rewriten and I2 >= 2, and we could + // instead have 2I0 + 2I1 + I2 >= 2 which should be tighter. const LinearConstraintManager::ConstraintIndex index(0); EXPECT_EQ(1, manager.num_coeff_strenghtening()); EXPECT_THAT(manager.AllConstraints()[index].constraint.DebugString(), - StartsWith("5 <= 5*X0 5*X1 3*X2")); + StartsWith("5 <= 5*I0 5*I1 3*I2")); } } // namespace diff --git a/ortools/sat/linear_constraint_test.cc b/ortools/sat/linear_constraint_test.cc index a41e5a7e9b..c8931e636f 100644 --- a/ortools/sat/linear_constraint_test.cc +++ b/ortools/sat/linear_constraint_test.cc @@ -240,44 +240,44 @@ TEST(LinearConstraintBuilderTest, AddLiterals) { const BooleanVariable d = model.Add(NewBooleanVariable()); // Create integer views. - model.Add(NewIntegerVariableFromLiteral(Literal(b, true))); // X0 - model.Add(NewIntegerVariableFromLiteral(Literal(b, false))); // X1 - model.Add(NewIntegerVariableFromLiteral(Literal(c, false))); // X2 - model.Add(NewIntegerVariableFromLiteral(Literal(d, false))); // X3 - model.Add(NewIntegerVariableFromLiteral(Literal(d, true))); // X4 + model.Add(NewIntegerVariableFromLiteral(Literal(b, true))); // I0 + model.Add(NewIntegerVariableFromLiteral(Literal(b, false))); // I1 + model.Add(NewIntegerVariableFromLiteral(Literal(c, false))); // I2 + model.Add(NewIntegerVariableFromLiteral(Literal(d, false))); // I3 + model.Add(NewIntegerVariableFromLiteral(Literal(d, true))); // I4 // When we have both view, we use the lowest IntegerVariable. { LinearConstraintBuilder builder(&model, kMinIntegerValue, IntegerValue(1)); EXPECT_TRUE(builder.AddLiteralTerm(Literal(b, true), IntegerValue(1))); - EXPECT_EQ(builder.Build().DebugString(), "1*X0 <= 1"); + EXPECT_EQ(builder.Build().DebugString(), "1*I0 <= 1"); } { LinearConstraintBuilder builder(&model, kMinIntegerValue, IntegerValue(1)); EXPECT_TRUE(builder.AddLiteralTerm(Literal(b, false), IntegerValue(1))); - EXPECT_EQ(builder.Build().DebugString(), "-1*X0 <= 0"); + EXPECT_EQ(builder.Build().DebugString(), "-1*I0 <= 0"); } { LinearConstraintBuilder builder(&model, kMinIntegerValue, IntegerValue(1)); EXPECT_TRUE(builder.AddLiteralTerm(Literal(d, true), IntegerValue(1))); - EXPECT_EQ(builder.Build().DebugString(), "-1*X3 <= 0"); + EXPECT_EQ(builder.Build().DebugString(), "-1*I3 <= 0"); } { LinearConstraintBuilder builder(&model, kMinIntegerValue, IntegerValue(1)); EXPECT_TRUE(builder.AddLiteralTerm(Literal(d, false), IntegerValue(1))); - EXPECT_EQ(builder.Build().DebugString(), "1*X3 <= 1"); + EXPECT_EQ(builder.Build().DebugString(), "1*I3 <= 1"); } // When we have just one view, we use the one we have. { LinearConstraintBuilder builder(&model, kMinIntegerValue, IntegerValue(1)); EXPECT_TRUE(builder.AddLiteralTerm(Literal(c, true), IntegerValue(1))); - EXPECT_EQ(builder.Build().DebugString(), "-1*X2 <= 0"); + EXPECT_EQ(builder.Build().DebugString(), "-1*I2 <= 0"); } { LinearConstraintBuilder builder(&model, kMinIntegerValue, IntegerValue(1)); EXPECT_TRUE(builder.AddLiteralTerm(Literal(c, false), IntegerValue(1))); - EXPECT_EQ(builder.Build().DebugString(), "1*X2 <= 1"); + EXPECT_EQ(builder.Build().DebugString(), "1*I2 <= 1"); } } @@ -288,31 +288,31 @@ TEST(LinearConstraintBuilderTest, AddConstant) { builder1.AddTerm(IntegerVariable(0), IntegerValue(5)); builder1.AddTerm(IntegerVariable(2), IntegerValue(10)); builder1.AddConstant(IntegerValue(3)); - EXPECT_EQ(builder1.Build().DebugString(), "5*X0 10*X1 <= 7"); + EXPECT_EQ(builder1.Build().DebugString(), "5*I0 10*I1 <= 7"); LinearConstraintBuilder builder2(&model, IntegerValue(4), kMaxIntegerValue); builder2.AddTerm(IntegerVariable(0), IntegerValue(5)); builder2.AddTerm(IntegerVariable(2), IntegerValue(10)); builder2.AddConstant(IntegerValue(-3)); - EXPECT_EQ(builder2.Build().DebugString(), "7 <= 5*X0 10*X1"); + EXPECT_EQ(builder2.Build().DebugString(), "7 <= 5*I0 10*I1"); LinearConstraintBuilder builder3(&model, kMinIntegerValue, IntegerValue(10)); builder3.AddTerm(IntegerVariable(0), IntegerValue(5)); builder3.AddTerm(IntegerVariable(2), IntegerValue(10)); builder3.AddConstant(IntegerValue(-3)); - EXPECT_EQ(builder3.Build().DebugString(), "5*X0 10*X1 <= 13"); + EXPECT_EQ(builder3.Build().DebugString(), "5*I0 10*I1 <= 13"); LinearConstraintBuilder builder4(&model, IntegerValue(4), kMaxIntegerValue); builder4.AddTerm(IntegerVariable(0), IntegerValue(5)); builder4.AddTerm(IntegerVariable(2), IntegerValue(10)); builder4.AddConstant(IntegerValue(3)); - EXPECT_EQ(builder4.Build().DebugString(), "1 <= 5*X0 10*X1"); + EXPECT_EQ(builder4.Build().DebugString(), "1 <= 5*I0 10*I1"); LinearConstraintBuilder builder5(&model, IntegerValue(4), IntegerValue(10)); builder5.AddTerm(IntegerVariable(0), IntegerValue(5)); builder5.AddTerm(IntegerVariable(2), IntegerValue(10)); builder5.AddConstant(IntegerValue(3)); - EXPECT_EQ(builder5.Build().DebugString(), "1 <= 5*X0 10*X1 <= 7"); + EXPECT_EQ(builder5.Build().DebugString(), "1 <= 5*I0 10*I1 <= 7"); } TEST(CleanTermsAndFillConstraintTest, VarAndItsNegation) { @@ -321,7 +321,7 @@ TEST(CleanTermsAndFillConstraintTest, VarAndItsNegation) { terms.push_back({IntegerVariable(5), IntegerValue(4)}); LinearConstraint constraint; CleanTermsAndFillConstraint(&terms, &constraint); - EXPECT_EQ(constraint.DebugString(), "0 <= 3*X2 <= 0"); + EXPECT_EQ(constraint.DebugString(), "0 <= 3*I2 <= 0"); } TEST(LinearConstraintBuilderTest, AddQuadraticLowerBound) { @@ -333,7 +333,7 @@ TEST(LinearConstraintBuilderTest, AddQuadraticLowerBound) { LinearConstraintBuilder builder1(&model, kMinIntegerValue, IntegerValue(10)); AffineExpression a0(x0, IntegerValue(3), IntegerValue(2)); // 3 * x0 + 2. builder1.AddQuadraticLowerBound(a0, x1, integer_trail); - EXPECT_EQ(builder1.Build().DebugString(), "9*X0 8*X1 <= 28"); + EXPECT_EQ(builder1.Build().DebugString(), "9*I0 8*I1 <= 28"); } TEST(LinearConstraintBuilderTest, AddQuadraticLowerBoundAffineIsVar) { @@ -344,7 +344,7 @@ TEST(LinearConstraintBuilderTest, AddQuadraticLowerBoundAffineIsVar) { IntegerVariable x1 = model.Add(NewIntegerVariable(3, 6)); LinearConstraintBuilder builder1(&model, kMinIntegerValue, IntegerValue(10)); builder1.AddQuadraticLowerBound(x0, x1, integer_trail); - EXPECT_EQ(builder1.Build().DebugString(), "3*X0 2*X1 <= 16"); + EXPECT_EQ(builder1.Build().DebugString(), "3*I0 2*I1 <= 16"); } TEST(LinearConstraintBuilderTest, AddQuadraticLowerBoundAffineIsConstant) { @@ -354,7 +354,7 @@ TEST(LinearConstraintBuilderTest, AddQuadraticLowerBoundAffineIsConstant) { IntegerVariable x0 = model.Add(NewIntegerVariable(2, 5)); LinearConstraintBuilder builder1(&model, kMinIntegerValue, IntegerValue(10)); builder1.AddQuadraticLowerBound(IntegerValue(4), x0, integer_trail); - EXPECT_EQ(builder1.Build().DebugString(), "4*X0 <= 10"); + EXPECT_EQ(builder1.Build().DebugString(), "4*I0 <= 10"); } TEST(LinExprTest, Bounds) { diff --git a/ortools/sat/linear_propagation_test.cc b/ortools/sat/linear_propagation_test.cc index 24b988d530..789aaf8874 100644 --- a/ortools/sat/linear_propagation_test.cc +++ b/ortools/sat/linear_propagation_test.cc @@ -15,8 +15,13 @@ #include +#include #include +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/random/distributions.h" +#include "absl/random/random.h" #include "absl/types/span.h" #include "gtest/gtest.h" #include "ortools/sat/integer.h" @@ -200,6 +205,71 @@ TEST(ReifiedWeightedSumTest, BoundToReifFalseLe) { EXPECT_FALSE(model.Get(Value(r))); } +TEST(AddWeightedSumLowerOrEqual, RandomTest) { + const int kNumTests = 10000; + absl::BitGen random; + for (int test = 0; test < kNumTests; ++test) { + const int num_variables = absl::Uniform(random, 1, 20); + std::vector solution(num_variables, 0); + for (int i = 0; i < num_variables; ++i) { + solution[i] = absl::Uniform(random, 0, 100); + } + Model model; + std::vector all_variables(num_variables); + std::vector all_variables_idx(num_variables); + for (int i = 0; i < num_variables; ++i) { + all_variables_idx[i] = i; + all_variables[i] = model.Add( + NewIntegerVariable(solution[i] - absl::Uniform(random, 0, 100), + solution[i] + absl::Uniform(random, 0, 100))); + } + const int num_constraints = absl::Uniform(random, 1, 100); + for (int j = 0; j < num_constraints; ++j) { + const int num_vars = absl::Uniform(random, 1, num_variables); + std::vector var_idx; + absl::c_sample(all_variables_idx, std::back_inserter(var_idx), num_vars, + random); + std::vector vars(num_vars); + for (int k = 0; k < num_vars; ++k) { + vars[k] = all_variables[var_idx[k]]; + } + std::vector coeffs(num_vars); + int64_t activity = 0; + for (int k = 0; k < num_vars; ++k) { + coeffs[k] = absl::Uniform(random, -10, 9); + if (coeffs[k] == 0) coeffs[k]++; + activity += coeffs[k] * solution[var_idx[k]]; + } + CHECK_EQ(coeffs.size(), vars.size()); + AddWeightedSumLowerOrEqual( + vars, coeffs, activity + absl::Uniform(random, 0, 40), &model); + if (absl::Bernoulli(random, 0.1)) { + CHECK(model.GetOrCreate()->Propagate()); + } + if (absl::Bernoulli(random, 0.1)) { + CHECK(model.GetOrCreate()->Propagate()); + } + if (absl::Bernoulli(random, 0.1)) { + IntegerTrail* integer_trail = model.GetOrCreate(); + const int var_idx = absl::Uniform(random, 0, num_variables); + const IntegerVariable var = all_variables[var_idx]; + if (absl::Bernoulli(random, 0.5)) { + if (integer_trail->UpperBound(var) > solution[var_idx]) { + CHECK(integer_trail->Enqueue(IntegerLiteral::LowerOrEqual( + var, integer_trail->UpperBound(var) - 1))); + } + } else { + if (integer_trail->LowerBound(var) < solution[var_idx]) { + CHECK(integer_trail->Enqueue(IntegerLiteral::GreaterOrEqual( + var, integer_trail->LowerBound(var) + 1))); + } + } + } + CHECK(!model.GetOrCreate()->ModelIsUnsat()); + } + } +} + } // namespace } // namespace sat } // namespace operations_research diff --git a/ortools/sat/linear_relaxation.cc b/ortools/sat/linear_relaxation.cc index f697e2f8b9..8469497f50 100644 --- a/ortools/sat/linear_relaxation.cc +++ b/ortools/sat/linear_relaxation.cc @@ -886,7 +886,7 @@ void AddCumulativeRelaxation(const AffineExpression& capacity, if (sizes_gcd != 1 && !makespan.has_value()) { VLOG(2) << "Cumulative relaxation: sizes_gcd = " << sizes_gcd << ", demands_gcd = " << demands_gcd - << ", no makespan, capacity is " << capacity.DebugString(); + << ", no makespan, capacity is " << capacity; // We can simplify the capacity only if it is fixed. // TODO(user): We could use (capacity / demands_gcd) * demands_gcd. if (!integer_trail->IsFixed(capacity)) demands_gcd = 1; diff --git a/ortools/sat/linear_relaxation_test.cc b/ortools/sat/linear_relaxation_test.cc index b8f73b7739..bafed6b26c 100644 --- a/ortools/sat/linear_relaxation_test.cc +++ b/ortools/sat/linear_relaxation_test.cc @@ -70,7 +70,7 @@ TEST(AppendRelaxationForEqualityEncodingTest, DomainOfSize2) { // The variable (0) is equal to 8 - 4 * [var == 4]. EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), - "8 <= 1*X0 4*X1 <= 8"); + "8 <= 1*I0 4*I1 <= 8"); } // Convert the at_most_one to a linear constraint and call DebugString(). @@ -104,13 +104,13 @@ TEST(AppendRelaxationForEqualityEncodingTest, DomainOfSize4) { EXPECT_EQ(relaxation.linear_constraints.size(), 2); EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), - "1 <= 1*X1 1*X2 1*X3 1*X4"); + "1 <= 1*I1 1*I2 1*I3 1*I4"); EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), - "1 <= 1*X0 -4*X2 -7*X3 -8*X4 <= 1"); + "1 <= 1*I0 -4*I2 -7*I3 -8*I4 <= 1"); EXPECT_EQ(relaxation.at_most_ones.size(), 1); EXPECT_EQ(AtMostOneAsString(relaxation.at_most_ones[0], &model), - "1*X1 1*X2 1*X3 1*X4 <= 1"); + "1*I1 1*I2 1*I3 1*I4 <= 1"); } TEST(AppendRelaxationForEqualityEncodingTest, PartialEncoding) { @@ -140,13 +140,13 @@ TEST(AppendRelaxationForEqualityEncodingTest, PartialEncoding) { EXPECT_EQ(relaxation.linear_constraints.size(), 2); EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), - "2 <= 1*X0 2*X1 1*X2 -3*X3"); + "2 <= 1*I0 2*I1 1*I2 -3*I3"); EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), - "1*X0 10*X1 9*X2 5*X3 <= 10"); + "1*I0 10*I1 9*I2 5*I3 <= 10"); EXPECT_EQ(relaxation.at_most_ones.size(), 1); EXPECT_EQ(AtMostOneAsString(relaxation.at_most_ones[0], &model), - "1*X1 1*X2 1*X3 <= 1"); + "1*I1 1*I2 1*I3 <= 1"); } TEST(AppendPartialGreaterThanEncodingRelaxationTest, FullEncoding) { @@ -167,17 +167,17 @@ TEST(AppendPartialGreaterThanEncodingRelaxationTest, FullEncoding) { // The implications. EXPECT_EQ(relaxation.at_most_ones.size(), 2); EXPECT_EQ(AtMostOneAsString(relaxation.at_most_ones[0], &model), - "-1*X1 1*X2 <= 0"); + "-1*I1 1*I2 <= 0"); EXPECT_EQ(AtMostOneAsString(relaxation.at_most_ones[1], &model), - "-1*X2 1*X3 <= 0"); + "-1*I2 1*I3 <= 0"); // The "diffs" are 4,3,1. // Because here we have a full encoding, we actually have == 1. EXPECT_EQ(relaxation.linear_constraints.size(), 2); EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), - "1 <= 1*X0 -4*X1 -3*X2 -1*X3"); + "1 <= 1*I0 -4*I1 -3*I2 -1*I3"); EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), - "-1 <= -1*X0 4*X1 3*X2 1*X3"); + "-1 <= -1*I0 4*I1 3*I2 1*I3"); } TEST(AppendPartialGreaterThanEncodingRelaxationTest, PartialEncoding) { @@ -203,19 +203,19 @@ TEST(AppendPartialGreaterThanEncodingRelaxationTest, PartialEncoding) { // The implications. EXPECT_EQ(relaxation.at_most_ones.size(), 2); EXPECT_EQ(AtMostOneAsString(relaxation.at_most_ones[0], &model), - "-1*X1 1*X2 <= 0"); + "-1*I1 1*I2 <= 0"); EXPECT_EQ(AtMostOneAsString(relaxation.at_most_ones[1], &model), - "-1*X2 1*X3 <= 0"); + "-1*I2 1*I3 <= 0"); // The first constraint is var >= 0 + (>=1) + (>=2) + 4*(>=6) EXPECT_EQ(relaxation.linear_constraints.size(), 2); EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), - "0 <= 1*X0 -1*X1 -1*X2 -4*X3"); + "0 <= 1*I0 -1*I1 -1*I2 -4*I3"); // The second is var <= (>=1) + 4*(>=2) + 5*(>=6) which gives the bounds // <=0,<=1,<=5 and <=10. EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), - "0 <= -1*X0 1*X1 4*X2 5*X3"); + "0 <= -1*I0 1*I1 4*I2 5*I3"); } TEST(TryToLinearizeConstraint, BoolOr) { @@ -238,7 +238,7 @@ TEST(TryToLinearizeConstraint, BoolOr) { EXPECT_EQ(relaxation.linear_constraints.size(), 1); EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), - "-1 <= -1*X0 -1*X1 1*X2"); + "-1 <= -1*I0 -1*I1 1*I2"); } TEST(TryToLinearizeConstraint, BoolOrLevel1) { @@ -283,9 +283,9 @@ TEST(TryToLinearizeConstraint, BoolAndSingleEnforcement) { EXPECT_EQ(relaxation.at_most_ones.size(), 2); EXPECT_EQ(AtMostOneAsString(relaxation.at_most_ones[0], &model), - "1*X0 1*X1 <= 1"); + "1*I0 1*I1 <= 1"); EXPECT_EQ(AtMostOneAsString(relaxation.at_most_ones[1], &model), - "1*X0 -1*X2 <= 0"); + "1*I0 -1*I2 <= 0"); } TEST(TryToLinearizeConstraint, BoolAndMultipleEnforcement) { @@ -307,12 +307,12 @@ TEST(TryToLinearizeConstraint, BoolAndMultipleEnforcement) { TryToLinearizeConstraint(initial_model, initial_model.constraints(0), /*linearization_level=*/2, &model, &relaxation); - // X0 & X3 => X2 ==1 & not(X1) == 1; + // I0 & I3 => I2 ==1 & not(I1) == 1; EXPECT_EQ(relaxation.linear_constraints.size(), 2); EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), - "1*X0 1*X1 1*X3 <= 2"); + "1*I0 1*I1 1*I3 <= 2"); EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), - "1*X0 -1*X2 1*X3 <= 1"); + "1*I0 -1*I2 1*I3 <= 1"); } TEST(TryToLinearizeConstraint, BoolAndNoEnforcement) { @@ -382,9 +382,9 @@ TEST(TryToLinearizeConstraint, LinMaxLevel1Bis) { /*linearization_level=*/1, &model, &relaxation); EXPECT_EQ(relaxation.linear_constraints.size(), 3); - EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), "1*X0 -1*X3 <= 0"); - EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), "1*X1 -1*X3 <= 0"); - EXPECT_EQ(relaxation.linear_constraints[2].DebugString(), "-1*X2 -1*X3 <= 0"); + EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), "1*I0 -1*I3 <= 0"); + EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), "1*I1 -1*I3 <= 0"); + EXPECT_EQ(relaxation.linear_constraints[2].DebugString(), "-1*I2 -1*I3 <= 0"); } TEST(TryToLinearizeConstraint, EnforcedLinMaxLevel2) { @@ -415,11 +415,11 @@ TEST(TryToLinearizeConstraint, EnforcedLinMaxLevel2) { EXPECT_EQ(relaxation.linear_constraints.size(), 3); EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), - "1*X0 -1*X3 10*X4 <= 10"); + "1*I0 -1*I3 10*I4 <= 10"); EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), - "1*X1 -1*X3 12*X4 <= 12"); + "1*I1 -1*I3 12*I4 <= 12"); EXPECT_EQ(relaxation.linear_constraints[2].DebugString(), - "-1*X2 -1*X3 7*X4 <= 7"); + "-1*I2 -1*I3 7*I4 <= 7"); } TEST(TryToLinearizeConstraint, LinMaxSmall) { @@ -445,8 +445,8 @@ TEST(TryToLinearizeConstraint, LinMaxSmall) { // Take into account the constraints added by the cut generator. EXPECT_GE(relaxation.linear_constraints.size(), 2); - EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), "1*X0 -1*X2 <= 0"); - EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), "1*X1 -1*X2 <= 0"); + EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), "1*I0 -1*I2 <= 0"); + EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), "1*I1 -1*I2 <= 0"); } TEST(TryToLinearizeConstraint, IntSquare) { @@ -471,10 +471,10 @@ TEST(TryToLinearizeConstraint, IntSquare) { EXPECT_EQ(relaxation.linear_constraints.size(), 3); EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), - "-11*X0 1*X1 <= -10"); - EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), "-2 <= -3*X0 1*X1"); + "-11*I0 1*I1 <= -10"); + EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), "-2 <= -3*I0 1*I1"); EXPECT_EQ(relaxation.linear_constraints[2].DebugString(), - "-90 <= -19*X0 1*X1"); + "-90 <= -19*I0 1*I1"); } TEST(TryToLinearizeConstraint, IntAbs) { @@ -498,10 +498,10 @@ TEST(TryToLinearizeConstraint, IntAbs) { /*linearization_level=*/1, &model, &relaxation); EXPECT_EQ(relaxation.linear_constraints.size(), 3); - EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), "-1*X0 1*X1 <= 0"); - EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), "-1*X0 -1*X1 <= 0"); + EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), "-1*I0 1*I1 <= 0"); + EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), "-1*I0 -1*I1 <= 0"); EXPECT_EQ(relaxation.linear_constraints[2].DebugString(), - "50*X0 -10*X1 <= 1200"); + "50*I0 -10*I1 <= 1200"); } TEST(TryToLinearizeConstraint, LinMaxLevel1) { @@ -544,9 +544,9 @@ TEST(TryToLinearizeConstraint, LinMaxLevel1) { /*linearization_level=*/1, &model, &relaxation); EXPECT_EQ(relaxation.linear_constraints.size(), 3); - EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), "-1*X0 2*X1 <= 2"); - EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), "-1*X0 -1*X2 <= 1"); - EXPECT_EQ(relaxation.linear_constraints[2].DebugString(), "-1*X0 3*X3 <= 0"); + EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), "-1*I0 2*I1 <= 2"); + EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), "-1*I0 -1*I2 <= 1"); + EXPECT_EQ(relaxation.linear_constraints[2].DebugString(), "-1*I0 3*I3 <= 0"); } TEST(AppendLinMaxRelaxation, BasicBehavior) { @@ -575,13 +575,13 @@ TEST(AppendLinMaxRelaxation, BasicBehavior) { EXPECT_EQ(literals.size(), 3); ASSERT_EQ(relaxation.linear_constraints.size(), 4); EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), - "1 <= 1*X4 1*X5 1*X6 <= 1"); + "1 <= 1*I4 1*I5 1*I6 <= 1"); EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), - "-1*X0 1*X3 -7*X5 -2*X6 <= 0"); + "-1*I0 1*I3 -7*I5 -2*I6 <= 0"); EXPECT_EQ(relaxation.linear_constraints[2].DebugString(), - "-1*X1 1*X3 -6*X4 -3*X6 <= 0"); + "-1*I1 1*I3 -6*I4 -3*I6 <= 0"); EXPECT_EQ(relaxation.linear_constraints[3].DebugString(), - "1*X2 1*X3 -14*X4 -16*X5 <= 0"); + "1*I2 1*I3 -14*I4 -16*I5 <= 0"); } TEST(AppendLinMaxRelaxation, BasicBehaviorExprs) { @@ -608,13 +608,13 @@ TEST(AppendLinMaxRelaxation, BasicBehaviorExprs) { EXPECT_EQ(literals.size(), 3); ASSERT_EQ(relaxation.linear_constraints.size(), 4); EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), - "1 <= 1*X3 1*X4 1*X5 <= 1"); + "1 <= 1*I3 1*I4 1*I5 <= 1"); EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), - "1*X2 -1*X3 -3*X4 -2*X5 <= 0"); + "1*I2 -1*I3 -3*I4 -2*I5 <= 0"); EXPECT_EQ(relaxation.linear_constraints[2].DebugString(), - "1*X0 2*X1 1*X2 -4*X3 -3*X5 <= 0"); + "1*I0 2*I1 1*I2 -4*I3 -3*I5 <= 0"); EXPECT_EQ(relaxation.linear_constraints[3].DebugString(), - "1*X0 -1*X1 1*X2 -3*X3 -3*X4 <= 0"); + "1*I0 -1*I1 1*I2 -3*I3 -3*I4 <= 0"); } TEST(AppendLinMaxRelaxation, BasicBehaviorExprs2) { @@ -646,13 +646,13 @@ TEST(AppendLinMaxRelaxation, BasicBehaviorExprs2) { EXPECT_EQ(literals.size(), 3); ASSERT_EQ(relaxation.linear_constraints.size(), 4); EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), - "1 <= 1*X4 1*X5 1*X6 <= 1"); + "1 <= 1*I4 1*I5 1*I6 <= 1"); EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), - "2*X0 3*X1 -1*X3 -5*X4 -9*X5 -9*X6 <= 0"); + "2*I0 3*I1 -1*I3 -5*I4 -9*I5 -9*I6 <= 0"); EXPECT_EQ(relaxation.linear_constraints[2].DebugString(), - "2*X1 5*X2 -1*X3 2*X4 6*X5 2*X6 <= 0"); + "2*I1 5*I2 -1*I3 2*I4 6*I5 2*I6 <= 0"); EXPECT_EQ(relaxation.linear_constraints[3].DebugString(), - "2*X0 3*X2 -1*X3 -2*X4 -2*X5 <= 0"); + "2*I0 3*I2 -1*I3 -2*I4 -2*I5 <= 0"); } void AppendNoOverlapRelaxation(const ConstraintProto& ct, Model* model, @@ -705,7 +705,7 @@ TEST(AppendNoOverlapRelaxation, IntersectingIntervals) { AppendNoOverlapRelaxation(initial_model.constraints(0), &model, &relaxation); EXPECT_EQ(relaxation.linear_constraints.size(), 1); - EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), "1*X1 1*X4 <= 12"); + EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), "1*I1 1*I4 <= 12"); } TEST(AppendNoOverlapRelaxation, NoIntersection) { @@ -740,7 +740,7 @@ TEST(AppendNoOverlapRelaxation, NoIntersection) { AppendNoOverlapRelaxation(initial_model.constraints(2), &model, &relaxation); EXPECT_EQ(relaxation.linear_constraints.size(), 1); - EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), "1*X4 <= 11"); + EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), "1*I4 <= 11"); } TEST(AppendNoOverlapRelaxation, IntervalWithEnforcement) { @@ -777,7 +777,7 @@ TEST(AppendNoOverlapRelaxation, IntervalWithEnforcement) { AppendNoOverlapRelaxation(initial_model.constraints(2), &model, &relaxation); EXPECT_EQ(relaxation.linear_constraints.size(), 1); - EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), "1*X1 1*X6 <= 10"); + EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), "1*I1 1*I6 <= 10"); } TEST(AppendNoOverlapRelaxation, ZeroMinEnergy) { @@ -909,7 +909,7 @@ TEST(AppendCumulativeRelaxation, GcdOnFixedDemandsSizesAndCapacity) { AppendCumulativeRelaxation(initial_model.constraints(3), &model, &relaxation); EXPECT_EQ(relaxation.linear_constraints.size(), 1); - EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), "4*X3 1*X4 <= 6"); + EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), "4*I3 1*I4 <= 6"); } TEST(AppendCumulativeRelaxation, IgnoreZeroDemandOrSize) { @@ -980,7 +980,7 @@ TEST(AppendCumulativeRelaxation, IgnoreZeroDemandOrSize) { AppendCumulativeRelaxation(initial_model.constraints(5), &model, &relaxation); EXPECT_EQ(relaxation.linear_constraints.size(), 1); - EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), "4*X3 1*X4 <= 6"); + EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), "4*I3 1*I4 <= 6"); } TEST(AppendLinearConstraintRelaxation, NoEnforcementLiteral) { @@ -1007,7 +1007,7 @@ TEST(AppendLinearConstraintRelaxation, NoEnforcementLiteral) { EXPECT_EQ(relaxation.linear_constraints.size(), 1); EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), - "3 <= 2*X0 1*X2 <= 4"); + "3 <= 2*I0 1*I2 <= 4"); } TEST(AppendLinearConstraintRelaxation, SmallLinearizationLevel) { @@ -1056,7 +1056,7 @@ TEST(AppendLinearConstraintRelaxation, PbConstraint) { &relaxation); EXPECT_EQ(relaxation.linear_constraints.size(), 1); EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), - "3 <= 2*X0 1*X1 3*X2 <= 5"); + "3 <= 2*I0 1*I1 3*I2 <= 5"); } TEST(AppendLinearConstraintRelaxation, SmallConstraint) { @@ -1109,7 +1109,7 @@ TEST(AppendLinearConstraintRelaxation, SingleEnforcementLiteralLowerBound) { EXPECT_EQ(relaxation.linear_constraints.size(), 1); EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), - "0 <= 2*X0 -3*X1 1*X2"); + "0 <= 2*I0 -3*I1 1*I2"); } TEST(AppendLinearConstraintRelaxation, SingleEnforcementLiteralUpperBound) { @@ -1137,7 +1137,7 @@ TEST(AppendLinearConstraintRelaxation, SingleEnforcementLiteralUpperBound) { EXPECT_EQ(relaxation.linear_constraints.size(), 1); EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), - "2*X0 1*X1 1*X2 <= 4"); + "2*I0 1*I1 1*I2 <= 4"); } TEST(AppendLinearConstraintRelaxation, SingleEnforcementLiteralBothBounds) { @@ -1165,9 +1165,9 @@ TEST(AppendLinearConstraintRelaxation, SingleEnforcementLiteralBothBounds) { EXPECT_EQ(relaxation.linear_constraints.size(), 2); EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), - "0 <= 2*X0 -2*X1 1*X2"); + "0 <= 2*I0 -2*I1 1*I2"); EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), - "2*X0 1*X1 1*X2 <= 4"); + "2*I0 1*I1 1*I2 <= 4"); } TEST(AppendLinearConstraintRelaxation, MultipleEnforcementLiteral) { @@ -1197,13 +1197,13 @@ TEST(AppendLinearConstraintRelaxation, MultipleEnforcementLiteral) { EXPECT_EQ(relaxation.linear_constraints.size(), 2); EXPECT_EQ(relaxation.linear_constraints[0].DebugString(), - "-4 <= 2*X0 -2*X1 1*X2 -2*X3 -2*X4"); + "-4 <= 2*I0 -2*I1 1*I2 -2*I3 -2*I4"); EXPECT_EQ(relaxation.linear_constraints[1].DebugString(), - "2*X0 1*X1 1*X2 1*X3 1*X4 <= 6"); + "2*I0 1*I1 1*I2 1*I3 1*I4 <= 6"); } // This used to generate the completely wrong constraint: -// 1*X0 -8*X1 1*X2 -8*X3 <= -6 before. +// 1*I0 -8*I1 1*I2 -8*I3 <= -6 before. TEST(AppendLinearConstraintRelaxation, BoundsNotTight) { const CpModelProto initial_model = ParseTestProto(R"pb( variables { domain: [ 0, 1 ] } diff --git a/ortools/sat/model.h b/ortools/sat/model.h index 88f8140a32..f457a2d623 100644 --- a/ortools/sat/model.h +++ b/ortools/sat/model.h @@ -24,9 +24,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" -#include "absl/meta/type_traits.h" #include "ortools/base/logging.h" -#include "ortools/base/typeid.h" namespace operations_research { namespace sat { @@ -38,6 +36,25 @@ namespace sat { * constraints, watchers, solvers and provide a mechanism to wire them together. */ class Model { + // FastTypeId() evaluates at compile/link-time to a unique integer for + // the passed in type. Their values are neither contiguous nor small, making + // them unfit for using as an index into a vector, but a good match for keys + // into maps or straight up comparisons. Note that on 64-bit (unix) systems + // size_t is 64-bit while int is 32-bit and the compiler will happily and + // quietly assign such a 64-bit value to a 32-bit integer. While a client + // should never do that it SHOULD still be safe, assuming the BSS segment + // doesn't span more than 4GiB. + template + static inline size_t FastTypeId() { + static_assert(sizeof(char*) <= sizeof(size_t), + "ptr size too large for size_t"); + + // This static variable isn't actually used, only its address, so there are + // no concurrency issues. + static char dummy_var; + return reinterpret_cast(&dummy_var); + } + public: Model() = default; @@ -110,7 +127,7 @@ class Model { */ template T* GetOrCreate() { - const size_t type_id = gtl::FastTypeId(); + const size_t type_id = FastTypeId(); auto find = singletons_.find(type_id); if (find != singletons_.end()) { return static_cast(find->second); @@ -131,7 +148,7 @@ class Model { */ template const T* Get() const { - const auto& it = singletons_.find(gtl::FastTypeId()); + const auto& it = singletons_.find(FastTypeId()); return it != singletons_.end() ? static_cast(it->second) : nullptr; } @@ -141,7 +158,7 @@ class Model { */ template T* Mutable() const { - const auto& it = singletons_.find(gtl::FastTypeId()); + const auto& it = singletons_.find(FastTypeId()); return it != singletons_.end() ? static_cast(it->second) : nullptr; } @@ -175,7 +192,7 @@ class Model { */ template void Register(T* non_owned_class) { - const size_t type_id = gtl::FastTypeId(); + const size_t type_id = FastTypeId(); CHECK(!singletons_.contains(type_id)); singletons_[type_id] = non_owned_class; } diff --git a/ortools/sat/optimization.cc b/ortools/sat/optimization.cc index cbdc1949cf..a603761c57 100644 --- a/ortools/sat/optimization.cc +++ b/ortools/sat/optimization.cc @@ -234,6 +234,17 @@ SatSolver::Status MinimizeIntegerVariableWithLinearScanAndLazyEncoding( return SatSolver::LIMIT_REACHED; } + // The solver usually always solve a "restricted decision problem" + // obj < current_best. So when we have an optimal solution, then the + // problem is UNSAT, and any clauses we learn can break the debug solution. + // So we disable this checks once we found an optimal solution. + if (DEBUG_MODE) { + const DebugSolution* debug_sol = model->Get(); + if (debug_sol && objective <= debug_sol->inner_objective_value) { + model->GetOrCreate()->Clear(); + } + } + // Restrict the objective. sat_solver->Backtrack(0); if (!integer_trail->Enqueue( diff --git a/ortools/sat/pb_constraint.cc b/ortools/sat/pb_constraint.cc index 0e8616f703..7fc0c923db 100644 --- a/ortools/sat/pb_constraint.cc +++ b/ortools/sat/pb_constraint.cc @@ -22,7 +22,6 @@ #include #include -#include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/hash/hash.h" #include "absl/log/check.h" @@ -599,7 +598,8 @@ bool UpperBoundedLinearConstraint::Propagate( } else { // Conflict. FillReason(*trail, trail_index, enforcement_literals, - /*propagated_variable=*/kNoBooleanVariable, &helper->conflict); + /*propagated_variable=*/kNoBooleanVariable, + &helper->temporary_tuples, &helper->conflict); return false; } } @@ -614,11 +614,14 @@ bool UpperBoundedLinearConstraint::Propagate( for (int i = starts_[index_ + 1]; i < already_propagated_end_; ++i) { if (trail->Assignment().LiteralIsFalse(literals_[i])) continue; if (trail->Assignment().LiteralIsTrue(literals_[i])) { - if (trail->Info(literals_[i].Variable()).trail_index > trail_index) { + const int literal_trail_index = + trail->Info(literals_[i].Variable()).trail_index; + if (literal_trail_index > trail_index) { if (enforcement_status == EnforcementStatus::IS_ENFORCED) { // Conflict. FillReason(*trail, trail_index, enforcement_literals, - literals_[i].Variable(), &helper->conflict); + literals_[i].Variable(), &helper->temporary_tuples, + &helper->conflict); helper->conflict.push_back(literals_[i].Negated()); Update(slack, threshold); return false; @@ -626,7 +629,8 @@ bool UpperBoundedLinearConstraint::Propagate( // Propagate the unique unassigned enforcement literal. for (const Literal literal : enforcement_literals) { if (!trail->Assignment().LiteralIsAssigned(literal)) { - helper->Enqueue(literal.Negated(), trail_index, this, trail); + helper->Enqueue(literal.Negated(), literal_trail_index, this, + trail); break; } } @@ -659,7 +663,9 @@ bool UpperBoundedLinearConstraint::Propagate( void UpperBoundedLinearConstraint::FillReason( const Trail& trail, int source_trail_index, absl::Span enforcement_literals, - BooleanVariable propagated_variable, std::vector* reason) { + BooleanVariable propagated_variable, + std::vector>* temporary_tuples, + std::vector* reason) { bool enforcement_propagation = false; reason->clear(); for (const Literal literal : enforcement_literals) { @@ -669,29 +675,36 @@ void UpperBoundedLinearConstraint::FillReason( enforcement_propagation = true; } } - const int enforcement_reason_size = reason->size(); - - Coefficient slack = rhs_; - Coefficient propagated_variable_coefficient(0); - Literal extra_literal_reason; - // Optimization: This will be set to the index of the last literal in the - // reason (they are sorted in decreasing coefficient order). - int last_i = 0; - int last_coeff_index = 0; // propagated_variable is set to kNoBooleanVariable when the constraint // becomes enforced when the slack is already negative. In this case, or when - // the enforcement can be propagated, the reason must include all the terms - // explaining the negative slack. The code below does that (while the 'else' - // part computes a reason for propagated_variable). extra_literal_reason is - // the literal which makes the slack become negative. - if (enforcement_propagation || propagated_variable == kNoBooleanVariable) { + // the enforcement can be propagated, the reason must include the literal + // which makes the slack become negative, called the "extra literal reason". + const bool add_extra_literal_reason = + enforcement_propagation || propagated_variable == kNoBooleanVariable; + + // Optimization for an "at most one" constraint. Note that the + // source_trail_index set by InitializeRhs() is ok in this case. + if (rhs_ == 1 && !add_extra_literal_reason) { + reason->push_back(trail[source_trail_index].Negated()); + return; + } + + // Compute all the literals of the constraint that were assigned to true at + // the time of the propagation, sorted by trail index. + // Vector of (trail_index, literal_index, coeff_index) tuples. + std::vector>& true_literals = *temporary_tuples; + true_literals.clear(); + Coefficient propagated_variable_coefficient(0); + { int literal_index = 0; int coeff_index = 0; - // Vector of (trail_index, literal_index, coeff_index) tuples. - std::vector> true_literals; for (Literal literal : literals_) { - if (trail.Assignment().LiteralIsTrue(literal)) { + if (literal.Variable() == propagated_variable) { + propagated_variable_coefficient = coeffs_[coeff_index]; + } + if (trail.Assignment().LiteralIsTrue(literal) && + trail.Info(literal.Variable()).trail_index <= source_trail_index) { true_literals.push_back({trail.Info(literal.Variable()).trail_index, literal_index, coeff_index}); } @@ -699,101 +712,54 @@ void UpperBoundedLinearConstraint::FillReason( if (literal_index == starts_[coeff_index + 1]) ++coeff_index; } std::sort(true_literals.begin(), true_literals.end()); - // Vector of (literal_index, coeff_index) pairs. - std::vector> reason_indices; - for (const auto& [trail_index, literal_index, coeff_index] : - true_literals) { - const Literal literal = literals_[literal_index]; - const Coefficient coeff = coeffs_[coeff_index]; - if (coeff > slack) { - propagated_variable_coefficient = coeff; - // This literal is added to the reason at the very end (see the cleanup - // below) because the code minimizing the reason assumes that it is not - // part of it. Another solution would be insert it at the beginning (and - // to increment enforcement_reason_size), but this is less efficient. - extra_literal_reason = literal.Negated(); - break; - } - if (trail.Info(literal.Variable()).level > 0) { - reason_indices.push_back({literal_index, coeff_index}); - } - slack -= coeff.value(); - } - std::sort(reason_indices.begin(), reason_indices.end(), std::greater<>()); - for (const auto& [literal_index, coeff_index] : reason_indices) { - reason->push_back(literals_[literal_index].Negated()); - } - if (!reason_indices.empty()) { - last_i = reason_indices.back().first; - last_coeff_index = reason_indices.back().second; - } - } else { - // Optimization for an "at most one" constraint. Note that the - // source_trail_index set by InitializeRhs() is ok in this case. - if (rhs_ == 1) { - reason->push_back(trail[source_trail_index].Negated()); - return; - } - - // Compute the initial reason which is formed by all the literals of the - // constraint that were assigned to true at the time of the propagation. - // We remove literals with a level of 0 since they are not needed. - // We also compute the slack at the time. - int coeff_index = coeffs_.size() - 1; - for (int i = literals_.size() - 1; i >= 0; --i) { - const Literal literal = literals_[i]; - if (literal.Variable() == propagated_variable) { - propagated_variable_coefficient = coeffs_[coeff_index]; - } else { - if (trail.Assignment().LiteralIsTrue(literal) && - trail.Info(literal.Variable()).trail_index <= source_trail_index) { - if (trail.Info(literal.Variable()).level > 0) { - reason->push_back(literal.Negated()); - last_i = i; - last_coeff_index = coeff_index; - } - slack -= coeffs_[coeff_index]; - } - } - if (i == starts_[coeff_index]) { - --coeff_index; - } - } } + + // Compute the initial reason which is formed by all the literals of the + // constraint that were assigned to true at the time of the propagation. We + // remove literals with a level of 0 since they are not needed. We also + // compute the slack at the time. + Coefficient slack = rhs_; + int new_size = 0; + for (int i = 0; i < true_literals.size(); ++i) { + auto [trail_index, literal_index, coeff_index] = true_literals[i]; + const Literal literal = literals_[literal_index]; + const Coefficient coeff = coeffs_[coeff_index]; + if (coeff > slack) { + propagated_variable_coefficient = coeff; + if (add_extra_literal_reason) { + reason->push_back(literal.Negated()); + } + break; + } + if (trail.Info(literal.Variable()).level > 0) { + true_literals[new_size++] = {trail_index, literal_index, coeff_index}; + } + slack -= coeff.value(); + } + true_literals.resize(new_size); DCHECK_GT(propagated_variable_coefficient, slack); DCHECK_GE(propagated_variable_coefficient, 0); - auto cleanup = absl::MakeCleanup([&] { - if (enforcement_propagation || propagated_variable == kNoBooleanVariable) { - reason->push_back(extra_literal_reason); - } - }); // In both cases, we can't minimize the reason further. - if (reason->size() <= enforcement_reason_size + 1 || coeffs_.size() == 1) { + if (true_literals.size() <= 1 || coeffs_.size() == 1) { + for (const auto& [unused1, literal_index, unused2] : true_literals) { + reason->push_back(literals_[literal_index].Negated()); + } return; } + // Remove literals with high trail indices from the reason as long as the + // limit is strictly positive. Coefficient limit = propagated_variable_coefficient - slack; DCHECK_GE(limit, 1); - - // Remove literals with small coefficients from the reason as long as the - // limit is still strictly positive. - int coeff_index = last_coeff_index; - if (coeffs_[coeff_index] >= limit) { - return; - } - for (int i = last_i; i < literals_.size(); ++i) { - const Literal literal = literals_[i]; - if (i == starts_[coeff_index + 1]) { - ++coeff_index; - if (coeffs_[coeff_index] >= limit) break; + for (int i = true_literals.size() - 1; i >= 0; --i) { + const auto [_, literal_index, coeff_index] = true_literals[i]; + const Coefficient coeff = coeffs_[coeff_index]; + if (coeff < limit) { + limit -= coeff.value(); + } else { + reason->push_back(literals_[literal_index].Negated()); } - DCHECK_GT(reason->size(), enforcement_reason_size); - if (literal.Negated() != reason->back()) continue; - limit -= coeffs_[coeff_index]; - reason->pop_back(); - if (reason->size() == enforcement_reason_size) break; - if (coeffs_[coeff_index] >= limit) break; } DCHECK_GE(limit, 1); } @@ -1175,7 +1141,7 @@ absl::Span PbConstraints::Reason(const Trail& trail, trail, reason_info.source_trail_index, enforcement_propagator_->GetEnforcementLiterals( reason_info.pb_constraint->enforcement_id()), - trail[trail_index].Variable(), reason); + trail[trail_index].Variable(), &enqueue_helper_.temporary_tuples, reason); return *reason; } diff --git a/ortools/sat/pb_constraint.h b/ortools/sat/pb_constraint.h index d70b6e4f72..1ceda586c1 100644 --- a/ortools/sat/pb_constraint.h +++ b/ortools/sat/pb_constraint.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "absl/container/flat_hash_map.h" @@ -372,6 +373,9 @@ struct PbConstraintsEnqueueHelper { UpperBoundedLinearConstraint* pb_constraint; }; std::vector reasons; + + // A temporary vector of tuples used in FillReason(). + mutable std::vector> temporary_tuples; }; // This class contains half the propagation logic for a constraint of the form @@ -444,7 +448,8 @@ class UpperBoundedLinearConstraint { // Provided that the literal with given source_trail_index was the one that // propagated the conflict or the literal we want to explain, then this will - // compute the reason. + // compute the reason. temporary_tuples is only used as a temporary storage to + // avoid allocating a vector at each call. // // Some properties of the reason: // - Literals of level 0 are removed. @@ -460,6 +465,7 @@ class UpperBoundedLinearConstraint { void FillReason(const Trail& trail, int source_trail_index, absl::Span enforcement_literals, BooleanVariable propagated_variable, + std::vector>* temporary_tuples, std::vector* reason); // Same operation as SatSolver::ResolvePBConflict(), the only difference is diff --git a/ortools/sat/pb_constraint_test.cc b/ortools/sat/pb_constraint_test.cc index 482adc92d9..b8415dbbda 100644 --- a/ortools/sat/pb_constraint_test.cc +++ b/ortools/sat/pb_constraint_test.cc @@ -17,7 +17,6 @@ #include #include -#include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/types/span.h" #include "gtest/gtest.h" @@ -335,11 +334,12 @@ TEST(UpperBoundedLinearConstraintTest, CompactReason) { EXPECT_EQ(trail.Index(), 4); EXPECT_EQ(trail[3], Literal(-4)); - // -1 do not need to be in the reason since {-3, -2} propagates exactly + // -2 do not need to be in the reason since {-3, -1} propagates exactly // the same way. cst.FillReason(trail, source_trail_index, /*enforcement_literals=*/{}, - Literal(-4).Variable(), &helper.conflict); - EXPECT_THAT(helper.conflict, LiteralsAre(-3, -2)); + Literal(-4).Variable(), &helper.temporary_tuples, + &helper.conflict); + EXPECT_THAT(helper.conflict, LiteralsAre(-3, -1)); } TEST(UpperBoundedLinearConstraintTest, ConflictAfterEnforcementStatusChange) { @@ -375,9 +375,9 @@ TEST(UpperBoundedLinearConstraintTest, ConflictAfterEnforcementStatusChange) { EnforcementStatus::IS_ENFORCED, enforcement_literals, &helper)); - // -1 do not need to be in the reason since {-4, -3, -2} propagates exactly + // -2 do not need to be in the reason since {-4, -3, -1} propagates exactly // the same way. - EXPECT_THAT(helper.conflict, LiteralsAre(-9, -3, -2, -4)); + EXPECT_THAT(helper.conflict, LiteralsAre(-9, -4, -3, -1)); } TEST(UpperBoundedLinearConstraintTest, PropagateEnforcementAfterStatusChange) { @@ -415,12 +415,13 @@ TEST(UpperBoundedLinearConstraintTest, PropagateEnforcementAfterStatusChange) { EXPECT_EQ(trail.Index(), 6); EXPECT_EQ(trail[5], Literal(-8)); - // -1 do not need to be in the reason since {-4, -3, -2} propagates exactly + // -2 do not need to be in the reason since {-4, -3, -1} propagates exactly // the same way. const PbConstraintsEnqueueHelper::ReasonInfo& reason = helper.reasons[5]; cst.FillReason(trail, reason.source_trail_index, enforcement_literals, - Literal(-8).Variable(), &helper.conflict); - EXPECT_THAT(helper.conflict, LiteralsAre(-9, -3, -2, -4)); + Literal(-8).Variable(), &helper.temporary_tuples, + &helper.conflict); + EXPECT_THAT(helper.conflict, LiteralsAre(-9, -4, -3, -1)); } TEST(UpperBoundedLinearConstraintTest, @@ -457,8 +458,9 @@ TEST(UpperBoundedLinearConstraintTest, const PbConstraintsEnqueueHelper::ReasonInfo& reason = helper.reasons[2]; cst.FillReason(trail, reason.source_trail_index, enforcement_literals, - Literal(-9).Variable(), &helper.conflict); - EXPECT_THAT(helper.conflict, LiteralsAre(-1, -2)); + Literal(-9).Variable(), &helper.temporary_tuples, + &helper.conflict); + EXPECT_THAT(helper.conflict, LiteralsAre(-2, -1)); } TEST(PbConstraintsTest, Duplicates) { @@ -511,8 +513,8 @@ TEST(PbConstraintsTest, BasicPropagation) { // Test the reason for each assignment. EXPECT_THAT(trail.Reason(Literal(-2).Variable()), LiteralsAre(+1)); - EXPECT_THAT(trail.Reason(Literal(-3).Variable()), LiteralsAre(+2, +1)); - EXPECT_THAT(trail.Reason(Literal(-4).Variable()), LiteralsAre(+3, +2, +1)); + EXPECT_THAT(trail.Reason(Literal(-3).Variable()), LiteralsAre(+1, +2)); + EXPECT_THAT(trail.Reason(Literal(-4).Variable()), LiteralsAre(+1, +2, +3)); // Untrail, and repropagate everything. csts.Untrail(trail, 0); diff --git a/ortools/sat/presolve_context.cc b/ortools/sat/presolve_context.cc index 85bd469bbc..a20c412361 100644 --- a/ortools/sat/presolve_context.cc +++ b/ortools/sat/presolve_context.cc @@ -148,6 +148,20 @@ int PresolveContext::GetTrueLiteral() { int PresolveContext::GetFalseLiteral() { return NegatedRef(GetTrueLiteral()); } +ConstraintProto* PresolveContext::AddEnforcedConstraint( + absl::Span enforcement_literals) { + ConstraintProto* const new_ct = working_model->add_constraints(); + *new_ct->mutable_enforcement_literal() = {enforcement_literals.begin(), + enforcement_literals.end()}; + return new_ct; +} + +ConstraintProto* PresolveContext::AddEnforcedConstraint(ConstraintProto* ct) { + ConstraintProto* const new_ct = working_model->add_constraints(); + *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); + return new_ct; +} + // a => b. void PresolveContext::AddImplication(int a, int b) { if (a == b) return; @@ -618,7 +632,9 @@ bool PresolveContext::ConstraintIsOptional(int ct_ref) const { return contains_one_free_literal; } -void PresolveContext::UpdateRuleStats(const std::string& name, int num_times) { +void PresolveContext::UpdateRuleStats(std::string_view name, int num_times) { + DCHECK(!name.empty()); + // Hack: we don't want to count TODO rules as this is used to decide if // we loop again. const bool is_todo = name.size() >= 4 && name.substr(0, 4) == "TODO"; diff --git a/ortools/sat/presolve_context.h b/ortools/sat/presolve_context.h index 0ba11cf43c..8985c23a25 100644 --- a/ortools/sat/presolve_context.h +++ b/ortools/sat/presolve_context.h @@ -134,6 +134,11 @@ class PresolveContext { int GetTrueLiteral(); int GetFalseLiteral(); + // Shortcuts to create enforced constraints. + ConstraintProto* AddEnforcedConstraint( + absl::Span enforcement_literals); + ConstraintProto* AddEnforcedConstraint(ConstraintProto* ct); + // a => b. void AddImplication(int a, int b); @@ -307,7 +312,7 @@ class PresolveContext { // Stores a description of a rule that was just applied to have a summary of // what the presolve did at the end. - void UpdateRuleStats(const std::string& name, int num_times = 1); + void UpdateRuleStats(std::string_view name, int num_times = 1); // Updates the constraints <-> variables graph. This needs to be called each // time a constraint is modified. diff --git a/ortools/sat/python/BUILD.bazel b/ortools/sat/python/BUILD.bazel index 0148f932c2..302c08eb8b 100644 --- a/ortools/sat/python/BUILD.bazel +++ b/ortools/sat/python/BUILD.bazel @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Python wrapper for cp_model. +# Description: python wrapping of the C++ code at ../ load("@pip_deps//:requirements.bzl", "requirement") load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") @@ -20,6 +20,7 @@ load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_python//python:py_library.bzl", "py_library") load("@rules_python//python:py_test.bzl", "py_test") +# This file is generated manually by running pybind11_mkdoc on linear_expr.h cc_library( name = "linear_expr_doc", hdrs = ["linear_expr_doc.h"], @@ -37,9 +38,10 @@ cc_library( "//ortools/util:sorted_interval_list", "@abseil-cpp//absl/container:btree", "@abseil-cpp//absl/container:fixed_array", - "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/hash", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/types:span", ], ) @@ -48,6 +50,7 @@ cc_library( srcs = ["wrappers.cc"], hdrs = ["wrappers.h"], deps = [ + "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/base:nullability", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:flat_hash_set", @@ -68,7 +71,10 @@ cc_binary( "//ortools/base", "//ortools/sat:cp_model_cc_proto", "//ortools/sat:sat_parameters_cc_proto", + "@abseil-cpp//absl/flags:parse", + "@abseil-cpp//absl/flags:usage", "@abseil-cpp//absl/log:die_if_null", + "@abseil-cpp//absl/log:initialize", "@abseil-cpp//absl/strings:str_format", ], ) @@ -93,22 +99,29 @@ pybind_extension( ":linear_expr", ":linear_expr_doc", ":proto_builder_pybind11", - "//ortools/base:string_view_migration", + "//ortools/port:proto_utils", "//ortools/sat:cp_model_cc_proto", "//ortools/sat:cp_model_utils", "//ortools/sat:sat_parameters_cc_proto", "//ortools/sat:swig_helper", + "//ortools/util:saturated_arithmetic", + "//ortools/util:sorted_interval_list", + "//ortools/util/python:sorted_interval_list", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/functional:any_invocable", + "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/strings", ], ) py_test( name = "cp_model_helper_test", + size = "small", srcs = ["cp_model_helper_test.py"], deps = [ ":cp_model_helper", - "//ortools/util/python:sorted_interval_list", requirement("absl-py"), + "//ortools/util/python:sorted_interval_list", ], ) @@ -126,11 +139,14 @@ py_library( py_test( name = "cp_model_test", + size = "small", srcs = ["cp_model_test.py"], + tags = ["noasan"], # Times out occasionally in ASAN mode. deps = [ ":cp_model", ":cp_model_helper", requirement("absl-py"), requirement("numpy"), + requirement("pandas"), ], ) diff --git a/ortools/sat/python/cp_model_test.py b/ortools/sat/python/cp_model_test.py index 382fdb426c..8de1996658 100644 --- a/ortools/sat/python/cp_model_test.py +++ b/ortools/sat/python/cp_model_test.py @@ -1879,6 +1879,7 @@ class CpModelTest(absltest.TestCase): model.maximize(x + 2 * y) solver = cp_model.CpSolver() + solver.parameters.num_workers = 1 status = solver.solve(model) self.assertEqual(cp_model.OPTIMAL, status) self.assertEqual(solver.num_booleans, 0) diff --git a/ortools/sat/routing_cuts_test.cc b/ortools/sat/routing_cuts_test.cc index 201d599d15..33bd8f328c 100644 --- a/ortools/sat/routing_cuts_test.cc +++ b/ortools/sat/routing_cuts_test.cc @@ -2255,9 +2255,9 @@ TEST(CreateStronglyConnectedGraphCutGeneratorTest, AnotherExample) { // However as an heuristic, we will wait another round to generate {1, 2, 3}. ASSERT_EQ(manager.num_cuts(), 2); EXPECT_THAT(manager.AllConstraints().front().constraint.DebugString(), - ::testing::StartsWith("1 <= 1*X3 1*X6")); + ::testing::StartsWith("1 <= 1*I3 1*I6")); EXPECT_THAT(manager.AllConstraints().back().constraint.DebugString(), - ::testing::StartsWith("1 <= 1*X1 1*X3")); + ::testing::StartsWith("1 <= 1*I1 1*I3")); } TEST(GenerateInterestingSubsetsTest, BasicExample) { @@ -2333,9 +2333,9 @@ TEST(CreateFlowCutGeneratorTest, BasicExample) { // The sets {2} and {3} will generate incoming flow cuts. EXPECT_EQ(manager.num_cuts(), 2); EXPECT_THAT(manager.AllConstraints().front().constraint.DebugString(), - ::testing::StartsWith("1 <= 1*X2")); + ::testing::StartsWith("1 <= 1*I2")); EXPECT_THAT(manager.AllConstraints().back().constraint.DebugString(), - ::testing::StartsWith("1 <= 1*X1 1*X3")); + ::testing::StartsWith("1 <= 1*I1 1*I3")); } TEST(CreateFlowCutGeneratorTest, WithMinusOneArcs) { @@ -2377,7 +2377,7 @@ TEST(CreateFlowCutGeneratorTest, WithMinusOneArcs) { // We artificially put bad LP values so that {1} generate outgoing flow cut. EXPECT_EQ(manager.num_cuts(), 1); EXPECT_THAT(manager.AllConstraints().front().constraint.DebugString(), - ::testing::StartsWith("1 <= 1*X1 1*X2")); + ::testing::StartsWith("1 <= 1*I1 1*I2")); } TEST(CreateCVRPCutGeneratorTest, InfeasiblePathCuts) { @@ -2451,7 +2451,7 @@ TEST(CreateCVRPCutGeneratorTest, InfeasiblePathCuts) { // Arcs with ID 2 (1->2) and ID 4 (2->3) should be in the cut. EXPECT_THAT(manager.AllConstraints().back().constraint.DebugString(), - ::testing::StartsWith("0 <= 1*X2 1*X4 <= 1")); + ::testing::StartsWith("0 <= 1*I2 1*I4 <= 1")); } } // namespace diff --git a/ortools/sat/sat_base.h b/ortools/sat/sat_base.h index ea1f48d31c..bb33404f73 100644 --- a/ortools/sat/sat_base.h +++ b/ortools/sat/sat_base.h @@ -318,6 +318,12 @@ class Trail { SetCurrentPropagatorId(propagator_id); FastEnqueue(true_literal); } + void EnqueueAtLevel(Literal true_literal, int propagator_id, int level) { + Enqueue(true_literal, propagator_id); + if (use_chronological_backtracking_) { + info_[true_literal.Variable()].level = level; + } + } // Specific Enqueue() version for the search decision. void EnqueueSearchDecision(Literal true_literal) { @@ -326,7 +332,7 @@ class Trail { // Specific Enqueue() version for a fixed variable. void EnqueueWithUnitReason(Literal true_literal) { - Enqueue(true_literal, AssignmentType::kUnitReason); + EnqueueAtLevel(true_literal, AssignmentType::kUnitReason, 0); } // Some constraints propagate a lot of literals at once. In these cases, it is @@ -336,6 +342,9 @@ class Trail { BooleanVariable reference_var) { reference_var_with_same_reason_as_[true_literal.Variable()] = reference_var; Enqueue(true_literal, AssignmentType::kSameReasonAs); + if (ChronologicalBacktrackingEnabled()) { + info_[true_literal.Variable()].level = Info(reference_var).level; + } } // Enqueues the given literal using the current content of @@ -357,6 +366,14 @@ class Trail { reasons_[var] = reasons_repository_[info_[var].trail_index]; old_type_[var] = info_[var].type; info_[var].type = AssignmentType::kCachedReason; + DCHECK_EQ(old_type_[var], AssignmentType::kCachedReason); + if (ChronologicalBacktrackingEnabled()) { + uint32_t level = 0; + for (const Literal literal : reasons_[var]) { + level = std::max(level, Info(literal.Variable()).level); + } + info_[var].level = level; + } return true; } @@ -420,6 +437,9 @@ class Trail { assignment_.UnassignLiteral(trail_[i]); } current_info_.trail_index = target_trail_index; + if (use_chronological_backtracking_) { + ReimplyAll(index); + } } // Changes the decision level used by the next Enqueue(). @@ -466,6 +486,8 @@ class Trail { return info_[var]; } + int AssignmentLevel(Literal lit) const { return Info(lit.Variable()).level; } + // Print the current literals on the trail. std::string DebugString() const { std::string result; @@ -481,7 +503,23 @@ class Trail { debug_checker_ = std::move(checker); } + bool ChronologicalBacktrackingEnabled() const { + return use_chronological_backtracking_; + } + + void EnableChronologicalBacktracking(bool enable) { + CHECK_EQ(CurrentDecisionLevel(), 0); + use_chronological_backtracking_ = enable; + } + private: + // Finds all literals between the current trail index and the given one + // assigned at the current level or lower, and re-enqueues them with the same + // reason. + void ReimplyAll(int old_trail_index); + + bool use_chronological_backtracking_ = false; + int64_t num_reimplied_literals_ = 0; int64_t num_untrailed_enqueues_ = 0; AssignmentInfo current_info_; VariablesAssignment assignment_; @@ -568,6 +606,16 @@ class SatPropagator { propagation_trail_index_ = std::min(propagation_trail_index_, trail_index); } + // Called if the implication at `old_trail_index` remains true after + // backtracking. If this propagator supports reimplication it should call + // `trail->EnqueueAtLevel`. + // This will be called after Untrail() when backtracking. + virtual void Reimply(Trail* /*trail*/, int /*old_trail_index*/) { + // It is inefficient and unexpected to call this on a propagator that + // doesn't support reimplication. + LOG(DFATAL) << "Reimply not implemented for " << name_ << "."; + } + // Explains why the literal at given trail_index was propagated by returning a // reason for this propagation. This will only be called for literals that are // on the trail and were propagated by this class. @@ -623,7 +671,7 @@ inline bool SatPropagator::PropagatePreconditionsAreSatisfied( return false; } if (propagation_trail_index_ < trail.Index() && - trail.Info(trail[propagation_trail_index_].Variable()).level != + trail.Info(trail[propagation_trail_index_].Variable()).level > trail.CurrentDecisionLevel()) { LOG(INFO) << "Issue in '" << name_ << "':" << " propagation_trail_index_=" << propagation_trail_index_ @@ -715,6 +763,39 @@ inline absl::Span Trail::Reason(BooleanVariable var, return reasons_[var]; } +inline void Trail::ReimplyAll(int old_trail_index) { + for (int i = Index(); i < old_trail_index; ++i) { + const Literal literal = trail_[i]; + const AssignmentInfo& info = Info(literal.Variable()); + if (info.level > current_info_.level) continue; + CHECK_LE(Index(), i); + CHECK(!Assignment().VariableIsAssigned(literal.Variable())); + if (info.type == AssignmentType::kSameReasonAs) { + // The reference variable must already be re-implied at this level, so we + // can just re-enqueue it without having to tell the propagator. + DCHECK_EQ(Info(ReferenceVarWithSameReason(literal.Variable())).level, + info.level); + DCHECK_LT( + Info(ReferenceVarWithSameReason(literal.Variable())).trail_index, + Index()); + EnqueueAtLevel(literal, AssignmentType::kSameReasonAs, info.level); + } else { + const int original_type = AssignmentType(literal.Variable()); + if (original_type >= AssignmentType::kFirstFreePropagationId) { + propagators_[original_type]->Reimply(this, i); + } else if (original_type == AssignmentType::kCachedReason) { + std::swap(reasons_repository_[Index()], reasons_repository_[i]); + reasons_[literal.Variable()] = reasons_repository_[Index()]; + EnqueueAtLevel(literal, original_type, info.level); + } else if (info.type == AssignmentType::kUnitReason || info.level == 0) { + CHECK(!Assignment().LiteralIsFalse(literal)); + EnqueueAtLevel(literal, AssignmentType::kUnitReason, info.level); + } + } + num_reimplied_literals_ += assignment_.LiteralIsTrue(literal); + } +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/sat_parameters.proto b/ortools/sat/sat_parameters.proto index 7e441b53cb..c077828c33 100644 --- a/ortools/sat/sat_parameters.proto +++ b/ortools/sat/sat_parameters.proto @@ -24,7 +24,7 @@ option java_multiple_files = true; // Contains the definitions for all the sat algorithm parameters and their // default values. // -// NEXT TAG: 330 +// NEXT TAG: 333 message SatParameters { // In some context, like in a portfolio of search, it makes sense to name a // given parameters set for logging purpose. @@ -140,6 +140,21 @@ message SatParameters { // from the problem. optional bool subsumption_during_conflict_analysis = 56 [default = true]; + // If true, try to backtrack as little as possible on conflict and re-imply + // the clauses later. + // This means we discard less propagation than traditional backjumping, but + // requites additional bookkeeping to handle reimplication. + // See: https://doi.org/10.1007/978-3-319-94144-8_7 + optional bool use_chronological_backtracking = 330 [default = false]; + + // If chronological backtracking is enabled, this is the maximum number of + // levels we will backjump over, otherwise we will backtrack. + optional int32 max_backjump_levels = 331 [default = 50]; + + // If chronological backtracking is enabled, this is the minimum number of + // conflicts before we will consider backjumping. + optional int32 chronological_backtrack_min_conflicts = 332 [default = 1000]; + // ========================================================================== // Clause database management // ========================================================================== diff --git a/ortools/sat/sat_runner.cc b/ortools/sat/sat_runner.cc index c1dceb038b..25f0484f42 100644 --- a/ortools/sat/sat_runner.cc +++ b/ortools/sat/sat_runner.cc @@ -111,7 +111,7 @@ class LastSolutionPrinter { public: // Note that is prints the solution in the PB competition format. void MaybePrintLastSolution() { - absl::MutexLock lock(&mutex_); + absl::MutexLock lock(mutex_); if (last_solution_printed_) return; last_solution_printed_ = true; @@ -140,7 +140,7 @@ class LastSolutionPrinter { void set_num_variables(int num_variables) { num_variables_ = num_variables; } void set_last_solution(absl::Span solution) { - absl::MutexLock lock(&mutex_); + absl::MutexLock lock(mutex_); if (last_solution_printed_) return; last_solution_.assign(solution.begin(), solution.end()); } @@ -148,7 +148,7 @@ class LastSolutionPrinter { // Returns false if the solution has already been printed, else mark it as // printed by caller code. bool mark_last_solution_printed() { - const absl::MutexLock lock(&mutex_); + const absl::MutexLock lock(mutex_); if (last_solution_printed_) { return false; } diff --git a/ortools/sat/sat_solver.cc b/ortools/sat/sat_solver.cc index efa9f9e614..2dc7408443 100644 --- a/ortools/sat/sat_solver.cc +++ b/ortools/sat/sat_solver.cc @@ -22,6 +22,7 @@ #include #include +#include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" @@ -80,6 +81,8 @@ SatSolver::SatSolver(Model* model) is_relevant_for_core_computation_(true), drat_proof_handler_(nullptr), stats_("SatSolver") { + trail_->EnableChronologicalBacktracking( + parameters_->use_chronological_backtracking()); InitializePropagators(); } @@ -447,9 +450,11 @@ int SatSolver::AddLearnedClauseAndEnqueueUnitPropagation( SCOPED_TIME_STAT(&stats_); if (literals.size() == 1) { - // A length 1 clause fix a literal for all the search. - // ComputeBacktrackLevel() should have returned 0. - CHECK_EQ(CurrentDecisionLevel(), 0); + if (!trail_->ChronologicalBacktrackingEnabled()) { + // A length 1 clause fix a literal for all the search. + // ComputeBacktrackLevel() should have returned 0. + CHECK_EQ(CurrentDecisionLevel(), 0); + } trail_->EnqueueWithUnitReason(literals[0]); return /*lbd=*/1; } @@ -731,7 +736,6 @@ void SatSolver::ProcessCurrentConflict() { ++counters_.num_failures; const int conflict_trail_index = trail_->Index(); - const int conflict_decision_level = current_decision_level_; // A conflict occurred, compute a nice reason for this failure. same_reason_identifier_.Clear(); @@ -743,9 +747,11 @@ void SatSolver::ProcessCurrentConflict() { // // TODO(user): We might still want to "learn" the clause, especially if // it reduces to only one literal in which case we can just fix it. - const int highest_level = - DecisionLevel((*trail_)[max_trail_index].Variable()); - if (highest_level == 1) return; + const bool all_literals_at_assumption_level = + absl::c_all_of(trail_->FailingClause(), [&](Literal l) { + return trail_->Info(l.Variable()).level <= assumption_level_; + }); + if (all_literals_at_assumption_level) return; } ComputeFirstUIPConflict(max_trail_index, &learned_conflict_, @@ -864,7 +870,7 @@ void SatSolver::ProcessCurrentConflict() { // Continue with the normal clause flow, but use the PB conflict clause // if it has a lower backjump level. - if (pb_backjump_level < ComputeBacktrackLevel(learned_conflict_)) { + if (pb_backjump_level < ComputePropagationLevel(learned_conflict_)) { subsumed_clauses_.clear(); // Because the conflict changes. learned_conflict_.clear(); is_marked_.ClearAndResize(num_variables_); @@ -951,7 +957,17 @@ void SatSolver::ProcessCurrentConflict() { // Backtrack and add the reason to the set of learned clause. counters_.num_literals_learned += learned_conflict_.size(); - Backtrack(ComputeBacktrackLevel(learned_conflict_)); + const int conflict_level = + trail_->Info(learned_conflict_[0].Variable()).level; + const int backjump_levels = CurrentDecisionLevel() - conflict_level; + const bool should_backjump = + !trail_->ChronologicalBacktrackingEnabled() || + (num_failures() > parameters_->chronological_backtrack_min_conflicts() && + backjump_levels > parameters_->max_backjump_levels()); + const int backtrack_level = should_backjump + ? ComputePropagationLevel(learned_conflict_) + : std::max(0, conflict_level - 1); + Backtrack(backtrack_level); DCHECK(ClauseIsValidUnderDebugAssignment(learned_conflict_)); // Note that we need to output the learned clause before cleaning the clause @@ -991,8 +1007,7 @@ void SatSolver::ProcessCurrentConflict() { // Create and attach the new learned clause. const int conflict_lbd = AddLearnedClauseAndEnqueueUnitPropagation( learned_conflict_, is_redundant); - restart_->OnConflict(conflict_trail_index, conflict_decision_level, - conflict_lbd); + restart_->OnConflict(conflict_trail_index, conflict_level, conflict_lbd); } SatSolver::Status SatSolver::ReapplyDecisionsUpTo( @@ -1643,7 +1658,7 @@ std::vector SatSolver::GetDecisionsFixing( // Marks all the literals of its reason. for (const Literal literal : trail_->Reason(marked_literal.Variable())) { const BooleanVariable var = literal.Variable(); - const int level = DecisionLevel(var); + const int level = AssignmentLevel(var); if (level > 0 && !is_marked_[var]) is_marked_.Set(var); } } @@ -1659,7 +1674,7 @@ void SatSolver::BumpReasonActivities(absl::Span literals) { SCOPED_TIME_STAT(&stats_); for (const Literal literal : literals) { const BooleanVariable var = literal.Variable(); - if (DecisionLevel(var) > 0) { + if (AssignmentLevel(var) > 0) { SatClause* clause = ReasonClauseOrNull(var); if (clause != nullptr) { BumpClauseActivity(clause); @@ -1734,15 +1749,15 @@ void SatSolver::UpdateClauseActivityIncrement() { bool SatSolver::IsConflictValid(absl::Span literals) { SCOPED_TIME_STAT(&stats_); if (literals.empty()) return false; - const int highest_level = DecisionLevel(literals[0].Variable()); + const int highest_level = AssignmentLevel(literals[0].Variable()); for (int i = 1; i < literals.size(); ++i) { - const int level = DecisionLevel(literals[i].Variable()); + const int level = AssignmentLevel(literals[i].Variable()); if (level <= 0 || level >= highest_level) return false; } return true; } -int SatSolver::ComputeBacktrackLevel(absl::Span literals) { +int SatSolver::ComputePropagationLevel(absl::Span literals) { SCOPED_TIME_STAT(&stats_); DCHECK_GT(CurrentDecisionLevel(), 0); @@ -1755,14 +1770,14 @@ int SatSolver::ComputeBacktrackLevel(absl::Span literals) { // AddLearnedClauseAndEnqueueUnitPropagation() to fix the literal and not // backtrack over it. Also, subsequent propagated variables may not have a // correct level in this case. - int backtrack_level = 0; + int propagation_level = 0; for (int i = 1; i < literals.size(); ++i) { - const int level = DecisionLevel(literals[i].Variable()); - backtrack_level = std::max(backtrack_level, level); + const int level = AssignmentLevel(literals[i].Variable()); + propagation_level = std::max(propagation_level, level); } - DCHECK_LT(backtrack_level, DecisionLevel(literals[0].Variable())); - DCHECK_LE(DecisionLevel(literals[0].Variable()), CurrentDecisionLevel()); - return backtrack_level; + DCHECK_LT(propagation_level, AssignmentLevel(literals[0].Variable())); + DCHECK_LE(AssignmentLevel(literals[0].Variable()), CurrentDecisionLevel()); + return propagation_level; } template @@ -1770,12 +1785,17 @@ int SatSolver::ComputeLbd(const LiteralList& literals) { SCOPED_TIME_STAT(&stats_); const int limit = parameters_->count_assumption_levels_in_lbd() ? 0 : assumption_level_; + int max_level = AssignmentLevel(literals.begin()->Variable()); + if (trail_->ChronologicalBacktrackingEnabled()) { + for (const Literal literal : literals) { + max_level = std::max(max_level, AssignmentLevel(literal.Variable())); + } + } // We know that the first literal is always of the highest level. - is_level_marked_.ClearAndResize( - SatDecisionLevel(DecisionLevel(literals.begin()->Variable()) + 1)); + is_level_marked_.ClearAndResize(SatDecisionLevel(max_level + 1)); for (const Literal literal : literals) { - const SatDecisionLevel level(DecisionLevel(literal.Variable())); + const SatDecisionLevel level(AssignmentLevel(literal.Variable())); DCHECK_GE(level, 0); if (level > limit && !is_level_marked_[level]) { is_level_marked_.Set(level); @@ -2146,12 +2166,20 @@ void SatSolver::ComputeFirstUIPConflict( subsumed_clauses->clear(); if (max_trail_index == -1) return; + absl::Span clause_to_expand = trail_->FailingClause(); + // max_trail_index is the maximum trail index appearing in the failing_clause // and its level (Which is almost always equals to the CurrentDecisionLevel(), // except for symmetry propagation). DCHECK_EQ(max_trail_index, ComputeMaxTrailIndex(trail_->FailingClause())); int trail_index = max_trail_index; - const int highest_level = DecisionLevel((*trail_)[trail_index].Variable()); + int highest_level = trail_->Info((*trail_)[max_trail_index].Variable()).level; + if (trail_->ChronologicalBacktrackingEnabled()) { + for (const Literal literal : clause_to_expand) { + highest_level = + std::max(highest_level, AssignmentLevel(literal.Variable())); + } + } if (highest_level == 0) return; // To find the 1-UIP conflict clause, we start by the failing_clause, and @@ -2169,7 +2197,6 @@ void SatSolver::ComputeFirstUIPConflict( // // This last literal will be the first UIP because by definition all the // propagation done at the current level will pass though it at some point. - absl::Span clause_to_expand = trail_->FailingClause(); SatClause* sat_clause = trail_->FailingSatClause(); DCHECK(!clause_to_expand.empty()); int num_literal_at_highest_level_that_needs_to_be_processed = 0; @@ -2178,8 +2205,9 @@ void SatSolver::ComputeFirstUIPConflict( int num_vars_at_positive_level_in_clause_to_expand = 0; for (const Literal literal : clause_to_expand) { const BooleanVariable var = literal.Variable(); - const int level = DecisionLevel(var); + const int level = AssignmentLevel(var); if (level == 0) continue; + DCHECK_LE(level, highest_level); ++num_vars_at_positive_level_in_clause_to_expand; if (!is_marked_[var]) { is_marked_.Set(var); @@ -2216,11 +2244,12 @@ void SatSolver::ComputeFirstUIPConflict( // Find next marked literal to expand from the trail. DCHECK_GT(num_literal_at_highest_level_that_needs_to_be_processed, 0); - while (!is_marked_[(*trail_)[trail_index].Variable()]) { + while ( + !is_marked_[(*trail_)[trail_index].Variable()] || + (trail_->ChronologicalBacktrackingEnabled() && + AssignmentLevel((*trail_)[trail_index].Variable()) < highest_level)) { --trail_index; DCHECK_GE(trail_index, 0); - DCHECK_EQ(DecisionLevel((*trail_)[trail_index].Variable()), - highest_level); } if (num_literal_at_highest_level_that_needs_to_be_processed == 1) { @@ -2311,7 +2340,7 @@ void SatSolver::ComputePBConflict(int max_trail_index, // So we can abort if the true assignment before that is at a lower level // TODO(user): Somewhat inefficient. // TODO(user): We could abort earlier... - const int current_level = DecisionLevel(var); + const int current_level = AssignmentLevel(var); int i = trail_index; while (i >= 0) { const BooleanVariable previous_var = (*trail_)[i].Variable(); @@ -2322,8 +2351,8 @@ void SatSolver::ComputePBConflict(int max_trail_index, } --i; } - if (i < 0 || DecisionLevel((*trail_)[i].Variable()) < current_level) { - backjump_level = i < 0 ? 0 : DecisionLevel((*trail_)[i].Variable()); + if (i < 0 || AssignmentLevel((*trail_)[i].Variable()) < current_level) { + backjump_level = i < 0 ? 0 : AssignmentLevel((*trail_)[i].Variable()); break; } @@ -2388,11 +2417,11 @@ void SatSolver::ComputePBConflict(int max_trail_index, max_sum += coeff; ++size; if (!trail_->Assignment().VariableIsAssigned(var) || - DecisionLevel(var) > backjump_level) { + AssignmentLevel(var) > backjump_level) { max_coeff_for_ge_level[backjump_level + 1] = std::max(max_coeff_for_ge_level[backjump_level + 1], coeff); } else { - const int level = DecisionLevel(var); + const int level = AssignmentLevel(var); if (trail_->Assignment().LiteralIsTrue(conflict->GetLiteral(var))) { sum_for_le_level[level] += coeff; } @@ -2470,13 +2499,13 @@ void SatSolver::MinimizeConflictSimple(std::vector* conflict) { for (int i = 1; i < conflict->size(); ++i) { const BooleanVariable var = (*conflict)[i].Variable(); bool can_be_removed = false; - if (DecisionLevel(var) != current_level) { + if (AssignmentLevel(var) != current_level) { // It is important not to call Reason(var) when it can be avoided. const absl::Span reason = trail_->Reason(var); if (!reason.empty()) { can_be_removed = true; for (Literal literal : reason) { - if (DecisionLevel(literal.Variable()) == 0) continue; + if (AssignmentLevel(literal.Variable()) == 0) continue; if (!is_marked_[literal.Variable()]) { can_be_removed = false; break; @@ -2532,7 +2561,7 @@ void SatSolver::MinimizeConflictRecursively(std::vector* conflict) { // implied if the 1-UIP literal is false, we can't just iterate on the // variables of the conflict here. for (BooleanVariable var : is_marked_.PositionsSetAtLeastOnce()) { - const int level = DecisionLevel(var); + const int level = AssignmentLevel(var); min_trail_index_per_level_[level] = std::min( min_trail_index_per_level_[level], trail_->Info(var).trail_index); } @@ -2563,7 +2592,7 @@ void SatSolver::MinimizeConflictRecursively(std::vector* conflict) { const int threshold = min_trail_index_per_level_.size() / 2; if (is_marked_.PositionsSetAtLeastOnce().size() < threshold) { for (BooleanVariable var : is_marked_.PositionsSetAtLeastOnce()) { - min_trail_index_per_level_[DecisionLevel(var)] = + min_trail_index_per_level_[AssignmentLevel(var)] = std::numeric_limits::max(); } } else { @@ -2657,7 +2686,8 @@ bool SatSolver::CanBeInferedFromConflictVariables(BooleanVariable variable) { bool abort_early = false; for (Literal literal : trail_->Reason(current_var)) { const BooleanVariable var = literal.Variable(); - DCHECK_NE(var, current_var); + DCHECK_NE(var, current_var) << trail_->Info(var).DebugString() + << " old: " << trail_->AssignmentType(var); const AssignmentInfo& info = trail_->Info(var); if (info.level == 0 || is_marked_[var]) continue; if (info.trail_index <= min_trail_index_per_level_[info.level] || @@ -2720,7 +2750,7 @@ void SatSolver::MinimizeConflictExperimental(std::vector* conflict) { for (Literal literal : *conflict) { const BooleanVariable var = literal.Variable(); is_marked_.Set(var); - const int level = DecisionLevel(var); + const int level = AssignmentLevel(var); if (level < current_level) { variables_sorted_by_level.push_back(WeightedVariable(var, level)); } @@ -2745,7 +2775,7 @@ void SatSolver::MinimizeConflictExperimental(std::vector* conflict) { const BooleanVariable reason_var = reason_literal.Variable(); // We ignore level 0 variables. - if (DecisionLevel(reason_var) == 0) continue; + if (AssignmentLevel(reason_var) == 0) continue; // We have a reason literal whose variable is not yet seen. // If there is more than one, break right away, we will not minimize the diff --git a/ortools/sat/sat_solver.h b/ortools/sat/sat_solver.h index 63bfab81b4..b8991d132c 100644 --- a/ortools/sat/sat_solver.h +++ b/ortools/sat/sat_solver.h @@ -516,6 +516,10 @@ class SatSolver { clauses_propagator_->EnsureNewClauseIndexInitialized(); } + void EnableChronologicalBacktracking(bool value) { + trail_->EnableChronologicalBacktracking(value); + } + private: // All Solve() functions end up calling this one. Status SolveInternal(TimeLimit* time_limit, int64_t max_number_of_conflicts); @@ -564,8 +568,8 @@ class SatSolver { // Sets model_is_unsat_ to true and return false. bool SetModelUnsat(); - // Returns the decision level of a given variable. - int DecisionLevel(BooleanVariable var) const { + // Returns the decision level at which the given variable is assigned. + int AssignmentLevel(BooleanVariable var) const { return trail_->Info(var).level; } @@ -701,9 +705,9 @@ class SatSolver { // - There is no literal with a decision level of zero. bool IsConflictValid(absl::Span literals); - // Given the learned clause after a conflict, this computes the correct - // backtrack level to call Backtrack() with. - int ComputeBacktrackLevel(absl::Span literals); + // Given the learned clause after a conflict, this computes the level at which + // the new clause will propagate. + int ComputePropagationLevel(absl::Span literals); // The LBD (Literal Blocks Distance) is the number of different decision // levels at which the literals of the clause were assigned. Note that we diff --git a/ortools/sat/scheduling_cuts.cc b/ortools/sat/scheduling_cuts.cc index 97381e44f2..dcf2dc8cef 100644 --- a/ortools/sat/scheduling_cuts.cc +++ b/ortools/sat/scheduling_cuts.cc @@ -210,8 +210,8 @@ struct EnergyEvent { std::string DebugString() const { return absl::StrCat( "EnergyEvent(start_min = ", start_min, ", start_max = ", start_max, - ", end_min = ", end_min, ", end_max = ", end_max, - ", demand = ", demand.DebugString(), ", energy = ", + ", end_min = ", end_min, ", end_max = ", end_max, ", demand = ", demand, + ", energy = ", decomposed_energy.empty() ? "{}" : absl::StrCat(decomposed_energy.size(), " terms"), @@ -1104,10 +1104,9 @@ std::string CompletionTimeEvent::DebugString() const { return absl::StrCat( "CompletionTimeEvent(task_index = ", task_index, ", start_min = ", start_min, ", start_max = ", start_max, - ", size_min = ", size_min, ", end = ", end.DebugString(), - ", lp_end = ", lp_end, ", size_min = ", size_min, - " demand_min = ", demand_min, ", demand_is_fixed = ", demand_is_fixed, - ", energy_min = ", energy_min, + ", size_min = ", size_min, ", end = ", end, ", lp_end = ", lp_end, + ", size_min = ", size_min, " demand_min = ", demand_min, + ", demand_is_fixed = ", demand_is_fixed, ", energy_min = ", energy_min, ", use_decomposed_energy_min = ", use_decomposed_energy_min, ", lifted = ", lifted, ", decomposed_energy = [", absl::StrJoin(decomposed_energy, ", ", diff --git a/ortools/sat/scheduling_cuts_test.cc b/ortools/sat/scheduling_cuts_test.cc index 8cc568b410..34f21cc3af 100644 --- a/ortools/sat/scheduling_cuts_test.cc +++ b/ortools/sat/scheduling_cuts_test.cc @@ -98,9 +98,9 @@ TEST(CumulativeEnergyCutGenerator, TestCutTimeTableGenerator) { cumulative.generate_cuts(manager); ASSERT_EQ(1, manager->num_cuts()); - // 3*X3 1*X7 -1*X9 <= 0 -> Normalized to 3*X3 1*X7 <= 10 + // 3*I3 1*I7 -1*I9 <= 0 -> Normalized to 3*I3 1*I7 <= 10 EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), - EndsWith("3*X3 1*X7 <= 10")); + EndsWith("3*I3 1*I7 <= 10")); } TEST(CumulativeEnergyCutGenerator, SameDemand) { @@ -169,23 +169,23 @@ TEST(CumulativeEnergyCutGenerator, SameDemand) { EXPECT_THAT( manager->AllConstraints()[LinearConstraintManager::ConstraintIndex(0)] .constraint.DebugString(), - EndsWith("1*X9 <= 5")); + EndsWith("1*I9 <= 5")); EXPECT_THAT( manager->AllConstraints()[LinearConstraintManager::ConstraintIndex(1)] .constraint.DebugString(), - EndsWith("1*X9 1*X10 <= 10")); + EndsWith("1*I9 1*I10 <= 10")); EXPECT_THAT( manager->AllConstraints()[LinearConstraintManager::ConstraintIndex(2)] .constraint.DebugString(), - EndsWith("3*X9 2*X10 <= 30")); + EndsWith("3*I9 2*I10 <= 30")); EXPECT_THAT( manager->AllConstraints()[LinearConstraintManager::ConstraintIndex(3)] .constraint.DebugString(), - EndsWith("5*X9 2*X10 <= 40")); + EndsWith("5*I9 2*I10 <= 40")); EXPECT_THAT( manager->AllConstraints()[LinearConstraintManager::ConstraintIndex(4)] .constraint.DebugString(), - EndsWith("2*X9 3*X10 <= 30")); + EndsWith("2*I9 3*I10 <= 30")); } TEST(CumulativeEnergyCutGenerator, SameDemandTimeTableGenerator) { @@ -241,12 +241,12 @@ TEST(CumulativeEnergyCutGenerator, SameDemandTimeTableGenerator) { cumulative.generate_cuts(manager); ASSERT_EQ(2, manager->num_cuts()); - // 1*X9 1*X9 <= X11 -> Normalized to 1*X9 <= 5 + // 1*I9 1*I9 <= I11 -> Normalized to 1*I9 <= 5 EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), - EndsWith("1*X9 <= 5")); - // 1*X9 1*X10 <= X11 -> Normalized to 1*X9 1*X10 <= 10 + EndsWith("1*I9 <= 5")); + // 1*I9 1*I10 <= I11 -> Normalized to 1*I9 1*I10 <= 10 EXPECT_THAT(manager->AllConstraints().back().constraint.DebugString(), - EndsWith("1*X9 1*X10 <= 10")); + EndsWith("1*I9 1*I10 <= 10")); } TEST(CumulativeEnergyCutGenerator, DetectedPrecedence) { @@ -288,7 +288,7 @@ TEST(CumulativeEnergyCutGenerator, DetectedPrecedence) { ASSERT_EQ(1, manager->num_cuts()); EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), - EndsWith("1*X0 -1*X1 <= -3")); + EndsWith("1*I0 -1*I1 <= -3")); } TEST(CumulativeEnergyCutGenerator, DetectedPrecedenceRev) { @@ -331,7 +331,7 @@ TEST(CumulativeEnergyCutGenerator, DetectedPrecedenceRev) { ASSERT_EQ(1, manager->num_cuts()); EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), - EndsWith("1*X0 -1*X1 <= -3")); + EndsWith("1*I0 -1*I1 <= -3")); } TEST(CumulativeEnergyCutGenerator, DisjunctionOnStart) { @@ -374,7 +374,7 @@ TEST(CumulativeEnergyCutGenerator, DisjunctionOnStart) { ASSERT_EQ(1, manager->num_cuts()); EXPECT_THAT(manager->AllConstraints().front().constraint.DebugString(), - StartsWith("15 <= 2*X0 5*X1")); + StartsWith("15 <= 2*I0 5*I1")); } TEST(ComputeMinSumOfEndMinsTest, CombinationOf3) { diff --git a/ortools/sat/scheduling_helpers_test.cc b/ortools/sat/scheduling_helpers_test.cc index e6245037d2..4fd7cf2d45 100644 --- a/ortools/sat/scheduling_helpers_test.cc +++ b/ortools/sat/scheduling_helpers_test.cc @@ -130,7 +130,7 @@ TEST(SchedulingDemandHelperTest, LinearizedDemandWithAffineExpression) { LinearConstraintBuilder builder(&model); ASSERT_TRUE(demands_helper.AddLinearizedDemand(0, &builder)); - EXPECT_EQ(builder.BuildExpression().DebugString(), "2*X3 + 5"); + EXPECT_EQ(builder.BuildExpression().DebugString(), "2*I3 + 5"); } TEST(SchedulingDemandHelperTest, LinearizedDemandWithDecomposedEnergy) { @@ -166,7 +166,7 @@ TEST(SchedulingDemandHelperTest, LinearizedDemandWithDecomposedEnergy) { demands_helper.CacheAllEnergyValues(); LinearConstraintBuilder builder(&model); ASSERT_TRUE(demands_helper.AddLinearizedDemand(0, &builder)); - EXPECT_EQ(builder.BuildExpression().DebugString(), "4*X4 2*X5"); + EXPECT_EQ(builder.BuildExpression().DebugString(), "4*I4 2*I5"); } TEST(SchedulingDemandHelperTest, FilteredDecomposedEnergy) { diff --git a/ortools/sat/shaving_solver.cc b/ortools/sat/shaving_solver.cc index 488ce0228e..2e03b59280 100644 --- a/ortools/sat/shaving_solver.cc +++ b/ortools/sat/shaving_solver.cc @@ -63,13 +63,13 @@ bool ObjectiveShavingSolver::TaskIsAvailable() { if (shared_->SearchIsDone()) return false; // We only support one task at the time. - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); return !task_in_flight_; } std::function ObjectiveShavingSolver::GenerateTask(int64_t task_id) { { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); stop_current_chunk_.store(false); task_in_flight_ = true; objective_lb_ = shared_->response->GetInnerObjectiveLowerBound(); @@ -92,13 +92,13 @@ std::function ObjectiveShavingSolver::GenerateTask(int64_t task_id) { } shared_->response->NewSolution(solution_values, Info()); } else if (local_response.status() == CpSolverStatus::INFEASIBLE) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); shared_->response->UpdateInnerObjectiveBounds( Info(), current_objective_target_ub_ + 1, objective_ub_); } } - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); task_in_flight_ = false; if (local_sat_model_ != nullptr) { const double dtime = local_sat_model_->GetOrCreate() @@ -110,7 +110,7 @@ std::function ObjectiveShavingSolver::GenerateTask(int64_t task_id) { } void ObjectiveShavingSolver::Synchronize() { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); if (!task_in_flight_) return; // We are just waiting for the inner code to check the time limit or @@ -172,7 +172,7 @@ bool ObjectiveShavingSolver::ResetAndSolveModel(int64_t task_id) { IntegerValue objective_lb; IntegerValue chosen_objective_ub; { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); objective_lb = objective_lb_; if (objective_ub_ - objective_lb <= local_params_.shaving_search_threshold()) { @@ -240,7 +240,7 @@ bool ObjectiveShavingSolver::ResetAndSolveModel(int64_t task_id) { const CpSolverStatus presolve_status = PresolveCpModel(context.get(), &postsolve_mapping_); if (presolve_status == CpSolverStatus::INFEASIBLE) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); shared_->response->UpdateInnerObjectiveBounds( Info(), chosen_objective_ub + 1, kMaxIntegerValue); return false; @@ -272,7 +272,7 @@ VariablesShavingSolver::VariablesShavingSolver( shared_bounds_id_ = shared_->bounds->RegisterNewId(); } - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); for (const IntegerVariableProto& var_proto : model_proto_.variables()) { var_domains_.push_back(ReadDomainFromProto(var_proto)); } @@ -284,7 +284,7 @@ VariablesShavingSolver::~VariablesShavingSolver() { if (!VLOG_IS_ON(1)) return; if (shared_ == nullptr || shared_->stats == nullptr) return; std::vector> stats; - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); stats.push_back({"variable_shaving/num_vars_tried", num_vars_tried_}); stats.push_back({"variable_shaving/num_vars_shaved", num_vars_shaved_}); stats.push_back( @@ -303,7 +303,7 @@ void VariablesShavingSolver::ProcessLocalResponse( if (local_response.status() == CpSolverStatus::INFEASIBLE) return; const int64_t obj_lb = local_response.inner_objective_lower_bound(); - absl::MutexLock lock(&mutex_); + absl::MutexLock lock(mutex_); const Domain domain = var_domains_[state.var_index]; if (state.minimize) { const int64_t lb = obj_lb; @@ -328,7 +328,7 @@ void VariablesShavingSolver::ProcessLocalResponse( } if (local_response.status() != CpSolverStatus::INFEASIBLE) return; - absl::MutexLock lock(&mutex_); + absl::MutexLock lock(mutex_); const Domain domain = var_domains_[state.var_index]; Domain new_domain = domain; ++num_infeasible_found_; @@ -362,7 +362,7 @@ std::function VariablesShavingSolver::GenerateTask(int64_t task_id) { ProcessLocalResponse(local_response, state); } - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); const double dtime = local_sat_model.GetOrCreate()->GetElapsedDeterministicTime(); AddTaskDeterministicDuration(dtime); @@ -371,7 +371,7 @@ std::function VariablesShavingSolver::GenerateTask(int64_t task_id) { } void VariablesShavingSolver::Synchronize() { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); // We are just waiting for the inner code to check the time limit or // to return nicely. if (stop_current_chunk_) return; @@ -624,7 +624,7 @@ bool VariablesShavingSolver::ResetAndSolveModel(int64_t task_id, State* state, bool has_no_overlap_2d = false; { - absl::MutexLock lock(&mutex_); + absl::MutexLock lock(mutex_); if (!FindNextVar(state)) return false; CopyModelConnectedToVar(state, local_model, shaving_proto, &has_no_overlap_2d); diff --git a/ortools/sat/stat_tables.cc b/ortools/sat/stat_tables.cc index 3d0c5b8f1a..d9b0e427b3 100644 --- a/ortools/sat/stat_tables.cc +++ b/ortools/sat/stat_tables.cc @@ -37,7 +37,7 @@ namespace operations_research::sat { SharedStatTables::SharedStatTables() { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); timing_table_.push_back( {"Task timing", "n [ min, max] avg dev time", @@ -73,13 +73,13 @@ SharedStatTables::SharedStatTables() { } void SharedStatTables::AddTimingStat(const SubSolver& subsolver) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); timing_table_.push_back({FormatName(subsolver.name()), subsolver.TimingInfo(), subsolver.DeterministicTimingInfo()}); } void SharedStatTables::AddSearchStat(absl::string_view name, Model* model) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); CpSolverResponse r; model->GetOrCreate()->FillSolveStatsInResponse(model, &r); @@ -92,7 +92,7 @@ void SharedStatTables::AddSearchStat(absl::string_view name, Model* model) { } void SharedStatTables::AddClausesStat(absl::string_view name, Model* model) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); SatSolver::Counters counters = model->GetOrCreate()->counters(); clauses_table_.push_back( {FormatName(name), FormatCounter(counters.num_minimizations), @@ -109,7 +109,7 @@ void SharedStatTables::AddClausesStat(absl::string_view name, Model* model) { } void SharedStatTables::AddLpStat(absl::string_view name, Model* model) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); // Sum per component for the lp_table. int64_t num_compo = 0; @@ -233,7 +233,7 @@ void SharedStatTables::AddLnsStat(absl::string_view name, int64_t num_improving_calls, double difficulty, double deterministic_limit) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); const double fully_solved_proportion = static_cast(num_fully_solved_calls) / static_cast(std::max(int64_t{1}, num_calls)); @@ -251,7 +251,7 @@ void SharedStatTables::AddLsStat(absl::string_view name, int64_t num_batches, int64_t num_backtracks, int64_t num_weight_updates, int64_t num_scores_computed) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); ls_table_.push_back( {FormatName(name), FormatCounter(num_batches), FormatCounter(num_restarts), FormatCounter(num_linear_moves), @@ -263,7 +263,7 @@ void SharedStatTables::AddLsStat(absl::string_view name, int64_t num_batches, void SharedStatTables::Display(SolverLogger* logger) { if (!logger->LoggingIsEnabled()) return; - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); if (timing_table_.size() > 1) SOLVER_LOG(logger, FormatTable(timing_table_)); if (search_table_.size() > 1) SOLVER_LOG(logger, FormatTable(search_table_)); if (clauses_table_.size() > 1) { diff --git a/ortools/sat/subsolver.cc b/ortools/sat/subsolver.cc index 266cf97785..237515be0e 100644 --- a/ortools/sat/subsolver.cc +++ b/ortools/sat/subsolver.cc @@ -235,7 +235,7 @@ void NonDeterministicLoop(std::vector>& subsolvers, // TODO(user): We could also directly register callback to set stopping // Boolean to false in a few places. if (!condition) { - mutex.Unlock(); + mutex.unlock(); SynchronizeAll(subsolvers); continue; } @@ -243,7 +243,7 @@ void NonDeterministicLoop(std::vector>& subsolvers, // The stopping condition is that we do not have anything else to generate // once all the task are done and synchronized. if (num_in_flight == 0) all_done = true; - mutex.Unlock(); + mutex.unlock(); } SynchronizeAll(subsolvers); @@ -251,7 +251,7 @@ void NonDeterministicLoop(std::vector>& subsolvers, { // We need to do that while holding the lock since substask below might // be currently updating the time via AddTaskDuration(). - const absl::MutexLock mutex_lock(&mutex); + const absl::MutexLock mutex_lock(mutex); ClearSubsolversThatAreDone(num_in_flight_per_subsolvers, subsolvers); best = NextSubsolverToSchedule(subsolvers, /*deterministic=*/false); if (VLOG_IS_ON(1) && time_limit->LimitReached()) { @@ -283,7 +283,7 @@ void NonDeterministicLoop(std::vector>& subsolvers, // Schedule next task. subsolvers[best]->NotifySelection(); { - absl::MutexLock mutex_lock(&mutex); + absl::MutexLock mutex_lock(mutex); num_in_flight++; num_in_flight_per_subsolvers[best]++; } @@ -295,7 +295,7 @@ void NonDeterministicLoop(std::vector>& subsolvers, timer.Start(); task(); - const absl::MutexLock mutex_lock(&mutex); + const absl::MutexLock mutex_lock(mutex); DCHECK(subsolvers[best] != nullptr); DCHECK_GT(num_in_flight_per_subsolvers[best], 0); num_in_flight_per_subsolvers[best]--; diff --git a/ortools/sat/synchronization.cc b/ortools/sat/synchronization.cc index 10f6c64bb7..a1ed925bb3 100644 --- a/ortools/sat/synchronization.cc +++ b/ortools/sat/synchronization.cc @@ -95,7 +95,7 @@ SharedSolutionPool::Add(SharedSolutionRepository::Solution solution) { void SharedSolutionPool::Synchronize(absl::BitGenRef random) { // Update the "seeds" for the aternative path. if (alternative_path_.num_solutions_to_keep() > 0) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); auto process_solution = [this](const SharedSolutionRepository::Solution& solution) @@ -187,7 +187,7 @@ void SharedSolutionPool::Synchronize(absl::BitGenRef random) { // If we restarted, or we are at the beginning, pick a seed for the path. if (alternative_path_.NumSolutions() == 0) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); // Pick random bucket with bias. If the bucket is empty, we will scan // "worse" bucket until we find a solution. We never pick bucket 0. @@ -226,7 +226,7 @@ void SharedLPSolutionRepository::NewLPSolution( // We always prefer to keep the solution from the last synchronize batch. { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); solution->rank = -num_synchronization_; ++num_added_; new_solutions_.push_back(solution); @@ -235,19 +235,19 @@ void SharedLPSolutionRepository::NewLPSolution( void SharedIncompleteSolutionManager::AddSolution( const std::vector& lp_solution) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); ++num_added_; solutions_.push_back(lp_solution); if (solutions_.size() > 100) solutions_.pop_front(); } bool SharedIncompleteSolutionManager::HasSolution() const { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); return !solutions_.empty(); } std::vector SharedIncompleteSolutionManager::PopLast() { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); if (solutions_.empty()) return {}; ++num_queried_; @@ -292,7 +292,7 @@ std::string SatProgressMessage(absl::string_view event_or_solution_count, void SharedResponseManager::FillSolveStatsInResponse( Model* model, CpSolverResponse* response) { if (model == nullptr) return; - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); for (const auto& set_stats : statistics_postprocessors_) { set_stats(model, response); } @@ -300,14 +300,14 @@ void SharedResponseManager::FillSolveStatsInResponse( void SharedResponseManager::LogMessage(absl::string_view prefix, absl::string_view message) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); SOLVER_LOG(logger_, absl::StrFormat("#%-5s %6.2fs %s", prefix, wall_timer_.Get(), message)); } void SharedResponseManager::LogMessageWithThrottling( absl::string_view prefix, absl::string_view message) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); int id; auto it = throttling_ids_.find(prefix); @@ -321,7 +321,7 @@ void SharedResponseManager::LogMessageWithThrottling( } bool SharedResponseManager::LoggingIsEnabled() const { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); return logger_->LoggingIsEnabled(); } @@ -340,17 +340,17 @@ void SharedResponseManager::InitializeObjective(const CpModelProto& cp_model) { } void SharedResponseManager::SetSynchronizationMode(bool always_synchronize) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); always_synchronize_ = always_synchronize; } void SharedResponseManager::SetUpdateGapIntegralOnEachChange(bool set) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); update_integral_on_each_change_ = set; } void SharedResponseManager::UpdateGapIntegral() { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); UpdateGapIntegralInternal(); } @@ -380,7 +380,7 @@ void SharedResponseManager::UpdateGapIntegralInternal() { void SharedResponseManager::SetGapLimitsFromParameters( const SatParameters& parameters) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); if (objective_or_null_ == nullptr) return; absolute_gap_limit_ = parameters.absolute_gap_limit(); relative_gap_limit_ = parameters.relative_gap_limit(); @@ -427,7 +427,7 @@ void SharedResponseManager::TestGapLimitsIfNeeded() { void SharedResponseManager::UpdateInnerObjectiveBounds( const std::string& update_info, IntegerValue lb, IntegerValue ub) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); CHECK(objective_or_null_ != nullptr); // The problem is already solved! @@ -504,7 +504,7 @@ void SharedResponseManager::UpdateInnerObjectiveBounds( // UNKNOWN -> INFEASIBLE void SharedResponseManager::NotifyThatImprovingProblemIsInfeasible( absl::string_view worker_info) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); if (best_status_ == CpSolverStatus::FEASIBLE || best_status_ == CpSolverStatus::OPTIMAL) { // We also use this status to indicate that we enumerated all solutions to @@ -524,24 +524,24 @@ void SharedResponseManager::NotifyThatImprovingProblemIsInfeasible( } void SharedResponseManager::AddUnsatCore(const std::vector& core) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); unsat_cores_ = core; } IntegerValue SharedResponseManager::GetInnerObjectiveLowerBound() { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); return synchronized_inner_objective_lower_bound_; } IntegerValue SharedResponseManager::GetInnerObjectiveUpperBound() { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); return synchronized_inner_objective_upper_bound_; } void SharedResponseManager::Synchronize() { solution_pool_.Synchronize(*random_); - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); synchronized_inner_objective_lower_bound_ = IntegerValue(inner_objective_lower_bound_); synchronized_inner_objective_upper_bound_ = @@ -554,49 +554,49 @@ void SharedResponseManager::Synchronize() { } IntegerValue SharedResponseManager::BestSolutionInnerObjectiveValue() { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); return IntegerValue(best_solution_objective_value_); } double SharedResponseManager::GapIntegral() const { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); return gap_integral_; } void SharedResponseManager::AddSolutionPostprocessor( std::function*)> postprocessor) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); solution_postprocessors_.push_back(postprocessor); } void SharedResponseManager::AddResponsePostprocessor( std::function postprocessor) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); postprocessors_.push_back(postprocessor); } void SharedResponseManager::AddFinalResponsePostprocessor( std::function postprocessor) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); final_postprocessors_.push_back(postprocessor); } void SharedResponseManager::AddStatisticsPostprocessor( std::function postprocessor) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); statistics_postprocessors_.push_back(postprocessor); } int SharedResponseManager::AddSolutionCallback( std::function callback) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); const int id = next_callback_id_++; callbacks_.emplace_back(id, std::move(callback)); return id; } void SharedResponseManager::UnregisterCallback(int callback_id) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); for (int i = 0; i < callbacks_.size(); ++i) { if (callbacks_[i].first == callback_id) { callbacks_.erase(callbacks_.begin() + i); @@ -608,14 +608,14 @@ void SharedResponseManager::UnregisterCallback(int callback_id) { int SharedResponseManager::AddLogCallback( std::function callback) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); const int id = next_search_log_callback_id_++; search_log_callbacks_.emplace_back(id, std::move(callback)); return id; } void SharedResponseManager::UnregisterLogCallback(int callback_id) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); for (int i = 0; i < search_log_callbacks_.size(); ++i) { if (search_log_callbacks_[i].first == callback_id) { search_log_callbacks_.erase(search_log_callbacks_.begin() + i); @@ -627,14 +627,14 @@ void SharedResponseManager::UnregisterLogCallback(int callback_id) { int SharedResponseManager::AddBestBoundCallback( std::function callback) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); const int id = next_best_bound_callback_id_++; best_bound_callbacks_.emplace_back(id, std::move(callback)); return id; } void SharedResponseManager::UnregisterBestBoundCallback(int callback_id) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); for (int i = 0; i < best_bound_callbacks_.size(); ++i) { if (best_bound_callbacks_[i].first == callback_id) { best_bound_callbacks_.erase(best_bound_callbacks_.begin() + i); @@ -693,7 +693,7 @@ CpSolverResponse SharedResponseManager::GetResponseInternal( } CpSolverResponse SharedResponseManager::GetResponse() { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); CpSolverResponse result; if (solution_pool_.BestSolutions().NumSolutions() == 0) { result = GetResponseInternal({}, ""); @@ -728,7 +728,7 @@ CpSolverResponse SharedResponseManager::GetResponse() { void SharedResponseManager::AppendResponseToBeMerged( const CpSolverResponse& response) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); return subsolver_responses_.push_back(response); } @@ -768,7 +768,7 @@ std::shared_ptr::Solution> SharedResponseManager::NewSolution(absl::Span solution_values, absl::string_view solution_info, Model* model, int source_id) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); std::shared_ptr::Solution> ret; // For SAT problems, we add the solution to the solution pool for retrieval @@ -912,7 +912,7 @@ SharedResponseManager::NewSolution(absl::Span solution_values, } bool SharedResponseManager::ProblemIsSolved() const { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); return synchronized_best_status_ == CpSolverStatus::OPTIMAL || synchronized_best_status_ == CpSolverStatus::INFEASIBLE; } @@ -953,7 +953,7 @@ void SharedResponseManager::RegisterObjectiveBoundImprovement( } void SharedResponseManager::DisplayImprovementStatistics() { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); if (!primal_improvements_count_.empty()) { std::vector> table; table.push_back( @@ -1043,7 +1043,7 @@ void SharedBoundsManager::ReportPotentialNewBounds( int num_improvements = 0; int num_symmetric_improvements = 0; - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); for (int i = 0; i < variables.size(); ++i) { int var = variables[i]; if (var >= num_variables_) continue; @@ -1119,7 +1119,7 @@ void SharedBoundsManager::FixVariablesFromPartialSolution( absl::Span variables_to_fix) { // This function shouldn't be called if we has symmetry. CHECK(!has_symmetry_); - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); // Abort if incompatible. Note that we only check the position that we are // about to fix. This should be enough. Otherwise we might never accept any @@ -1163,7 +1163,7 @@ void SharedBoundsManager::FixVariablesFromPartialSolution( } void SharedBoundsManager::Synchronize() { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); for (const int var : changed_variables_since_last_synchronize_.PositionsSetAtLeastOnce()) { DCHECK(!has_symmetry_ || var_to_representative_[var] == var); @@ -1177,7 +1177,7 @@ void SharedBoundsManager::Synchronize() { } int SharedBoundsManager::RegisterNewId() { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); const int id = id_to_changed_variables_.size(); id_to_changed_variables_.resize(id + 1); id_to_changed_variables_[id].ClearAndResize(num_variables_); @@ -1202,7 +1202,7 @@ void SharedBoundsManager::GetChangedBounds( new_upper_bounds->clear(); { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); for (const int var : id_to_changed_variables_[id].PositionsSetAtLeastOnce()) { DCHECK(!has_symmetry_ || var_to_representative_[var] == var); @@ -1245,7 +1245,7 @@ void SharedBoundsManager::GetChangedBounds( } void SharedBoundsManager::UpdateDomains(std::vector* domains) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); CHECK_EQ(domains->size(), synchronized_lower_bounds_.size()); for (int var = 0; var < domains->size(); ++var) { (*domains)[var] = (*domains)[var].IntersectionWith(Domain( @@ -1254,7 +1254,7 @@ void SharedBoundsManager::UpdateDomains(std::vector* domains) { } void SharedBoundsManager::LogStatistics(SolverLogger* logger) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); if (!bounds_exported_.empty()) { std::vector> table; table.push_back({"Improving bounds shared", "Num", "Sym"}); @@ -1268,7 +1268,7 @@ void SharedBoundsManager::LogStatistics(SolverLogger* logger) { } int SharedBoundsManager::NumBoundsExported(absl::string_view worker_name) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); const auto it = bounds_exported_.find(worker_name); if (it == bounds_exported_.end()) return 0; return it->second.num_exported; @@ -1388,7 +1388,7 @@ SharedClausesManager::SharedClausesManager(bool always_synchronize) int SharedClausesManager::RegisterNewId(absl::string_view worker_name, bool may_terminate_early) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); num_full_workers_ += may_terminate_early ? 0 : 1; const int id = id_to_last_processed_binary_clause_.size(); id_to_last_processed_binary_clause_.resize(id + 1, 0); @@ -1401,7 +1401,7 @@ int SharedClausesManager::RegisterNewId(absl::string_view worker_name, } int SharedLinear2Bounds::RegisterNewId(std::string worker_name) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); const int id = id_to_worker_name_.size(); id_to_stats_.resize(id + 1); @@ -1418,7 +1418,7 @@ void SharedClausesManager::AddBinaryClause(int id, int lit1, int lit2) { if (lit2 < lit1) std::swap(lit1, lit2); const auto p = std::make_pair(lit1, lit2); - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); const auto [unused_it, inserted] = added_binary_clauses_set_.insert(p); if (inserted) { added_binary_clauses_.push_back(p); @@ -1435,7 +1435,7 @@ void SharedClausesManager::AddBinaryClause(int id, int lit1, int lit2) { } void SharedClausesManager::AddBatch(int id, CompactVectorVector batch) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); id_to_num_exported_[id] += batch.size(); pending_batches_.push_back(std::move(batch)); } @@ -1443,7 +1443,7 @@ void SharedClausesManager::AddBatch(int id, CompactVectorVector batch) { const CompactVectorVector& SharedClausesManager::GetUnseenClauses(int id) { std::vector> result; { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); id_to_last_finished_batch_[id] = id_to_last_returned_batch_[id]; if (id_to_last_returned_batch_[id] + 1 < batches_.size()) { id_to_last_returned_batch_[id] += 1; @@ -1458,7 +1458,7 @@ const CompactVectorVector& SharedClausesManager::GetUnseenClauses(int id) { void SharedClausesManager::GetUnseenBinaryClauses( int id, std::vector>* new_clauses) { new_clauses->clear(); - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); const int last_binary_clause_seen = id_to_last_processed_binary_clause_[id]; if (last_binary_clause_seen >= last_visible_binary_clause_) return; @@ -1469,7 +1469,7 @@ void SharedClausesManager::GetUnseenBinaryClauses( } void SharedClausesManager::LogStatistics(SolverLogger* logger) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); absl::btree_map name_to_table_line; for (int id = 0; id < id_to_num_exported_.size(); ++id) { if (id_to_num_exported_[id] == 0) continue; @@ -1489,7 +1489,7 @@ void SharedClausesManager::LogStatistics(SolverLogger* logger) { // could merge small table with few columns. I am thinking list (row_name, // col_name, count) + function that create table? void SharedLinear2Bounds::LogStatistics(SolverLogger* logger) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); absl::btree_map name_to_table_line; for (int id = 0; id < id_to_stats_.size(); ++id) { const Stats stats = id_to_stats_[id]; @@ -1516,7 +1516,7 @@ void SharedLinear2Bounds::LogStatistics(SolverLogger* logger) { void SharedClausesManager::Synchronize() { std::vector> batches_to_merge; { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); last_visible_binary_clause_ = added_binary_clauses_.size(); const int num_workers = id_to_last_processed_binary_clause_.size(); if (num_workers <= 1) return; @@ -1551,7 +1551,7 @@ void SharedClausesManager::Synchronize() { } } if (next_batch.NumBufferedLiterals() > 0) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); VLOG(2) << "Merging batch"; batches_.push_back(next_batch.NextBatch()); } @@ -1561,7 +1561,7 @@ void SharedLinear2Bounds::Add(int id, Key expr, IntegerValue lb, IntegerValue ub) { DCHECK(expr.IsCanonicalized()) << expr; - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); auto [it, inserted] = shared_bounds_.insert({expr, {lb, ub}}); if (inserted) { // It is new. @@ -1582,7 +1582,7 @@ void SharedLinear2Bounds::Add(int id, Key expr, IntegerValue lb, } int SharedLinear2Bounds::RegisterNewImportId(std::string name) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); const int import_id = import_id_to_index_.size(); import_id_to_name_.push_back(name); import_id_to_index_.push_back(0); @@ -1595,7 +1595,7 @@ std::vector< SharedLinear2Bounds::NewlyUpdatedBounds(int import_id) { std::vector>> result; - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); MaybeCompressNewlyUpdateKeys(); const int size = newly_updated_keys_.size(); for (int i = import_id_to_index_[import_id]; i < size; ++i) { @@ -1622,14 +1622,14 @@ void SharedLinear2Bounds::MaybeCompressNewlyUpdateKeys() { void SharedStatistics::AddStats( absl::Span> stats) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); for (const auto& [key, count] : stats) { stats_[key] += count; } } void SharedStatistics::Log(SolverLogger* logger) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); if (stats_.empty()) return; SOLVER_LOG(logger, "Stats across workers (summed):"); diff --git a/ortools/sat/synchronization.h b/ortools/sat/synchronization.h index e6cd92e29c..2094c073c2 100644 --- a/ortools/sat/synchronization.h +++ b/ortools/sat/synchronization.h @@ -142,30 +142,30 @@ class SharedSolutionRepository { void Synchronize(std::function f = nullptr); std::vector TableLineStats() const { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); return {FormatName(name_), FormatCounter(num_added_), FormatCounter(num_queried_), FormatCounter(num_synchronization_)}; } int64_t NumRecentlyNonImproving() const { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); return num_non_improving_; } void ClearSolutionsAndIncreaseSourceId() { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); new_solutions_.clear(); solutions_.clear(); ++source_id_; } int source_id() const { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); return source_id_; } int num_queried() const { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); return num_queried_; } @@ -321,7 +321,7 @@ class SharedIncompleteSolutionManager { std::vector PopLast(); std::vector TableLineStats() const { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); return {FormatName("pump"), FormatCounter(num_added_), FormatCounter(num_queried_)}; } @@ -959,7 +959,7 @@ class SharedLinear2Bounds { // bounds that were not already known by the worker at the time of the import, // and we don't have this information here. void NotifyNumImported(int import_id, int num) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); import_id_to_num_imported_[import_id] += num; } @@ -1018,14 +1018,14 @@ class SharedStatistics { template int SharedSolutionRepository::NumSolutions() const { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); return solutions_.size(); } template std::shared_ptr::Solution> SharedSolutionRepository::GetSolution(int i) const { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); if (i >= solutions_.size()) return nullptr; ++num_queried_; return solutions_[i]; @@ -1033,7 +1033,7 @@ SharedSolutionRepository::GetSolution(int i) const { template int64_t SharedSolutionRepository::GetBestRank() const { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); if (solutions_.empty()) return std::numeric_limits::max(); return solutions_[0]->rank; } @@ -1042,7 +1042,7 @@ template std::vector::Solution>> SharedSolutionRepository::GetBestNSolutions(int n) const { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); // Sorted by rank and unique. DCHECK(absl::c_is_sorted(solutions_, [](const std::shared_ptr& a, @@ -1066,7 +1066,7 @@ SharedSolutionRepository::GetBestNSolutions(int n) const { template ValueType SharedSolutionRepository::GetVariableValueInSolution( int var_index, int solution_index) const { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); return solutions_[solution_index]->variable_values[var_index]; } @@ -1075,7 +1075,7 @@ template std::shared_ptr::Solution> SharedSolutionRepository::GetRandomBiasedSolution( absl::BitGenRef random) const { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); if (solutions_.empty()) return nullptr; ++num_queried_; int index = 0; @@ -1122,7 +1122,7 @@ SharedSolutionRepository::Add(Solution solution) { std::make_shared(std::move(solution)); if (num_solutions_to_keep_ <= 0) return std::move(solution_ptr); { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); ++num_added_; solution_ptr->source_id = source_id_; new_solutions_.push_back(solution_ptr); @@ -1133,7 +1133,7 @@ SharedSolutionRepository::Add(Solution solution) { template void SharedSolutionRepository::Synchronize( std::function f) { - absl::MutexLock mutex_lock(&mutex_); + absl::MutexLock mutex_lock(mutex_); if (new_solutions_.empty()) { const int64_t diff = num_queried_ - num_queried_at_last_sync_; num_non_improving_ += diff; diff --git a/ortools/sat/work_assignment.cc b/ortools/sat/work_assignment.cc index a9a48286e4..eba1990f49 100644 --- a/ortools/sat/work_assignment.cc +++ b/ortools/sat/work_assignment.cc @@ -251,12 +251,12 @@ SharedTreeManager::SharedTreeManager(Model* model) } int SharedTreeManager::NumNodes() const { - absl::MutexLock mutex_lock(&mu_); + absl::MutexLock mutex_lock(mu_); return nodes_.size(); } bool SharedTreeManager::SyncTree(ProtoTrail& path) { - absl::MutexLock mutex_lock(&mu_); + absl::MutexLock mutex_lock(mu_); std::vector> nodes = GetAssignedNodes(path); if (!IsValid(path)) { path.Clear(); @@ -309,7 +309,7 @@ int SharedTreeManager::TrySplitTree(absl::Span decisions, ProtoTrail& path) { decisions = decisions.subspan(0, max_path_depth_ - path.MaxLevel()); if (decisions.empty()) return 0; - absl::MutexLock l(&mu_); + absl::MutexLock l(mu_); for (int i = 0; i < decisions.size(); ++i) { if (!TrySplitTreeLockHeld(decisions[i], path)) return i; } @@ -375,7 +375,7 @@ bool SharedTreeManager::TrySplitTreeLockHeld(ProtoLiteral decision, } void SharedTreeManager::ReplaceTree(ProtoTrail& path) { - absl::MutexLock mutex_lock(&mu_); + absl::MutexLock mutex_lock(mu_); std::vector> nodes = GetAssignedNodes(path); if (nodes.back().first->children[0] == nullptr && !nodes.back().first->closed && nodes.size() > 1) { @@ -545,7 +545,7 @@ SharedTreeManager::GetAssignedNodes(const ProtoTrail& path) { } void SharedTreeManager::CloseTree(ProtoTrail& path, int level) { - absl::MutexLock mutex_lock(&mu_); + absl::MutexLock mutex_lock(mu_); const int node_id_to_close = path.NodeIds(level).front(); path.Clear(); if (node_id_to_close < node_id_offset_) return; diff --git a/ortools/sat/work_assignment.h b/ortools/sat/work_assignment.h index 580aa4196a..140debe21e 100644 --- a/ortools/sat/work_assignment.h +++ b/ortools/sat/work_assignment.h @@ -230,7 +230,7 @@ class SharedTreeManager { ABSL_LOCKS_EXCLUDED(mu_); void Restart() { - absl::MutexLock l(&mu_); + absl::MutexLock l(mu_); RestartLockHeld(); }