diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index 6dcd173923..1a91eeefaa 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -328,7 +328,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":cp_model_cc_proto", - ":drat_proof_handler", ":sat_base", "//ortools/base:base_export", "//ortools/base:file", @@ -372,6 +371,7 @@ cc_library( deps = [ ":cp_model_cc_proto", ":cp_model_utils", + ":drat_checker", ":integer_base", ":model", ":sat_base", @@ -701,6 +701,7 @@ cc_library( ":linear_programming_constraint", ":linear_relaxation", ":lp_utils", + ":lrat_proof_handler", ":max_hs", ":model", ":optimization", @@ -817,7 +818,6 @@ cc_library( ":cuts", ":diffn_util", ":drat_checker", - ":drat_proof_handler", ":feasibility_jump", ":feasibility_pump", ":implied_bounds", @@ -908,7 +908,6 @@ cc_test( ":cp_model_test_utils", ":cp_model_utils", ":drat_checker", - ":drat_proof_handler", ":lp_utils", ":model", ":sat_base", @@ -1154,6 +1153,7 @@ cc_library( "//ortools/algorithms:sparse_permutation", "//ortools/util:sorted_interval_list", "@abseil-cpp//absl/algorithm:container", + "@abseil-cpp//absl/container:btree", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:flat_hash_set", "@abseil-cpp//absl/container:inlined_vector", @@ -1470,7 +1470,6 @@ cc_library( hdrs = ["sat_solver.h"], deps = [ ":clause", - ":drat_proof_handler", ":enforcement", ":lrat_proof_handler", ":model", @@ -1594,7 +1593,6 @@ cc_library( hdrs = ["sat_inprocessing.h"], deps = [ ":clause", - ":drat_checker", ":linear_programming_constraint", ":lrat_proof_handler", ":model", @@ -1688,7 +1686,6 @@ cc_library( hdrs = ["clause.h"], deps = [ ":container", - ":drat_proof_handler", ":inclusion", ":lrat_proof_handler", ":model", @@ -1742,7 +1739,6 @@ cc_library( srcs = ["simplification.cc"], hdrs = ["simplification.h"], deps = [ - ":drat_proof_handler", ":sat_base", ":sat_parameters_cc_proto", ":sat_solver", @@ -4197,23 +4193,6 @@ cc_test( ], ) -cc_library( - name = "drat_proof_handler", - srcs = ["drat_proof_handler.cc"], - hdrs = ["drat_proof_handler.h"], - deps = [ - ":drat_checker", - ":drat_writer", - ":sat_base", - "//ortools/base", - "//ortools/base:file", - "//ortools/base:strong_vector", - "//ortools/util:strong_integers", - "@abseil-cpp//absl/log:check", - "@abseil-cpp//absl/types:span", - ], -) - cc_library( name = "drat_checker", srcs = ["drat_checker.cc"], @@ -4287,9 +4266,12 @@ cc_library( srcs = ["lrat_proof_handler.cc"], hdrs = ["lrat_proof_handler.h"], deps = [ + ":drat_checker", + ":drat_writer", ":lrat_checker", ":model", ":sat_base", + "//ortools/base:file", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/strings", diff --git a/ortools/sat/boolean_problem.cc b/ortools/sat/boolean_problem.cc index 958b89b8a2..f6ab7229ef 100644 --- a/ortools/sat/boolean_problem.cc +++ b/ortools/sat/boolean_problem.cc @@ -844,8 +844,7 @@ void ProbeAndSimplifyProblem(SatPostsolver* postsolver, } util_intops::StrongVector equiv_map; - ProbeAndFindEquivalentLiteral(&solver, postsolver, - /*drat_proof_handler=*/nullptr, &equiv_map); + ProbeAndFindEquivalentLiteral(&solver, postsolver, &equiv_map); // We can abort if no information is learned. if (equiv_map.empty() && solver.LiteralTrail().Index() == 0) break; diff --git a/ortools/sat/clause.cc b/ortools/sat/clause.cc index 35a9f0bcff..feaa3d19cc 100644 --- a/ortools/sat/clause.cc +++ b/ortools/sat/clause.cc @@ -40,7 +40,6 @@ #include "ortools/base/timer.h" #include "ortools/graph/strongly_connected_components.h" #include "ortools/sat/container.h" -#include "ortools/sat/drat_proof_handler.h" #include "ortools/sat/inclusion.h" #include "ortools/sat/lrat_proof_handler.h" #include "ortools/sat/model.h" @@ -398,20 +397,15 @@ void ClauseManager::Attach(SatClause* clause, Trail* trail) { void ClauseManager::InternalDetach(SatClause* clause, DeletionSourceForStat source) { - const size_t size = clause->size(); - // Double-deletion. // TODO(user): change that to a check? - if (size == 0) return; + if (clause->size() == 0) return; --num_watched_clauses_; - if (drat_proof_handler_ != nullptr && size > 2) { - drat_proof_handler_->DeleteClause({clause->begin(), size}); - } if (lrat_proof_handler_ != nullptr) { const auto it = clause_id_.find(clause); if (it != clause_id_.end()) { - lrat_proof_handler_->DeleteClauses({it->second}); + lrat_proof_handler_->DeleteClause(it->second, clause->AsSpan()); clause_id_.erase(it); } } @@ -460,9 +454,6 @@ void ClauseManager::AttachAllClauses() { bool ClauseManager::InprocessingAddUnitClause(ClauseId unit_clause_id, Literal true_literal) { DCHECK_EQ(trail_->CurrentDecisionLevel(), 0); - if (drat_proof_handler_ != nullptr) { - drat_proof_handler_->AddClause({true_literal}); - } if (trail_->Assignment().LiteralIsTrue(true_literal)) return true; trail_->EnqueueWithUnitReason(unit_clause_id, true_literal); @@ -475,9 +466,6 @@ bool ClauseManager::InprocessingAddUnitClause(ClauseId unit_clause_id, bool ClauseManager::InprocessingFixLiteral( Literal true_literal, absl::Span clause_ids) { DCHECK_EQ(trail_->CurrentDecisionLevel(), 0); - if (drat_proof_handler_ != nullptr) { - drat_proof_handler_->AddClause({true_literal}); - } if (trail_->Assignment().LiteralIsTrue(true_literal)) return true; ClauseId clause_id = kNoClauseId; @@ -530,10 +518,12 @@ bool ClauseManager::InprocessingRewriteClause( return true; } - if (drat_proof_handler_ != nullptr) { - // We must write the new clause before we delete the old one. - drat_proof_handler_->AddClause(new_clause); - drat_proof_handler_->DeleteClause(clause->AsSpan()); + if (lrat_proof_handler_ != nullptr) { + const auto it = clause_id_.find(clause); + if (it != clause_id_.end()) { + lrat_proof_handler_->DeleteClause(it->second, clause->AsSpan()); + } + SetClauseId(clause, new_clause_id); } if (all_clauses_are_attached_) { @@ -543,25 +533,13 @@ bool ClauseManager::InprocessingRewriteClause( clause->Clear(); for (const Literal l : {clause->FirstLiteral(), clause->SecondLiteral()}) { needs_cleaning_.Clear(l); - // std::erase_if is C++20, not yet fully supported on OR-Tools. - watchers_on_false_[l].erase( - std::remove_if(watchers_on_false_[l].begin(), - watchers_on_false_[l].end(), - [](const Watcher& watcher) { - return watcher.clause->IsRemoved(); - }), - watchers_on_false_[l].end()); + OpenSourceEraseIf(watchers_on_false_[l], [](const Watcher& watcher) { + return watcher.clause->IsRemoved(); + }); } } clause->Rewrite(new_clause); - if (lrat_proof_handler_ != nullptr) { - const auto it = clause_id_.find(clause); - if (it != clause_id_.end()) { - lrat_proof_handler_->DeleteClauses({it->second}); - } - SetClauseId(clause, new_clause_id); - } // And we reattach it. if (all_clauses_are_attached_) { @@ -601,12 +579,9 @@ void ClauseManager::CleanUpWatchers() { SCOPED_TIME_STAT(&stats_); for (const LiteralIndex index : needs_cleaning_.PositionsSetAtLeastOnce()) { if (!needs_cleaning_[index]) continue; - // std::erase_if is C++20, not yet fully supported on OR-Tools. - watchers_on_false_[index].erase( - std::remove_if( - watchers_on_false_[index].begin(), watchers_on_false_[index].end(), - [](const Watcher& watcher) { return watcher.clause->IsRemoved(); }), - watchers_on_false_[index].end()); + OpenSourceEraseIf(watchers_on_false_[index], [](const Watcher& watcher) { + return watcher.clause->IsRemoved(); + }); needs_cleaning_.Clear(index); } needs_cleaning_.NotifyAllClear(); @@ -789,9 +764,9 @@ bool BinaryImplicationGraph::HasNoDuplicates() { // use them here to maintains invariant? Explore this when we start cleaning our // clauses using equivalence during search. We can easily do it for every // conflict we learn instead of here. -bool BinaryImplicationGraph::AddBinaryClauseInternal(ClauseId id, Literal a, - Literal b, - bool change_reason) { +bool BinaryImplicationGraph::AddBinaryClauseInternal( + ClauseId id, Literal a, Literal b, bool change_reason, + bool delete_non_representative_id) { SCOPED_TIME_STAT(&stats_); // Tricky: If this is the first clause, the propagator will be added and @@ -799,14 +774,6 @@ bool BinaryImplicationGraph::AddBinaryClauseInternal(ClauseId id, Literal a, if (no_constraint_ever_added_) propagation_trail_index_ = trail_->Index(); no_constraint_ever_added_ = false; - if (drat_proof_handler_ != nullptr) { - // TODO(user): Like this we will duplicate all binary clause from the - // problem. However this leads to a simpler API (since we don't need to - // special case the loading of the original clauses) and we mainly use drat - // proof for testing anyway. - drat_proof_handler_->AddClause({a, b}); - } - Literal rep_a = a; Literal rep_b = b; ClauseId rep_id = kNoClauseId; @@ -839,8 +806,8 @@ bool BinaryImplicationGraph::AddBinaryClauseInternal(ClauseId id, Literal a, // Remember the non-canonical clause so we can delete it on restart. changed_reasons_on_trail_.emplace_back(std::minmax(a, b)); AddClauseId(id, a, b); - } else { - lrat_proof_handler_->DeleteClauses({id}); + } else if (delete_non_representative_id) { + lrat_proof_handler_->DeleteClause(id, {a, b}); } } AddClauseId(rep_id, rep_a, rep_b); @@ -939,9 +906,6 @@ bool BinaryImplicationGraph::FixLiteral(Literal true_literal, return false; } - if (drat_proof_handler_ != nullptr) { - drat_proof_handler_->AddClause({true_literal}); - } ClauseId new_clause_id = kNoClauseId; if (lrat_proof_handler_ != nullptr) { new_clause_id = clause_id_generator_->GetNextId(); @@ -1515,7 +1479,6 @@ class LratEquivalenceHelper { trail_(graph->trail_), implications_and_amos_(graph->implications_and_amos_), clause_id_generator_(graph->clause_id_generator_), - drat_proof_handler_(graph->drat_proof_handler_), lrat_proof_handler_(graph->lrat_proof_handler_) {} // Initializes the internal data structures to process the given component @@ -1714,9 +1677,6 @@ class LratEquivalenceHelper { void AddInferredClause(ClauseId new_clause_id, absl::Span literals, absl::Span clause_ids) { - if (drat_proof_handler_ != nullptr) { - drat_proof_handler_->AddClause(literals); - } if (lrat_proof_handler_ != nullptr) { lrat_proof_handler_->AddInferredClause(new_clause_id, literals, clause_ids); @@ -1728,7 +1688,6 @@ class LratEquivalenceHelper { util_intops::StrongVector& implications_and_amos_; ClauseIdGenerator* clause_id_generator_; - DratProofHandler* drat_proof_handler_; LratProofHandler* lrat_proof_handler_; // Temporary data structures used by the above methods: @@ -1745,14 +1704,6 @@ class LratEquivalenceHelper { std::vector tmp_literals_; }; -void BinaryImplicationGraph::SetDratProofHandler( - DratProofHandler* drat_proof_handler) { - drat_proof_handler_ = drat_proof_handler; - if (lrat_helper_ == nullptr) { - lrat_helper_ = new LratEquivalenceHelper(this); - } -} - bool BinaryImplicationGraph::DetectEquivalences(bool log_info) { // This was already called, and no new constraint where added. Note that new // fixed variable cannot create new equivalence, only new binary clauses do. @@ -1764,7 +1715,8 @@ bool BinaryImplicationGraph::DetectEquivalences(bool log_info) { if (trail_->CurrentDecisionLevel() == 0) { for (std::pair clause : changed_reasons_on_trail_) { auto it = clause_id_.find(clause); - lrat_proof_handler_->DeleteClauses({it->second}); + lrat_proof_handler_->DeleteClause(it->second, + {clause.first, clause.second}); clause_id_.erase(it); } changed_reasons_on_trail_.clear(); @@ -3157,15 +3109,9 @@ void BinaryImplicationGraph::RemoveBooleanVariable( // Notify the deletion to the proof checker and the postsolve. // Note that we want var first in these clauses for the postsolve. for (const Literal b : direct_implications_) { - if (drat_proof_handler_ != nullptr) { - drat_proof_handler_->DeleteClause({Literal(var, false), b}); - } postsolve_clauses->push_back({Literal(var, false), b}); } for (const Literal a_negated : direct_implications_of_negated_literal_) { - if (drat_proof_handler_ != nullptr) { - drat_proof_handler_->DeleteClause({Literal(var, true), a_negated}); - } postsolve_clauses->push_back({Literal(var, true), a_negated}); } diff --git a/ortools/sat/clause.h b/ortools/sat/clause.h index 61e003523f..4cfda6f011 100644 --- a/ortools/sat/clause.h +++ b/ortools/sat/clause.h @@ -33,7 +33,6 @@ #include "ortools/base/strong_vector.h" #include "ortools/graph/cliques.h" #include "ortools/sat/container.h" -#include "ortools/sat/drat_proof_handler.h" #include "ortools/sat/lrat_proof_handler.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" @@ -276,11 +275,6 @@ class ClauseManager : public SatPropagator { // Number of clauses currently watched. int64_t num_watched_clauses() const { return num_watched_clauses_; } - DratProofHandler* GetDratProofHandler() const { return drat_proof_handler_; } - void SetDratProofHandler(DratProofHandler* drat_proof_handler) { - drat_proof_handler_ = drat_proof_handler; - } - ClauseId GetClauseId(const SatClause* clause) const { const auto it = clause_id_.find(clause); return it != clause_id_.end() ? it->second : kNoClauseId; @@ -457,7 +451,6 @@ class ClauseManager : public SatPropagator { // Only contains removable clause. absl::flat_hash_map clauses_info_; - DratProofHandler* drat_proof_handler_ = nullptr; LratProofHandler* lrat_proof_handler_ = nullptr; // Temporary member used when adding LRAT inferred clauses. @@ -584,7 +577,12 @@ class BinaryImplicationGraph : public SatPropagator { bool IsEmpty() const final { return no_constraint_ever_added_; } // Adds the binary clause (a OR b), which is the same as (not a => b). - // Note that it is also equivalent to (not b => a). + // Note that it is also equivalent to (not b => a). More precisely, adds the + // binary clause (rep(a) OR rep(b)), where rep(l) is the representative of l. + // If they are different from a and b, a new inferred LRAT clause is also + // added (if an LRAT proof handler is set), with a new clause ID (and the old + // LRAT `id` clause is deleted, unless `delete_non_representative_id` is + // false). // // Preconditions: // - If we are at root node, then none of the literal should be assigned. @@ -593,8 +591,10 @@ class BinaryImplicationGraph : public SatPropagator { // - If we are at a positive decision level, we will propagate something if // we can. However, if both literal are false, we will just return false // and do nothing. In all other case, we will return true. - bool AddBinaryClause(ClauseId id, Literal a, Literal b) { - return AddBinaryClauseInternal(id, a, b, /*change_reason=*/false); + bool AddBinaryClause(ClauseId id, Literal a, Literal b, + bool delete_non_representative_id = true) { + return AddBinaryClauseInternal(id, a, b, /*change_reason=*/false, + delete_non_representative_id); } bool AddBinaryClause(Literal a, Literal b) { return AddBinaryClause(kNoClauseId, a, b); @@ -832,8 +832,6 @@ class BinaryImplicationGraph : public SatPropagator { } } - void SetDratProofHandler(DratProofHandler* drat_proof_handler); - // Adds a binary clause and changes the reason of `a` as if it were propagated // by this new clause. // This allows inprocessing to shrink clauses to binary without backtracking @@ -929,7 +927,8 @@ class BinaryImplicationGraph : public SatPropagator { friend class LratEquivalenceHelper; bool AddBinaryClauseInternal(ClauseId id, Literal a, Literal b, - bool change_reason = false); + bool change_reason = false, + bool delete_non_representative_id = true); // Marks implications_[a] for cleanup in RemoveDuplicatesAndFixedVariables(). void NotifyPossibleDuplicate(Literal a); @@ -1001,7 +1000,6 @@ class BinaryImplicationGraph : public SatPropagator { ModelRandomGenerator* random_; Trail* trail_; ClauseIdGenerator* clause_id_generator_; - DratProofHandler* drat_proof_handler_ = nullptr; LratProofHandler* lrat_proof_handler_ = nullptr; LratEquivalenceHelper* lrat_helper_ = nullptr; diff --git a/ortools/sat/cp_model_checker.cc b/ortools/sat/cp_model_checker.cc index 8ad69b431a..0de2876757 100644 --- a/ortools/sat/cp_model_checker.cc +++ b/ortools/sat/cp_model_checker.cc @@ -510,6 +510,24 @@ std::string ValidateElementConstraint(const CpModelProto& model, return ""; } +std::string ValidateInverseConstraint(const CpModelProto& model, + const ConstraintProto& ct) { + if (ct.inverse().f_direct().size() != ct.inverse().f_inverse().size()) { + return absl::StrCat("Non-matching fields size in inverse: ", + ProtobufShortDebugString(ct)); + } + const InverseConstraintProto& inverse = ct.inverse(); + for (const auto* vars : {&inverse.f_direct(), &inverse.f_inverse()}) { + for (const int var : *vars) { + if (!VariableIndexIsValid(model, var)) { + return absl::StrCat("Invalid variable index in inverse constraint: ", + var); + } + } + } + return ""; +} + std::string ValidateTableConstraint(const CpModelProto& model, const ConstraintProto& ct) { const TableConstraintProto& arg = ct.table(); @@ -1147,10 +1165,7 @@ std::string ValidateCpModel(const CpModelProto& model, bool after_presolve) { RETURN_IF_NOT_EMPTY(ValidateIntModConstraint(model, ct)); break; case ConstraintProto::ConstraintCase::kInverse: - if (ct.inverse().f_direct().size() != ct.inverse().f_inverse().size()) { - return absl::StrCat("Non-matching fields size in inverse: ", - ProtobufShortDebugString(ct)); - } + RETURN_IF_NOT_EMPTY(ValidateInverseConstraint(model, ct)); break; case ConstraintProto::ConstraintCase::kAllDiff: for (const LinearExpressionProto& expr : ct.all_diff().exprs()) { diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index 811dd7c2e6..d747725ba3 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -8761,8 +8761,7 @@ bool CpModelPresolver::PresolvePureSatPart() { // Probe + find equivalent literals. // TODO(user): Use a derived time limit in the probing phase. - ProbeAndFindEquivalentLiteral(sat_solver, &sat_postsolver, - /*drat_proof_handler=*/nullptr, &equiv_map, + ProbeAndFindEquivalentLiteral(sat_solver, &sat_postsolver, &equiv_map, logger_); if (sat_solver->ModelIsUnsat()) return false; @@ -13171,13 +13170,15 @@ void CpModelPresolver::ProcessVariableOnlyUsedInEncoding(int var) { return; } - const auto presolve_one_constraint = [this](ConstraintProto* ct) { - CHECK(ct->has_exactly_one()); - PresolveExactlyOne(ct); - }; - - TryToReplaceVariableByItsEncoding(var, presolve_one_constraint, context_, + int new_exo_to_presolve_index = -1; + TryToReplaceVariableByItsEncoding(var, new_exo_to_presolve_index, context_, solution_crush_); + if (new_exo_to_presolve_index != -1) { + if (PresolveExactlyOne(context_->working_model->mutable_constraints( + new_exo_to_presolve_index))) { + context_->UpdateConstraintVariableUsage(new_exo_to_presolve_index); + } + } } void CpModelPresolver::TryToSimplifyDomain(int var) { diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index 25499d4f2a..7453ee558f 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -72,7 +72,6 @@ #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/diffn_util.h" #include "ortools/sat/drat_checker.h" -#include "ortools/sat/drat_proof_handler.h" #include "ortools/sat/feasibility_jump.h" #include "ortools/sat/feasibility_pump.h" #include "ortools/sat/integer.h" @@ -135,20 +134,24 @@ ABSL_FLAG(bool, cp_model_fingerprint_model, true, "Fingerprint the model."); ABSL_FLAG(std::string, cp_model_drat_output, "", "If non-empty, a proof in DRAT format will be written to this file. " - "This will only be used for pure-SAT problems. And, as of September " - "2025, only works with a single worker, no presolve, and symmetry " - "level 0 or 1."); + "This only works in the same conditions as the --cp_model_lrat_check " + "flag, and only for pure SAT models."); ABSL_FLAG(bool, cp_model_drat_check, false, "If true, a proof in DRAT format will be stored in memory and " - "checked if the problem is UNSAT. This will only be used for " - "pure-SAT problems. And, as of September 2025, only works with a " - "single worker, no presolve, and symmetry level 0 or 1."); + "checked if the problem is UNSAT. This only works in the same " + "conditions as the --cp_model_lrat_check flag, and only for pure SAT " + "models."); ABSL_FLAG(bool, cp_model_lrat_check, false, - "If true, inferred clauses are checker with an LRAT checker as they " - "are learned. As of October 2025, this is currently being " - "implemented and is not working yet."); + "If true, inferred clauses are checked with an LRAT checker as they " + "are learned. As of November 2025, this only works with a single " + "worker and symmetry level 0 or 1. This also works with presolve, if " + "find_clauses_that_are_exactly_one is false and " + "merge_at_most_one_work_limit is 0. However, in this case, the " + "presolved problem is assumed to be correct, without proof. If the " + "model is not pure SAT, the checks are only partial (some clauses " + "can be assumed without proof)."); ABSL_FLAG(double, cp_model_max_drat_time_in_seconds, std::numeric_limits::infinity(), @@ -1032,13 +1035,32 @@ bool SolutionHintIsCompleteAndFeasible( } } +std::unique_ptr MaybeCreateLratProofHandler(Model* model) { + const bool check_lrat = absl::GetFlag(FLAGS_cp_model_lrat_check); + const bool check_drat = absl::GetFlag(FLAGS_cp_model_drat_check); + File* drat_output = nullptr; + if (!absl::GetFlag(FLAGS_cp_model_drat_output).empty()) { + CHECK_OK(file::Open(absl::GetFlag(FLAGS_cp_model_drat_output), "w", + &drat_output, file::Defaults())); + } + if (!check_lrat && !check_drat && drat_output == nullptr) return nullptr; + + // TODO(user): pass the [presolved] model proto to the handler, so that + // it can map internal problem clause IDs to constraint indices in the + // original model. This will be needed to write the LRAT proof in a file that + // can be checked with an external LRAT checker, expecting the standard LRAT + // ASCII file format (which requires problem clauses IDs between 1 and n). + return std::make_unique(model, check_lrat, check_drat, + drat_output, + /*in_binary_drat_format=*/false); +} + // Encapsulate a full CP-SAT solve without presolve in the SubSolver API. class FullProblemSolver : public SubSolver { public: FullProblemSolver(absl::string_view name, const SatParameters& local_parameters, bool split_in_chunks, - SharedClasses* shared, DratProofHandler* drat_proof_handler, - bool use_lrat_checker, bool stop_at_first_solution = false) + SharedClasses* shared, bool stop_at_first_solution = false) : SubSolver(name, stop_at_first_solution ? FIRST_SOLUTION : FULL_PROBLEM), shared_(shared), split_in_chunks_(split_in_chunks), @@ -1059,13 +1081,12 @@ class FullProblemSolver : public SubSolver { // by registering the SharedStatistics class with LNS local model. shared_->RegisterSharedClassesInLocalModel(&local_model_); - if (drat_proof_handler != nullptr) { - local_model_.GetOrCreate()->SetDratProofHandler( - drat_proof_handler); - } - if (use_lrat_checker) { - lrat_proof_handler_ = std::make_unique(&local_model_); - local_model_.Register(lrat_proof_handler_.get()); + std::unique_ptr lrat_proof_handler = + MaybeCreateLratProofHandler(&local_model_); + if (lrat_proof_handler != nullptr) { + local_model_.Register(lrat_proof_handler.get()); + local_model_.TakeOwnership(lrat_proof_handler.release()); + shared_->lrat_proof_status->NewSubSolver(); } // Setup the local logger, in multi-thread log_search_progress should be @@ -1084,13 +1105,21 @@ class FullProblemSolver : public SubSolver { shared_->stat_tables->AddLpStat(name(), &local_model_); shared_->stat_tables->AddSearchStat(name(), &local_model_); shared_->stat_tables->AddClausesStat(name(), &local_model_); - // TODO(user): move this to a final response post-processor? - if (lrat_proof_handler_ != nullptr && - local_model_.GetOrCreate()->ModelIsUnsat() && - !lrat_proof_handler_->Check()) { - auto* logger = local_model_.GetOrCreate(); - SOLVER_LOG(logger, - absl::StrFormat("ERROR: LRAT proof invalid for %s", name())); + LratProofHandler* lrat_proof_handler = + local_model_.Mutable(); + if (lrat_proof_handler != nullptr) { + WallTimer timer; + timer.Start(); + const bool valid = local_model_.GetOrCreate()->ModelIsUnsat() + ? lrat_proof_handler->Check(absl::GetFlag( + FLAGS_cp_model_max_drat_time_in_seconds)) + : lrat_proof_handler->Valid(); + shared_->lrat_proof_status->NewSubsolverProofStatus( + valid ? DratChecker::Status::VALID : DratChecker::Status::INVALID, + lrat_proof_handler->lrat_check_enabled(), + lrat_proof_handler->drat_check_enabled(), + lrat_proof_handler->num_assumed_clauses(), timer.Get()); + lrat_proof_handler->AddStats(); } } @@ -1225,7 +1254,6 @@ class FullProblemSolver : public SubSolver { const bool split_in_chunks_; const bool stop_at_first_solution_; Model local_model_; - std::unique_ptr lrat_proof_handler_; // The first chunk is special. It is the one in which we load the model and // try to follow the hint. @@ -1773,6 +1801,14 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) { const SatParameters& params = *global_model->GetOrCreate(); if (global_model->GetOrCreate()->LimitReached()) return; + if (absl::GetFlag(FLAGS_cp_model_drat_check) || + !absl::GetFlag(FLAGS_cp_model_drat_output).empty()) { + LOG(WARNING) + << "DRAT check and output are skipped when using several workers"; + absl::SetFlag(&FLAGS_cp_model_drat_check, false); + absl::SetFlag(&FLAGS_cp_model_drat_output, ""); + } + // If specified by the user, we might disable some parameters based on their // name. SubsolverNameFilter name_filter(params); @@ -1836,8 +1872,7 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) { num_shared_tree_workers)) { full_worker_subsolvers.push_back(std::make_unique( local_params.name(), local_params, - /*split_in_chunks=*/params.interleave_search(), shared, - /*drat_proof_handler=*/nullptr, /*use_lrat_checker=*/false)); + /*split_in_chunks=*/params.interleave_search(), shared)); } } @@ -1860,8 +1895,7 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) { full_worker_subsolvers.push_back(std::make_unique( local_params.name(), local_params, - /*split_in_chunks=*/params.interleave_search(), shared, - /*drat_proof_handler=*/nullptr, /*use_lrat_checker=*/false)); + /*split_in_chunks=*/params.interleave_search(), shared)); } // Add FeasibilityPumpSolver if enabled. @@ -2209,7 +2243,6 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) { std::make_unique( local_params.name(), local_params, /*split_in_chunks=*/local_params.interleave_search(), shared, - /*drat_proof_handler=*/nullptr, /*use_lrat_checker=*/false, /*stop_on_first_solution=*/true)); } } @@ -2382,64 +2415,6 @@ void MergeParamsWithFlagsAndDefaults(SatParameters* params) { } } -std::unique_ptr MaybeCreateDratProofHandler( - const CpModelProto& model_proto, - SharedResponseManager* shared_response_manager, SolverLogger* logger) { - std::unique_ptr drat_proof_handler; - if (!absl::GetFlag(FLAGS_cp_model_drat_output).empty() || - absl::GetFlag(FLAGS_cp_model_drat_check)) { - if (!absl::GetFlag(FLAGS_cp_model_drat_output).empty()) { - File* output; - CHECK_OK(file::Open(absl::GetFlag(FLAGS_cp_model_drat_output), "w", - &output, file::Defaults())); - drat_proof_handler = std::make_unique( - /*in_binary_format=*/false, output, - absl::GetFlag(FLAGS_cp_model_drat_check)); - } else { - drat_proof_handler = std::make_unique(); - } - if (!LoadCpModelInDratProofHandler(model_proto, drat_proof_handler.get())) { - SOLVER_LOG(logger, - "Model is not pure SAT: cannot output nor check DRAT proof"); - return nullptr; - } - } else { - return nullptr; - } - shared_response_manager->AddFinalResponsePostprocessor( - [logger, drat_proof_handler = - drat_proof_handler.get()](CpSolverResponse* response) { - if (absl::GetFlag(FLAGS_cp_model_drat_check) && - response->status() == CpSolverStatus::INFEASIBLE) { - WallTimer drat_timer; - drat_timer.Start(); - DratChecker::Status drat_status = drat_proof_handler->Check( - absl::GetFlag(FLAGS_cp_model_max_drat_time_in_seconds)); - switch (drat_status) { - case DratChecker::UNKNOWN: - SOLVER_LOG(logger, "DRAT_status: UNKNOWN"); - break; - case DratChecker::VALID: - SOLVER_LOG(logger, "DRAT_status: VALID"); - break; - case DratChecker::INVALID: - SOLVER_LOG(logger, "DRAT_status: INVALID"); - break; - default: - // Should not happen. - break; - } - SOLVER_LOG(logger, "DRAT_walltime: ", drat_timer.Get()); - } else { - // Always log a DRAT status to make it easier to extract it from a - // multirun result with awk. - SOLVER_LOG(logger, "DRAT_status: NA"); - SOLVER_LOG(logger, "DRAT_walltime: NA"); - } - }); - return drat_proof_handler; -} - void FixVariablesToHintValue(const PartialVariableAssignment& solution_hint, PresolveContext* context, SolverLogger* logger) { SOLVER_LOG(logger, "Fixing ", solution_hint.vars().size(), @@ -2641,9 +2616,6 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { SOLVER_LOG(logger, ""); SOLVER_LOG(logger, "Initial ", CpModelStats(model_proto)); - std::unique_ptr drat_proof_handler = - MaybeCreateDratProofHandler(model_proto, shared_response_manager, logger); - // Presolve and expansions. SOLVER_LOG(logger, ""); SOLVER_LOG(logger, @@ -3116,8 +3088,7 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { // the multi-thread architecture. std::vector> subsolvers; subsolvers.push_back(std::make_unique( - "main", params, /*split_in_chunks=*/false, &shared, - drat_proof_handler.get(), absl::GetFlag(FLAGS_cp_model_lrat_check))); + "main", params, /*split_in_chunks=*/false, &shared)); LaunchSubsolvers(params, &shared, subsolvers, {}); } } diff --git a/ortools/sat/cp_model_solver.h b/ortools/sat/cp_model_solver.h index 623c86c25f..3dc66e9751 100644 --- a/ortools/sat/cp_model_solver.h +++ b/ortools/sat/cp_model_solver.h @@ -28,6 +28,7 @@ OR_DLL ABSL_DECLARE_FLAG(bool, cp_model_dump_response); OR_DLL ABSL_DECLARE_FLAG(bool, cp_model_drat_check); OR_DLL ABSL_DECLARE_FLAG(bool, cp_model_lrat_check); +OR_DLL ABSL_DECLARE_FLAG(double, cp_model_max_drat_time_in_seconds); #endif namespace operations_research { diff --git a/ortools/sat/cp_model_solver_helpers.cc b/ortools/sat/cp_model_solver_helpers.cc index 0d5b6e992b..f8861bd359 100644 --- a/ortools/sat/cp_model_solver_helpers.cc +++ b/ortools/sat/cp_model_solver_helpers.cc @@ -27,6 +27,7 @@ #include "ortools/base/logging.h" #include "ortools/base/timer.h" +#include "ortools/sat/lrat_proof_handler.h" #if !defined(__PORTABLE_PLATFORM__) #include "ortools/base/helpers.h" #include "ortools/base/options.h" @@ -820,73 +821,87 @@ void RegisterVariableBoundsLevelZeroImport( auto* trail = model->GetOrCreate(); auto* sat_solver = model->GetOrCreate(); auto* mapping = model->GetOrCreate(); + auto* lrat_proof_handler = model->Mutable(); + auto* clause_id_generator = model->GetOrCreate(); const int id = shared_bounds_manager->RegisterNewId(); - const auto& import_level_zero_bounds = [&model_proto, shared_bounds_manager, - name = name, sat_solver, - integer_trail, trail, id, mapping]() { - std::vector model_variables; - std::vector new_lower_bounds; - std::vector new_upper_bounds; - shared_bounds_manager->GetChangedBounds( - id, &model_variables, &new_lower_bounds, &new_upper_bounds); - for (int i = 0; i < model_variables.size(); ++i) { - const int model_var = model_variables[i]; + const auto& import_level_zero_bounds = + [&model_proto, shared_bounds_manager, name = name, sat_solver, + integer_trail, trail, lrat_proof_handler, clause_id_generator, id, + mapping]() { + std::vector model_variables; + std::vector new_lower_bounds; + std::vector new_upper_bounds; + shared_bounds_manager->GetChangedBounds( + id, &model_variables, &new_lower_bounds, &new_upper_bounds); + for (int i = 0; i < model_variables.size(); ++i) { + const int model_var = model_variables[i]; - // If this is a Boolean, fix it if not already done. - // Note that it is important not to use AddUnitClause() as we do not - // want to propagate after each addition. - if (mapping->IsBoolean(model_var)) { - Literal lit = mapping->Literal(model_var); - if (new_upper_bounds[i] == 0) lit = lit.Negated(); - if (trail->Assignment().LiteralIsTrue(lit)) continue; - if (trail->Assignment().LiteralIsFalse(lit)) { - sat_solver->NotifyThatModelIsUnsat(); - return false; + // If this is a Boolean, fix it if not already done. + // Note that it is important not to use AddUnitClause() as we do not + // want to propagate after each addition. + if (mapping->IsBoolean(model_var)) { + Literal lit = mapping->Literal(model_var); + if (new_upper_bounds[i] == 0) lit = lit.Negated(); + if (trail->Assignment().LiteralIsTrue(lit)) continue; + ClauseId clause_id = kNoClauseId; + if (lrat_proof_handler != nullptr) { + clause_id = clause_id_generator->GetNextId(); + lrat_proof_handler->AddSharedClause(clause_id, {lit}); + } + if (trail->Assignment().LiteralIsFalse(lit)) { + if (lrat_proof_handler != nullptr) { + // Add the UNSAT proof. + lrat_proof_handler->AddInferredClause( + clause_id_generator->GetNextId(), {}, + {clause_id, trail->GetUnitClauseId(lit.Variable())}); + } + sat_solver->NotifyThatModelIsUnsat(); + return false; + } + trail->EnqueueWithUnitReason(clause_id, lit); + continue; + } + + // Deal with integer. + if (!mapping->IsInteger(model_var)) continue; + const IntegerVariable var = mapping->Integer(model_var); + const IntegerValue new_lb(new_lower_bounds[i]); + const IntegerValue new_ub(new_upper_bounds[i]); + const IntegerValue old_lb = integer_trail->LowerBound(var); + const IntegerValue old_ub = integer_trail->UpperBound(var); + const bool changed_lb = new_lb > old_lb; + const bool changed_ub = new_ub < old_ub; + if (!changed_lb && !changed_ub) continue; + + if (VLOG_IS_ON(3)) { + const IntegerVariableProto& var_proto = + model_proto.variables(model_var); + const std::string& var_name = + var_proto.name().empty() + ? absl::StrCat("anonymous_var(", model_var, ")") + : var_proto.name(); + LOG(INFO) << " '" << name << "' imports new bounds for " + << var_name << ": from [" << old_lb << ", " << old_ub + << "] to [" << new_lb << ", " << new_ub << "]"; + } + + if (changed_lb && + !integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(var, new_lb), {}, {})) { + return false; + } + if (changed_ub && + !integer_trail->Enqueue(IntegerLiteral::LowerOrEqual(var, new_ub), + {}, {})) { + return false; + } } - trail->EnqueueWithUnitReason(lit); - continue; - } - // Deal with integer. - if (!mapping->IsInteger(model_var)) continue; - const IntegerVariable var = mapping->Integer(model_var); - const IntegerValue new_lb(new_lower_bounds[i]); - const IntegerValue new_ub(new_upper_bounds[i]); - const IntegerValue old_lb = integer_trail->LowerBound(var); - const IntegerValue old_ub = integer_trail->UpperBound(var); - const bool changed_lb = new_lb > old_lb; - const bool changed_ub = new_ub < old_ub; - if (!changed_lb && !changed_ub) continue; - - if (VLOG_IS_ON(3)) { - const IntegerVariableProto& var_proto = - model_proto.variables(model_var); - const std::string& var_name = - var_proto.name().empty() - ? absl::StrCat("anonymous_var(", model_var, ")") - : var_proto.name(); - LOG(INFO) << " '" << name << "' imports new bounds for " << var_name - << ": from [" << old_lb << ", " << old_ub << "] to [" - << new_lb << ", " << new_ub << "]"; - } - - if (changed_lb && - !integer_trail->Enqueue(IntegerLiteral::GreaterOrEqual(var, new_lb), - {}, {})) { - return false; - } - if (changed_ub && - !integer_trail->Enqueue(IntegerLiteral::LowerOrEqual(var, new_ub), {}, - {})) { - return false; - } - } - - // Note that we will propagate if they are new bounds separately. - // See BeforeTakingDecision(). - return true; - }; + // Note that we will propagate if they are new bounds separately. + // See BeforeTakingDecision(). + return true; + }; model->GetOrCreate()->callbacks.push_back( import_level_zero_bounds); } @@ -1122,7 +1137,7 @@ int RegisterClausesLevelZeroImport(int id, for (const auto& [ref1, ref2] : new_binary_clauses) { const Literal l1 = mapping->Literal(ref1); const Literal l2 = mapping->Literal(ref2); - if (!sat_solver->AddProblemClause({l1, l2})) { + if (!sat_solver->AddProblemClause({l1, l2}, /*shared=*/true)) { return false; } } @@ -1146,8 +1161,8 @@ int RegisterClausesLevelZeroImport(int id, local_clause[i] = mapping->Literal(shared_clause[i]); } if (!sat_solver->AddProblemClause( - absl::MakeSpan(local_clause) - .subspan(0, shared_clause.size()))) { + absl::MakeSpan(local_clause).subspan(0, shared_clause.size()), + /*shared=*/true)) { return false; } } @@ -1389,6 +1404,9 @@ void LoadBaseModel(const CpModelProto& model_proto, Model* model) { SOLVER_LOG(logger, "BUG: We will wrongly report INFEASIBLE now."); return unsat(); } + if (model->Mutable() != nullptr) { + model->Mutable()->EndProblemClauses(); + } model->GetOrCreate() ->AddAllImplicationsBetweenAssociatedLiterals(); @@ -2223,7 +2241,8 @@ SharedClasses::SharedClasses(const CpModelProto* proto, Model* global_model) response(global_model->GetOrCreate()), shared_tree_manager(global_model->GetOrCreate()), ls_hints(global_model->GetOrCreate()), - progress_logger(global_model->GetOrCreate()) { + progress_logger(global_model->GetOrCreate()), + lrat_proof_status(global_model->GetOrCreate()) { const SatParameters& params = *global_model->GetOrCreate(); if (params.share_level_zero_bounds()) { @@ -2329,6 +2348,8 @@ void SharedClasses::LogFinalStatistics() { // Extra logging if needed. Note that these are mainly activated on // --vmodule *some_file*=1 and are here for development. stats->Log(logger); + + lrat_proof_status->Log(logger); } } // namespace sat diff --git a/ortools/sat/cp_model_solver_helpers.h b/ortools/sat/cp_model_solver_helpers.h index a3a4a9a225..0a5aa17aab 100644 --- a/ortools/sat/cp_model_solver_helpers.h +++ b/ortools/sat/cp_model_solver_helpers.h @@ -56,6 +56,7 @@ struct SharedClasses { SharedTreeManager* const shared_tree_manager; SharedLsSolutionRepository* const ls_hints; SolverProgressLogger* const progress_logger; + SharedLratProofStatus* const lrat_proof_status; // These can be nullptr depending on the options. std::unique_ptr bounds; diff --git a/ortools/sat/cp_model_solver_test.cc b/ortools/sat/cp_model_solver_test.cc index b55b08a95e..e5a223885e 100644 --- a/ortools/sat/cp_model_solver_test.cc +++ b/ortools/sat/cp_model_solver_test.cc @@ -32,8 +32,6 @@ #include "ortools/sat/cp_model_solver_helpers.h" #include "ortools/sat/cp_model_test_utils.h" #include "ortools/sat/cp_model_utils.h" -#include "ortools/sat/drat_checker.h" -#include "ortools/sat/drat_proof_handler.h" #include "ortools/sat/lp_utils.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" @@ -5461,35 +5459,23 @@ TEST(PresolveCpModelTest, SolutionCrushBug) { } TEST(CpModelSolverTest, DratProofIsValidForRandom3Sat) { + SatParameters params; + params.set_num_workers(1); + params.set_cp_model_presolve(false); + params.set_symmetry_level(0); + params.set_linearization_level(0); + params.set_debug_crash_if_lrat_check_fails(true); + absl::SetFlag(&FLAGS_cp_model_drat_check, true); + absl::SetFlag(&FLAGS_cp_model_max_drat_time_in_seconds, 60); + int num_infeasible = 0; for (int i = 0; i < 100; ++i) { - Model model; - SatSolver& solver = *model.GetOrCreate(); - auto drat_proof_handler = std::make_unique(); - solver.SetDratProofHandler(drat_proof_handler.get()); - const int kNumVariables = 100; CpModelProto model_proto = Random3SatProblem(kNumVariables); - drat_proof_handler->SetNumVariables(model_proto.variables_size()); - for (const ConstraintProto& ct : model_proto.constraints()) { - if (ct.constraint_case() == ConstraintProto::ConstraintCase::kBoolOr) { - std::vector clause; - for (const int ref : ct.bool_or().literals()) { - clause.push_back( - Literal(BooleanVariable(PositiveRef(ref)), RefIsPositive(ref))); - } - drat_proof_handler->AddProblemClause(clause); - } - } - - LoadCpModel(model_proto, &model); - SolveLoadedCpModel(model_proto, &model); - if (model.GetOrCreate()->GetResponse().status() == - CpSolverStatus::INFEASIBLE) { + CpSolverResponse response = SolveWithParameters(model_proto, params); + if (response.status() == CpSolverStatus::INFEASIBLE) { ++num_infeasible; - EXPECT_EQ(drat_proof_handler->Check(/*max_time_in_seconds=*/60), - DratChecker::Status::VALID); } } LOG(INFO) << "num_infeasible: " << num_infeasible; diff --git a/ortools/sat/cp_model_utils.cc b/ortools/sat/cp_model_utils.cc index 327bf053de..678cc67781 100644 --- a/ortools/sat/cp_model_utils.cc +++ b/ortools/sat/cp_model_utils.cc @@ -35,7 +35,6 @@ #include "google/protobuf/text_format.h" #include "ortools/base/stl_util.h" #include "ortools/sat/cp_model.pb.h" -#include "ortools/sat/drat_proof_handler.h" #include "ortools/sat/sat_base.h" #include "ortools/util/saturated_arithmetic.h" #include "ortools/util/sorted_interval_list.h" @@ -1145,20 +1144,6 @@ bool ConvertCpModelProtoToWCnf(const CpModelProto& cp_model, std::string* out) { return true; } -bool LoadCpModelInDratProofHandler(const CpModelProto& cp_model, - DratProofHandler* drat_proof_handler) { - const int num_vars = cp_model.variables().size(); - int num_clauses = 0; - if (!ModelIsPureSat(cp_model, &num_clauses)) return false; - - drat_proof_handler->SetNumVariables(num_vars); - ConvertSatCpModelProtoToClauses( - cp_model, [&drat_proof_handler](const std::vector& clause) { - drat_proof_handler->AddProblemClause(clause); - }); - return true; -} - int CombineSeed(int base_seed, int64_t delta) { CHECK_GE(delta, 0); const uint64_t fp = FingerprintSingleField(delta, kDefaultFingerprintSeed); diff --git a/ortools/sat/cp_model_utils.h b/ortools/sat/cp_model_utils.h index a57ee9e607..8d8979677f 100644 --- a/ortools/sat/cp_model_utils.h +++ b/ortools/sat/cp_model_utils.h @@ -16,13 +16,10 @@ #include #include -#include #include #include #include -#include "ortools/sat/drat_proof_handler.h" - #if !defined(__PORTABLE_PLATFORM__) #include "ortools/base/helpers.h" #endif // !defined(__PORTABLE_PLATFORM__) @@ -435,11 +432,6 @@ bool ConvertCpModelProtoToCnf(const CpModelProto& cp_model, std::string* out); // https://maxsat-evaluations.github.io/2022/rules.html bool ConvertCpModelProtoToWCnf(const CpModelProto& cp_model, std::string* out); -// Loads the model in the DratProofHandler and returns true if successful. -// Returns false if the model is not pure SAT. -bool LoadCpModelInDratProofHandler(const CpModelProto& cp_model, - DratProofHandler* drat_proof_handler); - // We assume delta >= 0 and we only use the low bit of delta. int CombineSeed(int base_seed, int64_t delta); diff --git a/ortools/sat/drat_checker.h b/ortools/sat/drat_checker.h index 37706eca5d..d666e79639 100644 --- a/ortools/sat/drat_checker.h +++ b/ortools/sat/drat_checker.h @@ -57,7 +57,7 @@ class DratChecker { void AddProblemClause(absl::Span clause); // Adds a clause which is inferred from the problem clauses and the previously - // inferred clauses (that are have not been deleted). inferred clauses must be + // inferred clauses (that are have not been deleted). Inferred clauses must be // added after the problem clauses. Clauses with the Reverse Asymmetric // Tautology (RAT) property for literal l must start with this literal. The // given clause must not contain a literal and its negation. Must not be diff --git a/ortools/sat/drat_proof_handler.cc b/ortools/sat/drat_proof_handler.cc deleted file mode 100644 index 7058d70d30..0000000000 --- a/ortools/sat/drat_proof_handler.cc +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2010-2025 Google LLC -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ortools/sat/drat_proof_handler.h" - -#include -#include -#include -#include -#if !defined(__PORTABLE_PLATFORM__) -#include "ortools/base/file.h" -#endif // !defined(__PORTABLE_PLATFORM__) -#include "absl/log/check.h" -#include "absl/types/span.h" -#include "ortools/base/strong_vector.h" -#include "ortools/sat/drat_checker.h" -#include "ortools/sat/drat_writer.h" -#include "ortools/sat/sat_base.h" -#include "ortools/util/strong_integers.h" - -namespace operations_research { -namespace sat { - -DratProofHandler::DratProofHandler() - : variable_index_(0), drat_checker_(new DratChecker()) {} - -DratProofHandler::DratProofHandler(bool in_binary_format, File* output, - bool check) - : variable_index_(0), - drat_writer_(new DratWriter(in_binary_format, output)) { - if (check) { - drat_checker_ = std::make_unique(); - } -} - -void DratProofHandler::ApplyMapping( - const util_intops::StrongVector& - mapping) { - util_intops::StrongVector new_mapping; - for (BooleanVariable v(0); v < mapping.size(); ++v) { - const BooleanVariable image = mapping[v]; - if (image != kNoBooleanVariable) { - if (image >= new_mapping.size()) - new_mapping.resize(image.value() + 1, kNoBooleanVariable); - CHECK_EQ(new_mapping[image], kNoBooleanVariable); - new_mapping[image] = - v < reverse_mapping_.size() ? reverse_mapping_[v] : v; - CHECK_NE(new_mapping[image], kNoBooleanVariable); - } - } - std::swap(new_mapping, reverse_mapping_); -} - -void DratProofHandler::SetNumVariables(int num_variables) { - CHECK_GE(num_variables, reverse_mapping_.size()); - while (reverse_mapping_.size() < num_variables) { - reverse_mapping_.push_back(BooleanVariable(variable_index_++)); - } -} - -void DratProofHandler::AddOneVariable() { - reverse_mapping_.push_back(BooleanVariable(variable_index_++)); -} - -void DratProofHandler::AddProblemClause(absl::Span clause) { - if (drat_checker_ != nullptr) { - drat_checker_->AddProblemClause(clause); - } -} - -void DratProofHandler::AddClause(absl::Span clause) { - MapClause(clause); - if (drat_checker_ != nullptr) { - drat_checker_->AddInferredClause(values_); - } - if (drat_writer_ != nullptr) { - drat_writer_->AddClause(values_); - } -} - -void DratProofHandler::DeleteClause(absl::Span clause) { - MapClause(clause); - if (drat_checker_ != nullptr) { - drat_checker_->DeleteClause(values_); - } - if (drat_writer_ != nullptr) { - drat_writer_->DeleteClause(values_); - } -} - -DratChecker::Status DratProofHandler::Check(double max_time_in_seconds) { - if (drat_checker_ != nullptr) { - // The empty clause is not explicitly added by the solver. - drat_checker_->AddInferredClause({}); - return drat_checker_->Check(max_time_in_seconds); - } - return DratChecker::Status::UNKNOWN; -} - -void DratProofHandler::MapClause(absl::Span clause) { - values_.clear(); - for (const Literal l : clause) { - CHECK_LT(l.Variable(), reverse_mapping_.size()); - const Literal original_literal = - Literal(reverse_mapping_[l.Variable()], l.IsPositive()); - values_.push_back(original_literal); - } - - // The sorting is such that new variables appear first. This is important for - // BVA since DRAT-trim only check the RAT property with respect to the first - // variable of the clause. - std::sort(values_.begin(), values_.end(), [](Literal a, Literal b) { - return std::abs(a.SignedValue()) > std::abs(b.SignedValue()); - }); -} - -} // namespace sat -} // namespace operations_research diff --git a/ortools/sat/drat_proof_handler.h b/ortools/sat/drat_proof_handler.h deleted file mode 100644 index 958483af3a..0000000000 --- a/ortools/sat/drat_proof_handler.h +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright 2010-2025 Google LLC -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef ORTOOLS_SAT_DRAT_PROOF_HANDLER_H_ -#define ORTOOLS_SAT_DRAT_PROOF_HANDLER_H_ - -#include -#include - -#if !defined(__PORTABLE_PLATFORM__) -#include "ortools/base/file.h" -#endif // !defined(__PORTABLE_PLATFORM__) -#include "absl/types/span.h" -#include "ortools/base/strong_vector.h" -#include "ortools/sat/drat_checker.h" -#include "ortools/sat/drat_writer.h" -#include "ortools/sat/sat_base.h" - -namespace operations_research { -namespace sat { - -// DRAT is a SAT proof format that allows a simple program to check that the -// problem is really UNSAT. The description of the format and a checker are -// available at: // http://www.cs.utexas.edu/~marijn/drat-trim/ -// -// Note that DRAT proofs are often huge (can be GB), and take about as much time -// to check as it takes for the solver to find the proof in the first place! -// -// This class is used to build the SAT proof, and can either save it to disk, -// and/or store it in memory (in which case the proof can be checked when it is -// complete). -class DratProofHandler { - public: - // Use this constructor to store the DRAT proof in memory. The proof will not - // be written to disk, and can be checked with Check() when it is complete. - DratProofHandler(); - // Use this constructor to write the DRAT proof to disk, and to optionally - // store it in memory as well (in which case the proof can be checked with - // Check() when it is complete). - DratProofHandler(bool in_binary_format, File* output, bool check = false); - ~DratProofHandler() = default; - - // During the presolve step, variable get deleted and the set of non-deleted - // variable is remapped in a dense set. This allows to keep track of that and - // always output the DRAT clauses in term of the original variables. Must be - // called before adding or deleting clauses AddClause() or DeleteClause(). - // - // TODO(user): This is exactly the same mechanism as in the SatPostsolver - // class. Factor out the code. - void ApplyMapping(const util_intops::StrongVector& mapping); - - // This need to be called when new variables are created. - void SetNumVariables(int num_variables); - void AddOneVariable(); - - // Adds a clause of the UNSAT problem. This must be called before any call to - // AddClause() or DeleteClause(), in order to be able to check the DRAT proof - // with the Check() method when it is complete. - void AddProblemClause(absl::Span clause); - - // Writes a new clause to the DRAT output. The output clause is sorted so that - // newer variables always comes first. This is needed because in the DRAT - // format, the clause is checked for the RAT property with only its first - // literal. Must not be called after Check(). - void AddClause(absl::Span clause); - - // Writes a "deletion" information about a clause that has been added before - // to the DRAT output. Note that it is also possible to delete a clause from - // the problem. Must not be called after Check(). - // - // Because of a limitation a the DRAT-trim tool, it seems the order of the - // literals during addition and deletion should be EXACTLY the same. Because - // of this we get warnings for problem clauses. - void DeleteClause(absl::Span clause); - - // Returns VALID if the DRAT proof is correct, INVALID if it is not correct, - // or UNKNOWN if proof checking was not enabled (by choosing the right - // constructor) or timed out. This requires the problem clauses to be - // specified with AddProblemClause(), before the proof itself. - // - // WARNING: no new clause must be added or deleted after this method has been - // called. - DratChecker::Status Check(double max_time_in_seconds); - - private: - void MapClause(absl::Span clause); - - // We need to keep track of the variable newly created. - int variable_index_; - - // Temporary vector used for sorting the outputted clauses. - std::vector values_; - - // This mapping will be applied to all clause passed to AddClause() or - // DeleteClause() so that they are in term of the original problem. - util_intops::StrongVector reverse_mapping_; - - std::unique_ptr drat_checker_; - std::unique_ptr drat_writer_; -}; - -} // namespace sat -} // namespace operations_research - -#endif // ORTOOLS_SAT_DRAT_PROOF_HANDLER_H_ diff --git a/ortools/sat/integer_search.cc b/ortools/sat/integer_search.cc index 0a3261a227..74e1558736 100644 --- a/ortools/sat/integer_search.cc +++ b/ortools/sat/integer_search.cc @@ -1372,6 +1372,8 @@ bool IntegerSearchHelper::BeforeTakingDecision() { if (integer_trail_->HasPendingRootLevelDeduction()) { sat_solver_->Backtrack(0); if (!sat_solver_->Propagate()) { + // This adds the UNSAT proof to the LRAT handler, if any. + sat_solver_->ProcessCurrentConflict(); sat_solver_->NotifyThatModelIsUnsat(); return false; } @@ -1398,6 +1400,8 @@ bool IntegerSearchHelper::BeforeTakingDecision() { integer_trail_->num_enqueues() > saved_integer_index || integer_trail_->HasPendingRootLevelDeduction()) { if (!sat_solver_->Propagate()) { + // This adds the UNSAT proof to the LRAT handler, if any. + sat_solver_->ProcessCurrentConflict(); sat_solver_->NotifyThatModelIsUnsat(); return false; } diff --git a/ortools/sat/lp_utils.cc b/ortools/sat/lp_utils.cc index 34cbbac6da..7bab178299 100644 --- a/ortools/sat/lp_utils.cc +++ b/ortools/sat/lp_utils.cc @@ -190,13 +190,14 @@ double GetIntegralityMultiplier(const MPModelProto& mp_model, } DCHECK_NE(var_coeff, 0.0); - // The constraint bound need to be infinite or integer. + // Makes sure that the constraint bound is infinite or integer. for (const double bound : {ct.lower_bound(), ct.upper_bound()}) { if (!std::isfinite(bound)) continue; - if (std::abs(std::round(bound * multiplier) - bound * multiplier) > - tolerance * multiplier) { - return 0.0; - } + + const double scaled_bound = multiplier * bound; + multiplier *= + FindRationalFactor(scaled_bound, /*limit=*/100, multiplier * tolerance); + if (multiplier == 0 || multiplier > max_multiplier) return 0.0; } return std::abs(multiplier * var_coeff); } diff --git a/ortools/sat/lrat_checker.cc b/ortools/sat/lrat_checker.cc index 23916c267d..f0e8bbda37 100644 --- a/ortools/sat/lrat_checker.cc +++ b/ortools/sat/lrat_checker.cc @@ -30,7 +30,7 @@ namespace operations_research { namespace sat { -LratChecker::~LratChecker() { +void LratChecker::AddStats() const { if (!VLOG_IS_ON(1)) return; stats_->AddStats( {{"LratChecker/num_problem_clauses", num_problem_clauses_}, diff --git a/ortools/sat/lrat_checker.h b/ortools/sat/lrat_checker.h index df4ad88fa0..271fc7de4b 100644 --- a/ortools/sat/lrat_checker.h +++ b/ortools/sat/lrat_checker.h @@ -38,7 +38,6 @@ class LratChecker { public: explicit LratChecker(Model* model) : stats_(model->GetOrCreate()) {} - ~LratChecker(); // The clause IDs used in a proof that a clause has a Resolution Asymmetric // Tautology (RAT) property. See AddInferredClause() for more details. @@ -95,6 +94,9 @@ class LratChecker { // already been deleted or has never been added. void DeleteClauses(absl::Span clause_ids); + // Returns true if all the operations made so far were valid. + bool Valid() const { return valid_; } + // Returns true if the unsatisfiability proof is valid and complete, i.e. // whether the empty clause has been successfully inferred. bool Check() { @@ -104,6 +106,8 @@ class LratChecker { return complete_; } + void AddStats() const; + // Returns the reason of the first failed operation, or an empty string if all // operations were successful. std::string_view error_message() const { return error_message_; } diff --git a/ortools/sat/lrat_proof_handler.cc b/ortools/sat/lrat_proof_handler.cc index 579ce4578c..4a19fafd94 100644 --- a/ortools/sat/lrat_proof_handler.cc +++ b/ortools/sat/lrat_proof_handler.cc @@ -13,12 +13,18 @@ #include "ortools/sat/lrat_proof_handler.h" +#include +#include #include +#include #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "ortools/base/file.h" +#include "ortools/sat/drat_checker.h" +#include "ortools/sat/drat_writer.h" #include "ortools/sat/lrat_checker.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" @@ -26,8 +32,30 @@ namespace operations_research { namespace sat { -LratProofHandler::LratProofHandler(Model* model) - : lrat_checker_(std::make_unique(model)), +namespace { +std::vector SortClauseForDrat(absl::Span clause) { + // The sorting is such that new variables appear first. This is important for + // BVA since DRAT-trim only check the RAT property with respect to the first + // variable of the clause. + std::vector sorted_clause(clause.begin(), clause.end()); + std::sort(sorted_clause.begin(), sorted_clause.end(), + [](Literal a, Literal b) { + return std::abs(a.SignedValue()) > std::abs(b.SignedValue()); + }); + return sorted_clause; +} +} // namespace + +LratProofHandler::LratProofHandler(Model* model, bool check_lrat, + bool check_drat, File* drat_output, + bool in_binary_drat_format) + : lrat_checker_(check_lrat ? std::make_unique(model) + : nullptr), + drat_checker_(check_drat ? std::make_unique() : nullptr), + drat_writer_( + drat_output != nullptr + ? std::make_unique(in_binary_drat_format, drat_output) + : nullptr), debug_crash_on_error_(model->GetOrCreate() ->debug_crash_if_lrat_check_fails()) {} @@ -35,7 +63,27 @@ bool LratProofHandler::AddProblemClause(ClauseId id, absl::Span clause) { VLOG(1) << "AddProblemClause: id=" << id << " literals=" << absl::StrJoin(clause, ","); - return CheckResult(lrat_checker_->AddProblemClause(id, clause)); + if (all_problem_clauses_loaded_ && debug_crash_on_error_) { + LOG(FATAL) << "LRAT error: problem clauses must not be added after " + "EndProblemClauses()"; + } + if (lrat_checker_ != nullptr) { + return CheckResult(lrat_checker_->AddProblemClause(id, clause)); + } + if (drat_checker_ != nullptr) { + drat_checker_->AddProblemClause(SortClauseForDrat(clause)); + } + return true; +} + +void LratProofHandler::EndProblemClauses() { + all_problem_clauses_loaded_ = true; + if (drat_checker_ != nullptr) { + for (const auto& clause : clauses_inferred_during_problem_loading_) { + drat_checker_->AddInferredClause(clause); + } + clauses_inferred_during_problem_loading_.clear(); + } } bool LratProofHandler::AddInferredClause( @@ -46,8 +94,36 @@ bool LratProofHandler::AddInferredClause( << " literals=" << absl::StrJoin(clause, ",") << " unit_ids=" << absl::StrJoin(unit_ids, ",") << " rat={" << absl::StrJoin(rat, " ") << "}"; - return CheckResult( - lrat_checker_->AddInferredClause(id, clause, unit_ids, rat)); + if (lrat_checker_ != nullptr) { + return CheckResult( + lrat_checker_->AddInferredClause(id, clause, unit_ids, rat)); + } + if (drat_checker_ != nullptr) { + if (all_problem_clauses_loaded_) { + drat_checker_->AddInferredClause(SortClauseForDrat(clause)); + } else { + clauses_inferred_during_problem_loading_.push_back( + SortClauseForDrat(clause)); + } + } + if (drat_writer_ != nullptr) { + drat_writer_->AddClause(clause); + } + return true; +} + +bool LratProofHandler::AddSharedClause(ClauseId id, + absl::Span clause) { + VLOG(1) << "AddSharedClause: id=" << id + << " literals=" << absl::StrJoin(clause, ","); + if (lrat_checker_ != nullptr) { + return CheckResult(lrat_checker_->AddProblemClause(id, clause)); + } + if (drat_checker_ != nullptr) { + LOG(ERROR) << "Shared clauses are not supported by the DRAT checker."; + return false; + } + return true; } bool LratProofHandler::AddAssumedClause(ClauseId id, @@ -57,24 +133,70 @@ bool LratProofHandler::AddAssumedClause(ClauseId id, if (debug_crash_on_error_) { LOG(FATAL) << "LRAT error: assumed clauses are not supposed to happen"; } - return CheckResult(lrat_checker_->AddProblemClause(id, clause)); + ++num_assumed_clauses_; + if (lrat_checker_ != nullptr) { + return CheckResult(lrat_checker_->AddProblemClause(id, clause)); + } + if (drat_checker_ != nullptr) { + // The DRAT checker requires all problem clauses first, followed by inferred + // clauses only. + LOG(ERROR) << "Assumed clauses are not supported by the DRAT checker."; + return false; + } + return true; } -void LratProofHandler::DeleteClauses(absl::Span clause_ids) { - VLOG(1) << "DeleteClauses: clause_ids=" << absl::StrJoin(clause_ids, " "); - lrat_checker_->DeleteClauses(clause_ids); +void LratProofHandler::DeleteClause(ClauseId id, + absl::Span clause) { + VLOG(1) << "DeleteClause: id=" << id + << " literals=" << absl::StrJoin(clause, ","); + if (drat_checker_ != nullptr) { + drat_checker_->DeleteClause(clause); + } + if (drat_writer_ != nullptr) { + drat_writer_->DeleteClause(clause); + } + if (lrat_checker_ != nullptr) { + lrat_checker_->DeleteClauses({id}); + } } -bool LratProofHandler::Check() const { - return CheckResult(lrat_checker_->Check()); +DratChecker::Status LratProofHandler::Valid() const { + if (lrat_checker_ != nullptr) { + return CheckResult(lrat_checker_->Valid()) ? DratChecker::Status::VALID + : DratChecker::Status::INVALID; + } + return DratChecker::Status::UNKNOWN; +} + +DratChecker::Status LratProofHandler::Check( + double max_drat_check_time_in_seconds) { + DratChecker::Status status = DratChecker::Status::UNKNOWN; + if (lrat_checker_ != nullptr) { + status = CheckResult(lrat_checker_->Check()) ? DratChecker::Status::VALID + : DratChecker::Status::INVALID; + } + if (status != DratChecker::Status::INVALID && drat_checker_ != nullptr) { + drat_checker_->Check(max_drat_check_time_in_seconds); + if (status == DratChecker::Status::INVALID && debug_crash_on_error_) { + LOG(FATAL) << "DRAT check failed"; + } + } + return status; } bool LratProofHandler::CheckResult(bool result) const { - if (debug_crash_on_error_ && !result) { + if (debug_crash_on_error_ && !result && lrat_checker_ != nullptr) { LOG(FATAL) << "LRAT error: " << lrat_checker_->error_message(); } return result; } +void LratProofHandler::AddStats() const { + if (lrat_checker_ != nullptr) { + lrat_checker_->AddStats(); + } +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/lrat_proof_handler.h b/ortools/sat/lrat_proof_handler.h index a29ee05bc2..43644b0ed9 100644 --- a/ortools/sat/lrat_proof_handler.h +++ b/ortools/sat/lrat_proof_handler.h @@ -14,9 +14,15 @@ #ifndef ORTOOLS_SAT_LRAT_PROOF_HANDLER_H_ #define ORTOOLS_SAT_LRAT_PROOF_HANDLER_H_ +#include +#include #include +#include #include "absl/types/span.h" +#include "ortools/base/file.h" +#include "ortools/sat/drat_checker.h" +#include "ortools/sat/drat_writer.h" #include "ortools/sat/lrat_checker.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" @@ -28,34 +34,69 @@ namespace sat { // and/or by saving it to a file. class LratProofHandler { public: - // TODO(user): Add a constructor to save the proof to a file in addition - // to or instead of using the LratChecker. - explicit LratProofHandler(Model* model); + explicit LratProofHandler(Model* model, bool check_lrat, bool check_drat, + File* drat_output, bool in_binary_drat_format); + + bool lrat_check_enabled() const { return lrat_checker_ != nullptr; } + bool drat_check_enabled() const { return drat_checker_ != nullptr; } + bool drat_output_enabled() const { return drat_writer_ != nullptr; } // Adds a clause of the problem. See LratChecker for more details. bool AddProblemClause(ClauseId id, absl::Span clause); + // No more problem clauses must be added after this call. + void EndProblemClauses(); + // Adds a clause which is inferred from the problem clauses and/or the // previously inferred clauses. See LratChecker for more details. bool AddInferredClause(ClauseId id, absl::Span clause, absl::Span unit_ids, absl::Span rat = {}); - // Adds a clause which is assumed to be true, without proof. + // Adds a clause which was inferred by another worker. Returns true if + // successful (the operation can fail if LRAT checks are enabled, and the ID + // is already used by another clause). + bool AddSharedClause(ClauseId id, absl::Span clause); + + // Adds a clause which is assumed to be true, without proof. Returns true if + // successful (the operation fails if DRAT checks are enabled, or if LRAT + // checks are enabled and the ID is already used by another clause). bool AddAssumedClause(ClauseId id, absl::Span clause); - // Deletes problem or inferred clauses. - void DeleteClauses(absl::Span clause_ids); + // Deletes a problem or inferred clause. The clause literals are only needed + // when checking DRAT. + void DeleteClause(ClauseId id, absl::Span clause); - // Returns true if the unsatisfiability proof is valid and complete, i.e. - // whether the empty clause has been successfully inferred. - bool Check() const; + // Returns VALID if all the inferred clauses were successfully checked with + // LRAT. Returns INVALID if at least one of them was not. Returns UNKNOWN if + // LRAT checks are not enabled. + DratChecker::Status Valid() const; + + // Returns VALID if the unsatisfiability proof is valid and complete, i.e. + // whether the empty clause has been successfully inferred. Returns INVALID if + // it is not. Returns UNKNOWN if the check timed out (this can only occur + // with DRAT checks), or if neither LRAT nor DRAT checks were enabled. + DratChecker::Status Check(double max_drat_check_time_in_seconds = + std::numeric_limits::infinity()); + + void AddStats() const; + + int64_t num_assumed_clauses() const { return num_assumed_clauses_; } private: bool CheckResult(bool result) const; std::unique_ptr lrat_checker_; + std::unique_ptr drat_checker_; + std::unique_ptr drat_writer_; + + bool all_problem_clauses_loaded_ = false; + int64_t num_assumed_clauses_ = 0; bool debug_crash_on_error_; + + // Only used when checking DRAT, because the DRAT checker does not support + // interleaving problem and inferred clauses. + std::vector> clauses_inferred_during_problem_loading_; }; } // namespace sat diff --git a/ortools/sat/probing.cc b/ortools/sat/probing.cc index 1d3f947655..0adac2ebf9 100644 --- a/ortools/sat/probing.cc +++ b/ortools/sat/probing.cc @@ -104,6 +104,8 @@ bool Prober::ProbeOneVariableInternal(BooleanVariable b) { if (!implied_bounds_->ProcessIntegerTrail(decision)) return false; product_detector_->ProcessTrailAtLevelOne(); integer_trail_->AppendNewBounds(&new_integer_bounds_); + to_fix_at_true_.clear(); + new_literals_implied_by_decision_.clear(); for (int i = saved_index + 1; i < trail_.Index(); ++i) { const Literal l = trail_[i]; @@ -163,7 +165,6 @@ bool Prober::ProbeOneVariableInternal(BooleanVariable b) { return false; } } - to_fix_at_true_.clear(); if (!sat_solver_->FinishPropagation()) return false; for (const Literal l : new_literals_implied_by_decision_) { // Some variables can be fixed by the above loop. @@ -175,12 +176,20 @@ bool Prober::ProbeOneVariableInternal(BooleanVariable b) { tmp_binary_clause_ids_.at(std::minmax(decision.Negated(), l)); } num_new_binary_++; - if (!implication_graph_->AddBinaryClause(clause_id, decision.Negated(), - l)) { + // Tricky: by default AddBinaryClause() can delete the LRAT `clause_id` + // and create a new ID for a similar clause between the representatives. + // But `clause_id`, registered in tmp_binary_clause_ids_, can be needed in + // the next iteration for the proof of a new fixed literal. Hence we must + // not delete it here. Instead, it is deleted at the end of this method, + // with the other non-longer needed clauses. + // TODO(user): can we maintain a one to one correspondence of clauses + // in LRAT and in the binary implication graph to avoid this? + if (!implication_graph_->AddBinaryClause( + clause_id, decision.Negated(), l, + /*delete_non_representative_id=*/false)) { return false; } } - new_literals_implied_by_decision_.clear(); if (!sat_solver_->FinishPropagation()) return false; } if (lrat_proof_handler_ != nullptr) { @@ -189,7 +198,8 @@ bool Prober::ProbeOneVariableInternal(BooleanVariable b) { for (const auto& [binary_clause, clause_id] : tmp_binary_clause_ids_) { if (implication_graph_->GetClauseId(binary_clause.first, binary_clause.second) != clause_id) { - lrat_proof_handler_->DeleteClauses({clause_id}); + lrat_proof_handler_->DeleteClause( + clause_id, {binary_clause.first, binary_clause.second}); } } } diff --git a/ortools/sat/sat_inprocessing.cc b/ortools/sat/sat_inprocessing.cc index 16e78ad191..c332021ff4 100644 --- a/ortools/sat/sat_inprocessing.cc +++ b/ortools/sat/sat_inprocessing.cc @@ -37,7 +37,6 @@ #include "ortools/base/timer.h" #include "ortools/graph/connected_components.h" #include "ortools/sat/clause.h" -#include "ortools/sat/drat_checker.h" #include "ortools/sat/linear_programming_constraint.h" #include "ortools/sat/lrat_proof_handler.h" #include "ortools/sat/probing.h" @@ -402,7 +401,7 @@ bool Inprocessing::RemoveFixedAndEquivalentVariables(bool log_info) { int64_t num_removed_literals = 0; int64_t num_inspected_literals = 0; - // We need this temporary vector for the DRAT proof settings, otherwise + // We need this temporary vector for the LRAT proof settings, otherwise // we could just have done an in-place transformation. std::vector new_clause; @@ -421,9 +420,6 @@ bool Inprocessing::RemoveFixedAndEquivalentVariables(bool log_info) { if (assignment_.LiteralIsTrue(l)) { DCHECK(lrat_proof_handler_ == nullptr || trail_->GetUnitClauseId(l.Variable()) != kNoClauseId); - if (clause_manager_->GetDratProofHandler() != nullptr) { - clause_manager_->GetDratProofHandler()->AddClause({l}); - } clause_manager_->LazyDelete(clause, DeletionSourceForStat::FIXED_AT_TRUE); num_removed_literals += clause->size(); @@ -508,7 +504,7 @@ bool Inprocessing::SubsumeAndStrenghtenRound(bool log_info) { int64_t num_inspected_signatures = 0; int64_t num_inspected_literals = 0; - // We need this temporary vector for the DRAT proof settings, otherwise + // We need this temporary vector for the LRAT proof settings, otherwise // we could just have done an in-place transformation. std::vector new_clause; @@ -2041,7 +2037,18 @@ class LratGateCongruenceHelper { ~LratGateCongruenceHelper() { if (lrat_proof_handler_ != nullptr) { - lrat_proof_handler_->DeleteClauses(to_delete_); + if (lrat_proof_handler_->drat_check_enabled() || + lrat_proof_handler_->drat_output_enabled()) { + for (int i = 0; i < to_delete_.size(); ++i) { + lrat_proof_handler_->DeleteClause( + to_delete_[i], + {clauses_to_delete_[i].first, clauses_to_delete_[i].second}); + } + } else { + for (const ClauseId id : to_delete_) { + lrat_proof_handler_->DeleteClause(id, {}); + } + } } } @@ -2088,6 +2095,11 @@ class LratGateCongruenceHelper { child_clauses.child_implies_parent = child_implies_rep; to_delete_.push_back(rep_implies_child); to_delete_.push_back(child_implies_rep); + if (lrat_proof_handler_->drat_check_enabled() || + lrat_proof_handler_->drat_output_enabled()) { + clauses_to_delete_.push_back({representative.Negated(), child}); + clauses_to_delete_.push_back({child.Negated(), representative}); + } } if (!literals.empty()) { // Make sure the parent links in union_find_ are shorten too, to keep the @@ -2226,6 +2238,9 @@ class LratGateCongruenceHelper { absl::flat_hash_map parent_equivalence_; // Equivalence clauses which are not needed after the current round. std::vector to_delete_; + // The literals of the clauses in `to_delete_`. Only needed when checking + // DRAT. + std::vector> clauses_to_delete_; }; } // namespace diff --git a/ortools/sat/sat_inprocessing.h b/ortools/sat/sat_inprocessing.h index c49bc34ca3..74fab5d716 100644 --- a/ortools/sat/sat_inprocessing.h +++ b/ortools/sat/sat_inprocessing.h @@ -30,7 +30,6 @@ #include "absl/types/span.h" #include "ortools/base/strong_vector.h" #include "ortools/sat/clause.h" -#include "ortools/sat/drat_checker.h" #include "ortools/sat/linear_programming_constraint.h" #include "ortools/sat/lrat_proof_handler.h" #include "ortools/sat/model.h" diff --git a/ortools/sat/sat_solver.cc b/ortools/sat/sat_solver.cc index c5b6951ff8..5c417787a0 100644 --- a/ortools/sat/sat_solver.cc +++ b/ortools/sat/sat_solver.cc @@ -38,7 +38,6 @@ #include "ortools/port/proto_utils.h" #include "ortools/port/sysinfo.h" #include "ortools/sat/clause.h" -#include "ortools/sat/drat_proof_handler.h" #include "ortools/sat/enforcement.h" #include "ortools/sat/lrat_proof_handler.h" #include "ortools/sat/model.h" @@ -218,7 +217,8 @@ bool SatSolver::AddTernaryClause(Literal a, Literal b, Literal c) { // better to make sure we always have "clean" clause in the solver rather than // to over-optimize this. In particular, presolve might be disabled or // incomplete, so such unclean clause might find their way here. -bool SatSolver::AddProblemClause(absl::Span literals) { +bool SatSolver::AddProblemClause(absl::Span literals, + bool shared) { SCOPED_TIME_STAT(&stats_); DCHECK_EQ(CurrentDecisionLevel(), 0); if (model_is_unsat_) return false; @@ -241,10 +241,11 @@ bool SatSolver::AddProblemClause(absl::Span literals) { } } - return AddProblemClauseInternal(tmp_literals_); + return AddProblemClauseInternal(tmp_literals_, shared); } -bool SatSolver::AddProblemClauseInternal(absl::Span literals) { +bool SatSolver::AddProblemClauseInternal(absl::Span literals, + bool shared) { SCOPED_TIME_STAT(&stats_); if (DEBUG_MODE && CurrentDecisionLevel() == 0) { for (const Literal l : literals) { @@ -254,23 +255,22 @@ bool SatSolver::AddProblemClauseInternal(absl::Span literals) { ClauseId id = kNoClauseId; if (lrat_proof_handler_ != nullptr) { id = clause_id_generator_->GetNextId(); - lrat_proof_handler_->AddProblemClause(id, literals); + if (shared) { + lrat_proof_handler_->AddSharedClause(id, literals); + } else { + lrat_proof_handler_->AddProblemClause(id, literals); + } } if (literals.empty()) return SetModelUnsat(); if (literals.size() == 1) { - if (drat_proof_handler_ != nullptr) { - // Note that we will output problem unit clauses twice, but that is a - // small price to pay for having a single variable fixing API. - drat_proof_handler_->AddClause({literals[0]}); - } trail_->EnqueueWithUnitReason(id, literals[0]); } else if (literals.size() == 2) { // TODO(user): Make sure the presolve do not generate such clauses. if (literals[0] == literals[1]) { // Literal must be true. - trail_->EnqueueWithUnitReason(literals[0]); + trail_->EnqueueWithUnitReason(id, literals[0]); } else if (literals[0] == literals[1].Negated()) { // Always true. return true; @@ -288,6 +288,8 @@ bool SatSolver::AddProblemClauseInternal(absl::Span literals) { // tigger computation (like the LP) even if no domain changed since the last // call. We do not want to do that. if (!PropagationIsDone() && !Propagate()) { + // This adds the UNSAT proof to the LRAT handler, if any. + ProcessCurrentConflict(); return SetModelUnsat(); } return true; @@ -982,9 +984,6 @@ void SatSolver::ProcessCurrentConflict( // database. This is because we already backtracked and some of the clauses // that were needed to infer the conflict may not be "reasons" anymore and // may be deleted. - if (drat_proof_handler_ != nullptr) { - drat_proof_handler_->AddClause(learned_conflict_); - } ClauseId learned_conflict_clause_id = kNoClauseId; if (lrat_proof_handler_ != nullptr) { if (!clause_ids_for_minimization->empty()) { @@ -1529,7 +1528,6 @@ bool SatSolver::TryToMinimizeClause(SatClause* clause) { const int variable_level = LiteralTrail().Info(literal.Variable()).level; if (variable_level == 0) { - ProcessNewlyFixedVariablesForDratProof(); DCHECK(lrat_proof_handler_ == nullptr || trail_->GetUnitClauseId(literal.Variable()) != kNoClauseId); counters_.minimization_num_true++; @@ -1647,12 +1645,9 @@ bool SatSolver::TryToMinimizeClause(SatClause* clause) { if (CurrentDecisionLevel() == 0) { // Ensure nothing is fixed at level 0 in case more propagation happened // after backtracking. - // std::erase_if is C++20, not yet fully supported on OR-Tools. - candidate.erase( - std::remove_if( - candidate.begin(), candidate.end(), - [&](Literal l) { return trail_->Assignment().LiteralIsFalse(l); }), - candidate.end()); + OpenSourceEraseIf(candidate, [&](Literal l) { + return trail_->Assignment().LiteralIsFalse(l); + }); if (absl::c_any_of(clause->AsSpan(), [&](Literal l) { return trail_->Assignment().LiteralIsTrue(l); })) { @@ -2235,36 +2230,12 @@ std::string SatSolver::RunningStatisticsString() const { num_variables_.value() - num_processed_fixed_variables_); } -void SatSolver::ProcessNewlyFixedVariablesForDratProof() { - if (drat_proof_handler_ == nullptr) return; - if (CurrentDecisionLevel() != 0) return; - - // We need to output the literals that are fixed so we can remove all - // clauses that contains them. Note that this doesn't seems to be needed - // for drat-trim. - // - // TODO(user): Ideally we could output such literal as soon as they are fixed, - // but this is not that easy to do. Spend some time to find a cleaner - // alternative? Currently this works, but: - // - We will output some fixed literals twice since we already output learnt - // clauses of size one. - // - We need to call this function when needed. - Literal temp; - for (; drat_num_processed_fixed_variables_ < trail_->Index(); - ++drat_num_processed_fixed_variables_) { - temp = (*trail_)[drat_num_processed_fixed_variables_]; - drat_proof_handler_->AddClause({&temp, 1}); - } -} - void SatSolver::ProcessNewlyFixedVariables() { SCOPED_TIME_STAT(&stats_); DCHECK_EQ(CurrentDecisionLevel(), 0); int num_detached_clauses = 0; int num_binary = 0; - ProcessNewlyFixedVariablesForDratProof(); - // We remove the clauses that are always true and the fixed literals from the // others. Note that none of the clause should be all false because we should // have detected a conflict before this is called. @@ -2284,12 +2255,6 @@ void SatSolver::ProcessNewlyFixedVariables() { const size_t new_size = clause->size(); if (new_size == old_size) continue; - if (drat_proof_handler_ != nullptr) { - CHECK_GT(new_size, 0); - drat_proof_handler_->AddClause({clause->begin(), new_size}); - drat_proof_handler_->DeleteClause({clause->begin(), old_size}); - } - ClauseId new_clause_id = kNoClauseId; if (lrat_proof_handler_ != nullptr) { std::vector& clause_ids = tmp_clause_ids_for_minimization_; @@ -2304,8 +2269,14 @@ void SatSolver::ProcessNewlyFixedVariables() { new_clause_id = clause_id_generator_->GetNextId(); lrat_proof_handler_->AddInferredClause( new_clause_id, {clause->begin(), new_size}, clause_ids); - lrat_proof_handler_->DeleteClauses({old_clause_id}); if (new_size > 2) { + // If the new size is 2 the LRAT clause is deleted as part of the + // LazyDelete(clause, PROMOTED_TO_BINARY) call below. Also the SatClause + // ID must not be changed to the new ID in this case, otherwise we would + // get a SatClause and a binary clause with the same ID, leading to a + // double delete. + lrat_proof_handler_->DeleteClause(old_clause_id, + {clause->begin(), old_size}); clauses_propagator_->SetClauseId(clause, new_clause_id); } } diff --git a/ortools/sat/sat_solver.h b/ortools/sat/sat_solver.h index c7576bf49e..53dbe45ee3 100644 --- a/ortools/sat/sat_solver.h +++ b/ortools/sat/sat_solver.h @@ -38,7 +38,6 @@ #include "ortools/base/logging.h" #include "ortools/base/timer.h" #include "ortools/sat/clause.h" -#include "ortools/sat/drat_proof_handler.h" #include "ortools/sat/enforcement.h" #include "ortools/sat/lrat_proof_handler.h" #include "ortools/sat/model.h" @@ -126,12 +125,14 @@ class SatSolver { // We call this a "problem" clause just because we will never delete such // clause unless it is proven to always be satisfied. So this can be called // with the initial clause of a problem, but also an inferred clause that we - // don't want to delete. + // don't want to delete (`shared` must be true iff the clause was inferred by + // another solver, from the same initial clauses). // // TODO(user): Rename this to AddClause() ? Also get rid of the specialized // AddUnitClause(), AddBinaryClause() and AddTernaryClause() since they // just end up calling this? - bool AddProblemClause(absl::Span literals); + bool AddProblemClause(absl::Span literals, + bool shared = false); // Adds a pseudo-Boolean constraint to the problem. Returns false if the // problem is detected to be UNSAT. If the constraint is always true, this @@ -504,12 +505,6 @@ class SatSolver { void SaveDebugAssignment(); void LoadDebugSolution(absl::Span solution); - void SetDratProofHandler(DratProofHandler* drat_proof_handler) { - drat_proof_handler_ = drat_proof_handler; - clauses_propagator_->SetDratProofHandler(drat_proof_handler_); - binary_implication_graph_->SetDratProofHandler(drat_proof_handler_); - } - // This function is here to deal with the case where a SAT/CP model is found // to be trivially UNSAT while the user is constructing the model. Instead of // having to test the status of all the lines adding a constraint, one can @@ -663,7 +658,8 @@ class SatSolver { // Add a problem clause. The clause is assumed to be "cleaned", that is no // duplicate variables (not strictly required) and not empty. - bool AddProblemClauseInternal(absl::Span literals); + bool AddProblemClauseInternal(absl::Span literals, + bool shared = false); // This is used by all the Add*LinearConstraint() functions. It detects // infeasible/trivial constraints or clause constraints and takes the proper @@ -694,9 +690,6 @@ class SatSolver { // Update the propagators_ list with the relevant propagators. void InitializePropagators(); - // Output to the DRAT proof handler any newly fixed variables. - void ProcessNewlyFixedVariablesForDratProof(); - // Returns the maximum trail_index of the literals in the given clause. // All the literals must be assigned. Returns -1 if the clause is empty. int ComputeMaxTrailIndex(absl::Span clause) const; @@ -877,9 +870,6 @@ class SatSolver { int num_processed_fixed_variables_ = 0; double deterministic_time_of_last_fixed_variables_cleanup_ = 0.0; - // Used in ProcessNewlyFixedVariablesForDratProof(). - int drat_num_processed_fixed_variables_ = 0; - Counters counters_; // Solver information. @@ -950,7 +940,6 @@ class SatSolver { // it is necessary to keep track of the last time the time was advanced. double deterministic_time_at_last_advanced_time_limit_ = 0; - DratProofHandler* drat_proof_handler_ = nullptr; LratProofHandler* lrat_proof_handler_ = nullptr; mutable StatsGroup stats_; diff --git a/ortools/sat/simplification.cc b/ortools/sat/simplification.cc index 58436ae6e5..e0d2c5b3d4 100644 --- a/ortools/sat/simplification.cc +++ b/ortools/sat/simplification.cc @@ -33,7 +33,6 @@ #include "ortools/base/strong_vector.h" #include "ortools/base/timer.h" #include "ortools/graph/strongly_connected_components.h" -#include "ortools/sat/drat_proof_handler.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/sat_solver.h" @@ -205,11 +204,6 @@ void SatPresolver::AddClause(absl::Span clause) { signatures_.push_back(ComputeSignatureOfClauseVariables(ci)); DCHECK_EQ(signatures_.size(), clauses_.size()); - if (drat_proof_handler_ != nullptr && changed) { - drat_proof_handler_->AddClause(clause_ref); - drat_proof_handler_->DeleteClause(clause); - } - const Literal max_literal = clause_ref.back(); const int required_size = std::max(max_literal.Index().value(), max_literal.NegatedIndex().value()) + @@ -245,8 +239,6 @@ void SatPresolver::RebuildLiteralToClauses() { } void SatPresolver::AddClauseInternal(std::vector* clause) { - if (drat_proof_handler_ != nullptr) drat_proof_handler_->AddClause(*clause); - DCHECK(std::is_sorted(clause->begin(), clause->end())); DCHECK_GT(clause->size(), 0) << "TODO(user): Unsat during presolve?"; const ClauseIndex ci(clauses_.size()); @@ -507,7 +499,6 @@ void SatPresolver::SimpleBva(LiteralIndex l) { bva_pq_elements_[x_false.value()].literal = x_false; // Add the new clauses. - if (drat_proof_handler_ != nullptr) drat_proof_handler_->AddOneVariable(); for (const LiteralIndex lit : m_lit_) { tmp_new_clause_ = {Literal(lit), Literal(x_true)}; AddClauseInternal(&tmp_new_clause_); @@ -612,10 +603,6 @@ bool SatPresolver::ProcessClauseToSimplifyOthersUsingLiteral( } else { DCHECK_NE(opposite_literal, lit.Index()); if (clauses_[ci].empty()) return false; // UNSAT. - if (drat_proof_handler_ != nullptr) { - // TODO(user): remove the old clauses_[ci] afterwards. - drat_proof_handler_->AddClause(clauses_[ci]); - } // Recompute signature. signatures_[ci] = ComputeSignatureOfClauseVariables(ci); @@ -692,10 +679,6 @@ bool SatPresolver::ProcessClauseToSimplifyOthers(ClauseIndex clause_index) { &num_inspected_literals_)) { DCHECK_EQ(opposite_literal, lit.NegatedIndex()); if (clauses_[ci].empty()) return false; // UNSAT. - if (drat_proof_handler_ != nullptr) { - // TODO(user): remove the old clauses_[ci] afterwards. - drat_proof_handler_->AddClause(clauses_[ci]); - } // Recompute signature. signatures_[ci] = ComputeSignatureOfClauseVariables(ci); @@ -850,9 +833,6 @@ void SatPresolver::Remove(ClauseIndex ci) { UpdateBvaPriorityQueue(Literal(e.Variable(), true).Index()); UpdateBvaPriorityQueue(Literal(e.Variable(), false).Index()); } - if (drat_proof_handler_ != nullptr) { - drat_proof_handler_->DeleteClause(clauses_[ci]); - } gtl::STLClearObject(&clauses_[ci]); } @@ -1185,7 +1165,6 @@ class PropagationGraph { void ProbeAndFindEquivalentLiteral( SatSolver* solver, SatPostsolver* postsolver, - DratProofHandler* drat_proof_handler, util_intops::StrongVector* mapping, SolverLogger* logger) { WallTimer timer; @@ -1255,9 +1234,6 @@ void ProbeAndFindEquivalentLiteral( ? Literal(rep) : Literal(rep).Negated(); if (!solver->AddUnitClause(true_lit)) return; - if (drat_proof_handler != nullptr) { - drat_proof_handler->AddClause({true_lit}); - } } } for (LiteralIndex i(0); i < size; ++i) { @@ -1269,9 +1245,6 @@ void ProbeAndFindEquivalentLiteral( ? Literal(i) : Literal(i).Negated(); if (!solver->AddUnitClause(true_lit)) return; - if (drat_proof_handler != nullptr) { - drat_proof_handler->AddClause({true_lit}); - } } } else if (assignment.LiteralIsAssigned(Literal(i))) { if (!assignment.LiteralIsAssigned(Literal(rep))) { @@ -1279,16 +1252,10 @@ void ProbeAndFindEquivalentLiteral( ? Literal(rep) : Literal(rep).Negated(); if (!solver->AddUnitClause(true_lit)) return; - if (drat_proof_handler != nullptr) { - drat_proof_handler->AddClause({true_lit}); - } } } else if (rep != i) { ++num_equiv; postsolver->Add(Literal(i), {Literal(i), Literal(rep).Negated()}); - if (drat_proof_handler != nullptr) { - drat_proof_handler->AddClause({Literal(i), Literal(rep).Negated()}); - } } } } diff --git a/ortools/sat/simplification.h b/ortools/sat/simplification.h index ecdb54a8f2..581b4e0244 100644 --- a/ortools/sat/simplification.h +++ b/ortools/sat/simplification.h @@ -28,7 +28,6 @@ #include "absl/types/span.h" #include "ortools/base/adjustable_priority_queue.h" #include "ortools/base/strong_vector.h" -#include "ortools/sat/drat_proof_handler.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/sat_solver.h" @@ -152,10 +151,7 @@ class SatPresolver { typedef int32_t ClauseIndex; explicit SatPresolver(SatPostsolver* postsolver, SolverLogger* logger) - : postsolver_(postsolver), - num_trivial_clauses_(0), - drat_proof_handler_(nullptr), - logger_(logger) {} + : postsolver_(postsolver), num_trivial_clauses_(0), logger_(logger) {} // This type is neither copyable nor movable. SatPresolver(const SatPresolver&) = delete; @@ -230,10 +226,6 @@ class SatPresolver { // Visible for testing. Just applies the BVA step of the presolve. void PresolveWithBva(); - void SetDratProofHandler(DratProofHandler* drat_proof_handler) { - drat_proof_handler_ = drat_proof_handler; - } - private: // Internal function used by ProcessClauseToSimplifyOthers(). bool ProcessClauseToSimplifyOthersUsingLiteral(ClauseIndex clause_index, @@ -378,7 +370,6 @@ class SatPresolver { int num_trivial_clauses_; SatParameters parameters_; - DratProofHandler* drat_proof_handler_; TimeLimit* time_limit_ = nullptr; SolverLogger* logger_; }; @@ -439,7 +430,6 @@ int ComputeResolvantSize(Literal x, const std::vector& a, // constraints. void ProbeAndFindEquivalentLiteral( SatSolver* solver, SatPostsolver* postsolver, - DratProofHandler* drat_proof_handler, util_intops::StrongVector* mapping, SolverLogger* = nullptr); diff --git a/ortools/sat/solution_crush.cc b/ortools/sat/solution_crush.cc index aac30854af..bb49a23ae9 100644 --- a/ortools/sat/solution_crush.cc +++ b/ortools/sat/solution_crush.cc @@ -25,6 +25,7 @@ #include #include "absl/algorithm/container.h" +#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" @@ -62,12 +63,6 @@ void SolutionCrush::Resize(int new_size) { var_values_.resize(new_size, 0); } -std::optional SolutionCrush::GetHintedValue(int var) const { - if (!solution_is_loaded_) return std::nullopt; - if (!HasValue(var)) return std::nullopt; - return var_values_[var]; -} - void SolutionCrush::MaybeSetLiteralToValueEncoding(int literal, int var, int64_t value) { DCHECK(RefIsPositive(var)); @@ -241,6 +236,36 @@ void SolutionCrush::SetOrUpdateVarToDomain(int var, const Domain& domain) { } } +void SolutionCrush::SetOrUpdateVarToDomain( + int var, const Domain& domain, + const absl::btree_map& encoding, bool has_objective, + bool minimize) { + if (!solution_is_loaded_) return; + if (HasValue(var)) { + const int64_t old_value = GetVarValue(var); + if (domain.Contains(old_value)) return; + int64_t new_value = old_value; + if (!has_objective) { + new_value = domain.ClosestValue(old_value); + } else if (minimize) { + new_value = domain.ValueAtOrBefore(old_value); + } else { + new_value = domain.ValueAtOrAfter(old_value); + } + SetVarValue(var, new_value); + VLOG(3) << "SetOrUpdateVarToDomain: " << var << ", old_value: " << old_value + << ", new_value: " << new_value + << ", domain: " << domain.ToString(); + DCHECK(encoding.contains(new_value)) + << "domain: " << domain.ToString() << "old_value: " << old_value + << " new_value: " << new_value; + const int encoding_lit = encoding.at(new_value); + SetLiteralValue(encoding_lit, true); + } else if (domain.IsFixed()) { + SetVarValue(var, domain.FixedValue()); + } +} + void SolutionCrush::UpdateLiteralsToFalseIfDifferent(int lit1, int lit2) { // Set lit1 and lit2 to false if "lit1 - lit2 == 0" is violated. const int sign1 = RefIsPositive(lit1) ? 1 : -1; diff --git a/ortools/sat/solution_crush.h b/ortools/sat/solution_crush.h index 4db87a83e1..dbc95ece36 100644 --- a/ortools/sat/solution_crush.h +++ b/ortools/sat/solution_crush.h @@ -20,6 +20,7 @@ #include #include +#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/types/span.h" @@ -59,8 +60,6 @@ class SolutionCrush { SolutionCrush& operator=(const SolutionCrush&) = delete; SolutionCrush& operator=(SolutionCrush&&) = delete; - std::optional GetHintedValue(int var) const; - bool SolutionIsLoaded() const { return solution_is_loaded_; } // Visible for testing. @@ -152,6 +151,15 @@ class SolutionCrush { // value. Otherwise does nothing. void SetOrUpdateVarToDomain(int var, const Domain& domain); + // If `var` already has a value, updates it to be within the given domain + // following the given encoding and the status of the variable w.r.t. the + // objective. Otherwise, if the domain is fixed, sets the value of `var` to + // this fixed value. Otherwise does nothing. In the process, update the + // encoding literals to reflect the new value of `var`. + void SetOrUpdateVarToDomain(int var, const Domain& domain, + const absl::btree_map& encoding, + bool has_objective, bool minimize); + // Updates the value of the given literals to false if their current values // are different (or does nothing otherwise). void UpdateLiteralsToFalseIfDifferent(int lit1, int lit2); diff --git a/ortools/sat/subsolver.h b/ortools/sat/subsolver.h index 09ca82168f..15dc48abb0 100644 --- a/ortools/sat/subsolver.h +++ b/ortools/sat/subsolver.h @@ -101,12 +101,12 @@ class SubSolver { // called sequentially. Subclasses do not need to call this. void AddTaskDuration(double duration_in_seconds) { ++num_finished_tasks_; - if (duration_in_seconds > 0) { - wall_time_ += duration_in_seconds; - timing_.AddTimeInSec(duration_in_seconds); - } + duration_in_seconds = std::max(0.0, duration_in_seconds); + wall_time_ += duration_in_seconds; + timing_.AddTimeInSec(duration_in_seconds); } + // Note that this is protected by the global execution mutex and so it is // called sequentially. Subclasses do not need to call this. void NotifySelection() { ++num_scheduled_tasks_; } diff --git a/ortools/sat/synchronization.cc b/ortools/sat/synchronization.cc index d749085aa4..4d3bb5680b 100644 --- a/ortools/sat/synchronization.cc +++ b/ortools/sat/synchronization.cc @@ -32,6 +32,7 @@ #include "ortools/base/logging.h" #include "ortools/base/timer.h" +#include "ortools/sat/drat_checker.h" #if !defined(__PORTABLE_PLATFORM__) #include "ortools/base/helpers.h" #include "ortools/base/options.h" @@ -1585,5 +1586,67 @@ void SharedStatistics::Log(SolverLogger* logger) { SOLVER_LOG(logger, ""); } +SharedLratProofStatus::SharedLratProofStatus() + : num_subsolvers_(0), + num_valid_proofs_(0), + num_invalid_proofs_(0), + num_unknown_proofs_(0), + lrat_check_enabled_(false), + drat_check_enabled_(false), + num_assumed_clauses_(0), + walltime_in_seconds_(0.0) {} + +void SharedLratProofStatus::NewSubSolver() { + absl::MutexLock mutex_lock(mutex_); + num_subsolvers_++; +} + +void SharedLratProofStatus::NewSubsolverProofStatus( + DratChecker::Status status, bool lrat_check_enabled, + bool drat_check_enabled, int num_assumed_clauses, + double walltime_in_seconds) { + absl::MutexLock mutex_lock(mutex_); + if (status == DratChecker::Status::VALID) { + num_valid_proofs_++; + } else if (status == DratChecker::Status::INVALID) { + num_invalid_proofs_++; + } else if (status == DratChecker::Status::UNKNOWN) { + num_unknown_proofs_++; + } + lrat_check_enabled_ |= lrat_check_enabled; + drat_check_enabled_ |= drat_check_enabled; + num_assumed_clauses_ += num_assumed_clauses; + if (drat_check_enabled) { + walltime_in_seconds_ += walltime_in_seconds; + } +} + +void SharedLratProofStatus::Log(SolverLogger* logger) { + absl::MutexLock mutex_lock(mutex_); + if (lrat_check_enabled_ || drat_check_enabled_) { + if (num_valid_proofs_ == num_subsolvers_) { + if (num_assumed_clauses_ > 0) { + SOLVER_LOG(logger, "LRAT_status: VALID_WITH_ASSUMED_CLAUSES"); + } else { + SOLVER_LOG(logger, "LRAT_status: VALID"); + } + } else if (num_invalid_proofs_ > 0) { + SOLVER_LOG(logger, "LRAT_status: INVALID"); + } else { + SOLVER_LOG(logger, "LRAT_status: UNKNOWN"); + } + if (drat_check_enabled_) { + SOLVER_LOG(logger, "DRAT_walltime: ", walltime_in_seconds_); + } + } else { + // Always log an LRAT status to make it easier to extract it from a + // multirun result with awk. + SOLVER_LOG(logger, "LRAT_status: NA"); + if (drat_check_enabled_) { + SOLVER_LOG(logger, "DRAT_walltime: NA"); + } + } +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/synchronization.h b/ortools/sat/synchronization.h index 123bfd54ba..f87c59d074 100644 --- a/ortools/sat/synchronization.h +++ b/ortools/sat/synchronization.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -42,6 +43,7 @@ #include "ortools/base/stl_util.h" #include "ortools/base/timer.h" #include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/drat_checker.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_parameters.pb.h" @@ -1286,6 +1288,32 @@ void SharedSolutionRepository::Synchronize( num_queried_at_last_sync_ = num_queried_; } +// Thread-safe. +class SharedLratProofStatus { + public: + SharedLratProofStatus(); + + void NewSubSolver(); + + void NewSubsolverProofStatus(DratChecker::Status status, + bool lrat_check_enabled, bool drat_check_enabled, + int num_assumed_clauses, + double walltime_in_seconds); + + void Log(SolverLogger* logger); + + private: + absl::Mutex mutex_; + int num_subsolvers_ ABSL_GUARDED_BY(mutex_); + int num_valid_proofs_ ABSL_GUARDED_BY(mutex_); + int num_invalid_proofs_ ABSL_GUARDED_BY(mutex_); + int num_unknown_proofs_ ABSL_GUARDED_BY(mutex_); + bool lrat_check_enabled_ ABSL_GUARDED_BY(mutex_); + bool drat_check_enabled_ ABSL_GUARDED_BY(mutex_); + int num_assumed_clauses_ ABSL_GUARDED_BY(mutex_); + double walltime_in_seconds_ ABSL_GUARDED_BY(mutex_); +}; + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/util.h b/ortools/sat/util.h index 984f8ced27..5725458135 100644 --- a/ortools/sat/util.h +++ b/ortools/sat/util.h @@ -49,6 +49,14 @@ namespace operations_research { namespace sat { +// Removes all elements for which `pred` returns true. +// This implementation provides std::erase_if for C++17. +template +void OpenSourceEraseIf(Container& c, Pred pred) { + auto it = std::remove_if(c.begin(), c.end(), pred); + c.erase(it, c.end()); +} + // A simple class with always IdentityMap[t] == t. // This is to avoid allocating vector with std::iota() in some Apis. template diff --git a/ortools/sat/variable_expand.cc b/ortools/sat/variable_expand.cc index 4cd6d3411f..bfc13d7e08 100644 --- a/ortools/sat/variable_expand.cc +++ b/ortools/sat/variable_expand.cc @@ -15,7 +15,6 @@ #include #include -#include #include #include #include @@ -122,7 +121,7 @@ std::pair, EncodingLinear1Status> ProcessLinear1( if (complement.IsEmpty()) { return {std::nullopt, EncodingLinear1Status::kIgnore}; } else if (complement.IsFixed()) { - CHECK(var_domain.Contains(complement.FixedValue())); + DCHECK(var_domain.Contains(complement.FixedValue())); lin.type = EncodingLinear1Type::kVarNeValue; lin.value = complement.FixedValue(); } else if (rhs.Min() > complement.Max()) { @@ -146,19 +145,16 @@ void AbslStringify(Sink& sink, const EncodingLinear1& lin) { } // namespace -ValueEncoding::ValueEncoding(int var, PresolveContext* context, - SolutionCrush& solution_crush) - : var_(var), - var_domain_(context->DomainOf(var)), - context_(context), - solution_crush_(solution_crush) {} +ValueEncoding::ValueEncoding(int var, PresolveContext* context) + : var_(var), var_domain_(context->DomainOf(var)), context_(context) {} void ValueEncoding::AddValueToEncode(int64_t value) { DCHECK(!is_closed_); encoded_values_.push_back(value); } -void ValueEncoding::CloseEncodedValues(ObjectiveStatus objective_status) { +void ValueEncoding::CanonicalizeEncodedValuesAndAddEscapeValue( + bool var_in_objective, bool var_has_positive_objective_coefficient) { if (!is_closed_) { gtl::STLSortAndRemoveDuplicates(&encoded_values_); // Add an escape value to the existing encoded values when the encoding is @@ -172,29 +168,12 @@ void ValueEncoding::CloseEncodedValues(ObjectiveStatus objective_status) { if (encoded_values_.size() < var_domain_.Size()) { const Domain residual = var_domain_.IntersectionWith( Domain::FromValues(encoded_values_).Complement()); - int64_t escape_value = 0; - switch (objective_status) { - case ObjectiveStatus::kNotInObjective: - escape_value = residual.SmallestValue(); - break; - case ObjectiveStatus::kInObjectiveAndMinimization: - escape_value = residual.Min(); - break; - case ObjectiveStatus::kInObjectiveAndMaximization: - escape_value = residual.Max(); - break; - } + const int64_t escape_value = + !var_in_objective + ? residual.SmallestValue() + : (var_has_positive_objective_coefficient ? residual.Min() + : residual.Max()); encoded_values_.push_back(escape_value); - - // Add the hinted value if it exists and does not appear in the encoded - // values. - const std::optional hint = solution_crush_.GetHintedValue(var_); - if (hint.has_value() && residual.Contains(hint.value()) && - hint.value() != escape_value) { - encoded_values_.push_back(hint.value()); - } - - absl::c_sort(encoded_values_); } is_closed_ = true; is_fully_encoded_ = encoded_values_.size() == var_domain_.Size(); @@ -236,7 +215,7 @@ void ValueEncoding::CreateAllValueEncodingLiterals() { int ValueEncoding::literal(int64_t value) const { DCHECK(is_closed_); const auto it = encoding_.find(value); - CHECK(it != encoding_.end()); + DCHECK(it != encoding_.end()); return it->second; } @@ -296,8 +275,8 @@ void OrderEncoding::CreateAllOrderEncodingLiterals( const ValueEncoding& values) { CollectAllOrderEncodingValues(); for (const auto& [value, literal] : encoded_le_literal_) { - CHECK(values.encoding().contains(value)); - CHECK(values.encoding().contains(var_domain_.ValueAtOrAfter(value + 1))) + DCHECK(values.encoding().contains(value)); + DCHECK(values.encoding().contains(var_domain_.ValueAtOrAfter(value + 1))) << "Cannot find " << var_domain_.ValueAtOrAfter(value + 1) << " for var <= " << value; } @@ -328,7 +307,7 @@ void OrderEncoding::CreateAllOrderEncodingLiterals( encoded_le_literal_.find(var_domain_.ValueAtOrBefore(value - 1)); if (it_ge != encoded_le_literal_.end()) { const int ge_literal = NegatedRef(it_ge->second); - CHECK(not_ge != nullptr); + DCHECK(not_ge != nullptr); not_ge->add_enforcement_literal(ge_literal); if (value != max_ge_value) { not_ge = context_->working_model->add_constraints(); @@ -346,16 +325,16 @@ void OrderEncoding::CreateAllOrderEncodingLiterals( } int OrderEncoding::ge_literal(int64_t value) const { - CHECK_GT(value, var_domain_.Min()); + DCHECK_GT(value, var_domain_.Min()); const auto it = encoded_le_literal_.find(var_domain_.ValueAtOrBefore(value - 1)); - CHECK(it != encoded_le_literal_.end()); + DCHECK(it != encoded_le_literal_.end()); return NegatedRef(it->second); } int OrderEncoding::le_literal(int64_t value) const { const auto it = encoded_le_literal_.find(value); - CHECK(it != encoded_le_literal_.end()); + DCHECK(it != encoded_le_literal_.end()); return it->second; } @@ -388,16 +367,17 @@ bool ProcessEncodingConstraints( int var, PresolveContext* context, ValueEncoding& values, OrderEncoding& order, std::vector>& linear_ones_by_type, - std::vector& constraint_indices, ObjectiveStatus& objective_status) { + std::vector& constraint_indices, bool& var_in_objective, + bool& var_has_positive_objective_coefficient) { const Domain& var_domain = context->DomainOf(var); constraint_indices.clear(); - objective_status = ObjectiveStatus::kNotInObjective; + var_in_objective = false; + var_has_positive_objective_coefficient = false; for (const int c : context->VarToConstraints(var)) { if (c == kObjectiveConstraint) { const int64_t obj_coeff = context->ObjectiveMap().at(var); - objective_status = obj_coeff > 0 - ? ObjectiveStatus::kInObjectiveAndMinimization - : ObjectiveStatus::kInObjectiveAndMaximization; + var_in_objective = true; + var_has_positive_objective_coefficient = obj_coeff > 0; continue; } if (c < 0) continue; @@ -460,23 +440,25 @@ bool ProcessEncodingConstraints( linear_ones_by_type[static_cast(lin.type)].push_back(lin); } - values.CloseEncodedValues(objective_status); + values.CanonicalizeEncodedValuesAndAddEscapeValue( + var_in_objective, var_has_positive_objective_coefficient); return true; } -void TryToReplaceVariableByItsEncoding( - int var, std::function presolve_one_constraint, - PresolveContext* context, SolutionCrush& solution_crush) { +void TryToReplaceVariableByItsEncoding(int var, int& new_exo_to_presolve_index, + PresolveContext* context, + SolutionCrush& solution_crush) { const Domain var_domain = context->DomainOf(var); std::vector> linear_ones_by_type( kNumEncodingLinear1Types); - ValueEncoding values(var, context, solution_crush); + ValueEncoding values(var, context); OrderEncoding order(var, context, solution_crush); - ObjectiveStatus objective_status = ObjectiveStatus::kNotInObjective; + bool var_in_objective = false; + bool var_has_positive_objective_coefficient = false; std::vector constraint_indices; - if (!ProcessEncodingConstraints(var, context, values, order, - linear_ones_by_type, constraint_indices, - objective_status)) { + if (!ProcessEncodingConstraints( + var, context, values, order, linear_ones_by_type, constraint_indices, + var_in_objective, var_has_positive_objective_coefficient)) { return; } @@ -516,7 +498,7 @@ void TryToReplaceVariableByItsEncoding( } VLOG(2) << "ProcessVariableOnlyUsedInEncoding(): var(" << var - << "): " << var_domain + << "): " << var_domain << ", size: " << var_domain.Size() << ", #encoded_values: " << values.encoded_values().size() << ", #ordered_values: " << order.num_encoded_values() << ", #var_eq_value: " << lin_eq.size() @@ -524,7 +506,9 @@ void TryToReplaceVariableByItsEncoding( << ", #var_ge_value: " << lin_ge.size() << ", #var_le_value: " << lin_le.size() << ", #var_in_domain: " << lin_domain.size() - << ", objective_status: " << objective_status + << ", var_in_objective: " << var_in_objective + << ", var_has_positive_objective_coefficient: " + << var_has_positive_objective_coefficient << ", #implied_literals_in_complex_domains: " << num_implied_literals_in_complex_domains; if (!lin_domain.empty() && (!values.is_fully_encoded() || @@ -539,6 +523,10 @@ void TryToReplaceVariableByItsEncoding( } values.CreateAllValueEncodingLiterals(); + // Fix the hinted value if needed. + solution_crush.SetOrUpdateVarToDomain( + var, Domain::FromValues(values.encoded_values()), values.encoding(), + var_in_objective, var_has_positive_objective_coefficient); order.CreateAllOrderEncodingLiterals(values); // Link all Boolean in our linear1 to the encoding literals. @@ -619,14 +607,13 @@ void TryToReplaceVariableByItsEncoding( // Update the objective if needed. Note that this operation can fail if // the new expression result in potential overflow. - if (objective_status != ObjectiveStatus::kNotInObjective) { + if (var_in_objective) { // We substract the min or the max of the variable from all // coefficients. This should reduce the objective size and helps with // the bounds. - const int64_t base_value = - objective_status == ObjectiveStatus::kInObjectiveAndMinimization - ? var_domain.Min() - : var_domain.Max(); + const int64_t base_value = var_has_positive_objective_coefficient + ? var_domain.Min() + : var_domain.Max(); // Tricky: We cannot just choose an arbitrary value if the objective has // a restrictive domain! if (!values.is_fully_encoded() && @@ -693,7 +680,7 @@ void TryToReplaceVariableByItsEncoding( } } if (!values.is_fully_encoded()) { - VLOG(1) << "Reduce domain size: " << var_domain.Size() << " to " + VLOG(2) << "Reduce domain size: " << var_domain.Size() << " to " << values.encoded_values().size() << ": " << var_domain << " -> " << Domain::FromValues(values.encoded_values()); context->UpdateRuleStats("variables: reduce domain to encoded values"); @@ -715,17 +702,17 @@ void TryToReplaceVariableByItsEncoding( } // This must be done after we removed all the constraint containing var. - ConstraintProto* encoding = context->working_model->add_constraints(); - BoolArgumentProto* arg = encoding->mutable_exactly_one(); + new_exo_to_presolve_index = context->working_model->constraints_size(); + ConstraintProto* exo = context->working_model->add_constraints(); + BoolArgumentProto* arg = exo->mutable_exactly_one(); for (const auto& [value, literal] : values.encoding()) { arg->add_literals(literal); } - presolve_one_constraint(encoding); context->UpdateNewConstraintsVariableUsage(); if (context->ModelIsUnsat()) return; // To simplify the postsolve, we output a single constraint to infer X from - // the bi: X = sum bi * (Vi - special_value) + special_value + // the bi: X = sum bi * (Vi - min_value) + min_value const int64_t var_min = var_domain.Min(); ConstraintProto* mapping_ct = context->NewMappingConstraint(__FILE__, __LINE__); diff --git a/ortools/sat/variable_expand.h b/ortools/sat/variable_expand.h index 548d6a8f21..1f4db009cc 100644 --- a/ortools/sat/variable_expand.h +++ b/ortools/sat/variable_expand.h @@ -15,7 +15,6 @@ #define ORTOOLS_SAT_VARIABLE_EXPAND_H_ #include -#include #include #include "absl/container/btree_map.h" @@ -27,34 +26,24 @@ namespace operations_research { namespace sat { -enum class ObjectiveStatus { - kNotInObjective, - kInObjectiveAndMinimization, - kInObjectiveAndMaximization, -}; - -template -void AbslStringify(Sink& sink, ObjectiveStatus obj_status) { - switch (obj_status) { - case ObjectiveStatus::kNotInObjective: - sink.Append("kNotInObjective"); - return; - case ObjectiveStatus::kInObjectiveAndMinimization: - sink.Append("kInObjectiveAndMinimization"); - return; - case ObjectiveStatus::kInObjectiveAndMaximization: - sink.Append("kInObjectiveAndMaximization"); - return; - } -} - class ValueEncoding { public: - ValueEncoding(int var, PresolveContext* context, - SolutionCrush& solution_crush); - // Build the set of observed values. + ValueEncoding(int var, PresolveContext* context); + // Build the set of observed values. This cannot be called after + // CanonicalizeEncodedValuesAndAddEscapeValues() has been called. void AddValueToEncode(int64_t value); - void CloseEncodedValues(ObjectiveStatus objective_status); + + // This method is called after all values from lit => var ==/!= value have + // been added. It canonicalizes the encoded values and adds an escape value + // if needed. It there is an objective, the escape value is the min or the max + // of the residual domain of the variable depending on the objective + // coefficient of the variable. It there are no objective, the escape value is + // the smallest value of the residual domain. + // With this escape value, we can safely reduce the domain of the variable to + // observed + escape values, and add an exactly_one constraint on all the + // literals involved. + void CanonicalizeEncodedValuesAndAddEscapeValue( + bool var_in_objective, bool var_has_positive_objective_coefficient); // Getters on the observed values. bool empty() const; @@ -71,7 +60,6 @@ class ValueEncoding { const int var_; const Domain var_domain_; PresolveContext* context_; - SolutionCrush& solution_crush_; std::vector encoded_values_; absl::btree_map encoding_; bool is_closed_ = false; @@ -101,9 +89,9 @@ class OrderEncoding { absl::btree_map encoded_le_literal_; }; -void TryToReplaceVariableByItsEncoding( - int var, std::function presolve_one_constraint, - PresolveContext* context, SolutionCrush& solution_crush); +void TryToReplaceVariableByItsEncoding(int var, int& new_exo_to_presolve_index, + PresolveContext* context, + SolutionCrush& solution_crush); } // namespace sat } // namespace operations_research