diff --git a/ortools/linear_solver/proto_solver/sat_proto_solver.cc b/ortools/linear_solver/proto_solver/sat_proto_solver.cc index 8db86cd5dc..2eaa939b1e 100644 --- a/ortools/linear_solver/proto_solver/sat_proto_solver.cc +++ b/ortools/linear_solver/proto_solver/sat_proto_solver.cc @@ -160,7 +160,8 @@ MPSolutionResponse TimeLimitResponse(SolverLogger& logger) { MPSolutionResponse SatSolveProto( LazyMutableCopy request, std::atomic* interrupt_solve, std::function logging_callback, - std::function solution_callback) { + std::function solution_callback, + std::function best_bound_callback) { sat::SatParameters params; params.set_log_search_progress(request->enable_internal_solver_output()); @@ -431,6 +432,9 @@ MPSolutionResponse SatSolveProto( solution_callback(post_solve(cp_response)); })); } + if (best_bound_callback != nullptr) { + sat_model.Add(sat::NewBestBoundCallback(best_bound_callback)); + } // Solve. const sat::CpSolverResponse cp_response = diff --git a/ortools/linear_solver/proto_solver/sat_proto_solver.h b/ortools/linear_solver/proto_solver/sat_proto_solver.h index 6421501325..8a2ee65df1 100644 --- a/ortools/linear_solver/proto_solver/sat_proto_solver.h +++ b/ortools/linear_solver/proto_solver/sat_proto_solver.h @@ -49,11 +49,18 @@ namespace operations_research { // found by the solver. The solver may call solution_callback from multiple // threads, but it will ensure that at most one thread executes // solution_callback at a time. +// +// The optional best_bound_callback will be called each time the best bound is +// improved. The solver may call solution_callback from multiple +// threads, but it will ensure that at most one thread executes +// solution_callback at a time. It is guaranteed that the best bound is strictly +// improving. MPSolutionResponse SatSolveProto( LazyMutableCopy request, std::atomic* interrupt_solve = nullptr, std::function logging_callback = nullptr, - std::function solution_callback = nullptr); + std::function solution_callback = nullptr, + std::function best_bound_callback = nullptr); // Returns a string that describes the version of the CP-SAT solver. std::string SatSolverVersion(); diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index 6bfffcae16..fd80c3cff9 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -1470,11 +1470,11 @@ cc_library( "//ortools/util:stats", "//ortools/util:strong_integers", "//ortools/util:time_limit", - "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/container:btree", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/functional:function_ref", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/log:vlog_is_on", @@ -1524,6 +1524,7 @@ cc_library( ":implied_bounds", ":integer", ":integer_base", + ":lrat_proof_handler", ":model", ":sat_base", ":sat_parameters_cc_proto", diff --git a/ortools/sat/clause.cc b/ortools/sat/clause.cc index c85c4bda45..391ff8b073 100644 --- a/ortools/sat/clause.cc +++ b/ortools/sat/clause.cc @@ -366,8 +366,7 @@ void ClauseManager::InternalDetach(SatClause* clause) { } if (lrat_proof_handler_ != nullptr) { const auto it = clause_id_.find(clause); - // TODO(user): why is it necessary to keep binary clauses? - if (it != clause_id_.end() && size != 2) { + if (it != clause_id_.end()) { lrat_proof_handler_->DeleteClauses({it->second}); clause_id_.erase(it); } diff --git a/ortools/sat/clause.h b/ortools/sat/clause.h index 86b81b32dd..f7bf615f95 100644 --- a/ortools/sat/clause.h +++ b/ortools/sat/clause.h @@ -797,15 +797,6 @@ class BinaryImplicationGraph : public SatPropagator { void SetDratProofHandler(DratProofHandler* drat_proof_handler); - // Changes the reason of the variable at trail index to a binary reason. - // Note that the implication "new_reason => trail_[trail_index]" should be - // part of the implication graph. - void ChangeReason(int trail_index, Literal new_reason) { - CHECK(trail_->Assignment().LiteralIsTrue(new_reason)); - reasons_[trail_index] = new_reason.Negated(); - trail_->ChangeReason(trail_index, propagator_id_); - } - // The literals that are "directly" implied when literal is set to true. This // is not a full "reachability". It includes at most ones propagation. The set // of all direct implications is enough to describe the implications graph diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index ae881dddd0..a6d91535a6 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -13141,6 +13141,130 @@ void CpModelPresolver::MaybeTransferLinear1ToAnotherVariable(int var) { context_->MarkVariableAsRemoved(var); } +namespace { +enum class EncodingLinear1Type { + kTypeUnknown = 0, + kVarEqValue = 1, + kVarNeValue = 2, + kVarGeValue = 3, + kVarLeValue = 4, + kVarInDomain = 5, + kIgnore = 6, + kAbort = 7, + kUnsat = 8, +}; + +template +void AbslStringify(Sink& sink, EncodingLinear1Type type) { + switch (type) { + case EncodingLinear1Type::kTypeUnknown: + sink.Append("kTypeUnknown"); + return; + case EncodingLinear1Type::kVarEqValue: + sink.Append("kVarEqValue"); + return; + case EncodingLinear1Type::kVarNeValue: + sink.Append("kVarNeValue"); + return; + case EncodingLinear1Type::kVarGeValue: + sink.Append("kVarGeValue"); + return; + case EncodingLinear1Type::kVarLeValue: + sink.Append("kVarLeValue"); + return; + case EncodingLinear1Type::kVarInDomain: + sink.Append("kVarInDomain"); + return; + case EncodingLinear1Type::kIgnore: + sink.Append("kIgnore"); + return; + case EncodingLinear1Type::kAbort: + sink.Append("kAbort"); + return; + case EncodingLinear1Type::kUnsat: + sink.Append("kUnsat"); + return; + } +} + +struct EncodingLinear1 { + EncodingLinear1Type type = EncodingLinear1Type::kTypeUnknown; + int64_t value = std::numeric_limits::min(); + Domain rhs; + int enforcement_literal; + int constraint_index; + + static EncodingLinear1 FromConstraint( + PresolveContext* context, int constraint_index, const Domain& var_domain, + int64_t& num_implied_literals_in_complex_domains) { + const ConstraintProto& ct = + context->working_model->constraints(constraint_index); + const Domain rhs = ReadDomainFromProto(ct.linear()) + .InverseMultiplicationBy(ct.linear().coeffs(0)) + .IntersectionWith(var_domain); + EncodingLinear1 lin; + lin.enforcement_literal = ct.enforcement_literal(0); + lin.constraint_index = constraint_index; + + if (rhs.IsEmpty()) { + if (!context->SetLiteralToFalse(lin.enforcement_literal)) { + lin.type = EncodingLinear1Type::kUnsat; + } else { + lin.type = EncodingLinear1Type::kIgnore; + } + } else if (rhs.IsFixed()) { + if (!var_domain.Contains(rhs.FixedValue())) { + if (!context->SetLiteralToFalse(lin.enforcement_literal)) { + lin.type = EncodingLinear1Type::kUnsat; + } + lin.type = EncodingLinear1Type::kIgnore; + } else { + lin.type = EncodingLinear1Type::kVarEqValue; + lin.value = rhs.FixedValue(); + } + } else { + const Domain complement = var_domain.IntersectionWith(rhs.Complement()); + if (complement.IsEmpty()) { + lin.type = EncodingLinear1Type::kIgnore; + } else if (complement.IsFixed()) { + CHECK(var_domain.Contains(complement.FixedValue())); + lin.type = EncodingLinear1Type::kVarNeValue; + lin.value = complement.FixedValue(); + } else if (rhs.Min() > complement.Max()) { + lin.type = EncodingLinear1Type::kVarGeValue; + lin.value = rhs.Min(); + num_implied_literals_in_complex_domains = + CapAdd(num_implied_literals_in_complex_domains, complement.Size()); + } else if (rhs.Max() < complement.Min()) { + lin.type = EncodingLinear1Type::kVarLeValue; + lin.value = rhs.Max(); + num_implied_literals_in_complex_domains = + CapAdd(num_implied_literals_in_complex_domains, complement.Size()); + } else { + lin.type = EncodingLinear1Type::kVarInDomain; + lin.rhs = rhs; + num_implied_literals_in_complex_domains = + CapAdd(num_implied_literals_in_complex_domains, complement.Size()); + } + } + return lin; + } + + std::string ToString() const { + return absl::StrCat("EncodingLinear1(type: ", type, ", value: ", value, + ", rhs: ", rhs.ToString(), + ", enforcement_literal: ", enforcement_literal, + ", constraint_index: ", constraint_index, ")"); + } +}; + +template +void AbslStringify(Sink& sink, const EncodingLinear1& lin) { + sink.Append(lin.ToString()); +} + +} // namespace + // TODO(user): We can still remove the variable even if we want to keep // all feasible solutions for the cases when we have a full encoding. // Similarly this shouldn't break symmetry, but we do need to do it for all @@ -13232,77 +13356,135 @@ void CpModelPresolver::ProcessVariableOnlyUsedInEncoding(int var) { } } - std::vector encoded_values; - absl::flat_hash_map> value_to_equal_literals; - absl::flat_hash_map> value_to_not_equal_literals; - std::vector> enforced_domains; const Domain var_domain = context_->DomainOf(var); + std::vector encoded_values; + std::vector linear_ones; + absl::btree_map> tmp_ge_to_literals; + absl::btree_map> tmp_le_to_literals; + absl::btree_map encoded_le_literal; + absl::btree_map encoded_ge_literal; + int64_t num_implied_literals_in_complex_domains = 0; - bool abort = false; + int num_var_eq_value = 0; + int num_var_ne_value = 0; + int num_var_ge_value = 0; + int num_var_le_value = 0; + int num_var_in_domain = 0; + + const auto insert_var_le_value_literal = [&](int64_t value, int literal) { + if (!tmp_le_to_literals[value].insert(literal).second) return; + DCHECK_LT(value, var_domain.Max()); + const int64_t next_value = var_domain.ValueAtOrAfter(value + 1); + if (tmp_ge_to_literals[next_value].contains(NegatedRef(literal))) { + const auto [it, inserted] = encoded_le_literal.insert({value, literal}); + if (!inserted) { + VLOG(2) << "Duplicate var_le_value literal: " << literal + << " for value: " << value << " previous: " << it->second; + } else { + VLOG(3) << " - insert " << literal << " => var <= " << value; + VLOG(3) << " - insert " << NegatedRef(literal) + << " => var >= " << next_value; + DCHECK(encoded_ge_literal.insert({next_value, NegatedRef(literal)}) + .second); + } + } + }; + + const auto insert_var_ge_value_literal = [&](int64_t value, int literal) { + if (!tmp_ge_to_literals[value].insert(literal).second) return; + DCHECK_GT(value, var_domain.Min()); + const int64_t previous_value = var_domain.ValueAtOrBefore(value - 1); + if (tmp_le_to_literals[previous_value].contains(NegatedRef(literal))) { + const auto [it, inserted] = + encoded_le_literal.insert({previous_value, NegatedRef(literal)}); + if (!inserted) { + VLOG(2) << "Duplicate var_le_value literal: " << NegatedRef(literal) + << " for value: " << previous_value + << " previous: " << it->second; + } else { + VLOG(3) << " - insert " << NegatedRef(literal) + << " => var <= " << previous_value; + VLOG(3) << " - insert " << literal << " => var >= " << value; + DCHECK(encoded_ge_literal.insert({value, literal}).second); + } + } + }; for (const int c : context_->VarToConstraints(var)) { if (c < 0) continue; const ConstraintProto& ct = context_->working_model->constraints(c); - CHECK_EQ(ct.constraint_case(), ConstraintProto::kLinear); - CHECK_EQ(ct.linear().vars().size(), 1); - CHECK(RefIsPositive(ct.linear().vars(0))); - CHECK_EQ(ct.linear().vars(0), var); - const int64_t coeff = ct.linear().coeffs(0); + DCHECK_EQ(ct.constraint_case(), ConstraintProto::kLinear); + DCHECK_EQ(ct.linear().vars().size(), 1); + DCHECK(RefIsPositive(ct.linear().vars(0))); + DCHECK_EQ(ct.linear().vars(0), var); if (ct.enforcement_literal().size() != 1) { - abort = true; - break; - } - const Domain rhs = ReadDomainFromProto(ct.linear()) - .InverseMultiplicationBy(coeff) - .IntersectionWith(var_domain); - const int enforcement_literal = ct.enforcement_literal(0); - - if (rhs.IsEmpty()) { - if (!context_->SetLiteralToFalse(enforcement_literal)) { - return; - } + context_->UpdateRuleStats( + "TODO variables: linear1 with multiple enforcement literals"); return; - } else if (rhs.IsFixed()) { - if (!var_domain.Contains(rhs.FixedValue())) { - if (!context_->SetLiteralToFalse(enforcement_literal)) { - return; - } - } else { - encoded_values.push_back(rhs.FixedValue()); - value_to_equal_literals[rhs.FixedValue()].push_back( - ct.enforcement_literal(0)); - } - } else { - const Domain complement = var_domain.IntersectionWith(rhs.Complement()); - if (complement.IsEmpty()) { - // TODO(user): This should be dealt with elsewhere. - abort = true; - break; - } - if (complement.IsFixed()) { - CHECK(var_domain.Contains(complement.FixedValue())); - encoded_values.push_back(complement.FixedValue()); - value_to_not_equal_literals[complement.FixedValue()].push_back( - enforcement_literal); - } else { - enforced_domains.push_back({rhs, enforcement_literal}); - num_implied_literals_in_complex_domains = - CapAdd(num_implied_literals_in_complex_domains, complement.Size()); - } } - } - if (abort) { - context_->UpdateRuleStats("TODO variables: only used in complex linear1"); - return; + const EncodingLinear1 lin = EncodingLinear1::FromConstraint( + context_, c, var_domain, num_implied_literals_in_complex_domains); + VLOG(3) << "ProcessVariableOnlyUsedInEncoding(): var(" << var + << ") domain: " << var_domain << " linear1: " << lin; + switch (lin.type) { + case EncodingLinear1Type::kTypeUnknown: + LOG(FATAL) << "Unset EncodingLinear1 type"; + return; + case EncodingLinear1Type::kVarEqValue: + encoded_values.push_back(lin.value); + if (lin.value == var_domain.Min()) { + insert_var_le_value_literal(lin.value, lin.enforcement_literal); + } else if (lin.value == var_domain.Max()) { + insert_var_ge_value_literal(lin.value, lin.enforcement_literal); + } + ++num_var_eq_value; + break; + case EncodingLinear1Type::kVarNeValue: + encoded_values.push_back(lin.value); + if (lin.value == var_domain.Min()) { + insert_var_ge_value_literal(var_domain.ValueAtOrAfter(lin.value + 1), + lin.enforcement_literal); + } else if (lin.value == var_domain.Max()) { + insert_var_le_value_literal(var_domain.ValueAtOrBefore(lin.value - 1), + lin.enforcement_literal); + } + ++num_var_ne_value; + break; + case EncodingLinear1Type::kVarGeValue: + insert_var_ge_value_literal(lin.value, lin.enforcement_literal); + ++num_var_ge_value; + break; + case EncodingLinear1Type::kVarLeValue: + insert_var_le_value_literal(lin.value, lin.enforcement_literal); + ++num_var_le_value; + break; + case EncodingLinear1Type::kVarInDomain: + ++num_var_in_domain; + break; + case EncodingLinear1Type::kIgnore: + break; + case EncodingLinear1Type::kAbort: + context_->UpdateRuleStats( + "TODO variables: only used in complex linear1"); + return; + case EncodingLinear1Type::kUnsat: + return; + } + linear_ones.push_back(lin); } // We force the full encoding if the variable is mostly encoded and some // linear1 involves domains that do not correspond to value encodings. + // We also fully encode if the set if var <= value literals is bigger than + // half the size of the domain. gtl::STLSortAndRemoveDuplicates(&encoded_values); bool fully_encode_var = encoded_values.size() == var_domain.Size(); - if (!enforced_domains.empty()) { - if (context_->IsMostlyFullyEncoded(var) || var_domain.Size() <= 32) { + const bool has_non_value_encodings = + num_var_ge_value + num_var_le_value + num_var_in_domain > 0; + if (has_non_value_encodings) { + if (context_->IsMostlyFullyEncoded(var) || var_domain.Size() <= 32 || + 2 * encoded_le_literal.size() > var_domain.Size()) { fully_encode_var = true; encoded_values.clear(); encoded_values.reserve(var_domain.Size()); @@ -13312,99 +13494,268 @@ void CpModelPresolver::ProcessVariableOnlyUsedInEncoding(int var) { } } - if (value_to_not_equal_literals.empty() && value_to_equal_literals.empty() && - !fully_encode_var) { + // A variable without order encodings linear1 can still have le and ge + // literals for its min and max values. + // We clean those in that case. + if (num_var_le_value + num_var_ge_value == 0) { + encoded_le_literal.clear(); + encoded_ge_literal.clear(); + tmp_le_to_literals.clear(); + tmp_ge_to_literals.clear(); + } + + if (encoded_values.empty()) { // This variable has no value encoding. Either enforced_domains is // empty, and in that case, we will not do anything about it, or the // variable is not used anymore, and it will be removed later. return; } - if (!enforced_domains.empty() && + VLOG(2) << "ProcessVariableOnlyUsedInEncoding(): var(" << var + << "): " << var_domain + << ", #encoded_values: " << encoded_values.size() + << ", #ordered_values: " << encoded_le_literal.size() + << ", #var_eq_value: " << num_var_eq_value + << ", #var_ne_value: " << num_var_ne_value + << ", #var_ge_value: " << num_var_ge_value + << ", #var_le_value: " << num_var_le_value + << ", #var_in_domain: " << num_var_in_domain + << ", #implied_literals_in_complex_domains: " + << num_implied_literals_in_complex_domains; + if (has_non_value_encodings && (!fully_encode_var || num_implied_literals_in_complex_domains > 2500)) { - VLOG(2) << "Abort in ProcessVariableOnlyUsedInEncoding(). var(" << var - << "): " << var_domain - << ", #var_eq_value: " << value_to_equal_literals.size() - << ", #var_ne_value: " << value_to_not_equal_literals.size() - << ", #var_in_domain: " << enforced_domains.size() - << ", #implied_literals_in_complex_domains: " + VLOG(2) << "Abort - fully_encode_var: " << fully_encode_var + << ", num_implied_literals_in_complex_domains: " << num_implied_literals_in_complex_domains; context_->UpdateRuleStats( - "TODO variables: only used in value encoding and order encoding."); + "TODO variables: only used in large value encoding and order " + "encoding."); return; } - // Link all Boolean in our linear1 to the encoding literals. Note that we - // should hopefully already have detected such literal before and this - // should add trivial implications. - absl::btree_map value_to_encoding_literal; - for (const int64_t v : encoded_values) { - const int encoding_lit = context_->GetOrCreateVarValueEncoding(var, v); - value_to_encoding_literal[v] = encoding_lit; + // We process value encoding literals first. + // Store the values and literals before we delete the linear1 constraints. + // + // NOTE: we actually create the value encoding literals here, and not just + // simple Boolean variables, as we may abort during the objective + // substitution. + absl::btree_map value_encoding_to_literal; + for (const int64_t value : encoded_values) { + value_encoding_to_literal[value] = + context_->GetOrCreateVarValueEncoding(var, value); + } - const auto eq_it = value_to_equal_literals.find(v); - if (eq_it != value_to_equal_literals.end()) { - absl::c_sort(eq_it->second); - for (const int lit : eq_it->second) { - context_->AddImplication(lit, encoding_lit); + // Same as with value encoding literals, we create the order encoding literals + // here. We will collect all values that appear in a lit => var >= value, and + // lit => var <= value constraints and create the corresponding constraints. + // + // TODO(user): Move model level order encoding to PresolveContext. + for (const auto& [value, lits] : tmp_le_to_literals) { + const auto it = encoded_le_literal.find(value); + if (it != encoded_le_literal.end()) continue; + const int64_t next_value = var_domain.ValueAtOrAfter(value + 1); + const int le_literal = context_->NewBoolVar("order encoding"); + solution_crush_.MaybeSetLiteralToOrderEncoding(le_literal, var, value, + /*is_le=*/true); + encoded_le_literal[value] = le_literal; + encoded_ge_literal[next_value] = NegatedRef(le_literal); + } + + for (const auto& [value, lits] : tmp_ge_to_literals) { + const auto it = encoded_ge_literal.find(value); + if (it != encoded_ge_literal.end()) continue; + const int64_t previous_value = var_domain.ValueAtOrBefore(value - 1); + const int ge_literal = context_->NewBoolVar("order encoding"); + solution_crush_.MaybeSetLiteralToOrderEncoding(ge_literal, var, value, + /*is_le=*/false); + encoded_ge_literal[value] = ge_literal; + encoded_le_literal[previous_value] = NegatedRef(ge_literal); + } + + // We process the order encoding literals next. + // + // In the following examples, x has 5 values: 0, 1, 2, 3, 4, and some order + // encoding literals. + // 0 1 2 3 4 + // x_le_0 x_le_1 x_le_3 + // x_ge_1 x_ge_3 x_ge_4 + // + // x_le_0 => not(x == 1) && x_le_1 + // x_le_1 => not(x == 2) && not(x == 3) && x_le_3 + // + // x_ge_1 => not(x == 0) + // x_ge_3 => not(x == 1) && not(x == 2) && x_ge_1 + // x_ge_4 => not(x == 3) && x_ge_3 + if (!encoded_le_literal.empty()) { + const int64_t max_ge_value = encoded_ge_literal.rbegin()->first; + DCHECK(!encoded_ge_literal.empty()); + DCHECK(fully_encode_var); + ConstraintProto* not_le = nullptr; + ConstraintProto* not_ge = context_->working_model->add_constraints(); + for (const auto [value, eq_literal] : value_encoding_to_literal) { + const int ne_literal = NegatedRef(eq_literal); + + // Lower or equal. + if (not_le != nullptr) { + not_le->mutable_bool_and()->add_literals(ne_literal); + } + const auto it_le = encoded_le_literal.find(value); + if (it_le != encoded_le_literal.end()) { + const int le_literal = it_le->second; + if (not_le != nullptr) { + not_le->mutable_bool_and()->add_literals(le_literal); + } + not_le = context_->AddEnforcedConstraint({le_literal}); } - } - const auto neq_it = value_to_not_equal_literals.find(v); - if (neq_it != value_to_not_equal_literals.end()) { - absl::c_sort(neq_it->second); - for (const int lit : neq_it->second) { - context_->AddImplication(lit, NegatedRef(encoding_lit)); + // Greater or equal. + const auto it_ge = encoded_ge_literal.find(value); + if (it_ge != encoded_ge_literal.end()) { + const int ge_literal = it_ge->second; + not_ge->add_enforcement_literal(ge_literal); + if (value != max_ge_value) { + not_ge = context_->working_model->add_constraints(); + not_ge->mutable_bool_and()->add_literals(ge_literal); + } else { + not_ge = nullptr; + } + } + if (not_ge != nullptr) { + not_ge->mutable_bool_and()->add_literals(ne_literal); } } } - absl::c_stable_sort(enforced_domains, [](const auto& a, const auto& b) { - return a.second < b.second; + // Sort the linear1_infos by constraint index to make the encoding + // deterministic. + absl::c_sort(linear_ones, [](const auto& a, const auto& b) { + return a.constraint_index < b.constraint_index; }); - for (int i = 0; i < enforced_domains.size(); ++i) { - const Domain& implied_domain = enforced_domains[i].first; - const int e = enforced_domains[i].second; - ConstraintProto* imply = context_->AddEnforcedConstraint({e}); - const Domain implied_complement = - var_domain.IntersectionWith(implied_domain.Complement()); - for (const int64_t v : implied_complement.Values()) { - imply->mutable_bool_and()->add_literals( - NegatedRef(value_to_encoding_literal[v])); + // Link all Boolean in our linear1 to the encoding literals. Note that we + // should hopefully already have detected value encoding literal before and + // this should add trivial implications for the VarEqValue and VarNeValue + // cases. Same idea for the order encoding. + for (const EncodingLinear1& info : linear_ones) { + switch (info.type) { + case EncodingLinear1Type::kVarEqValue: { + DCHECK(value_encoding_to_literal.contains(info.value)); + const int encoding_lit = value_encoding_to_literal[info.value]; + context_->AddImplication(info.enforcement_literal, encoding_lit); + break; + } + case EncodingLinear1Type::kVarNeValue: { + DCHECK(value_encoding_to_literal.contains(info.value)); + const int encoding_lit = value_encoding_to_literal[info.value]; + context_->AddImplication(info.enforcement_literal, + NegatedRef(encoding_lit)); + break; + } + case EncodingLinear1Type::kVarGeValue: { + const auto it = encoded_ge_literal.find(info.value); + CHECK(it != encoded_ge_literal.end()); + context_->AddImplication(info.enforcement_literal, it->second); + break; + } + case EncodingLinear1Type::kVarLeValue: { + const auto it = encoded_le_literal.find(info.value); + CHECK(it != encoded_le_literal.end()); + context_->AddImplication(info.enforcement_literal, it->second); + break; + } + case EncodingLinear1Type::kVarInDomain: { + ConstraintProto* imply = + context_->AddEnforcedConstraint({info.enforcement_literal}); + const Domain implied_complement = + var_domain.IntersectionWith(info.rhs.Complement()); + for (const int64_t v : implied_complement.Values()) { + DCHECK(value_encoding_to_literal.contains(v)); + imply->mutable_bool_and()->add_literals( + NegatedRef(value_encoding_to_literal[v])); + } + break; + } + case EncodingLinear1Type::kIgnore: { + break; + } + default: + LOG(FATAL) << "Unsupported Linear1Type: " << info.type; + return; } } // Detect implications between complex encodings. - // TODO(user): reduce the number of implication by performing a transitive - // reduction. - if (enforced_domains.size() < 1000 && enforced_domains.size() > 1) { - const int64_t var_size = var_domain.Size(); - std::vector domain_sizes; - domain_sizes.reserve(enforced_domains.size()); - for (const auto& [domain, e] : enforced_domains) { - domain_sizes.push_back(domain.Size()); + absl::c_sort(linear_ones, [](const auto& a, const auto& b) { + return std::tie(a.type, a.constraint_index) < + std::tie(b.type, b.constraint_index); + }); + int min_ge_index = -1; + int max_ge_index = -1; + int min_le_index = -1; + int max_le_index = -1; + int min_in_domain_index = -1; + int max_in_domain_index = -1; + for (int i = 0; i < linear_ones.size(); ++i) { + if (linear_ones[i].type == EncodingLinear1Type::kVarGeValue) { + if (min_ge_index == -1) min_ge_index = i; + max_ge_index = i; + } else if (linear_ones[i].type == EncodingLinear1Type::kVarLeValue) { + if (min_le_index == -1) min_le_index = i; + max_le_index = i; + } else if (linear_ones[i].type == EncodingLinear1Type::kVarInDomain) { + if (min_in_domain_index == -1) min_in_domain_index = i; + max_in_domain_index = i; } + } - for (int i = 0; i < enforced_domains.size(); ++i) { - const Domain& implied_domain = enforced_domains[i].first; - const int e = enforced_domains[i].second; + const auto add_incompatibility = [this, &linear_ones](int i, int j) { + DCHECK_NE(i, j); + const EncodingLinear1& info_i = linear_ones[i]; + const int e_i = info_i.enforcement_literal; + const EncodingLinear1& info_j = linear_ones[j]; + const int e_j = info_j.enforcement_literal; + if (e_i == NegatedRef(e_j)) return; + BoolArgumentProto* incompatible = + context_->working_model->add_constraints()->mutable_bool_or(); + incompatible->add_literals(NegatedRef(e_i)); + incompatible->add_literals(NegatedRef(e_j)); + context_->UpdateRuleStats( + "variables: add at_most_one between incompatible complex encodings"); + }; - for (int j = i + 1; j < enforced_domains.size(); ++j) { - // Quick sufficient test to check that the two domains overlap. - if (domain_sizes[i] + domain_sizes[j] > var_size) continue; + if (min_in_domain_index != -1) { + for (int i = min_in_domain_index; i <= max_in_domain_index; ++i) { + const EncodingLinear1& info_i = linear_ones[i]; + DCHECK_EQ(info_i.type, EncodingLinear1Type::kVarInDomain); - const Domain& other_domain = enforced_domains[j].first; - const int other_e = enforced_domains[j].second; - if (e == other_e || e == NegatedRef(other_e)) continue; - if (!other_domain.OverlapsWith(implied_domain)) { - BoolArgumentProto* incompatible = - context_->working_model->add_constraints()->mutable_bool_or(); - incompatible->add_literals(NegatedRef(e)); - incompatible->add_literals(NegatedRef(other_e)); - context_->UpdateRuleStats( - "variables: add at_most_one between incompatible complex " - "encodings"); + // Incompatibilities between x in domain and x >= ge. + if (min_ge_index != -1) { + for (int j = min_ge_index; j <= max_ge_index; ++j) { + const EncodingLinear1& info_j = linear_ones[j]; + DCHECK_EQ(info_j.type, EncodingLinear1Type::kVarGeValue); + if (info_i.rhs.Max() < info_j.value) { + add_incompatibility(i, j); + } + } + } + + // Incompatibilities between x in domain and x <= value. + if (min_le_index != -1) { + for (int j = min_le_index; j <= max_le_index; ++j) { + const EncodingLinear1& info_j = linear_ones[j]; + DCHECK_EQ(info_j.type, EncodingLinear1Type::kVarLeValue); + if (info_i.rhs.Min() > info_j.value) { + add_incompatibility(i, j); + } + } + } + + // Incompatibilites between x in domain_i and x in domain_j. + for (int j = i + 1; j <= max_in_domain_index; ++j) { + const EncodingLinear1& info_j = linear_ones[j]; + DCHECK_EQ(info_j.type, EncodingLinear1Type::kVarInDomain); + if (!info_i.rhs.OverlapsWith(info_j.rhs)) { + add_incompatibility(i, j); } } } @@ -13440,7 +13791,7 @@ void CpModelPresolver::ProcessVariableOnlyUsedInEncoding(int var) { // Tricky: If the variable is not fully encoded, then when all // partial encoding literal are false, it must take the "best" value - // in other_values. That depend on the sign of the objective coeff. + // in other_values. That depends on the sign of the objective coeff. // // We also restrict other value so that the postsolve code below // will fix the variable to the correct value when this happen. @@ -13467,8 +13818,8 @@ void CpModelPresolver::ProcessVariableOnlyUsedInEncoding(int var) { const int64_t coeff_in_equality = -1; linear->add_vars(var); linear->add_coeffs(coeff_in_equality); - int64_t rhs = -pivot; - for (const auto& [value, literal] : value_to_encoding_literal) { + int64_t rhs_value = -pivot; + for (const auto& [value, literal] : value_encoding_to_literal) { const int64_t coeff = value - pivot; if (coeff == 0) continue; if (RefIsPositive(literal)) { @@ -13476,13 +13827,13 @@ void CpModelPresolver::ProcessVariableOnlyUsedInEncoding(int var) { linear->add_coeffs(coeff); } else { // (1 - var) * coeff; - rhs -= coeff; + rhs_value -= coeff; linear->add_vars(PositiveRef(literal)); linear->add_coeffs(-coeff); } } - linear->add_domain(rhs); - linear->add_domain(rhs); + linear->add_domain(rhs_value); + linear->add_domain(rhs_value); if (!context_->SubstituteVariableInObjective(var, coeff_in_equality, encoding_ct)) { context_->UpdateRuleStats( @@ -13492,14 +13843,19 @@ void CpModelPresolver::ProcessVariableOnlyUsedInEncoding(int var) { context_->UpdateRuleStats( "variables: only used in objective and in encoding"); } else { - if (enforced_domains.empty()) { - context_->UpdateRuleStats("variables: only used in value encoding"); - } else { + if (has_non_value_encodings && num_var_in_domain == 0) { + context_->UpdateRuleStats( + "variables: only used in value and order encodings"); + } else if (has_non_value_encodings) { context_->UpdateRuleStats("variables: only used in complex encoding"); + } else { + context_->UpdateRuleStats("variables: only used in value encoding"); } } - // Clear all involved constraint. + // Clear all involved constraint. We do it in two passes to avoid + // invalidating the iterator. We also use the constraint variable graph as + // extra encodings (value, order) may have added new constraints. { std::vector to_clear; for (const int c : context_->VarToConstraints(var)) { @@ -13519,7 +13875,7 @@ void CpModelPresolver::ProcessVariableOnlyUsedInEncoding(int var) { ConstraintProto* encoding = context_->working_model->add_constraints(); BoolArgumentProto* arg = fully_encode_var ? encoding->mutable_exactly_one() : encoding->mutable_at_most_one(); - for (const auto& [value, literal] : value_to_encoding_literal) { + for (const auto& [value, literal] : value_encoding_to_literal) { arg->add_literals(literal); } if (fully_encode_var) { @@ -13541,8 +13897,8 @@ void CpModelPresolver::ProcessVariableOnlyUsedInEncoding(int var) { mapping_ct->mutable_linear()->add_vars(var); mapping_ct->mutable_linear()->add_coeffs(1); int64_t offset = special_value; - for (const auto& [value, literal] : value_to_encoding_literal) { - const int64_t coeff = (value - special_value); + for (const auto& [value, literal] : value_encoding_to_literal) { + const int64_t coeff = value - special_value; if (coeff == 0) continue; if (RefIsPositive(literal)) { mapping_ct->mutable_linear()->add_vars(literal); @@ -13557,7 +13913,7 @@ void CpModelPresolver::ProcessVariableOnlyUsedInEncoding(int var) { mapping_ct->mutable_linear()->add_domain(offset); context_->MarkVariableAsRemoved(var); -} +} // NOLINT(readability/fn_size) void CpModelPresolver::TryToSimplifyDomain(int var) { DCHECK(RefIsPositive(var)); @@ -14796,10 +15152,7 @@ void ApplyVariableMapping(absl::Span mapping, CpModelProto* cp_model, std::vector* reverse_mapping) { // Remap all the variable/literal references in the constraints and the // enforcement literals in the variables. - // - // Using a absl::FunctionRef doesn't work when open-sourced. - std::function mapping_function = [&mapping, - &reverse_mapping](int* ref) { + const auto mapping_function = [&mapping, &reverse_mapping](int* ref) { const int var = PositiveRef(*ref); int image = mapping[var]; if (image < 0) { diff --git a/ortools/sat/cp_model_solver.h b/ortools/sat/cp_model_solver.h index a61c0755cd..623c86c25f 100644 --- a/ortools/sat/cp_model_solver.h +++ b/ortools/sat/cp_model_solver.h @@ -115,7 +115,7 @@ std::function NewFeasibleSolutionLogCallback( /** * Creates a callbacks that will be called on each new best objective bound - * found. + * found. It is guaranteed that the best bound is strictly improving. * * Note that this function is called before the update takes place. */ diff --git a/ortools/sat/cp_model_solver_helpers.cc b/ortools/sat/cp_model_solver_helpers.cc index 9f3b164238..0d5b6e992b 100644 --- a/ortools/sat/cp_model_solver_helpers.cc +++ b/ortools/sat/cp_model_solver_helpers.cc @@ -1173,16 +1173,17 @@ int RegisterClausesLevelZeroImport(int id, namespace { -// Fills the BinaryRelationRepository with the enforced linear constraints of -// size 1 or 2 in the model, and with the non-enforced linear constraints of -// size 2. Also expands linear constraints of size 1 enforced by two literals +// Fills several repositories of bounds of linear2 (RootLevelLinear2Bounds, +// ConditionalLinear2Bounds and ReifiedLinear2Bounds) using the linear +// constraints of size 2 and the linear constraints of size 3 with domain of +// size 1. Also expands linear constraints of size 1 enforced by two literals // into (up to) 4 binary relations enforced by only one literal. -void FillBinaryRelationRepository(const CpModelProto& model_proto, +void FillConditionalLinear2Bounds(const CpModelProto& model_proto, Model* model) { auto* integer_trail = model->GetOrCreate(); auto* encoder = model->GetOrCreate(); auto* mapping = model->GetOrCreate(); - auto* repository = model->GetOrCreate(); + auto* repository = model->GetOrCreate(); auto* root_level_lin2_bounds = model->GetOrCreate(); auto* reified_lin2_bounds = model->GetOrCreate(); @@ -1323,7 +1324,7 @@ void LoadBaseModel(const CpModelProto& model_proto, Model* model) { AddFullEncodingFromSearchBranching(model_proto, model); if (sat_solver->ModelIsUnsat()) return unsat(); - FillBinaryRelationRepository(model_proto, model); + FillConditionalLinear2Bounds(model_proto, model); if (time_limit->LimitReached()) return; diff --git a/ortools/sat/precedences.cc b/ortools/sat/precedences.cc index d3ab037d4d..933d450f73 100644 --- a/ortools/sat/precedences.cc +++ b/ortools/sat/precedences.cc @@ -742,17 +742,17 @@ void EnforcedLinear2Bounds::CollectPrecedences( } } -BinaryRelationRepository::~BinaryRelationRepository() { +ConditionalLinear2Bounds::~ConditionalLinear2Bounds() { if (!VLOG_IS_ON(1)) return; std::vector> stats; - stats.push_back({"BinaryRelationRepository/num_enforced_relations", + stats.push_back({"ConditionalLinear2Bounds/num_enforced_relations", num_enforced_relations_}); - stats.push_back({"BinaryRelationRepository/num_encoded_equivalences", + stats.push_back({"ConditionalLinear2Bounds/num_encoded_equivalences", num_encoded_equivalences_}); shared_stats_->AddStats(stats); } -void BinaryRelationRepository::Add(Literal lit, LinearExpression2 expr, +void ConditionalLinear2Bounds::Add(Literal lit, LinearExpression2 expr, IntegerValue lhs, IntegerValue rhs) { expr.SimpleCanonicalization(); if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) return; @@ -765,7 +765,7 @@ void BinaryRelationRepository::Add(Literal lit, LinearExpression2 expr, {.enforcement = lit, .expr = expr, .lhs = lhs, .rhs = rhs}); } -void BinaryRelationRepository::AddPartialRelation(Literal lit, +void ConditionalLinear2Bounds::AddPartialRelation(Literal lit, IntegerVariable a, IntegerVariable b) { DCHECK_NE(a, kNoIntegerVariable); @@ -774,7 +774,7 @@ void BinaryRelationRepository::AddPartialRelation(Literal lit, Add(lit, LinearExpression2(a, b, 1, 1), 0, 0); } -void BinaryRelationRepository::Build() { +void ConditionalLinear2Bounds::Build() { DCHECK(!is_built_); is_built_ = true; std::vector> literal_key_values; @@ -869,7 +869,7 @@ void BinaryRelationRepository::Build() { bool PropagateLocalBounds( const IntegerTrail& integer_trail, const RootLevelLinear2Bounds& root_level_bounds, - const BinaryRelationRepository& repository, + const ConditionalLinear2Bounds& repository, const ImpliedBounds& implied_bounds, Literal lit, const absl::flat_hash_map& input, absl::flat_hash_map* output) { @@ -1625,7 +1625,7 @@ bool Linear2Bounds::EnqueueLowerOrEqual( } // TODO(user): also check partially-encoded bounds, e.g. (expr <= ub) => l, - // which might be in BinaryRelationRepository as ~l => (-expr <= - ub - 1). + // which might be in ConditionalLinear2Bounds as ~l => (-expr <= - ub - 1). const auto reified_bound = reified_lin2_bounds_->GetEncodedBound(expr, ub); // Already true. diff --git a/ortools/sat/precedences.h b/ortools/sat/precedences.h index 20a785e3c2..e014364c8d 100644 --- a/ortools/sat/precedences.h +++ b/ortools/sat/precedences.h @@ -43,6 +43,45 @@ namespace operations_research { namespace sat { +// This file defines several classes to manage bounds of expressions of the form +// `a*x + b*y <= upper_bound`. The `a*x + b*y` expressions are stored in a +// `LinearExpression2` object, which is often canonicalized and GCD-reduced, +// with the bound being divided by the GCD. +// +// To efficiently store and query such bounds in different contexts, we map each +// `LinearExpression2` expressions for which we have a non-trivial bound +// to a `LinearExpression2Index`, managed by the `Linear2Indices` class. +// +// Most callers of this class should use the `Linear2Bounds` class, which hides +// the complexity of the different ways such bounds are deduced and allow: +// - knowing the bound of a given expression at current level; +// - getting the literals and integer literals that can be used to explain that +// bound; +// - pushing a new bound to an expression. +// +// Other classes in this file dealing with the current level bounds: +// - `EnforcedLinear2Bounds`: Store the best relation of the form +// `{lits} => a*x + b*y <= ub` that is non-trivial at the current level. +// - `Linear2BoundsFromLinear3`: Class that keeps the best upper bound at the +// current level for `a*x + b*y` from all the linear3 relations of the +// form `a*x + b*y + c*z <= d`. +// +// Classes in this file dealing with root-level bounds or implications: +// - `RootLevelLinear2Bounds`: Holds all the non-trivial bounds of the form +// `a*x + b*y <= ub` at root level. +// - `ConditionalLinear2Bounds`: Holds all the relations of the form +// `{lits} => a*x + b*y <= ub` that are defined in the model. +// - `ReifiedLinear2Bounds`: Store all the relations of the form +// `{lits} <=> a*x + b*y <= ub` that are defined in the model. Also stores all +// the relations of the form `a*x + b*y + c*z == d`. +// +// Other classes in this file: +// - `Linear2Watcher`: Allow a propagator to be called back when a bound on a +// given linear2 changed. +// - `TransitivePrecedencesEvaluator`: Computes the transitive closure of a +// DAG of `a*x + b*y <= expr` relations that are stored in +// `RootLevelLinear2Bounds`. + DEFINE_STRONG_INDEX_TYPE(LinearExpression2Index); const LinearExpression2Index kNoLinearExpression2Index(-1); inline LinearExpression2Index NegationOf(LinearExpression2Index i) { @@ -352,8 +391,8 @@ class TransitivePrecedencesEvaluator { std::vector topological_order_; }; -// Stores all the precedences relation of the form "{lits} => a*x + b*y <= ub" -// that we could extract from the model. +// Store the best non-trivial relation of the form "{lits} => a*x + b*y <= ub" +// for which `{lits}` are assigned tp true at the current level. class EnforcedLinear2Bounds : public ReversibleInterface { public: explicit EnforcedLinear2Bounds(Model* model) @@ -510,12 +549,12 @@ class ReifiedLinear2Bounds; // // TODO(user): This is not always needed, find a way to clean this once we // don't need it. -class BinaryRelationRepository { +class ConditionalLinear2Bounds { public: - explicit BinaryRelationRepository(Model* model) + explicit ConditionalLinear2Bounds(Model* model) : reified_linear2_bounds_(model->GetOrCreate()), shared_stats_(model->GetOrCreate()) {} - ~BinaryRelationRepository(); + ~ConditionalLinear2Bounds(); int size() const { return relations_.size(); } @@ -614,7 +653,7 @@ class Linear2BoundsFromLinear3 { best_affine_ub_; }; -// TODO(user): Merge with BinaryRelationRepository. Note that this one provides +// TODO(user): Merge with ConditionalLinear2Bounds. Note that this one provides // different indexing though, so it could be kept separate. class ReifiedLinear2Bounds { public: @@ -754,7 +793,7 @@ class Linear2Bounds : public LazyReasonInterface { bool PropagateLocalBounds( const IntegerTrail& integer_trail, const RootLevelLinear2Bounds& root_level_bounds, - const BinaryRelationRepository& repository, + const ConditionalLinear2Bounds& repository, const ImpliedBounds& implied_bounds, Literal lit, const absl::flat_hash_map& input, absl::flat_hash_map* output); @@ -768,7 +807,7 @@ bool PropagateLocalBounds( class GreaterThanAtLeastOneOfDetector { public: explicit GreaterThanAtLeastOneOfDetector(Model* model) - : repository_(*model->GetOrCreate()), + : repository_(*model->GetOrCreate()), implied_bounds_(*model->GetOrCreate()), integer_trail_(*model->GetOrCreate()), shared_stats_(model->GetOrCreate()) {} @@ -816,7 +855,7 @@ class GreaterThanAtLeastOneOfDetector { IntegerVariable var, absl::Span clause, absl::Span bounds, Model* model); - BinaryRelationRepository& repository_; + ConditionalLinear2Bounds& repository_; const ImpliedBounds& implied_bounds_; IntegerTrail& integer_trail_; SharedStatistics* shared_stats_; diff --git a/ortools/sat/precedences_test.cc b/ortools/sat/precedences_test.cc index 6649c14bf9..2b7d5819f9 100644 --- a/ortools/sat/precedences_test.cc +++ b/ortools/sat/precedences_test.cc @@ -526,14 +526,14 @@ TEST(EnforcedLinear2BoundsTest, CollectPrecedences) { EXPECT_TRUE(p.empty()); } -TEST(BinaryRelationRepositoryTest, Build) { +TEST(ConditionalLinear2BoundsTest, Build) { Model model; const IntegerVariable x = model.Add(NewIntegerVariable(-100, 100)); const IntegerVariable y = model.Add(NewIntegerVariable(-100, 100)); const IntegerVariable z = model.Add(NewIntegerVariable(-100, 100)); const Literal lit_a = Literal(model.Add(NewBooleanVariable()), true); const Literal lit_b = Literal(model.Add(NewBooleanVariable()), true); - BinaryRelationRepository repository(&model); + ConditionalLinear2Bounds repository(&model); RootLevelLinear2Bounds* root_level_bounds = model.GetOrCreate(); repository.Add(lit_a, LinearExpression2(NegationOf(x), y, 1, 1), 2, 8); @@ -597,8 +597,8 @@ TEST(BinaryRelationRepositoryTest, Build) { } std::vector GetRelations(Model& model) { - const BinaryRelationRepository& repository = - *model.GetOrCreate(); + const ConditionalLinear2Bounds& repository = + *model.GetOrCreate(); std::vector relations; for (int i = 0; i < repository.size(); ++i) { Relation r = repository.relation(i); @@ -613,7 +613,7 @@ std::vector GetRelations(Model& model) { return relations; } -TEST(BinaryRelationRepositoryTest, LoadCpModelAddUnaryAndBinaryRelations) { +TEST(ConditionalLinear2BoundsTest, LoadCpModelAddUnaryAndBinaryRelations) { const CpModelProto model_proto = ParseTestProto(R"pb( variables { domain: [ 0, 1 ] } variables { domain: [ 0, 1 ] } @@ -672,7 +672,7 @@ TEST(BinaryRelationRepositoryTest, LoadCpModelAddUnaryAndBinaryRelations) { 0, 10})); } -TEST(BinaryRelationRepositoryTest, +TEST(ConditionalLinear2BoundsTest, LoadCpModelAddsUnaryRelationsEnforcedByTwoLiterals_Case1) { // x in [10, 90] and "a and b => x in [20, 90]". const CpModelProto model_proto = ParseTestProto(R"pb( @@ -711,7 +711,7 @@ TEST(BinaryRelationRepositoryTest, LinearExpression2(a, NegationOf(x), 10, 1), -90, -10})); } -TEST(BinaryRelationRepositoryTest, +TEST(ConditionalLinear2BoundsTest, LoadCpModelAddsUnaryRelationsEnforcedByTwoLiterals_Case2) { // x in [10, 90] and "a and b => x in [10, 80]". const CpModelProto model_proto = ParseTestProto(R"pb( @@ -749,7 +749,7 @@ TEST(BinaryRelationRepositoryTest, 90})); } -TEST(BinaryRelationRepositoryTest, +TEST(ConditionalLinear2BoundsTest, LoadCpModelAddsUnaryRelationsEnforcedByTwoLiterals_Case3) { // x in [10, 90] and "a and not(b) => x in [20, 90]". const CpModelProto model_proto = ParseTestProto(R"pb( @@ -787,7 +787,7 @@ TEST(BinaryRelationRepositoryTest, LinearExpression2(a, NegationOf(x), 10, 1), -90, -10})); } -TEST(BinaryRelationRepositoryTest, +TEST(ConditionalLinear2BoundsTest, LoadCpModelAddsUnaryRelationsEnforcedByTwoLiterals_Case4) { // x in [10, 90] and "a and not(b) => x in [10, 80]". const CpModelProto model_proto = ParseTestProto(R"pb( @@ -825,12 +825,12 @@ TEST(BinaryRelationRepositoryTest, LinearExpression2(a, x, 10, 1), 10, 90})); } -TEST(BinaryRelationRepositoryTest, PropagateLocalBounds_EnforcedRelation) { +TEST(ConditionalLinear2BoundsTest, PropagateLocalBounds_EnforcedRelation) { Model model; const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); const Literal lit_a = Literal(model.Add(NewBooleanVariable()), true); - BinaryRelationRepository repository(&model); + ConditionalLinear2Bounds repository(&model); RootLevelLinear2Bounds* root_level_bounds = model.GetOrCreate(); repository.Add(lit_a, LinearExpression2::Difference(y, x), 2, @@ -849,14 +849,14 @@ TEST(BinaryRelationRepositoryTest, PropagateLocalBounds_EnforcedRelation) { std::make_pair(y, 5))); } -TEST(BinaryRelationRepositoryTest, PropagateLocalBounds_UnenforcedRelation) { +TEST(ConditionalLinear2BoundsTest, PropagateLocalBounds_UnenforcedRelation) { Model model; RootLevelLinear2Bounds* root_level_bounds = model.GetOrCreate(); const IntegerVariable x = model.Add(NewIntegerVariable(-100, 100)); const IntegerVariable y = model.Add(NewIntegerVariable(-100, 100)); const Literal lit_a = Literal(model.Add(NewBooleanVariable()), true); - BinaryRelationRepository repository(&model); + ConditionalLinear2Bounds repository(&model); repository.Add(lit_a, LinearExpression2(x, y, -1, 1), -5, 10); // lit_a => y => x - 5 root_level_bounds->Add(LinearExpression2(x, y, -1, 1), 2, @@ -875,7 +875,7 @@ TEST(BinaryRelationRepositoryTest, PropagateLocalBounds_UnenforcedRelation) { std::make_pair(y, 5))); } -TEST(BinaryRelationRepositoryTest, +TEST(ConditionalLinear2BoundsTest, PropagateLocalBounds_EnforcedBoundSmallerThanLevelZeroBound) { Model model; RootLevelLinear2Bounds* root_level_bounds = @@ -884,7 +884,7 @@ TEST(BinaryRelationRepositoryTest, const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); const Literal lit_a = Literal(model.Add(NewBooleanVariable()), true); const Literal lit_b = Literal(model.Add(NewBooleanVariable()), true); - BinaryRelationRepository repository(&model); + ConditionalLinear2Bounds repository(&model); repository.Add(lit_a, LinearExpression2::Difference(y, x), -5, 10); // lit_a => y => x - 5 repository.Add(lit_b, LinearExpression2::Difference(y, x), 2, @@ -902,13 +902,13 @@ TEST(BinaryRelationRepositoryTest, EXPECT_THAT(output, IsEmpty()); } -TEST(BinaryRelationRepositoryTest, +TEST(ConditionalLinear2BoundsTest, PropagateLocalBounds_EnforcedBoundSmallerThanOutputBound) { Model model; const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); const Literal lit_a = Literal(model.Add(NewBooleanVariable()), true); - BinaryRelationRepository repository(&model); + ConditionalLinear2Bounds repository(&model); RootLevelLinear2Bounds* root_level_bounds = model.GetOrCreate(); repository.Add(lit_a, LinearExpression2::Difference(y, x), 2, @@ -927,12 +927,12 @@ TEST(BinaryRelationRepositoryTest, std::make_pair(y, 8))); } -TEST(BinaryRelationRepositoryTest, PropagateLocalBounds_Infeasible) { +TEST(ConditionalLinear2BoundsTest, PropagateLocalBounds_Infeasible) { Model model; const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); const Literal lit_a = Literal(model.Add(NewBooleanVariable()), true); - BinaryRelationRepository repository(&model); + ConditionalLinear2Bounds repository(&model); RootLevelLinear2Bounds* root_level_bounds = model.GetOrCreate(); repository.Add(lit_a, LinearExpression2::Difference(y, x), 8, @@ -962,7 +962,7 @@ TEST(GreaterThanAtLeastOneOfDetectorTest, AddGreaterThanAtLeastOneOf) { const Literal lit_c = Literal(model.Add(NewBooleanVariable()), true); model.Add(ClauseConstraint({lit_a, lit_b, lit_c})); - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); repository->Add(lit_a, LinearExpression2::Difference(d, a), 2, 1000); // d >= a + 2 repository->Add(lit_b, LinearExpression2::Difference(d, b), -1, @@ -993,7 +993,7 @@ TEST(GreaterThanAtLeastOneOfDetectorTest, const Literal lit_c = Literal(model.Add(NewBooleanVariable()), true); model.Add(ClauseConstraint({lit_a, lit_b, lit_c})); - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); repository->Add(lit_a, LinearExpression2(a, d, -1, 1), 2, 1000); // d >= a + 2 repository->Add(lit_b, LinearExpression2(b, d, -1, 1), -1, diff --git a/ortools/sat/probing.cc b/ortools/sat/probing.cc index 4af125eb9a..f9c9f82111 100644 --- a/ortools/sat/probing.cc +++ b/ortools/sat/probing.cc @@ -14,8 +14,8 @@ #include "ortools/sat/probing.h" #include -#include #include +#include #include #include @@ -33,6 +33,7 @@ #include "ortools/sat/implied_bounds.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" +#include "ortools/sat/lrat_proof_handler.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" @@ -505,64 +506,46 @@ bool LookForTrivialSatSolution(double deterministic_time_limit, Model* model, return sat_solver->FinishPropagation(); } +FailedLiteralProbing::FailedLiteralProbing(Model* model) + : sat_solver_(model->GetOrCreate()), + implication_graph_(model->GetOrCreate()), + time_limit_(model->GetOrCreate()), + trail_(*model->GetOrCreate()), + assignment_(trail_.Assignment()), + clause_manager_(model->GetOrCreate()), + clause_id_generator_(model->GetOrCreate()), + lrat_proof_handler_(model->Mutable()), + id_(implication_graph_->PropagatorId()), + clause_propagator_id_(clause_manager_->PropagatorId()), + num_variables_(sat_solver_->NumVariables()) {} + // TODO(user): This might be broken if backtrack() propagates and go further // back. Investigate and fix any issue. -bool FailedLiteralProbingRound(ProbingOptions options, Model* model) { +bool FailedLiteralProbing::DoOneRound(ProbingOptions options) { WallTimer wall_timer; wall_timer.Start(); + options.log_info |= VLOG_IS_ON(1); // Reset the solver in case it was already used. - auto* sat_solver = model->GetOrCreate(); - if (!sat_solver->ResetToLevelZero()) return false; + if (!sat_solver_->ResetToLevelZero()) return false; // When called from Inprocessing, the implication graph should already be a // DAG, so these two calls should return right away. But we do need them to // get the topological order if this is used in isolation. - auto* implication_graph = model->GetOrCreate(); - if (!implication_graph->DetectEquivalences()) return false; - if (!sat_solver->FinishPropagation()) return false; + if (!implication_graph_->DetectEquivalences()) return false; + if (!sat_solver_->FinishPropagation()) return false; - auto* time_limit = model->GetOrCreate(); - const int initial_num_fixed = sat_solver->LiteralTrail().Index(); + const int initial_num_fixed = sat_solver_->LiteralTrail().Index(); const double initial_deterministic_time = - time_limit->GetElapsedDeterministicTime(); + time_limit_->GetElapsedDeterministicTime(); const double limit = initial_deterministic_time + options.deterministic_limit; - int num_variables = sat_solver->NumVariables(); - SparseBitset processed(LiteralIndex(2 * num_variables)); + processed_.ClearAndResize(LiteralIndex(2 * num_variables_)); - int64_t num_probed = 0; - int64_t num_explicit_fix = 0; - int64_t num_conflicts = 0; - int64_t num_new_binary = 0; - int64_t num_subsumed = 0; + if (!options.use_queue) starts_.resize(2 * num_variables_, 0); - const auto& trail = *(model->Get()); - const auto& assignment = trail.Assignment(); - auto* clause_manager = model->GetOrCreate(); - const int id = implication_graph->PropagatorId(); - const int clause_id = clause_manager->PropagatorId(); - - // This is only needed when options.use_queue is true. - struct SavedNextLiteral { - LiteralIndex literal_index; // kNoLiteralIndex if we need to backtrack. - int rank; // Cached position_in_order, we prefer lower positions. - - bool operator<(const SavedNextLiteral& o) const { return rank < o.rank; } - }; - std::vector queue; - util_intops::StrongVector position_in_order; - - // This is only needed when options use_queue is false; - util_intops::StrongVector starts; - if (!options.use_queue) starts.resize(2 * num_variables, 0); - - // We delay fixing of already assigned literal once we go back to level - // zero. - std::vector to_fix; - - // Depending on the options. we do not use the same order. + // Depending on the options, we do not use the same order. // With tree look, it is better to start with "leaf" first since we try // to reuse propagation as much as possible. This is also interesting to // do when extracting binary clauses since we will need to propagate @@ -573,389 +556,479 @@ bool FailedLiteralProbingRound(ProbingOptions options, Model* model) { // clauses, it is better to just probe the root of the binary implication // graph. This is exactly what happen when we probe using the topological // order. - int order_index(0); - std::vector probing_order = - implication_graph->ReverseTopologicalOrder(); + probing_order_ = implication_graph_->ReverseTopologicalOrder(); if (!options.use_tree_look && !options.extract_binary_clauses) { - std::reverse(probing_order.begin(), probing_order.end()); + std::reverse(probing_order_.begin(), probing_order_.end()); } // We only use this for the queue version. if (options.use_queue) { - position_in_order.assign(2 * num_variables, -1); - for (int i = 0; i < probing_order.size(); ++i) { - position_in_order[probing_order[i]] = i; + position_in_order_.assign(2 * num_variables_, -1); + for (int i = 0; i < probing_order_.size(); ++i) { + position_in_order_[probing_order_[i]] = i; } } - while (!time_limit->LimitReached() && - time_limit->GetElapsedDeterministicTime() <= limit) { + binary_clause_extracted_.assign(trail_.Index(), false); + subsumed_clauses_.clear(); + + while (!time_limit_->LimitReached() && + time_limit_->GetElapsedDeterministicTime() <= limit) { // We only enqueue literal at level zero if we don't use "tree look". if (!options.use_tree_look) { - if (!sat_solver->BacktrackAndPropagateReimplications(0)) return false; + if (!sat_solver_->BacktrackAndPropagateReimplications(0)) return false; } - // Probing works by taking a series of decisions, and by analyzing what they - // propagate. For efficiency, we only take a new decision d' if it directly - // implies the last one d. By doing this we know that d' directly or - // indirectly implies all the previous decisions, which then propagate all - // the literals on the trail up to and excluding d'. The first step is to - // find the next_decision d', which can be done in different ways depending - // on the options. + // Probing works by taking a series of decisions, and by analyzing what + // they propagate. For efficiency, we only take a new decision d' if it + // directly implies the last one d. By doing this we know that d' directly + // or indirectly implies all the previous decisions, which then propagate + // all the literals on the trail up to and excluding d'. The first step is + // to find the next_decision d', which can be done in different ways + // depending on the options. LiteralIndex next_decision = kNoLiteralIndex; - - // A first option is to use an unassigned literal which implies the last - // decision and which comes first in the probing order (which itself can be - // the topological order of the implication graph, or the reverse). - if (sat_solver->CurrentDecisionLevel() > 0 && options.use_queue) { - // TODO(user): Instead of minimizing index in topo order (which might be - // nice for binary extraction), we could try to maximize reusability in - // some way. - const Literal last_decision = - sat_solver->Decisions()[sat_solver->CurrentDecisionLevel() - 1] - .literal; - // If l => last_decision, then not(last_decision) => not(l). We can thus - // find the candidates for the next decision by looking at all the - // implications of not(last_decision). - const absl::Span list = - implication_graph->Implications(last_decision.Negated()); - const int saved_queue_size = queue.size(); - for (const Literal l : list) { - const Literal candidate = l.Negated(); - if (processed[candidate]) continue; - if (position_in_order[candidate] == -1) continue; - if (assignment.LiteralIsAssigned(candidate)) { - // candidate => last_decision => all previous decisions, which then - // propagate not(candidate). Hence candidate must be false. - if (assignment.LiteralIsFalse(candidate)) { - to_fix.push_back(Literal(candidate.Negated())); - } - continue; - } - queue.push_back({candidate.Index(), -position_in_order[candidate]}); - } - // Sort all the candidates. - std::sort(queue.begin() + saved_queue_size, queue.end()); - - // Set next_decision to the first unassigned candidate. - while (!queue.empty()) { - const LiteralIndex index = queue.back().literal_index; - queue.pop_back(); - if (index == kNoLiteralIndex) { - // This is a backtrack marker, go back one level. - CHECK_GT(sat_solver->CurrentDecisionLevel(), 0); - if (!sat_solver->BacktrackAndPropagateReimplications( - sat_solver->CurrentDecisionLevel() - 1)) - return false; - continue; - } - const Literal candidate(index); - if (processed[candidate]) continue; - if (assignment.LiteralIsAssigned(candidate)) { - // candidate => last_decision => all previous decisions, which then - // propagate not(candidate). Hence candidate must be false. - if (assignment.LiteralIsFalse(candidate)) { - to_fix.push_back(Literal(candidate.Negated())); - } - continue; - } - next_decision = candidate.Index(); - break; - } + if (sat_solver_->CurrentDecisionLevel() > 0 && options.use_queue) { + if (!ComputeNextDecisionInOrder(next_decision)) return false; } - // A second option to find the next decision is to use the first unassigned - // literal we find which implies the last decision, in no particular - // order. - if (sat_solver->CurrentDecisionLevel() > 0 && + if (sat_solver_->CurrentDecisionLevel() > 0 && next_decision == kNoLiteralIndex) { - const int level = sat_solver->CurrentDecisionLevel(); - const Literal last_decision = sat_solver->Decisions()[level - 1].literal; - const absl::Span list = - implication_graph->Implications(last_decision.Negated()); - - // If l => last_decision, then not(last_decision) => not(l). We can thus - // find the candidates for the next decision by looking at all the - // implications of not(last_decision). - int j = starts[last_decision.NegatedIndex()]; - for (int i = 0; i < list.size(); ++i, ++j) { - j %= list.size(); - const Literal candidate = Literal(list[j]).Negated(); - if (processed[candidate]) continue; - if (assignment.LiteralIsFalse(candidate)) { - // candidate => last_decision => all previous decisions, which then - // propagate not(candidate). Hence candidate must be false. - to_fix.push_back(Literal(candidate.Negated())); - continue; - } - // This shouldn't happen if extract_binary_clauses is false. - // We have an equivalence. - if (assignment.LiteralIsTrue(candidate)) continue; - next_decision = candidate.Index(); - break; - } - starts[last_decision.NegatedIndex()] = j; - if (next_decision == kNoLiteralIndex) { - if (!sat_solver->BacktrackAndPropagateReimplications(level - 1)) { - return false; - } - continue; - } + if (!GetNextDecisionInRandomOrder(next_decision)) return false; + if (next_decision == kNoLiteralIndex) continue; } - // If there is no last decision we can use any literal as first decision. - // We use the first unassigned literal in probing_order. - if (sat_solver->CurrentDecisionLevel() == 0) { - // Fix any delayed fixed literal. - for (const Literal literal : to_fix) { - if (!assignment.LiteralIsTrue(literal)) { - ++num_explicit_fix; - if (!sat_solver->AddUnitClause(literal)) return false; - } - } - to_fix.clear(); - if (!sat_solver->FinishPropagation()) return false; - - // Probe an unexplored node. - for (; order_index < probing_order.size(); ++order_index) { - const Literal candidate(probing_order[order_index]); - if (processed[candidate]) continue; - if (assignment.LiteralIsAssigned(candidate)) continue; - next_decision = candidate.Index(); - break; - } - + if (sat_solver_->CurrentDecisionLevel() == 0) { + if (!GetFirstDecision(next_decision)) return false; // The pass is finished. if (next_decision == kNoLiteralIndex) break; } // We now have a next decision, enqueue it and propagate until fix point. - ++num_probed; - processed.Set(next_decision); - CHECK_NE(next_decision, kNoLiteralIndex); - queue.push_back({kNoLiteralIndex, 0}); // Backtrack marker. - const int level = sat_solver->CurrentDecisionLevel(); - const int first_new_trail_index = - sat_solver->EnqueueDecisionAndBackjumpOnConflict( - Literal(next_decision)); - - // This is tricky, depending on the parameters, and for integer problem, - // EnqueueDecisionAndBackjumpOnConflict() might create new Booleans. - if (sat_solver->NumVariables() > num_variables) { - num_variables = sat_solver->NumVariables(); - processed.Resize(LiteralIndex(2 * num_variables)); - if (!options.use_queue) { - starts.resize(2 * num_variables, 0); - } else { - position_in_order.resize(2 * num_variables, -1); - } - } - - const int new_level = sat_solver->CurrentDecisionLevel(); - sat_solver->AdvanceDeterministicTime(time_limit); - if (sat_solver->ModelIsUnsat()) return false; - if (new_level <= level) { - ++num_conflicts; - - // Sync the queue with the new level. - if (options.use_queue) { - if (new_level == 0) { - queue.clear(); - } else { - int queue_level = level + 1; - while (queue_level > new_level) { - CHECK(!queue.empty()); - if (queue.back().literal_index == kNoLiteralIndex) --queue_level; - queue.pop_back(); - } - } - } - - // Fix `next_decision` to `false` if not already done. - // - // Even if we fixed something at level zero, next_decision might not be - // fixed! But we can fix it. It can happen because when we propagate - // with clauses, we might have `a => b` but not `not(b) => not(a)`. Like - // `a => b` and clause `(not(a), not(b), c)`, propagating `a` will set - // `c`, but propagating `not(c)` will not do anything. - // - // We "delay" the fixing if we are not at level zero so that we can - // still reuse the current propagation work via tree look. - // - // TODO(user): Can we be smarter here? Maybe we can still fix the - // literal without going back to level zero by simply enqueuing it with - // no reason? it will be backtracked over, but we will still lazily fix - // it later. - if (sat_solver->CurrentDecisionLevel() != 0 || - assignment.LiteralIsFalse(Literal(next_decision))) { - to_fix.push_back(Literal(next_decision).Negated()); - } + int first_new_trail_index; + if (!EnqueueDecisionAndBackjumpOnConflict(next_decision, options.use_queue, + first_new_trail_index)) { + return false; } // Inspect the newly propagated literals. Depending on the options, try to // extract binary clauses via hyper binary resolution and/or mark the // literals on the trail so that they do not need to be probed later. + const int new_level = sat_solver_->CurrentDecisionLevel(); if (new_level == 0) continue; const Literal last_decision = - sat_solver->Decisions()[new_level - 1].literal; - int num_new_subsumed = 0; - for (int i = first_new_trail_index; i < trail.Index(); ++i) { - const Literal l = trail[i]; + sat_solver_->Decisions()[new_level - 1].literal; + for (int i = first_new_trail_index; i < trail_.Index(); ++i) { + const Literal l = trail_[i]; if (l == last_decision) continue; - // If we can extract a binary clause that subsume the reason clause, we - // do add the binary and remove the subsumed clause. - // - // TODO(user): We could be slightly more generic and subsume some - // clauses that do not contains last_decision.Negated(). bool subsumed = false; - if (options.subsume_with_binary_clause && - trail.AssignmentType(l.Variable()) == clause_id) { - for (const Literal lit : trail.Reason(l.Variable())) { - if (lit == last_decision.Negated()) { - subsumed = true; - break; - } - } - if (subsumed) { - ++num_new_subsumed; - ++num_new_binary; - CHECK(implication_graph->AddBinaryClause(last_decision.Negated(), l)); - const int trail_index = trail.Info(l.Variable()).trail_index; - - int test = 0; - for (const Literal lit : - clause_manager->ReasonClause(trail_index)->AsSpan()) { - if (lit == l) ++test; - if (lit == last_decision.Negated()) ++test; - } - CHECK_EQ(test, 2); - clause_manager->LazyDetach(clause_manager->ReasonClause(trail_index)); - - // We need to change the reason now that the clause is cleared. - implication_graph->ChangeReason(trail_index, last_decision); - } + if (options.subsume_with_binary_clause) { + subsumed = MaybeSubsumeWithBinaryClause(last_decision, l); } - if (options.extract_binary_clauses) { - // Anything not propagated by the BinaryImplicationGraph is a "new" - // binary clause. This is because the BinaryImplicationGraph has the - // highest priority of all propagators. - // - // Note(user): This is not 100% true, since when we launch the clause - // propagation for one literal we do finish it before calling again - // the binary propagation. - // - // TODO(user): Think about trying to extract clause that will not - // get removed by transitive reduction later. If we can both extract - // a => c and b => c , ideally we don't want to extract a => c first - // if we already know that a => b. - // - // TODO(user): Similar to previous point, we could find the LCA - // of all literals in the reason for this propagation. And use this - // as a reason for later hyber binary resolution. Like we do when - // this clause subsume the reason. - if (!subsumed && trail.AssignmentType(l.Variable()) != id) { - ++num_new_binary; - CHECK(implication_graph->AddBinaryClause(last_decision.Negated(), l)); + if (!subsumed) { + MaybeExtractBinaryClause(last_decision, l); } } else { // If we don't extract binary, we don't need to explore any of - // these literal until more variables are fixed. - processed.Set(l.Index()); + // these literals until more variables are fixed. + processed_.Set(l.Index()); } } - // Inspect the watcher list for last_decision, If we have a blocking - // literal at true (implied by last decision), then we have subsumptions. - // - // The intuition behind this is that if a binary clause (a,b) subsume a - // clause, and we watch a.Negated() for this clause with a blocking - // literal b, then this watch entry will never change because we always - // propagate binary clauses first and the blocking literal will always be - // true. So after many propagations, we hope to have such configuration - // which is quite cheap to test here. if (options.subsume_with_binary_clause) { - // Tricky: If we have many "decision" and we do not extract the binary - // clause, then the fact that last_decision => literal might not be - // currently encoded in the problem clause, so if we use that relation - // to subsume, we should make sure it is added. - // - // Note that it is okay to add duplicate binary clause, we will clean that - // later. - const bool always_add_binary = sat_solver->CurrentDecisionLevel() > 1 && - !options.extract_binary_clauses; + SubsumeWithBinaryClauseUsingBlockingLiteral(last_decision); + } + } - for (const auto& w : - clause_manager->WatcherListOnFalse(last_decision.Negated())) { - if (assignment.LiteralIsTrue(w.blocking_literal)) { - if (w.clause->IsRemoved()) continue; - CHECK_NE(w.blocking_literal, last_decision.Negated()); + if (!sat_solver_->ResetToLevelZero()) return false; + if (!ProcessLiteralsToFix()) return false; + if (!subsumed_clauses_.empty()) { + for (SatClause* clause : subsumed_clauses_) { + clause_manager_->LazyDetach(clause); + } + clause_manager_->CleanUpWatchers(); + } - // Add the binary clause if needed. Note that we change the reason - // to a binary one so that we never add the same clause twice. - // - // Tricky: while last_decision would be a valid reason, we need a - // reason that was assigned before this literal, so we use the - // decision at the level where this literal was assigned which is an - // even better reason. Maybe it is just better to change all the - // reason above to a binary one so we don't have an issue here. - if (always_add_binary || - trail.AssignmentType(w.blocking_literal.Variable()) != id) { - // If the variable was true at level zero, there is no point - // adding the clause. - const auto& info = trail.Info(w.blocking_literal.Variable()); - if (info.level > 0) { - ++num_new_binary; - CHECK(implication_graph->AddBinaryClause(last_decision.Negated(), - w.blocking_literal)); + // Display stats. + const int num_fixed = sat_solver_->LiteralTrail().Index(); + const int num_newly_fixed = num_fixed - initial_num_fixed; + const double time_diff = + time_limit_->GetElapsedDeterministicTime() - initial_deterministic_time; + const bool limit_reached = time_limit_->LimitReached() || + time_limit_->GetElapsedDeterministicTime() > limit; + LOG_IF(INFO, options.log_info) + << "Probing. " + << " num_probed: " << num_probed_ << "/" << probing_order_.size() + << " num_fixed: +" << num_newly_fixed << " (" << num_fixed << "/" + << num_variables_ << ")" + << " explicit_fix:" << num_explicit_fix_ + << " num_conflicts:" << num_conflicts_ + << " new_binary_clauses: " << num_new_binary_ + << " subsumed: " << num_subsumed_ << " dtime: " << time_diff + << " wtime: " << wall_timer.Get() << (limit_reached ? " (Aborted)" : ""); + return sat_solver_->FinishPropagation(); +} - const Literal d = sat_solver->Decisions()[info.level - 1].literal; - if (d != w.blocking_literal) { - implication_graph->ChangeReason(info.trail_index, d); - } - } - } +// Sets `next_decision` to the unassigned literal which implies the last +// decision and which comes first in the probing order (which itself can be +// the topological order of the implication graph, or the reverse). +bool FailedLiteralProbing::ComputeNextDecisionInOrder( + LiteralIndex& next_decision) { + // TODO(user): Instead of minimizing index in topo order (which might be + // nice for binary extraction), we could try to maximize reusability in + // some way. + const Literal last_decision = + sat_solver_->Decisions()[sat_solver_->CurrentDecisionLevel() - 1].literal; + // If l => last_decision, then not(last_decision) => not(l). We can thus + // find the candidates for the next decision by looking at all the + // implications of not(last_decision). + const absl::Span list = + implication_graph_->Implications(last_decision.Negated()); + const int saved_queue_size = queue_.size(); + for (const Literal l : list) { + const Literal candidate = l.Negated(); + if (processed_[candidate]) continue; + if (position_in_order_[candidate] == -1) continue; + if (assignment_.LiteralIsAssigned(candidate)) { + // candidate => last_decision => all previous decisions, which then + // propagate not(candidate). Hence candidate must be false. + if (assignment_.LiteralIsFalse(candidate)) { + AddFailedLiteralToFix(candidate); + } + continue; + } + queue_.push_back({candidate.Index(), -position_in_order_[candidate]}); + } + // Sort all the candidates. + std::sort(queue_.begin() + saved_queue_size, queue_.end()); - ++num_new_subsumed; - clause_manager->LazyDetach(w.clause); + // Set next_decision to the first unassigned candidate. + while (!queue_.empty()) { + const LiteralIndex index = queue_.back().literal_index; + queue_.pop_back(); + if (index == kNoLiteralIndex) { + // This is a backtrack marker, go back one level. + CHECK_GT(sat_solver_->CurrentDecisionLevel(), 0); + if (!sat_solver_->BacktrackAndPropagateReimplications( + sat_solver_->CurrentDecisionLevel() - 1)) + return false; + continue; + } + const Literal candidate(index); + if (processed_[candidate]) continue; + if (assignment_.LiteralIsAssigned(candidate)) { + // candidate => last_decision => all previous decisions, which then + // propagate not(candidate). Hence candidate must be false. + if (assignment_.LiteralIsFalse(candidate)) { + AddFailedLiteralToFix(candidate); + } + continue; + } + next_decision = candidate.Index(); + break; + } + return true; +} + +// Sets `next_decision` to the first unassigned literal we find which implies +// the last decision, in no particular order. +bool FailedLiteralProbing::GetNextDecisionInRandomOrder( + LiteralIndex& next_decision) { + const int level = sat_solver_->CurrentDecisionLevel(); + const Literal last_decision = sat_solver_->Decisions()[level - 1].literal; + const absl::Span list = + implication_graph_->Implications(last_decision.Negated()); + + // If l => last_decision, then not(last_decision) => not(l). We can thus + // find the candidates for the next decision by looking at all the + // implications of not(last_decision). + int j = starts_[last_decision.NegatedIndex()]; + for (int i = 0; i < list.size(); ++i, ++j) { + j %= list.size(); + const Literal candidate = Literal(list[j]).Negated(); + if (processed_[candidate]) continue; + if (assignment_.LiteralIsFalse(candidate)) { + // candidate => last_decision => all previous decisions, which then + // propagate not(candidate). Hence candidate must be false. + AddFailedLiteralToFix(candidate); + continue; + } + // This shouldn't happen if extract_binary_clauses is false. + // We have an equivalence. + if (assignment_.LiteralIsTrue(candidate)) continue; + next_decision = candidate.Index(); + break; + } + starts_[last_decision.NegatedIndex()] = j; + if (next_decision == kNoLiteralIndex) { + if (!sat_solver_->BacktrackAndPropagateReimplications(level - 1)) { + return false; + } + } + return true; +} + +// Sets `next_decision` to the first unassigned literal in probing_order (if +// there is no last decision we can use any literal as first decision). +bool FailedLiteralProbing::GetFirstDecision(LiteralIndex& next_decision) { + // Fix any delayed fixed literal. + if (!ProcessLiteralsToFix()) return false; + + // Probe an unexplored node. + for (; order_index_ < probing_order_.size(); ++order_index_) { + const Literal candidate(probing_order_[order_index_]); + if (processed_[candidate]) continue; + if (assignment_.LiteralIsAssigned(candidate)) continue; + next_decision = candidate.Index(); + break; + } + return true; +} + +bool FailedLiteralProbing::EnqueueDecisionAndBackjumpOnConflict( + LiteralIndex next_decision, bool use_queue, int& first_new_trail_index) { + ++num_probed_; + processed_.Set(next_decision); + CHECK_NE(next_decision, kNoLiteralIndex); + queue_.push_back({kNoLiteralIndex, 0}); // Backtrack marker. + const int level = sat_solver_->CurrentDecisionLevel(); + + // The unit clause ID that fixes next_decision to false, if it causes a + // conflict. + ClauseId fixed_decision_unit_id = kNoClauseId; + auto conflict_callback = [&](ClauseId conflict_id, + absl::Span conflict_clause) { + if (fixed_decision_unit_id != kNoClauseId) return; + ComputeDecisionImplicationIds(Literal(next_decision), level, + /*end_level=*/0, tmp_clause_ids_); + sat_solver_->AppendClausesFixing(conflict_clause, &tmp_clause_ids_); + tmp_clause_ids_.push_back(conflict_id); + fixed_decision_unit_id = clause_id_generator_->GetNextId(); + lrat_proof_handler_->AddInferredClause(fixed_decision_unit_id, + {Literal(next_decision).Negated()}, + tmp_clause_ids_); + }; + first_new_trail_index = sat_solver_->EnqueueDecisionAndBackjumpOnConflict( + Literal(next_decision), + lrat_proof_handler_ != nullptr + ? conflict_callback + : std::optional()); + + if (first_new_trail_index == kUnsatTrailIndex) return false; + binary_clause_extracted_.resize(first_new_trail_index); + binary_clause_extracted_.resize(trail_.Index(), false); + + // This is tricky, depending on the parameters, and for integer problem, + // EnqueueDecisionAndBackjumpOnConflict() might create new Booleans. + if (sat_solver_->NumVariables() > num_variables_) { + num_variables_ = sat_solver_->NumVariables(); + processed_.Resize(LiteralIndex(2 * num_variables_)); + if (!use_queue) { + starts_.resize(2 * num_variables_, 0); + } else { + position_in_order_.resize(2 * num_variables_, -1); + } + } + + const int new_level = sat_solver_->CurrentDecisionLevel(); + sat_solver_->AdvanceDeterministicTime(time_limit_); + if (sat_solver_->ModelIsUnsat()) return false; + if (new_level <= level) { + ++num_conflicts_; + + // Sync the queue with the new level. + if (use_queue) { + if (new_level == 0) { + queue_.clear(); + } else { + int queue_level = level + 1; + while (queue_level > new_level) { + CHECK(!queue_.empty()); + if (queue_.back().literal_index == kNoLiteralIndex) --queue_level; + queue_.pop_back(); } } } - if (num_new_subsumed > 0) { - // TODO(user): We might just want to do that even more lazily by - // checking for detached clause while propagating here? and do a big - // cleanup at the end. - clause_manager->CleanUpWatchers(); - num_subsumed += num_new_subsumed; + // Fix `next_decision` to `false` if not already done. + // + // Even if we fixed something at level zero, next_decision might not be + // fixed! But we can fix it. It can happen because when we propagate + // with clauses, we might have `a => b` but not `not(b) => not(a)`. Like + // `a => b` and clause `(not(a), not(b), c)`, propagating `a` will set + // `c`, but propagating `not(c)` will not do anything. + // + // We "delay" the fixing if we are not at level zero so that we can + // still reuse the current propagation work via tree look. + // + // TODO(user): Can we be smarter here? Maybe we can still fix the + // literal without going back to level zero by simply enqueuing it with + // no reason? it will be backtracked over, but we will still lazily fix + // it later. + if (sat_solver_->CurrentDecisionLevel() != 0 || + assignment_.LiteralIsFalse(Literal(next_decision))) { + to_fix_.push_back(Literal(next_decision).Negated()); + if (lrat_proof_handler_ != nullptr) { + to_fix_unit_id_.push_back(fixed_decision_unit_id); + } } } + return true; +} - if (!sat_solver->ResetToLevelZero()) return false; - for (const Literal literal : to_fix) { - ++num_explicit_fix; - if (!sat_solver->AddUnitClause(literal)) return false; +// If we can extract a binary clause that subsume the reason clause, we do add +// the binary and remove the subsumed clause. +// +// TODO(user): We could be slightly more generic and subsume some clauses that +// do not contain last_decision.Negated(). +bool FailedLiteralProbing::MaybeSubsumeWithBinaryClause( + const Literal last_decision, const Literal l) { + const int trail_index = trail_.Info(l.Variable()).trail_index; + if (binary_clause_extracted_[trail_index] || + trail_.AssignmentType(l.Variable()) != clause_propagator_id_) { + return false; } - to_fix.clear(); - if (!sat_solver->FinishPropagation()) return false; + bool subsumed = false; + for (const Literal lit : trail_.Reason(l.Variable())) { + if (lit == last_decision.Negated()) { + subsumed = true; + break; + } + } + if (!subsumed) return false; + ++num_subsumed_; + ++num_new_binary_; + // TODO(user): add LRAT proof. + CHECK(implication_graph_->AddBinaryClause(last_decision.Negated(), l)); + binary_clause_extracted_[trail_index] = true; - // Display stats. - const int num_fixed = sat_solver->LiteralTrail().Index(); - const int num_newly_fixed = num_fixed - initial_num_fixed; - const double time_diff = - time_limit->GetElapsedDeterministicTime() - initial_deterministic_time; - const bool limit_reached = time_limit->LimitReached() || - time_limit->GetElapsedDeterministicTime() > limit; - LOG_IF(INFO, options.log_info) - << "Probing. " - << " num_probed: " << num_probed << "/" << probing_order.size() - << " num_fixed: +" << num_newly_fixed << " (" << num_fixed << "/" - << num_variables << ")" - << " explicit_fix:" << num_explicit_fix - << " num_conflicts:" << num_conflicts - << " new_binary_clauses: " << num_new_binary - << " subsumed: " << num_subsumed << " dtime: " << time_diff - << " wtime: " << wall_timer.Get() << (limit_reached ? " (Aborted)" : ""); - return sat_solver->FinishPropagation(); + int test = 0; + for (const Literal lit : + clause_manager_->ReasonClause(trail_index)->AsSpan()) { + if (lit == l) ++test; + if (lit == last_decision.Negated()) ++test; + } + CHECK_EQ(test, 2); + subsumed_clauses_.push_back(clause_manager_->ReasonClause(trail_index)); + return true; +} + +void FailedLiteralProbing::MaybeExtractBinaryClause(const Literal last_decision, + const Literal l) { + // Anything not propagated by the BinaryImplicationGraph is a "new" + // binary clause. This is because the BinaryImplicationGraph has the + // highest priority of all propagators. + // + // Note(user): This is not 100% true, since when we launch the clause + // propagation for one literal we do finish it before calling again + // the binary propagation. + // + // TODO(user): Think about trying to extract clause that will not + // get removed by transitive reduction later. If we can both extract + // a => c and b => c , ideally we don't want to extract a => c first + // if we already know that a => b. + // + // TODO(user): Similar to previous point, we could find the LCA + // of all literals in the reason for this propagation. And use this + // as a reason for later hyber binary resolution. Like we do when + // this clause subsumes the reason. + if (trail_.AssignmentType(l.Variable()) != id_) { + ++num_new_binary_; + // TODO(user): add LRAT proof. + CHECK(implication_graph_->AddBinaryClause(last_decision.Negated(), l)); + } +} + +// Inspect the watcher list for last_decision, If we have a blocking +// literal at true (implied by last decision), then we have subsumptions. +// +// The intuition behind this is that if a binary clause (a,b) subsumes a +// clause, and we watch a.Negated() for this clause with a blocking +// literal b, then this watch entry will never change because we always +// propagate binary clauses first and the blocking literal will always be +// true. So after many propagations, we hope to have such configuration +// which is quite cheap to test here. +void FailedLiteralProbing::SubsumeWithBinaryClauseUsingBlockingLiteral( + const Literal last_decision) { + for (const auto& w : + clause_manager_->WatcherListOnFalse(last_decision.Negated())) { + if (assignment_.LiteralIsTrue(w.blocking_literal)) { + if (w.clause->IsRemoved()) continue; + const auto& info = trail_.Info(w.blocking_literal.Variable()); + CHECK_NE(w.blocking_literal, last_decision.Negated()); + + // Add the binary clause if not already done. If the variable was true at + // level zero, there is no point adding the clause. + if (info.level > 0 && !binary_clause_extracted_[info.trail_index]) { + ++num_new_binary_; + // TODO(user): add LRAT proof. + CHECK(implication_graph_->AddBinaryClause(last_decision.Negated(), + w.blocking_literal)); + binary_clause_extracted_[info.trail_index] = true; + } + + ++num_subsumed_; + subsumed_clauses_.push_back(w.clause); + } + } +} + +// Sets `clause_ids` to the IDs of the implications from `l` to the decisions +// from `start_level` to `end_level` (included, going backwards). +void FailedLiteralProbing::ComputeDecisionImplicationIds( + Literal l, int start_level, int end_level, + std::vector& clause_ids) { + clause_ids.clear(); + int level = start_level; + while (level != end_level) { + const Literal decision = sat_solver_->Decisions()[--level].literal; + clause_ids.push_back( + implication_graph_->GetClauseId(l.Negated(), decision)); + l = decision; + } +} + +// Adds 'not(literal)' to `to_fix_`, assuming that 'literal' directly implies +// the current decision, which itself implies all the previous decisions, with +// some of them propagating 'not(literal)'. +void FailedLiteralProbing::AddFailedLiteralToFix(const Literal literal) { + // TODO(user): skip if literal.Negated() is already in to_fix? + to_fix_.push_back(literal.Negated()); + if (lrat_proof_handler_ == nullptr) return; + + ClauseId unit_id = trail_.GetUnitClauseId(literal.Variable()); + if (unit_id == kNoClauseId) { + // TODO(user): it should be possible to stop at the level of the + // first decision necessary to fix 'literal'. A first attempt at this + // failed. + ComputeDecisionImplicationIds(literal, sat_solver_->CurrentDecisionLevel(), + /*end_level=*/0, tmp_clause_ids_); + sat_solver_->AppendClausesFixing({literal.Negated()}, &tmp_clause_ids_); + unit_id = clause_id_generator_->GetNextId(); + lrat_proof_handler_->AddInferredClause(unit_id, {literal.Negated()}, + tmp_clause_ids_); + } + to_fix_unit_id_.push_back({unit_id}); +} + +// Fixes all the literals in to_fix_, and finish propagation. +bool FailedLiteralProbing::ProcessLiteralsToFix() { + for (int i = 0; i < to_fix_.size(); ++i) { + const Literal literal = to_fix_[i]; + if (assignment_.LiteralIsTrue(literal)) continue; + ++num_explicit_fix_; + const ClauseId clause_id = + lrat_proof_handler_ != nullptr ? to_fix_unit_id_[i] : kNoClauseId; + if (!clause_manager_->InprocessingAddUnitClause(clause_id, literal)) { + return false; + } + } + to_fix_.clear(); + to_fix_unit_id_.clear(); + return sat_solver_->FinishPropagation(); +} + +bool FailedLiteralProbingRound(ProbingOptions options, Model* model) { + return FailedLiteralProbing(model).DoOneRound(options); } } // namespace sat diff --git a/ortools/sat/probing.h b/ortools/sat/probing.h index f93ea5b1c3..b7a37d2605 100644 --- a/ortools/sat/probing.h +++ b/ortools/sat/probing.h @@ -14,6 +14,7 @@ #ifndef ORTOOLS_SAT_PROBING_H_ #define ORTOOLS_SAT_PROBING_H_ +#include #include #include #include @@ -24,10 +25,12 @@ #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "ortools/base/strong_vector.h" #include "ortools/sat/clause.h" #include "ortools/sat/implied_bounds.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" +#include "ortools/sat/lrat_proof_handler.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_solver.h" @@ -256,6 +259,117 @@ struct ProbingOptions { // // It will add any detected binary clause (via hyper binary resolution) to // the implication graph. See the option comments for more details. +class FailedLiteralProbing { + public: + explicit FailedLiteralProbing(Model* model); + + bool DoOneRound(ProbingOptions options); + + private: + struct SavedNextLiteral { + LiteralIndex literal_index; // kNoLiteralIndex if we need to backtrack. + int rank; // Cached position_in_order, we prefer lower positions. + + bool operator<(const SavedNextLiteral& o) const { return rank < o.rank; } + }; + + // Sets `next_decision` to the unassigned literal which implies the last + // decision and which comes first in the probing order (which itself can be + // the topological order of the implication graph, or the reverse). + bool ComputeNextDecisionInOrder(LiteralIndex& next_decision); + + // Sets `next_decision` to the first unassigned literal we find which implies + // the last decision, in no particular order. + bool GetNextDecisionInRandomOrder(LiteralIndex& next_decision); + + // Sets `next_decision` to the first unassigned literal in probing_order (if + // there is no last decision we can use any literal as first decision). + bool GetFirstDecision(LiteralIndex& next_decision); + + // Enqueues `next_decision`. Backjumps and sets `next_decision` to false in + // case of conflict. Returns false if the problem was proved UNSAT. + bool EnqueueDecisionAndBackjumpOnConflict(LiteralIndex next_decision, + bool use_queue, + int& first_new_trail_index); + + // If we can extract a binary clause that subsume the reason clause, we do add + // the binary and remove the subsumed clause. + // + // TODO(user): We could be slightly more generic and subsume some clauses that + // do not contain last_decision.Negated(). + bool MaybeSubsumeWithBinaryClause(Literal last_decision, Literal l); + + void MaybeExtractBinaryClause(Literal last_decision, Literal l); + + // Inspect the watcher list for last_decision, If we have a blocking + // literal at true (implied by last decision), then we have subsumptions. + // + // The intuition behind this is that if a binary clause (a,b) subsume a + // clause, and we watch a.Negated() for this clause with a blocking + // literal b, then this watch entry will never change because we always + // propagate binary clauses first and the blocking literal will always be + // true. So after many propagations, we hope to have such configuration + // which is quite cheap to test here. + void SubsumeWithBinaryClauseUsingBlockingLiteral(Literal last_decision); + + // Sets `clause_ids` to the IDs of the implications from `l` to the decisions + // from `start_level` to `end_level` (included, going backwards). + void ComputeDecisionImplicationIds(Literal l, int start_level, int end_level, + std::vector& clause_ids); + + // Adds 'not(literal)' to `to_fix_`, assuming that 'literal' directly implies + // the current decision, which itself implies all the previous decisions, with + // some of them propagating 'not(literal)'. + void AddFailedLiteralToFix(Literal literal); + + // Fixes all the literals in to_fix_, and finish propagation. + bool ProcessLiteralsToFix(); + + SatSolver* sat_solver_; + BinaryImplicationGraph* implication_graph_; + TimeLimit* time_limit_; + const Trail& trail_; + const VariablesAssignment& assignment_; + ClauseManager* clause_manager_; + ClauseIdGenerator* clause_id_generator_; + LratProofHandler* lrat_proof_handler_; + int id_; + int clause_propagator_id_; + + int num_variables_; + std::vector probing_order_; + int order_index_ = 0; + SparseBitset processed_; + + // This is only needed when options.use_queue is true. + std::vector queue_; + util_intops::StrongVector position_in_order_; + + // This is only needed when options use_queue is false; + util_intops::StrongVector starts_; + + // We delay fixing of already assigned literals once we go back to level 0. + std::vector to_fix_; + // For each literal in to_fix_, the ID of the corresponding LRAT unit clause. + std::vector to_fix_unit_id_; + + // For each literal 'l' in the trail, whether a binary clause "d => l" has + // been extracted, with 'd' the decision at the same level as 'l'. + std::vector binary_clause_extracted_; + std::vector subsumed_clauses_; + + // Temporary vector used for LRAT proofs. + std::vector tmp_clause_ids_; + + // Stats. + int64_t num_probed_ = 0; + int64_t num_explicit_fix_ = 0; + int64_t num_conflicts_ = 0; + int64_t num_new_binary_ = 0; + int64_t num_subsumed_ = 0; +}; + +// TODO(user): remove this and use the class directly. bool FailedLiteralProbingRound(ProbingOptions options, Model* model); } // namespace sat diff --git a/ortools/sat/routing_cuts.cc b/ortools/sat/routing_cuts.cc index 03647718b5..d15e8b1d31 100644 --- a/ortools/sat/routing_cuts.cc +++ b/ortools/sat/routing_cuts.cc @@ -121,7 +121,7 @@ MinOutgoingFlowHelper::MinOutgoingFlowHelper( heads_(heads), literals_(literals), binary_relation_repository_( - *model->GetOrCreate()), + *model->GetOrCreate()), implied_bounds_(*model->GetOrCreate()), trail_(*model->GetOrCreate()), integer_trail_(*model->GetOrCreate()), @@ -1106,7 +1106,7 @@ class RouteRelationsBuilder { int num_nodes, absl::Span tails, absl::Span heads, absl::Span literals, absl::Span flat_node_dim_expressions, - const BinaryRelationRepository& binary_relation_repository) + const ConditionalLinear2Bounds& binary_relation_repository) : num_nodes_(num_nodes), num_arcs_(tails.size()), tails_(tails), @@ -1683,7 +1683,7 @@ class RouteRelationsBuilder { absl::Span tails_; absl::Span heads_; absl::Span literals_; - const BinaryRelationRepository& binary_relation_repository_; + const ConditionalLinear2Bounds& binary_relation_repository_; int num_dimensions_; absl::flat_hash_map dimension_by_var_; @@ -1705,7 +1705,7 @@ class RouteRelationsBuilder { RoutingCumulExpressions DetectDimensionsAndCumulExpressions( int num_nodes, absl::Span tails, absl::Span heads, absl::Span literals, - const BinaryRelationRepository& binary_relation_repository) { + const ConditionalLinear2Bounds& binary_relation_repository) { RoutingCumulExpressions result; RouteRelationsBuilder builder(num_nodes, tails, heads, literals, {}, binary_relation_repository); @@ -1789,7 +1789,7 @@ std::unique_ptr RouteRelationsHelper::Create( int num_nodes, absl::Span tails, absl::Span heads, absl::Span literals, absl::Span flat_node_dim_expressions, - const BinaryRelationRepository& binary_relation_repository, Model* model) { + const ConditionalLinear2Bounds& binary_relation_repository, Model* model) { CHECK(model != nullptr); if (flat_node_dim_expressions.empty()) return nullptr; RouteRelationsBuilder builder(num_nodes, tails, heads, literals, @@ -1886,10 +1886,10 @@ int ToNodeVariableIndex(IntegerVariable var) { // domains) of the enforced linear constraints (of size 2 only) in `model`. This // is the only information needed to infer the mapping from variables to nodes // in routes constraints. -BinaryRelationRepository ComputePartialBinaryRelationRepository( +ConditionalLinear2Bounds ComputePartialConditionalLinear2Bounds( const CpModelProto& model) { Model empty_model; - BinaryRelationRepository repository(&empty_model); + ConditionalLinear2Bounds repository(&empty_model); for (const ConstraintProto& ct : model.constraints()) { if (ct.constraint_case() != ConstraintProto::kLinear) continue; const absl::Span vars = ct.linear().vars(); @@ -1905,7 +1905,7 @@ BinaryRelationRepository ComputePartialBinaryRelationRepository( // Returns the number of dimensions added to the constraint. int MaybeFillRoutesConstraintNodeExpressions( - RoutesConstraintProto& routes, const BinaryRelationRepository& repository) { + RoutesConstraintProto& routes, const ConditionalLinear2Bounds& repository) { int max_node = 0; for (const int node : routes.tails()) { max_node = std::max(max_node, node); @@ -1957,8 +1957,8 @@ std::pair MaybeFillMissingRoutesConstraintNodeExpressions( if (routes_to_fill.empty()) return {0, 0}; int total_num_dimensions = 0; - const BinaryRelationRepository partial_repository = - ComputePartialBinaryRelationRepository(input_model); + const ConditionalLinear2Bounds partial_repository = + ComputePartialConditionalLinear2Bounds(input_model); for (RoutesConstraintProto* routes : routes_to_fill) { total_num_dimensions += MaybeFillRoutesConstraintNodeExpressions(*routes, partial_repository); @@ -1983,7 +1983,7 @@ class RoutingCutHelper { trail_(*model->GetOrCreate()), integer_trail_(*model->GetOrCreate()), binary_relation_repository_( - *model->GetOrCreate()), + *model->GetOrCreate()), implied_bounds_(*model->GetOrCreate()), random_(model->GetOrCreate()), encoder_(model->GetOrCreate()), @@ -1996,7 +1996,7 @@ class RoutingCutHelper { min_outgoing_flow_helper_(num_nodes, tails_, heads_, literals_, model), route_relations_helper_(RouteRelationsHelper::Create( num_nodes, tails_, heads_, literals_, flat_node_dim_expressions, - *model->GetOrCreate(), model)) {} + *model->GetOrCreate(), model)) {} int num_nodes() const { return num_nodes_; } @@ -2118,7 +2118,7 @@ class RoutingCutHelper { const SatParameters& params_; const Trail& trail_; const IntegerTrail& integer_trail_; - const BinaryRelationRepository& binary_relation_repository_; + const ConditionalLinear2Bounds& binary_relation_repository_; const ImpliedBounds& implied_bounds_; ModelRandomGenerator* random_; IntegerEncoder* encoder_; diff --git a/ortools/sat/routing_cuts.h b/ortools/sat/routing_cuts.h index c29b560174..f9bca8b915 100644 --- a/ortools/sat/routing_cuts.h +++ b/ortools/sat/routing_cuts.h @@ -61,7 +61,7 @@ struct RoutingCumulExpressions { RoutingCumulExpressions DetectDimensionsAndCumulExpressions( int num_nodes, absl::Span tails, absl::Span heads, absl::Span literals, - const BinaryRelationRepository& binary_relation_repository); + const ConditionalLinear2Bounds& binary_relation_repository); // A coeff * var + offset affine expression, where `var` is always a positive // reference (contrary to AffineExpression, where the coefficient is always @@ -130,7 +130,7 @@ class RouteRelationsHelper { int num_nodes, absl::Span tails, absl::Span heads, absl::Span literals, absl::Span flat_node_dim_expressions, - const BinaryRelationRepository& binary_relation_repository, Model* model); + const ConditionalLinear2Bounds& binary_relation_repository, Model* model); // Returns the number of "dimensions", such as time or vehicle load. int num_dimensions() const { return num_dimensions_; } @@ -541,7 +541,7 @@ class MinOutgoingFlowHelper { const std::vector& tails_; const std::vector& heads_; const std::vector& literals_; - const BinaryRelationRepository& binary_relation_repository_; + const ConditionalLinear2Bounds& binary_relation_repository_; const ImpliedBounds& implied_bounds_; const Trail& trail_; const IntegerTrail& integer_trail_; diff --git a/ortools/sat/routing_cuts_test.cc b/ortools/sat/routing_cuts_test.cc index 8aa8a09718..e3de80cf0e 100644 --- a/ortools/sat/routing_cuts_test.cc +++ b/ortools/sat/routing_cuts_test.cc @@ -183,7 +183,7 @@ TEST(MinOutgoingFlowHelperTest, CapacityConstraints) { loads.push_back(model.Add(NewIntegerVariable(0, max_capacity))); } // Capacity constraints. - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); for (const auto& [arc, literal] : literal_by_arc) { const auto& [tail, head] = arc; // We consider that, at each node n other than the depot, n+10 items must be @@ -252,7 +252,7 @@ TEST_P(DimensionBasedMinOutgoingFlowHelperTest, BasicCapacities) { } } // Capacity constraints. - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); for (const auto& [arc, literal] : literal_by_arc) { const auto& [tail, head] = arc; if (tail == 0 || head == 0) continue; @@ -327,7 +327,7 @@ TEST_P(DimensionBasedMinOutgoingFlowHelperTest, } } // Capacity constraints. - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); for (int i = 0; i < 4; ++i) { const int head = heads[i]; const int tail = tails[i]; @@ -391,7 +391,7 @@ TEST(MinOutgoingFlowHelperTest, NodeExpressionWithConstant) { const IntegerVariable offset_load2 = model.Add(NewIntegerVariable(-offset, capacity - demand2 - offset)); - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); // Capacity constraint: (offset_load2 + offset) - load1 >= demand1 repository->Add(literals[0], LinearExpression2(offset_load2, load1, 1, -1), demand1 - offset, 1000); @@ -434,7 +434,7 @@ TEST(MinOutgoingFlowHelperTest, ConstantNodeExpression) { // The load of the vehicle arriving at node 2, a constant value. const IntegerValue load2 = capacity - demand2; - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); auto* implied_bounds = model.GetOrCreate(); // Capacity constraint: load2 - load1 >= demand1 implied_bounds->Add(literals[0], IntegerLiteral::GreaterOrEqual( @@ -489,7 +489,7 @@ TEST(MinOutgoingFlowHelperTest, NodeExpressionUsingArcLiteralAsVariable) { // The load of the vehicle arriving at node 3, a constant value. const IntegerValue load3 = capacity - demand3; - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); // Capacity constraint: load2 - load1 >= demand1. This expands to // (capacity - demand2 - demand3 * l) - load1 >= demand1, i.e., // -demand3 * l - load1 >= demand1 + demand2 - capacity @@ -549,7 +549,7 @@ TEST(MinOutgoingFlowHelperTest, // The load of the vehicle arriving at node 3, a constant value. const IntegerValue load3 = capacity - demand3; - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); // Capacity constraint: load2 - load1 >= demand1. This expands to // (capacity - demand2 - demand3 + demand3 * l) - load1 >= demand1, i.e., // demand3 * l - load1 >= demand1 + demand2 + demand3 - capacity @@ -610,7 +610,7 @@ TEST(MinOutgoingFlowHelperTest, ArcNodeExpressionsWithSharedVariable) { const AffineExpression load3 = AffineExpression(x, -coeff, capacity - demand3); - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); // Capacity constraint: load2 - load1 >= demand1. This expands to // (capacity - demand2 - demand3) - coeff * x - load1 >= demand1, i.e., // -coeff * x - load1 >= demand1 + demand2 + demand3 - capacity. @@ -673,7 +673,7 @@ TEST(MinOutgoingFlowHelperTest, UnaryRelationForTwoNodeExpressions) { model.GetOrCreate()->AddImplication( b, literals[0].Negated()); - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); // Capacity constraint: load2 - load1 >= demand1. This expands to // (capacity - demand2) - demand1 * x - load1 >= demand1. Since this // constraint is enforced by arc_1_2_lit we can assume it is true, which @@ -738,7 +738,7 @@ TEST(MinOutgoingFlowHelperTest, NodeMustBeInnerNode) { } // Capacity constraints. - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); for (int i = 0; i < num_arcs; ++i) { // loads[head] - loads[tail] >= demand[arc] repository->Add( @@ -800,7 +800,7 @@ TEST(MinOutgoingFlowHelperTest, BetterUseOfUpperBound) { } // Capacity constraints. - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); for (int i = 0; i < num_arcs; ++i) { // loads[head] - loads[tail] >= demand[arc] repository->Add( @@ -836,7 +836,7 @@ TEST(MinOutgoingFlowHelperTest, DimensionBasedMinOutgoingFlow_IsolatedNodes) { std::vector heads; std::vector literals; std::vector variables; - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); // The depot variable. variables.push_back(model.Add(NewIntegerVariable(0, 100))); for (int head = 1; head < num_nodes; ++head) { @@ -894,7 +894,7 @@ TEST(MinOutgoingFlowHelperTest, TimeWindows) { times.push_back(model.Add(NewIntegerVariable(18, 22))); // Node 3. times.push_back(model.Add(NewIntegerVariable(28, 32))); // Node 4. // Travel time constraints. - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); for (const auto& [arc, literal] : literal_by_arc) { const auto& [tail, head] = arc; const int travel_time = 10 - tail; @@ -1023,7 +1023,7 @@ TEST(MinOutgoingFlowHelperTest, SubsetMightBeServedWithKRoutes) { } // Capacity constraints on two dimensions. - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); for (const auto& [arc, literal] : literal_by_arc) { const auto& [tail, head] = arc; @@ -1095,7 +1095,7 @@ TEST(MinOutgoingFlowHelperTest, SubsetMightBeServedWithKRoutesRandom) { } // Capacity constraints on two dimensions. - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); for (const auto& [arc, literal] : literal_by_arc) { const auto& [tail, head] = arc; @@ -1226,7 +1226,7 @@ TEST(MinOutgoingFlowHelperTest, } // Travel time constraint. - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); for (int arc = 0; arc < tails.size(); ++arc) { const int tail = tails[arc]; const int head = heads[arc]; @@ -1461,7 +1461,7 @@ TEST(RouteRelationsHelperTest, Basic) { const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); const IntegerVariable z = model.Add(NewIntegerVariable(0, 10)); - BinaryRelationRepository repository(&model); + ConditionalLinear2Bounds repository(&model); repository.Add(literals[0], LinearExpression2::Difference(a, b), 50, 1000); repository.Add(literals[1], LinearExpression2::Difference(a, c), 70, 1000); repository.Add(literals[2], LinearExpression2::Difference(c, b), 40, 1000); @@ -1556,7 +1556,7 @@ TEST(RouteRelationsHelperTest, UnenforcedRelations) { const IntegerVariable b = model.Add(NewIntegerVariable(0, 100)); const IntegerVariable c = model.Add(NewIntegerVariable(0, 100)); const IntegerVariable d = model.Add(NewIntegerVariable(0, 100)); - BinaryRelationRepository repository(&model); + ConditionalLinear2Bounds repository(&model); RootLevelLinear2Bounds* bounds = model.GetOrCreate(); repository.Add(literals[0], LinearExpression2::Difference(b, a), 1, 1); repository.Add(literals[1], LinearExpression2::Difference(c, b), 2, 2); @@ -1606,7 +1606,7 @@ TEST(RouteRelationsHelperTest, SeveralVariablesPerNode) { const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); const IntegerVariable z = model.Add(NewIntegerVariable(0, 10)); - BinaryRelationRepository repository(&model); + ConditionalLinear2Bounds repository(&model); repository.Add(literals[0], LinearExpression2::Difference(b, a), 50, 1000); repository.Add(literals[1], LinearExpression2::Difference(c, b), 70, 1000); repository.Add(literals[0], LinearExpression2::Difference(z, y), 5, 100); @@ -1637,7 +1637,7 @@ TEST(RouteRelationsHelperTest, ComplexVariableRelations) { // and 1, respectively. const IntegerVariable a = model.Add(NewIntegerVariable(0, 150)); const IntegerVariable b = model.Add(NewIntegerVariable(0, 1)); - BinaryRelationRepository repository(&model); + ConditionalLinear2Bounds repository(&model); // "complex" relation with non +1/-1 coefficients. repository.Add(literals[0], LinearExpression2(b, a, 10, 1), 0, 150); repository.Build(); @@ -1672,7 +1672,7 @@ TEST(RouteRelationsHelperTest, TwoUnaryRelationsPerArc) { IntegerEncoder& encoder = *model.GetOrCreate(); encoder.AssociateToIntegerEqualValue(literals[0], a, 20); encoder.AssociateToIntegerLiteral(literals[0], {b, 50}); - BinaryRelationRepository repository(&model); + ConditionalLinear2Bounds repository(&model); repository.Build(); const RoutingCumulExpressions cumuls = { @@ -1702,7 +1702,7 @@ TEST(RouteRelationsHelperTest, SeveralRelationsPerArc) { const IntegerVariable a = model.Add(NewIntegerVariable(0, 100)); const IntegerVariable b = model.Add(NewIntegerVariable(0, 100)); const IntegerVariable c = model.Add(NewIntegerVariable(0, 100)); - BinaryRelationRepository repository(&model); + ConditionalLinear2Bounds repository(&model); repository.Add(literals[0], LinearExpression2::Difference(b, a), 50, 1000); repository.Add(literals[1], LinearExpression2::Difference(c, b), 70, 1000); // Add a second relation for some arc. @@ -1738,7 +1738,7 @@ TEST(RouteRelationsHelperTest, SeveralArcsPerLiteral) { const IntegerVariable a = model.Add(NewIntegerVariable(0, 100)); const IntegerVariable b = model.Add(NewIntegerVariable(0, 100)); const IntegerVariable c = model.Add(NewIntegerVariable(0, 100)); - BinaryRelationRepository repository(&model); + ConditionalLinear2Bounds repository(&model); repository.Add(literals[0], LinearExpression2::Difference(b, a), 50, 1000); repository.Add(literals[0], LinearExpression2::Difference(c, b), 40, 1000); repository.Build(); @@ -1780,7 +1780,7 @@ TEST(RouteRelationsHelperTest, InconsistentRelationIsSkipped) { const IntegerVariable d = model.Add(NewIntegerVariable(0, 100)); const IntegerVariable e = model.Add(NewIntegerVariable(0, 100)); const IntegerVariable f = model.Add(NewIntegerVariable(0, 100)); - BinaryRelationRepository repository(&model); + ConditionalLinear2Bounds repository(&model); repository.Add(literals[0], LinearExpression2::Difference(b, a), 0, 0); repository.Add(literals[1], LinearExpression2::Difference(c, b), 1, 1); repository.Add(literals[2], LinearExpression2::Difference(d, c), 2, 2); @@ -1840,7 +1840,7 @@ TEST(RouteRelationsHelperTest, InconsistentRelationWithMultipleArcsPerLiteral) { const IntegerVariable c = model.Add(NewIntegerVariable(0, 100)); const IntegerVariable d = model.Add(NewIntegerVariable(0, 100)); const IntegerVariable e = model.Add(NewIntegerVariable(0, 100)); - BinaryRelationRepository repository(&model); + ConditionalLinear2Bounds repository(&model); repository.Add(literals[0], LinearExpression2::Difference(b, a), 0, 0); repository.Add(literals[1], LinearExpression2::Difference(c, b), 1, 1); repository.Add(literals[2], LinearExpression2::Difference(d, c), 2, 2); @@ -2478,7 +2478,7 @@ TEST(CreateCVRPCutGeneratorTest, InfeasiblePathCuts) { flat_node_dim_expressions.push_back(AffineExpression(load)); } // Capacity constraints. - auto* repository = model.GetOrCreate(); + auto* repository = model.GetOrCreate(); for (int i = 0; i < tails.size(); ++i) { const int tail = tails[i]; const int head = heads[i]; diff --git a/ortools/sat/sat_base.h b/ortools/sat/sat_base.h index 7c9c26b6ab..2603d9cf02 100644 --- a/ortools/sat/sat_base.h +++ b/ortools/sat/sat_base.h @@ -506,14 +506,6 @@ class Trail { return GetEmptyVectorToStoreReason(Index()); } - // Explicitly overwrite the reason so that the given propagator will be - // asked for it. This is currently only used by the BinaryImplicationGraph. - void ChangeReason(int trail_index, int propagator_id) { - const BooleanVariable var = trail_[trail_index].Variable(); - info_[var].type = propagator_id; - old_type_[var] = propagator_id; - } - // Reverts the trail and underlying assignment to the given target trail // index. Note that we do not touch the assignment info. void Untrail(int target_trail_index) { diff --git a/ortools/sat/sat_solver.cc b/ortools/sat/sat_solver.cc index 3b075f8054..b2b9457db2 100644 --- a/ortools/sat/sat_solver.cc +++ b/ortools/sat/sat_solver.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -632,7 +633,8 @@ bool ClauseSubsumption(absl::Span a, SatClause* b) { } // namespace -int SatSolver::EnqueueDecisionAndBackjumpOnConflict(Literal true_literal) { +int SatSolver::EnqueueDecisionAndBackjumpOnConflict( + Literal true_literal, std::optional callback) { SCOPED_TIME_STAT(&stats_); if (model_is_unsat_) return kUnsatTrailIndex; DCHECK(PropagationIsDone()); @@ -643,18 +645,18 @@ int SatSolver::EnqueueDecisionAndBackjumpOnConflict(Literal true_literal) { } EnqueueNewDecision(true_literal); - if (!FinishPropagation()) return kUnsatTrailIndex; + if (!FinishPropagation(callback)) return kUnsatTrailIndex; DCHECK(PropagationIsDone()); return last_decision_or_backtrack_trail_index_; } -bool SatSolver::FinishPropagation() { +bool SatSolver::FinishPropagation(std::optional callback) { if (model_is_unsat_) return false; int num_loop = 0; while (true) { const int old_decision_level = current_decision_level_; if (!Propagate()) { - ProcessCurrentConflict(); + ProcessCurrentConflict(callback); if (model_is_unsat_) return false; if (current_decision_level_ == old_decision_level) { CHECK(!assumptions_.empty()); @@ -756,7 +758,8 @@ bool SatSolver::ReapplyAssumptionsIfNeeded() { return (status == SatSolver::FEASIBLE); } -void SatSolver::ProcessCurrentConflict() { +void SatSolver::ProcessCurrentConflict( + std::optional callback) { SCOPED_TIME_STAT(&stats_); if (model_is_unsat_) return; @@ -817,18 +820,22 @@ void SatSolver::ProcessCurrentConflict() { std::vector* clause_ids_for_1iup = &tmp_clause_ids_for_1uip_; if (lrat_proof_handler_ != nullptr) { - ComputeLratProofForLearnedConflict(clause_ids_for_1iup); + FillLratProofForLearnedConflict(clause_ids_for_1iup); } // An empty conflict means that the problem is UNSAT. if (learned_conflict_.empty()) { + ClauseId clause_id = kNoClauseId; if (lrat_proof_handler_ != nullptr) { - if (!lrat_proof_handler_->AddInferredClause( - clause_id_generator_->GetNextId(), learned_conflict_, - *clause_ids_for_1iup)) { + clause_id = clause_id_generator_->GetNextId(); + if (!lrat_proof_handler_->AddInferredClause(clause_id, learned_conflict_, + *clause_ids_for_1iup)) { VLOG(1) << "WARNING: invalid LRAT inferred clause!"; } } + if (callback.has_value()) { + (*callback)(clause_id, learned_conflict_); + } return (void)SetModelUnsat(); } @@ -994,31 +1001,6 @@ void SatSolver::ProcessCurrentConflict() { // Minimize the learned conflict. MinimizeConflict(&learned_conflict_, clause_ids_for_minimization); - // We notify the decision before backtracking so that we can save the phase. - // The current heuristic is to try to take a trail prefix for which there is - // currently no conflict (hence just before the last decision was taken). - // - // TODO(user): It is unclear what the best heuristic is here. Both the current - // trail index or the trail before the current decision perform well, but - // using the full trail seems slightly better even though it will contain the - // current conflicting literal. - decision_policy_->BeforeConflict(trail_->Index()); - - // Backtrack and add the reason to the set of learned clause. - counters_.num_literals_learned += learned_conflict_.size(); - const int conflict_level = - trail_->Info(learned_conflict_[0].Variable()).level; - const int backjump_levels = CurrentDecisionLevel() - conflict_level; - const bool should_backjump = - !trail_->ChronologicalBacktrackingEnabled() || - (num_failures() > parameters_->chronological_backtrack_min_conflicts() && - backjump_levels > parameters_->max_backjump_levels()); - const int backtrack_level = should_backjump - ? ComputePropagationLevel(learned_conflict_) - : std::max(0, conflict_level - 1); - Backtrack(backtrack_level); - DCHECK(ClauseIsValidUnderDebugAssignment(learned_conflict_)); - // Note that we need to output the learned clause before cleaning the clause // database. This is because we already backtracked and some of the clauses // that were needed to infer the conflict may not be "reasons" anymore and @@ -1052,6 +1034,34 @@ void SatSolver::ProcessCurrentConflict() { VLOG(1) << "WARNING: invalid LRAT inferred clause!"; } } + if (callback.has_value()) { + (*callback)(learned_conflict_clause_id, learned_conflict_); + } + + // We notify the decision before backtracking so that we can save the phase. + // The current heuristic is to try to take a trail prefix for which there is + // currently no conflict (hence just before the last decision was taken). + // + // TODO(user): It is unclear what the best heuristic is here. Both the current + // trail index or the trail before the current decision perform well, but + // using the full trail seems slightly better even though it will contain the + // current conflicting literal. + decision_policy_->BeforeConflict(trail_->Index()); + + // Backtrack and add the reason to the set of learned clause. + counters_.num_literals_learned += learned_conflict_.size(); + const int conflict_level = + trail_->Info(learned_conflict_[0].Variable()).level; + const int backjump_levels = CurrentDecisionLevel() - conflict_level; + const bool should_backjump = + !trail_->ChronologicalBacktrackingEnabled() || + (num_failures() > parameters_->chronological_backtrack_min_conflicts() && + backjump_levels > parameters_->max_backjump_levels()); + const int backtrack_level = should_backjump + ? ComputePropagationLevel(learned_conflict_) + : std::max(0, conflict_level - 1); + Backtrack(backtrack_level); + DCHECK(ClauseIsValidUnderDebugAssignment(learned_conflict_)); // Detach any subsumed clause. They will actually be deleted on the next // clause cleanup phase. @@ -1075,7 +1085,7 @@ void SatSolver::ProcessCurrentConflict() { restart_->OnConflict(conflict_trail_index, conflict_level, conflict_lbd); } -void SatSolver::ComputeLratProofForLearnedConflict( +void SatSolver::FillLratProofForLearnedConflict( std::vector* clause_ids) { clause_ids->clear(); // First add all the unit clauses used in the reasons to infer the conflict. @@ -1552,19 +1562,20 @@ bool SatSolver::TryToMinimizeClause(SatClause* clause) { if (lrat_proof_handler_ != nullptr) { DCHECK(fixed_true_literal != kNoLiteralIndex || !fixed_false_literals.empty()); + clause_ids.clear(); if (fixed_true_literal != kNoLiteralIndex) { // If some literals of the minimized clause fix another to true, we just // need the propagating clauses to prove this (assuming that all the // minimized clause literals are false will lead to a conflict on this // 'fixed to true' literal). - GetClausesFixing({Literal(fixed_true_literal)}, &clause_ids); + AppendClausesFixing({Literal(fixed_true_literal)}, &clause_ids); } else { // If some literals of the minimized clause fix those that have been // removed to false, the propagating clauses and the original one prove // this (assuming that all the minimized clause literals are false will // lead to all the literals of the original clause fixed to false, which // is a conflict with the original clause). - GetClausesFixing(fixed_false_literals, &clause_ids); + AppendClausesFixing(fixed_false_literals, &clause_ids); clause_ids.push_back(clauses_propagator_->GetClauseId(clause)); } } @@ -1824,24 +1835,24 @@ std::vector SatSolver::GetDecisionsFixing( SCOPED_TIME_STAT(&stats_); std::vector unsat_assumptions; - is_marked_.ClearAndResize(num_variables_); + tmp_mark_.ClearAndResize(num_variables_); int trail_index = 0; for (const Literal lit : literals) { CHECK(Assignment().LiteralIsAssigned(lit)); trail_index = std::max(trail_index, trail_->Info(lit.Variable()).trail_index); - is_marked_.Set(lit.Variable()); + tmp_mark_.Set(lit.Variable()); } - // We just expand the conflict until we only have decisions. + // We just expand the reasons recursively until we only have decisions. const int limit = CurrentDecisionLevel() > 0 ? decisions_[0].trail_index : trail_->Index(); CHECK_LT(trail_index, trail_->Index()); while (true) { // Find next marked literal to expand from the trail. while (trail_index >= limit && - !is_marked_[(*trail_)[trail_index].Variable()]) { + !tmp_mark_[(*trail_)[trail_index].Variable()]) { --trail_index; } if (trail_index < limit) break; @@ -1856,7 +1867,7 @@ std::vector SatSolver::GetDecisionsFixing( for (const Literal literal : trail_->Reason(marked_literal.Variable())) { const BooleanVariable var = literal.Variable(); const int level = AssignmentLevel(var); - if (level > 0 && !is_marked_[var]) is_marked_.Set(var); + if (level > 0 && !tmp_mark_[var]) tmp_mark_.Set(var); } } } @@ -1867,34 +1878,33 @@ std::vector SatSolver::GetDecisionsFixing( return unsat_assumptions; } -void SatSolver::GetClausesFixing(absl::Span literals, - std::vector* clause_ids) { +void SatSolver::AppendClausesFixing(absl::Span literals, + std::vector* clause_ids) { SCOPED_TIME_STAT(&stats_); // Unit clauses must come first. We put them in clause_ids directly. We put // the others in non_unit_clause_ids and append them to clause_ids at the end. std::vector non_unit_clause_ids; - clause_ids->clear(); - is_marked_.ClearAndResize(num_variables_); + tmp_mark_.ClearAndResize(num_variables_); int trail_index = 0; for (const Literal lit : literals) { CHECK(Assignment().LiteralIsAssigned(lit)); trail_index = std::max(trail_index, trail_->Info(lit.Variable()).trail_index); - is_marked_.Set(lit.Variable()); + tmp_mark_.Set(lit.Variable()); } - // We just expand the conflict until we only have decisions. This is the same - // algorithm as in GetDecisionsFixing(). + // We just expand the reasons recursively until we only have decisions. This + // is the same algorithm as in GetDecisionsFixing(). const int limit = CurrentDecisionLevel() > 0 ? decisions_[0].trail_index : trail_->Index(); CHECK_LT(trail_index, trail_->Index()); while (true) { // Find next marked literal to expand from the trail. while (trail_index >= limit && - !is_marked_[(*trail_)[trail_index].Variable()]) { + !tmp_mark_[(*trail_)[trail_index].Variable()]) { --trail_index; } if (trail_index < limit) break; @@ -1910,10 +1920,9 @@ void SatSolver::GetClausesFixing(absl::Span literals, for (const Literal literal : trail_->Reason(marked_literal.Variable())) { const BooleanVariable var = literal.Variable(); const int level = AssignmentLevel(var); - if (!is_marked_[var]) { - if (level > 0) { - is_marked_.Set(var); - } else { + if (!tmp_mark_[var]) { + tmp_mark_.Set(var); + if (level == 0) { clause_ids->push_back(trail_->GetUnitClauseId(var)); } } @@ -2222,7 +2231,9 @@ void SatSolver::ProcessNewlyFixedVariables() { lrat_proof_handler_->AddInferredClause( new_clause_id, {clause->begin(), new_size}, clause_ids); lrat_proof_handler_->DeleteClauses({old_clause_id}); - clauses_propagator_->SetClauseId(clause, new_clause_id); + if (new_size > 2) { + clauses_propagator_->SetClauseId(clause, new_clause_id); + } } if (new_size == 2) { diff --git a/ortools/sat/sat_solver.h b/ortools/sat/sat_solver.h index fc38680c81..ac58572a1d 100644 --- a/ortools/sat/sat_solver.h +++ b/ortools/sat/sat_solver.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -30,6 +31,7 @@ #include "absl/base/attributes.h" #include "absl/container/flat_hash_set.h" +#include "absl/functional/function_ref.h" #include "absl/log/check.h" #include "absl/types/span.h" #include "ortools/base/logging.h" @@ -61,6 +63,11 @@ const int kUnsatTrailIndex = -1; // http://en.wikipedia.org/wiki/Conflict_Driven_Clause_Learning class SatSolver { public: + // Callback called when a new conflict clause is learned. The arguments are + // the ID and the literals of the learned clause. + typedef absl::FunctionRef)> + ConflictCallback; + SatSolver(); explicit SatSolver(Model* model); @@ -248,11 +255,11 @@ class SatSolver { // `literals` are fixed to their current value. std::vector GetDecisionsFixing(absl::Span literals); - // Sets `clause_ids` to the IDs of the clauses which, by unit propagation from - // some decisions, are sufficient to ensure that all literals in `literals` - // are fixed to their current value. - void GetClausesFixing(absl::Span literals, - std::vector* clause_ids); + // Appends to `clause_ids` the IDs of the clauses which, by unit propagation + // from some decisions, are sufficient to ensure that all literals in + // `literals` are fixed to their current value. + void AppendClausesFixing(absl::Span literals, + std::vector* clause_ids); // Advanced usage. The next 3 functions allow to drive the search from outside // the solver. @@ -270,13 +277,16 @@ class SatSolver { // CurrentDecisionLevel() was increased by 1 or not. // // If there is a conflict, the given decision is not applied and: - // - The conflict is learned. + // - The conflict is learned. If `conflict_callback` is provided, it is called + // with for each learned conflict, if any, before backtracking. // - The decisions are potentially backtracked to the first decision that // propagates more variables because of the newly learned conflict. // - The returned value is equal to trail_->Index() after this backtracking // and just before the new propagation (due to the conflict) which is also // performed by this function. - int EnqueueDecisionAndBackjumpOnConflict(Literal true_literal); + int EnqueueDecisionAndBackjumpOnConflict( + Literal true_literal, + std::optional callback = std::nullopt); // This function starts by calling EnqueueDecisionAndBackjumpOnConflict(). If // there is no conflict, it stops there. Otherwise, it tries to reapply all @@ -326,10 +336,13 @@ class SatSolver { return true; } - // Advanced usage. Finish the progation if it was interrupted. Note that this - // might run into conflict and will propagate again until a fixed point is - // reached or the model was proven UNSAT. Returns IsModelUnsat(). - ABSL_MUST_USE_RESULT bool FinishPropagation(); + // Advanced usage. Finish the propagation if it was interrupted. Note that + // this might run into conflict and will propagate again until a fixed point + // is reached or the model was proven UNSAT. If `callback` is provided it is + // called for each learned conflict (if any), before backtracking. Returns + // IsModelUnsat(). + ABSL_MUST_USE_RESULT bool FinishPropagation( + std::optional callback = std::nullopt); // Like Backtrack(0) but make sure the propagation is finished and return // false if unsat was detected. This also removes any assumptions level. @@ -524,13 +537,15 @@ class SatSolver { // Processes the current conflict from trail->FailingClause(). // // This learns the conflict, backtracks, enqueues the consequence of the - // learned conflict and return. When handling assumptions, this might return - // false without backtracking in case of ASSUMPTIONS_UNSAT. This is only - // exposed to allow processing a conflict detected outside normal propagation. - void ProcessCurrentConflict(); - - // Fills `clause_ids` with the LRAT proof for the learned conflict. - void ComputeLratProofForLearnedConflict(std::vector* clause_ids); + // learned conflict and return. If `callback` is provided it is called with + // the learned conflict, if any, before backtracking (there might not be any + // learned conflict if there are assumptions or if the conflict is not a + // clause -- pseudo Boolean case). When handling assumptions, this might + // return false without backtracking in case of ASSUMPTIONS_UNSAT. This is + // only exposed to allow processing a conflict detected outside normal + // propagation. + void ProcessCurrentConflict( + std::optional callback = std::nullopt); void EnsureNewClauseIndexInitialized() { clauses_propagator_->EnsureNewClauseIndexInitialized(); @@ -693,6 +708,9 @@ class SatSolver { std::vector* reason_used_to_infer_the_conflict, std::vector* subsumed_clauses); + // Fills `clause_ids` with the LRAT proof for the learned conflict. + void FillLratProofForLearnedConflict(std::vector* clause_ids); + // Fills literals with all the literals in the reasons of the literals in the // given input. The output vector will have no duplicates and will not contain // the literals already present in the input. diff --git a/ortools/sat/solution_crush.cc b/ortools/sat/solution_crush.cc index 5858d5e178..fc532522ce 100644 --- a/ortools/sat/solution_crush.cc +++ b/ortools/sat/solution_crush.cc @@ -71,6 +71,16 @@ void SolutionCrush::MaybeSetLiteralToValueEncoding(int literal, int var, } } +void SolutionCrush::MaybeSetLiteralToOrderEncoding(int literal, int var, + int64_t value, bool is_le) { + DCHECK(RefIsPositive(var)); + if (!solution_is_loaded_) return; + if (!HasValue(PositiveRef(literal)) && HasValue(var)) { + SetLiteralValue( + literal, is_le ? GetVarValue(var) <= value : GetVarValue(var) >= value); + } +} + void SolutionCrush::SetVarToLinearExpression( int new_var, absl::Span> linear, int64_t offset) { diff --git a/ortools/sat/solution_crush.h b/ortools/sat/solution_crush.h index 7545be443c..fc5426785f 100644 --- a/ortools/sat/solution_crush.h +++ b/ortools/sat/solution_crush.h @@ -81,6 +81,11 @@ class SolutionCrush { // `literal` already has a value. void MaybeSetLiteralToValueEncoding(int literal, int var, int64_t value); + // Sets the value of `literal` to "`var`'s value >=/<= `value`". Does nothing + // if `literal` already has a value. + void MaybeSetLiteralToOrderEncoding(int literal, int var, int64_t value, + bool is_le); + // Sets the value of `var` to the value of the given linear expression, if all // the variables in this expression have a value. `linear` must be a list of // (variable index, coefficient) pairs.