From 01dd97f64e588621093a1e71bd32f9832ca06035 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Wed, 18 Oct 2023 15:47:37 +0200 Subject: [PATCH] [CP-SAT] support int_prod with arity > 2; fix a few bugs, mostly around unsat models; add parameters for lp tolerance; optimize the code on hot spot --- ortools/sat/clause.cc | 4 +- ortools/sat/constraint_violation.cc | 8 +-- ortools/sat/cp_model.proto | 9 ++-- ortools/sat/cp_model_checker.cc | 33 +++++++----- ortools/sat/cp_model_expand.cc | 38 ++++++++++++++ ortools/sat/cp_model_loader.cc | 54 ++++++++++++++------ ortools/sat/cp_model_presolve.cc | 21 ++++---- ortools/sat/cp_model_search.cc | 4 +- ortools/sat/cp_model_solver.cc | 21 +++----- ortools/sat/cumulative.cc | 2 +- ortools/sat/encoding.cc | 14 ++--- ortools/sat/feasibility_jump.cc | 2 +- ortools/sat/feasibility_pump.cc | 20 +++----- ortools/sat/feasibility_pump.h | 2 - ortools/sat/integer.cc | 12 ++--- ortools/sat/integer.h | 5 +- ortools/sat/linear_programming_constraint.cc | 6 ++- ortools/sat/optimization.cc | 15 ++++-- ortools/sat/parameters_validation.cc | 5 +- ortools/sat/presolve_context.cc | 3 ++ ortools/sat/probing.cc | 6 +-- ortools/sat/sat_base.h | 19 +++++++ ortools/sat/sat_parameters.proto | 15 ++++-- ortools/sat/sat_solver.cc | 11 ++-- ortools/sat/sat_solver.h | 8 +-- ortools/sat/simplification.cc | 6 +-- 26 files changed, 216 insertions(+), 127 deletions(-) diff --git a/ortools/sat/clause.cc b/ortools/sat/clause.cc index ff65bfade5..b17185b31d 100644 --- a/ortools/sat/clause.cc +++ b/ortools/sat/clause.cc @@ -106,7 +106,7 @@ bool LiteralWatchers::PropagateOnFalse(Literal false_literal, Trail* trail) { SCOPED_TIME_STAT(&stats_); DCHECK(is_clean_); std::vector& watchers = watchers_on_false_[false_literal]; - const VariablesAssignment& assignment = trail->Assignment(); + const auto assignment = AssignmentView(trail->Assignment()); // Note(user): It sounds better to inspect the list in order, this is because // small clauses like binary or ternary clauses will often propagate and thus @@ -715,7 +715,7 @@ bool BinaryImplicationGraph::PropagateOnTrue(Literal true_literal, Trail* trail) { SCOPED_TIME_STAT(&stats_); - const VariablesAssignment& assignment = trail->Assignment(); + const auto assignment = AssignmentView(trail->Assignment()); DCHECK(assignment.LiteralIsTrue(true_literal)); // Note(user): This update is not exactly correct because in case of conflict diff --git a/ortools/sat/constraint_violation.cc b/ortools/sat/constraint_violation.cc index cae939dbc7..f48c96f4aa 100644 --- a/ortools/sat/constraint_violation.cc +++ b/ortools/sat/constraint_violation.cc @@ -978,10 +978,10 @@ int64_t CompiledIntProdConstraint::ComputeViolation( absl::Span solution) { const int64_t target_value = ExprValue(ct_proto().int_prod().target(), solution); - DCHECK_EQ(ct_proto().int_prod().exprs_size(), 2); - const int64_t prod_value = - ExprValue(ct_proto().int_prod().exprs(0), solution) * - ExprValue(ct_proto().int_prod().exprs(1), solution); + int64_t prod_value = 1; + for (const LinearExpressionProto& expr : ct_proto().int_prod().exprs()) { + prod_value *= ExprValue(expr, solution); + } return std::abs(target_value - prod_value); } diff --git a/ortools/sat/cp_model.proto b/ortools/sat/cp_model.proto index 08e6405694..716d9a7a69 100644 --- a/ortools/sat/cp_model.proto +++ b/ortools/sat/cp_model.proto @@ -368,11 +368,10 @@ message ConstraintProto { // variables. By convention, because we can just remove term equal to one, // the empty product forces the target to be one. // - // Note that the solver checks for potential integer overflow. So it is - // recommended to limit the domain of the variables such that the product - // fits in [INT_MIN + 1..INT_MAX - 1]. - // - // TODO(user): Support more than two terms in the product. + // Note that the solver checks for potential integer overflow. So the + // product of the maximum absolute value of all the terms (using the initial + // domain) should fit on an int64. Otherwise the model will be declared + // invalid. LinearArgumentProto int_prod = 11; // The lin_max constraint forces the target to equal the maximum of all diff --git a/ortools/sat/cp_model_checker.cc b/ortools/sat/cp_model_checker.cc index 83724ab62b..e7c2bd1866 100644 --- a/ortools/sat/cp_model_checker.cc +++ b/ortools/sat/cp_model_checker.cc @@ -266,6 +266,12 @@ std::string ValidateLinearExpression(const CpModelProto& model, return absl::StrCat("Possible overflow in linear expression: ", ProtobufShortDebugString(expr)); } + for (const int ref : expr.vars()) { + if (!RefIsPositive(ref)) { + return absl::StrCat("Invalid negated reference in linear expression: ", + ProtobufShortDebugString(expr)); + } + } return ""; } @@ -332,26 +338,27 @@ std::string ValidateIntModConstraint(const CpModelProto& model, std::string ValidateIntProdConstraint(const CpModelProto& model, const ConstraintProto& ct) { - if (ct.int_prod().exprs().size() != 2) { - return absl::StrCat("An int_prod constraint should have exactly 2 terms: ", - ProtobufShortDebugString(ct)); - } if (!ct.int_prod().has_target()) { return absl::StrCat("An int_prod constraint should have a target: ", ProtobufShortDebugString(ct)); } - RETURN_IF_NOT_EMPTY(ValidateAffineExpression(model, ct.int_prod().exprs(0))); - RETURN_IF_NOT_EMPTY(ValidateAffineExpression(model, ct.int_prod().exprs(1))); + for (const LinearExpressionProto& expr : ct.int_prod().exprs()) { + RETURN_IF_NOT_EMPTY(ValidateAffineExpression(model, expr)); + } RETURN_IF_NOT_EMPTY(ValidateAffineExpression(model, ct.int_prod().target())); - // Detect potential overflow if some of the variables span across 0. - const LinearExpressionProto& expr0 = ct.int_prod().exprs(0); - const LinearExpressionProto& expr1 = ct.int_prod().exprs(1); - const Domain product_domain = - Domain({MinOfExpression(model, expr0), MaxOfExpression(model, expr0)}) - .ContinuousMultiplicationBy(Domain( - {MinOfExpression(model, expr1), MaxOfExpression(model, expr1)})); + // Detect potential overflow. + Domain product_domain(1); + for (const LinearExpressionProto& expr : ct.int_prod().exprs()) { + product_domain = product_domain.ContinuousMultiplicationBy( + {MinOfExpression(model, expr), MaxOfExpression(model, expr)}); + } + if (product_domain.Max() <= -std ::numeric_limits::max() || + product_domain.Min() >= std::numeric_limits::max()) { + return absl::StrCat("integer overflow in constraint: ", + ProtobufShortDebugString(ct)); + } if ((product_domain.Max() == std::numeric_limits::max() && product_domain.Min() < 0) || (product_domain.Min() == std::numeric_limits::min() && diff --git a/ortools/sat/cp_model_expand.cc b/ortools/sat/cp_model_expand.cc index 0745528f6a..5b610f7e2e 100644 --- a/ortools/sat/cp_model_expand.cc +++ b/ortools/sat/cp_model_expand.cc @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -271,6 +272,39 @@ void ExpandIntMod(ConstraintProto* ct, PresolveContext* context) { context->UpdateRuleStats("int_mod: expanded"); } +void ExpandNonBinaryIntProd(ConstraintProto* ct, PresolveContext* context) { + CHECK_GT(ct->int_prod().exprs_size(), 2); + std::deque terms( + {ct->int_prod().exprs().begin(), ct->int_prod().exprs().end()}); + while (terms.size() > 2) { + const LinearExpressionProto& left = terms[0]; + const LinearExpressionProto& right = terms[1]; + const Domain new_domain = + context->DomainSuperSetOf(left).ContinuousMultiplicationBy( + context->DomainSuperSetOf(right)); + const int new_var = context->NewIntVar(new_domain); + LinearArgumentProto* const int_prod = + context->working_model->add_constraints()->mutable_int_prod(); + *int_prod->add_exprs() = left; + *int_prod->add_exprs() = right; + int_prod->mutable_target()->add_vars(new_var); + int_prod->mutable_target()->add_coeffs(1); + terms.pop_front(); + terms.pop_front(); + terms.push_back(int_prod->target()); + } + + LinearArgumentProto* const final_int_prod = + context->working_model->add_constraints()->mutable_int_prod(); + *final_int_prod->add_exprs() = terms[0]; + *final_int_prod->add_exprs() = terms[1]; + *final_int_prod->mutable_target() = ct->int_prod().target(); + + context->UpdateRuleStats(absl::StrCat( + "int_prod: expanded int_prod with arity ", ct->int_prod().exprs_size())); + ct->Clear(); +} + // TODO(user): Move this into the presolve instead? void ExpandIntProdWithBoolean(int bool_ref, const LinearExpressionProto& int_expr, @@ -294,6 +328,10 @@ void ExpandIntProdWithBoolean(int bool_ref, void ExpandIntProd(ConstraintProto* ct, PresolveContext* context) { const LinearArgumentProto& int_prod = ct->int_prod(); + if (int_prod.exprs_size() > 2) { + ExpandNonBinaryIntProd(ct, context); + return; + } if (int_prod.exprs_size() != 2) return; const LinearExpressionProto& a = int_prod.exprs(0); const LinearExpressionProto& b = int_prod.exprs(1); diff --git a/ortools/sat/cp_model_loader.cc b/ortools/sat/cp_model_loader.cc index 35dbe06e5b..d19a61fb9c 100644 --- a/ortools/sat/cp_model_loader.cc +++ b/ortools/sat/cp_model_loader.cc @@ -51,11 +51,13 @@ #include "ortools/sat/linear_constraint.h" #include "ortools/sat/model.h" #include "ortools/sat/pb_constraint.h" +#include "ortools/sat/precedences.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/sat_solver.h" #include "ortools/sat/symmetry.h" #include "ortools/sat/timetable.h" +#include "ortools/sat/util.h" #include "ortools/util/logging.h" #include "ortools/util/sorted_interval_list.h" #include "ortools/util/strong_integers.h" @@ -863,7 +865,9 @@ void PropagateEncodingFromEquivalenceRelations(const CpModelProto& model_proto, if (intermediate % coeff2 != 0) { // Using this function deals properly with UNSAT. ++num_set_to_false; - sat_solver->AddUnitClause(value_literal.literal.Negated()); + if (!sat_solver->AddUnitClause(value_literal.literal.Negated())) { + return; + } continue; } ++num_associations; @@ -1203,7 +1207,7 @@ void LoadLinearConstraint(const ConstraintProto& ct, Model* m) { } m->Add(ClauseConstraint(clause)); } else { - VLOG(1) << "Trivially UNSAT constraint: " << ct.DebugString(); + VLOG(1) << "Trivially UNSAT constraint: " << ct; m->GetOrCreate()->NotifyThatModelIsUnsat(); } return; @@ -1320,7 +1324,7 @@ void LoadLinearConstraint(const ConstraintProto& ct, Model* m) { ct.linear().domain(0) != min_sum && ct.linear().domain(0) != max_sum && encoder->VariableIsFullyEncoded(vars[0]) && encoder->VariableIsFullyEncoded(vars[1])) { - VLOG(3) << "Load AC version of " << ct.DebugString() << ", var0 domain = " + VLOG(3) << "Load AC version of " << ct << ", var0 domain = " << integer_trail->InitialVariableDomain(vars[0]) << ", var1 domain = " << integer_trail->InitialVariableDomain(vars[1]); @@ -1336,8 +1340,7 @@ void LoadLinearConstraint(const ConstraintProto& ct, Model* m) { single_value != min_sum && single_value != max_sum && encoder->VariableIsFullyEncoded(vars[0]) && encoder->VariableIsFullyEncoded(vars[1])) { - VLOG(3) << "Load NAC version of " << ct.DebugString() - << ", var0 domain = " + VLOG(3) << "Load NAC version of " << ct << ", var0 domain = " << integer_trail->InitialVariableDomain(vars[0]) << ", var1 domain = " << integer_trail->InitialVariableDomain(vars[1]) @@ -1469,19 +1472,36 @@ void LoadAllDiffConstraint(const ConstraintProto& ct, Model* m) { void LoadIntProdConstraint(const ConstraintProto& ct, Model* m) { auto* mapping = m->GetOrCreate(); const AffineExpression prod = mapping->Affine(ct.int_prod().target()); - CHECK_EQ(ct.int_prod().exprs_size(), 2) - << "General int_prod not supported yet."; - - const AffineExpression expr0 = mapping->Affine(ct.int_prod().exprs(0)); - const AffineExpression expr1 = mapping->Affine(ct.int_prod().exprs(1)); - if (VLOG_IS_ON(1)) { - LinearConstraintBuilder builder(m); - if (m->GetOrCreate()->TryToLinearize(expr0, expr1, - &builder)) { - VLOG(1) << "Product " << ct.DebugString() << " can be linearized"; + std::vector terms; + for (const LinearExpressionProto& expr : ct.int_prod().exprs()) { + terms.push_back(mapping->Affine(expr)); + } + switch (terms.size()) { + case 0: { + auto* integer_trail = m->GetOrCreate(); + auto* sat_solver = m->GetOrCreate(); + if (!integer_trail->Enqueue(prod.LowerOrEqual(1), {}) || + !integer_trail->Enqueue(prod.GreaterOrEqual(1), {})) { + sat_solver->NotifyThatModelIsUnsat(); + } + break; + } + case 1: { + LinearConstraintBuilder builder(m, /*lb=*/0, /*ub=*/0); + builder.AddTerm(prod, 1); + builder.AddTerm(terms[0], -1); + LoadLinearConstraint(builder.Build(), m); + break; + } + case 2: { + m->Add(ProductConstraint(terms[0], terms[1], prod)); + break; + } + default: { + LOG(FATAL) << "Loading int_prod with arity > 2, should not be here."; + break; } } - m->Add(ProductConstraint(expr0, expr1, prod)); } void LoadIntDivConstraint(const ConstraintProto& ct, Model* m) { @@ -1497,7 +1517,7 @@ void LoadIntDivConstraint(const ConstraintProto& ct, Model* m) { LinearConstraintBuilder builder(m); if (m->GetOrCreate()->TryToLinearize(num, denom, &builder)) { - VLOG(1) << "Division " << ct.DebugString() << " can be linearized"; + VLOG(1) << "Division " << ct << " can be linearized"; } } m->Add(DivisionConstraint(num, denom, div)); diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index 079729e98f..6450ae5abe 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -1105,7 +1105,7 @@ bool CpModelPresolver::PresolveIntAbs(ConstraintProto* ct) { arg->add_domain(0); AddLinearExpressionToLinearConstraint(target_expr, 1, arg); AddLinearExpressionToLinearConstraint(expr, -1, arg); - if (!CanonicalizeLinear(new_ct)) return false; + CanonicalizeLinear(new_ct); context_->UpdateNewConstraintsVariableUsage(); return RemoveConstraint(ct); } @@ -1119,7 +1119,7 @@ bool CpModelPresolver::PresolveIntAbs(ConstraintProto* ct) { arg->add_domain(0); AddLinearExpressionToLinearConstraint(target_expr, 1, arg); AddLinearExpressionToLinearConstraint(expr, 1, arg); - if (!CanonicalizeLinear(new_ct)) return false; + CanonicalizeLinear(new_ct); context_->UpdateNewConstraintsVariableUsage(); return RemoveConstraint(ct); } @@ -1418,7 +1418,7 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { literals.push_back(lit); } - // This is a bool constraint! + // This is a Boolean constraint! context_->UpdateRuleStats("int_prod: all Boolean."); { ConstraintProto* new_ct = context_->working_model->add_constraints(); @@ -3599,13 +3599,16 @@ bool CpModelPresolver::PropagateDomainsInLinear(int ct_index, if (!SubstituteVariable( var, var_coeff, *ct, context_->working_model->mutable_constraints(c))) { - // The function do not modify the constraint. - // It is possible we already started performing substitution, but that - // is usually not the case, and still correct. + // The function above can fail because of overflow, but also if the + // constraint was not canonicalized yet and the variable is actually not + // there (we have var - var for instance). // - // This can happen if the constraint was not canonicalized and the - // variable is actually not there (we have var - var for instance). - CanonicalizeLinear(context_->working_model->mutable_constraints(c)); + // TODO(user): we canonicalize it right away, but I am not sure it is + // really needed. + if (CanonicalizeLinear( + context_->working_model->mutable_constraints(c))) { + context_->UpdateConstraintVariableUsage(c); + } abort = true; break; } diff --git a/ortools/sat/cp_model_search.cc b/ortools/sat/cp_model_search.cc index 5de6dea783..c7a3ad1941 100644 --- a/ortools/sat/cp_model_search.cc +++ b/ortools/sat/cp_model_search.cc @@ -362,7 +362,9 @@ ConstructIntegerCompletionSearchStrategy( const std::vector& variable_mapping, IntegerVariable objective_var, Model* model) { const auto& params = *model->GetOrCreate(); - if (!params.instantiate_all_variables()) return nullptr; + if (!params.instantiate_all_variables()) { + return []() { return BooleanOrIntegerLiteral(); }; + } std::vector decisions; for (const IntegerVariable var : variable_mapping) { diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index d1a0f71f64..3555136e37 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -672,9 +672,6 @@ std::string CpSolverResponseStats(const CpSolverResponse& response, namespace { -#if !defined(__PORTABLE_PLATFORM__) -#endif // __PORTABLE_PLATFORM__ - // This should be called on the presolved model. It will read the file // specified by --cp_model_load_debug_solution and properly fill the // model->Get() proto vector. @@ -1474,6 +1471,7 @@ void LoadBaseModel(const CpModelProto& model_proto, Model* model) { // Fully encode variables as needed by the search strategy. AddFullEncodingFromSearchBranching(model_proto, model); + if (sat_solver->ModelIsUnsat()) return unsat(); // Reserve space for the precedence relations. model->GetOrCreate()->Resize( @@ -1863,7 +1861,7 @@ void SolveLoadedCpModel(const CpModelProto& model_proto, Model* model) { } }; - // Make sure we are not at a postive level. + // Make sure we are not at a positive level. if (!model->GetOrCreate()->ResetToLevelZero()) { shared_response_manager->NotifyThatImprovingProblemIsInfeasible( model->Name()); @@ -2311,7 +2309,8 @@ CpSolverResponse SolvePureSatModel(const CpModelProto& model_proto, if (ct.enforcement_literal_size() == 0) { for (const int ref : ct.bool_and().literals()) { const Literal b = get_literal(ref); - solver->AddUnitClause(b); + // We should report infeasible below. + if (!solver->AddUnitClause(b)) continue; } } else { // a => b @@ -2345,7 +2344,7 @@ CpSolverResponse SolvePureSatModel(const CpModelProto& model_proto, if (domain.Min() == domain.Max()) { const Literal ref_literal = domain.Min() == 0 ? get_literal(ref).Negated() : get_literal(ref); - solver->AddUnitClause(ref_literal); + if (!solver->AddUnitClause(ref_literal)) break; } } @@ -3170,9 +3169,6 @@ class LnsSolver : public SubSolver { debug_copy = lns_fragment; } -#if !defined(__PORTABLE_PLATFORM__) -#endif // __PORTABLE_PLATFORM__ - if (absl::GetFlag(FLAGS_cp_model_dump_lns)) { // TODO(user): export the delta too if needed. const std::string lns_name = @@ -3565,7 +3561,7 @@ void SolveCpModelParallel(const CpModelProto& model_proto, // schedule more than the available number of threads. They will just be // interleaved. We will get an higher diversity, but use more memory. const int num_feasibility_jump = - params.interleave_search() + (params.interleave_search() || !params.use_feasibility_jump()) ? 0 : (params.test_feasibility_jump() ? num_available : (num_available + 1) / 2); @@ -3992,9 +3988,6 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { wall_timer->Start(); user_timer->Start(); -#if !defined(__PORTABLE_PLATFORM__) -#endif // __PORTABLE_PLATFORM__ - #if !defined(__PORTABLE_PLATFORM__) // Dump initial model? if (absl::GetFlag(FLAGS_cp_model_dump_models)) { @@ -4007,9 +4000,7 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { DumpModelProto(model_proto, model_proto.name()); } } -#endif // __PORTABLE_PLATFORM__ -#if !defined(__PORTABLE_PLATFORM__) // Override parameters? if (!absl::GetFlag(FLAGS_cp_model_params).empty()) { SatParameters params = *model->GetOrCreate(); diff --git a/ortools/sat/cumulative.cc b/ortools/sat/cumulative.cc index 75dc829da6..4423100e07 100644 --- a/ortools/sat/cumulative.cc +++ b/ortools/sat/cumulative.cc @@ -306,7 +306,7 @@ std::function CumulativeTimeDecomposition( for (IntegerValue time = min_start; time < max_end; ++time) { std::vector literals_with_coeff; for (int t = 0; t < num_tasks; ++t) { - sat_solver->Propagate(); + if (!sat_solver->Propagate()) return; const IntegerValue start_min = integer_trail->LowerBound(start_vars[t]); const IntegerValue end_max = integer_trail->UpperBound(end_vars[t]); if (end_max <= time || time < start_min || fixed_demands[t] == 0) { diff --git a/ortools/sat/encoding.cc b/ortools/sat/encoding.cc index 838bee8864..2a3a1e12a7 100644 --- a/ortools/sat/encoding.cc +++ b/ortools/sat/encoding.cc @@ -199,7 +199,7 @@ void EncodingNode::ApplyWeightUpperBound(Coefficient gap, SatSolver* solver) { std::max(0, (weight_lb_ - lb_) + static_cast(num_allowed.value())); if (size() <= new_size) return; for (int i = new_size; i < size(); ++i) { - solver->AddUnitClause(literal(i).Negated()); + if (!solver->AddUnitClause(literal(i).Negated())) return; } literals_.resize(new_size); ub_ = lb_ + new_size; @@ -208,7 +208,7 @@ void EncodingNode::ApplyWeightUpperBound(Coefficient gap, SatSolver* solver) { void EncodingNode::TransformToBoolean(SatSolver* solver) { if (size() > 1) { for (int i = 1; i < size(); ++i) { - solver->AddUnitClause(literal(i).Negated()); + if (!solver->AddUnitClause(literal(i).Negated())) return; } literals_.resize(1); ub_ = lb_ + 1; @@ -220,7 +220,7 @@ void EncodingNode::TransformToBoolean(SatSolver* solver) { // TODO(user): Avoid creating a Boolean just to fix it! IncreaseNodeSize(this, solver); CHECK_EQ(size(), 2); - solver->AddUnitClause(literal(1).Negated()); + if (!solver->AddUnitClause(literal(1).Negated())) return; literals_.resize(1); ub_ = lb_ + 1; } @@ -362,7 +362,7 @@ void IncreaseNodeSize(EncodingNode* node, SatSolver* solver) { { const int ib = target - (a->lb() - 1); if ((ib - 1) == b->lb() - 1) { - solver->AddUnitClause(n->GreaterThan(target)); + if (!solver->AddUnitClause(n->GreaterThan(target))) return; } if ((ib - 1) >= b->lb() && (ib - 1) < b->current_ub()) { solver->AddBinaryClause(n->GreaterThan(target), @@ -378,7 +378,7 @@ void IncreaseNodeSize(EncodingNode* node, SatSolver* solver) { b->GreaterThan(ib)); } if (ib == b->ub()) { - solver->AddUnitClause(n->GreaterThan(target).Negated()); + if (!solver->AddUnitClause(n->GreaterThan(target).Negated())) return; } } } @@ -399,7 +399,7 @@ EncodingNode FullMerge(Coefficient upper_bound, EncodingNode* a, solver->AddBinaryClause(n.literal(ia), a->literal(ia).Negated()); } else { // Fix the variable to false because of the given upper_bound. - solver->AddUnitClause(a->literal(ia).Negated()); + if (!solver->AddUnitClause(a->literal(ia).Negated())) return n; } } for (int ib = 0; ib < b->size(); ++ib) { @@ -411,7 +411,7 @@ EncodingNode FullMerge(Coefficient upper_bound, EncodingNode* a, solver->AddBinaryClause(n.literal(ib), b->literal(ib).Negated()); } else { // Fix the variable to false because of the given upper_bound. - solver->AddUnitClause(b->literal(ib).Negated()); + if (!solver->AddUnitClause(b->literal(ib).Negated())) return n; } } for (int ia = 0; ia < a->size(); ++ia) { diff --git a/ortools/sat/feasibility_jump.cc b/ortools/sat/feasibility_jump.cc index a4fb0c8c52..30c8643d21 100644 --- a/ortools/sat/feasibility_jump.cc +++ b/ortools/sat/feasibility_jump.cc @@ -623,7 +623,7 @@ std::pair FeasibilityJumpSolver::ComputeGeneralJump(int var) { const int64_t min_delta = domain[i].start - current_value; const int64_t max_delta = domain[i].end - current_value; const auto& [delta, score] = RangeConvexMinimum( - result, min_delta, max_delta + 1, [&](int64_t delta) -> double { + min_delta, max_delta + 1, [&](int64_t delta) -> double { return ComputeScore(ScanWeights(), var, delta, /*linear_only=*/false); }); if (score < result.second) result = std::make_pair(delta, score); diff --git a/ortools/sat/feasibility_pump.cc b/ortools/sat/feasibility_pump.cc index cab21b72a1..766001c3d1 100644 --- a/ortools/sat/feasibility_pump.cc +++ b/ortools/sat/feasibility_pump.cc @@ -189,7 +189,7 @@ bool FeasibilityPump::Solve() { if (integer_solution_is_feasible_) MaybePushToRepo(); } - if (model_is_unsat_) return false; + if (sat_solver_->ModelIsUnsat()) return false; PrintStats(); MaybePushToRepo(); @@ -567,7 +567,7 @@ bool FeasibilityPump::ActiveLockBasedRounding() { bool FeasibilityPump::PropagationRounding() { if (!lp_solution_is_set_) return false; - sat_solver_->ResetToLevelZero(); + if (!sat_solver_->ResetToLevelZero()) return false; // Compute an order in which we will fix variables and do the propagation. std::vector rounding_order; @@ -623,7 +623,7 @@ bool FeasibilityPump::PropagationRounding() { (domain.Contains(ceil_value) && ub.value() >= ceil_value); if (domain.IsEmpty()) { integer_solution_[var_index] = rounded_value; - model_is_unsat_ = true; + sat_solver_->NotifyThatModelIsUnsat(); return false; } @@ -678,20 +678,12 @@ bool FeasibilityPump::PropagationRounding() { integer_encoder_->GetOrCreateLiteralAssociatedToEquality(var, value); } - if (!sat_solver_->FinishPropagation()) { - model_is_unsat_ = true; - return false; - } + if (!sat_solver_->FinishPropagation()) return false; sat_solver_->EnqueueDecisionAndBacktrackOnConflict(to_enqueue); - - if (sat_solver_->ModelIsUnsat()) { - model_is_unsat_ = true; - return false; - } + if (sat_solver_->ModelIsUnsat()) return false; } - sat_solver_->ResetToLevelZero(); integer_solution_is_set_ = true; - return true; + return sat_solver_->ResetToLevelZero(); } void FeasibilityPump::FillIntegerSolutionStats() { diff --git a/ortools/sat/feasibility_pump.h b/ortools/sat/feasibility_pump.h index c6b56a862d..219099a936 100644 --- a/ortools/sat/feasibility_pump.h +++ b/ortools/sat/feasibility_pump.h @@ -236,8 +236,6 @@ class FeasibilityPump { // TODO(user): Tune default value. Expose as parameter. int max_fp_iterations_ = 20; - - bool model_is_unsat_ = false; }; } // namespace sat diff --git a/ortools/sat/integer.cc b/ortools/sat/integer.cc index 8106d743f6..6293c22fd6 100644 --- a/ortools/sat/integer.cc +++ b/ortools/sat/integer.cc @@ -364,12 +364,10 @@ void IntegerEncoder::AssociateToIntegerLiteral(Literal literal, const IntegerValue min(domain.Min()); const IntegerValue max(domain.Max()); if (i_lit.bound <= min) { - sat_solver_->AddUnitClause(literal); - return; + return (void)sat_solver_->AddUnitClause(literal); } if (i_lit.bound > max) { - sat_solver_->AddUnitClause(literal.Negated()); - return; + return (void)sat_solver_->AddUnitClause(literal.Negated()); } if (index >= encoding_by_var_.size()) { @@ -470,8 +468,7 @@ void IntegerEncoder::AssociateToIntegerEqualValue(Literal literal, // Fix literal for value outside the domain. if (!domain.Contains(value.value())) { - sat_solver_->AddUnitClause(literal.Negated()); - return; + return (void)sat_solver_->AddUnitClause(literal.Negated()); } // Update equality_by_var. Note that due to the @@ -485,8 +482,7 @@ void IntegerEncoder::AssociateToIntegerEqualValue(Literal literal, // Fix literal for constant domain. if (domain.IsFixed()) { - sat_solver_->AddUnitClause(literal); - return; + return (void)sat_solver_->AddUnitClause(literal); } const IntegerLiteral ge = IntegerLiteral::GreaterOrEqual(var, value); diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index f8ee5f1ccb..b177267d00 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -625,7 +625,10 @@ class IntegerEncoder { const Literal literal_true = Literal(sat_solver_->NewBooleanVariable(), true); literal_index_true_ = literal_true.Index(); - sat_solver_->AddUnitClause(literal_true); + + // This might return false if we are already UNSAT. + // TODO(user): Make sure we abort right away on unsat! + (void)sat_solver_->AddUnitClause(literal_true); } return Literal(literal_index_true_); } diff --git a/ortools/sat/linear_programming_constraint.cc b/ortools/sat/linear_programming_constraint.cc index 08aadd60e4..72860b6d32 100644 --- a/ortools/sat/linear_programming_constraint.cc +++ b/ortools/sat/linear_programming_constraint.cc @@ -232,10 +232,12 @@ LinearProgrammingConstraint::LinearProgrammingConstraint( // Tweak the default parameters to make the solve incremental. simplex_params_.set_use_dual_simplex(true); simplex_params_.set_cost_scaling(glop::GlopParameters::MEAN_COST_SCALING); + simplex_params_.set_primal_feasibility_tolerance( + parameters_.lp_primal_tolerance()); + simplex_params_.set_dual_feasibility_tolerance( + parameters_.lp_dual_tolerance()); if (parameters_.use_exact_lp_reason()) { simplex_params_.set_change_status_to_imprecise(false); - simplex_params_.set_primal_feasibility_tolerance(1e-7); - simplex_params_.set_dual_feasibility_tolerance(1e-7); } simplex_.SetParameters(simplex_params_); if (parameters_.use_branching_in_lp() || diff --git a/ortools/sat/optimization.cc b/ortools/sat/optimization.cc index f1ef3978c8..e09763f2b9 100644 --- a/ortools/sat/optimization.cc +++ b/ortools/sat/optimization.cc @@ -166,7 +166,7 @@ void MinimizeCoreWithSearch(TimeLimit* limit, SatSolver* solver, << core->size(); } - solver->ResetToLevelZero(); + (void)solver->ResetToLevelZero(); solver->mutable_logger()->EnableLogging(old_log_state); } @@ -182,10 +182,15 @@ bool ProbeLiteral(Literal assumption, SatSolver* solver) { // TODO(user): Still use it if the problem is Boolean only. const auto status = solver->ResetAndSolveWithGivenAssumptions( {assumption}, /*max_number_of_conflicts=*/1'000); - solver->ResetToLevelZero(); + if (!solver->ResetToLevelZero()) return false; if (status == SatSolver::ASSUMPTIONS_UNSAT) { - solver->AddUnitClause(assumption.Negated()); - solver->Propagate(); + if (!solver->AddUnitClause(assumption.Negated())) { + return false; + } + if (!solver->Propagate()) { + solver->NotifyThatModelIsUnsat(); + return false; + } } solver->mutable_logger()->EnableLogging(old_log_state); @@ -827,7 +832,7 @@ SatSolver::Status CoreBasedOptimizer::OptimizeWithSatEncoding( if (parameters_->core_minimization_level() > 1) { MinimizeCoreWithSearch(time_limit_, sat_solver_, &core); } - sat_solver_->ResetToLevelZero(); + if (!sat_solver_->ResetToLevelZero()) return SatSolver::INFEASIBLE; FilterAssignedLiteral(sat_solver_->Assignment(), &core); if (core.empty()) return SatSolver::INFEASIBLE; diff --git a/ortools/sat/parameters_validation.cc b/ortools/sat/parameters_validation.cc index ab114956ef..4518577baf 100644 --- a/ortools/sat/parameters_validation.cc +++ b/ortools/sat/parameters_validation.cc @@ -65,6 +65,8 @@ std::string ValidateParameters(const SatParameters& params) { TEST_IS_FINITE(glucose_max_decay); TEST_IS_FINITE(glucose_decay_increment); TEST_IS_FINITE(clause_activity_decay); + TEST_IS_FINITE(lp_dual_tolerance); + TEST_IS_FINITE(lp_primal_tolerance); TEST_IS_FINITE(max_clause_activity_value); TEST_IS_FINITE(restart_dl_average_ratio); TEST_IS_FINITE(restart_lbd_average_ratio); @@ -78,11 +80,11 @@ std::string ValidateParameters(const SatParameters& params) { TEST_IS_FINITE(merge_no_overlap_work_limit); TEST_IS_FINITE(merge_at_most_one_work_limit); TEST_IS_FINITE(min_orthogonality_for_lp_constraints); + TEST_IS_FINITE(mip_var_scaling); TEST_IS_FINITE(cut_max_active_count_value); TEST_IS_FINITE(cut_active_count_decay); TEST_IS_FINITE(shaving_search_deterministic_time); TEST_IS_FINITE(mip_max_bound); - TEST_IS_FINITE(mip_var_scaling); TEST_IS_FINITE(mip_wanted_precision); TEST_IS_FINITE(mip_check_precision); TEST_IS_FINITE(mip_max_valid_magnitude); @@ -123,6 +125,7 @@ std::string ValidateParameters(const SatParameters& params) { TEST_POSITIVE(glucose_decay_increment_period); TEST_POSITIVE(shared_tree_max_nodes_per_worker); + TEST_POSITIVE(mip_var_scaling); TEST_NON_NEGATIVE(mip_wanted_precision); TEST_NON_NEGATIVE(max_time_in_seconds); diff --git a/ortools/sat/presolve_context.cc b/ortools/sat/presolve_context.cc index bd83e22ea7..97f3c50686 100644 --- a/ortools/sat/presolve_context.cc +++ b/ortools/sat/presolve_context.cc @@ -2182,6 +2182,9 @@ bool LoadModelForProbing(PresolveContext* context, Model* local_model) { local_model); ExtractEncoding(model_proto, local_model); auto* sat_solver = local_model->GetOrCreate(); + if (sat_solver->ModelIsUnsat()) { + return context->NotifyThatModelIsUnsat("Initial loading for probing"); + } for (const ConstraintProto& ct : model_proto.constraints()) { if (mapping->ConstraintIsAlreadyLoaded(&ct)) continue; CHECK(LoadConstraint(ct, local_model)); diff --git a/ortools/sat/probing.cc b/ortools/sat/probing.cc index f1240440af..0540b1d012 100644 --- a/ortools/sat/probing.cc +++ b/ortools/sat/probing.cc @@ -113,7 +113,7 @@ bool Prober::ProbeOneVariableInternal(BooleanVariable b) { // Fix variable and add new binary clauses. if (!sat_solver_->RestoreSolverToAssumptionLevel()) return false; for (const Literal l : to_fix_at_true_) { - sat_solver_->AddUnitClause(l); + if (!sat_solver_->AddUnitClause(l)) return false; } to_fix_at_true_.clear(); if (!sat_solver_->FinishPropagation()) return false; @@ -516,7 +516,7 @@ bool FailedLiteralProbingRound(ProbingOptions options, Model* model) { for (const Literal literal : to_fix) { if (!assignment.LiteralIsTrue(literal)) { ++num_explicit_fix; - sat_solver->AddUnitClause(literal); + if (!sat_solver->AddUnitClause(literal)) return false; } } to_fix.clear(); @@ -747,7 +747,7 @@ bool FailedLiteralProbingRound(ProbingOptions options, Model* model) { if (!sat_solver->ResetToLevelZero()) return false; for (const Literal literal : to_fix) { ++num_explicit_fix; - sat_solver->AddUnitClause(literal); + if (!sat_solver->AddUnitClause(literal)) return false; } to_fix.clear(); if (!sat_solver->FinishPropagation()) return false; diff --git a/ortools/sat/sat_base.h b/ortools/sat/sat_base.h index bd5884d170..5eedef7056 100644 --- a/ortools/sat/sat_base.h +++ b/ortools/sat/sat_base.h @@ -198,6 +198,25 @@ class VariablesAssignment { // - assignment_.IsSet(literal.Index() ^ 1]) means literal is false. // - If both are false, then the variable (and the literal) is unassigned. Bitset64 assignment_; + + friend class AssignmentView; +}; + +// For "hot" loop, it is better not to reload the Bitset64 pointer on each +// check. +class AssignmentView { + public: + explicit AssignmentView(const VariablesAssignment& assignment) + : view_(assignment.assignment_.const_view()) {} + + bool LiteralIsFalse(Literal literal) const { + return view_[literal.NegatedIndex()]; + } + + bool LiteralIsTrue(Literal literal) const { return view_[literal.Index()]; } + + private: + Bitset64::ConstView view_; }; // Forward declaration. diff --git a/ortools/sat/sat_parameters.proto b/ortools/sat/sat_parameters.proto index c82fed1d03..9e385baf95 100644 --- a/ortools/sat/sat_parameters.proto +++ b/ortools/sat/sat_parameters.proto @@ -23,7 +23,7 @@ option csharp_namespace = "Google.OrTools.Sat"; // Contains the definitions for all the sat algorithm parameters and their // default values. // -// NEXT TAG: 265 +// NEXT TAG: 268 message SatParameters { // In some context, like in a portfolio of search, it makes sense to name a // given parameters set for logging purpose. @@ -1082,8 +1082,10 @@ message SatParameters { // Parameters for an heuristic similar to the one described in the paper: // "Feasibility Jump: an LP-free Lagrangian MIP heuristic", Bjørnar // Luteberget, Giorgio Sartor, 2023, Mathematical Programming Computation. - // - // The test_feasibility_jump is used to only enable this for benchmarking. + optional bool use_feasibility_jump = 265 [default = true]; + + // Disable every other type of subsolver, setting this turns CP-SAT into a + // pure local-search solver. optional bool test_feasibility_jump = 240 [default = false]; // On each restart, we randomly choose if we use decay (with this parameter) @@ -1329,6 +1331,13 @@ message SatParameters { // time to do such polish step. optional bool polish_lp_solution = 175 [default = false]; + // The internal LP tolerance used by CP-SAT. These applies to the internal and + // scaled problem. If the domain of your variables are large it might be good + // to use lower tolerance. If your problem is binary with low coefficient, it + // might be good to use higher one to speed-up the lp solves. + optional double lp_primal_tolerance = 266 [default = 1e-7]; + optional double lp_dual_tolerance = 267 [default = 1e-7]; + // Temporary flag util the feature is more mature. This convert intervals to // the newer proto format that support affine start/var/end instead of just // variables. diff --git a/ortools/sat/sat_solver.cc b/ortools/sat/sat_solver.cc index 23a383cf78..998e5c6692 100644 --- a/ortools/sat/sat_solver.cc +++ b/ortools/sat/sat_solver.cc @@ -232,7 +232,7 @@ bool SatSolver::AddProblemClause(absl::Span literals, } } - AddProblemClauseInternal(literals_scratchpad_); + if (!AddProblemClauseInternal(literals_scratchpad_)) return false; // Tricky: The PropagationIsDone() condition shouldn't change anything for a // pure SAT problem, however in the CP-SAT context, calling Propagate() can @@ -1226,7 +1226,7 @@ void SatSolver::TryToMinimizeClause(SatClause* clause) { if (!Assignment().VariableIsAssigned(candidate[0].Variable())) { counters_.minimization_num_removed_literals += clause->size(); trail_->EnqueueWithUnitReason(candidate[0]); - FinishPropagation(); + return (void)FinishPropagation(); } return; } @@ -1241,8 +1241,7 @@ void SatSolver::TryToMinimizeClause(SatClause* clause) { // This is needed in the corner case where this was the first binary clause // of the problem so that PropagationIsDone() returns true on the newly // created BinaryImplicationGraph. - FinishPropagation(); - return; + return (void)FinishPropagation(); } counters_.minimization_num_removed_literals += @@ -1766,6 +1765,7 @@ bool SatSolver::PropagationIsDone() const { // part or the full integer part... bool SatSolver::Propagate() { SCOPED_TIME_STAT(&stats_); + DCHECK(!ModelIsUnsat()); // Because we might potentially iterate often on this list below, we remove // empty propagators. @@ -2692,8 +2692,7 @@ std::string SatStatusString(SatSolver::Status status) { void MinimizeCore(SatSolver* solver, std::vector* core) { std::vector result; - - solver->ResetToLevelZero(); + if (!solver->ResetToLevelZero()) return; for (const Literal lit : *core) { if (solver->Assignment().LiteralIsTrue(lit)) continue; result.push_back(lit); diff --git a/ortools/sat/sat_solver.h b/ortools/sat/sat_solver.h index 7e9b8fc15e..e2869e47b4 100644 --- a/ortools/sat/sat_solver.h +++ b/ortools/sat/sat_solver.h @@ -103,7 +103,7 @@ class SatSolver { // solve a subproblem where some variables are fixed. Note that it is more // efficient to add such unit clause before all the others. // Returns false if the problem is detected to be UNSAT. - bool AddUnitClause(Literal true_literal); + ABSL_MUST_USE_RESULT bool AddUnitClause(Literal true_literal); // Same as AddProblemClause() below, but for small clauses. bool AddBinaryClause(Literal a, Literal b); @@ -311,11 +311,11 @@ class SatSolver { // Advanced usage. Finish the progation if it was interrupted. Note that this // might run into conflict and will propagate again until a fixed point is // reached or the model was proven UNSAT. Returns IsModelUnsat(). - bool FinishPropagation(); + ABSL_MUST_USE_RESULT bool FinishPropagation(); // Like Backtrack(0) but make sure the propagation is finished and return // false if unsat was detected. This also removes any assumptions level. - bool ResetToLevelZero(); + ABSL_MUST_USE_RESULT bool ResetToLevelZero(); // Changes the assumptions level and the current solver assumptions. Returns // false if the model is UNSAT or ASSUMPTION_UNSAT, true otherwise. @@ -443,7 +443,7 @@ class SatSolver { // Performs propagation of the recently enqueued elements. // Mainly visible for testing. - bool Propagate(); + ABSL_MUST_USE_RESULT bool Propagate(); // This must be called at level zero. It will spend the given num decision and // use propagation to try to minimize some clauses from the database. diff --git a/ortools/sat/simplification.cc b/ortools/sat/simplification.cc index c3bb087833..724fc780b6 100644 --- a/ortools/sat/simplification.cc +++ b/ortools/sat/simplification.cc @@ -1209,7 +1209,7 @@ void ProbeAndFindEquivalentLiteral( const Literal true_lit = assignment.LiteralIsTrue(Literal(i)) ? Literal(rep) : Literal(rep).Negated(); - solver->AddUnitClause(true_lit); + if (!solver->AddUnitClause(true_lit)) return; if (drat_proof_handler != nullptr) { drat_proof_handler->AddClause({true_lit}); } @@ -1223,7 +1223,7 @@ void ProbeAndFindEquivalentLiteral( const Literal true_lit = assignment.LiteralIsTrue(Literal(rep)) ? Literal(i) : Literal(i).Negated(); - solver->AddUnitClause(true_lit); + if (!solver->AddUnitClause(true_lit)) return; if (drat_proof_handler != nullptr) { drat_proof_handler->AddClause({true_lit}); } @@ -1233,7 +1233,7 @@ void ProbeAndFindEquivalentLiteral( const Literal true_lit = assignment.LiteralIsTrue(Literal(i)) ? Literal(rep) : Literal(rep).Negated(); - solver->AddUnitClause(true_lit); + if (!solver->AddUnitClause(true_lit)) return; if (drat_proof_handler != nullptr) { drat_proof_handler->AddClause({true_lit}); }