diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index ed10c3abc2..b2d13e4c4c 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -99,6 +99,7 @@ cc_library( "//ortools/base", "//ortools/util:bitset", "//ortools/util:logging", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/synchronization", ], diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index d59c67a489..0b8b3b3665 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -6325,17 +6325,17 @@ void CpModelPresolver::DetectDominatedLinearConstraints() { bool abort = false; for (int i = 0; i < lin.vars().size(); ++i) { const int var = lin.vars(i); - if (!RefIsPositive(var)) { + const int64_t coeff = lin.coeffs(i); + if (!RefIsPositive(var) || coeff == 0) { // This shouldn't happen except in potential corner cases were the // constraints were not canonicalized before this point. We just skip // such constraint. abort = true; break; } - implied = implied - .AdditionWith( - context_->DomainOf(var).MultiplicationBy(lin.coeffs(i))) - .RelaxIfTooComplex(); + implied = + implied.AdditionWith(context_->DomainOf(var).MultiplicationBy(coeff)) + .RelaxIfTooComplex(); } if (abort) continue; @@ -6358,6 +6358,12 @@ void CpModelPresolver::DetectDominatedLinearConstraints() { coeff_map[subset_lin.vars(i)] = subset_lin.coeffs(i); } + // We have a perfect match if 'factor_a * subset == factor_b * superset' on + // the common positions. Note that assuming subset has been gcd reduced, + // there is not point considering factor_b != 1. + bool perfect_match = true; + int64_t factor = 0; + // Lets compute the implied domain of the linear expression // "superset - subset". Note that we actually do not need exact inclusion // for this algorithm to work, but it is an heuristic to not try it with @@ -6371,7 +6377,23 @@ void CpModelPresolver::DetectDominatedLinearConstraints() { int64_t coeff = superset_lin.coeffs(i); const auto it = coeff_map.find(var); if (it != coeff_map.end()) { - coeff -= it->second; + const int64_t subset_coeff = it->second; + if (perfect_match) { + if (coeff % subset_coeff == 0) { + const int64_t div = coeff / subset_coeff; + if (factor == 0) { + // Note that factor can be negative. + factor = div; + } else if (factor != div) { + perfect_match = false; + } + } else { + perfect_match = false; + } + } + + // TODO(user): compute the factor first in case it is != 1 ? + coeff -= subset_coeff; } if (coeff == 0) continue; @@ -6384,17 +6406,6 @@ void CpModelPresolver::DetectDominatedLinearConstraints() { const Domain subset_ct_domain = ReadDomainFromProto(subset_lin); const Domain superset_ct_domain = ReadDomainFromProto(superset_lin); - // TODO(user): when we have equality constraint, we might try substitution. - if (subset_ct_domain.IsFixed()) { - // This should be easy since we can just simplify the superset. - // Especially if we have a true inclusion. - context_->UpdateRuleStats("TODO linear inclusion: subset is equality"); - } - if (superset_ct_domain.IsFixed()) { - // This one could make sense if subset is large. - context_->UpdateRuleStats("TODO linear inclusion: superset is equality"); - } - // Case 1: superset is redundant. // We process this one first as it let us remove the longest constraint. const Domain implied_superset_domain = @@ -6429,6 +6440,51 @@ void CpModelPresolver::DetectDominatedLinearConstraints() { detector.StopProcessingCurrentSubset(); return; } + + // When we have equality constraint, we might try substitution. For now we + // only try that when we have a perfect inclusion with the same coefficients + // after multiplication by factor. + if (perfect_match) { + CHECK_NE(factor, 0); + if (subset_ct_domain.IsFixed()) { + // Rewrite the constraint by removing subset from it and updating + // the domain to domain - factor * subset_domain. + // + // This seems always beneficial, although we might miss some + // oportunities for constraint included in the superset if we do that + // too early. + context_->UpdateRuleStats("linear inclusion: subset is equality"); + int new_size = 0; + auto* mutable_linear = + context_->working_model->mutable_constraints(superset_c) + ->mutable_linear(); + for (int i = 0; i < mutable_linear->vars().size(); ++i) { + const int var = mutable_linear->vars(i); + const int64_t coeff = mutable_linear->coeffs(i); + const auto it = coeff_map.find(var); + if (it != coeff_map.end()) { + CHECK_EQ(factor * it->second, coeff); + continue; + } + mutable_linear->set_vars(new_size, var); + mutable_linear->set_coeffs(new_size, coeff); + ++new_size; + } + mutable_linear->mutable_vars()->Truncate(new_size); + mutable_linear->mutable_coeffs()->Truncate(new_size); + FillDomainInProto(superset_ct_domain.AdditionWith( + subset_ct_domain.MultiplicationBy(-factor)), + mutable_linear); + constraint_indices_to_clean.push_back(superset_c); + detector.StopProcessingCurrentSuperset(); + return; + } + if (superset_ct_domain.IsFixed()) { + // This one could make sense if subset is large vs superset. + context_->UpdateRuleStats( + "TODO linear inclusion: superset is equality"); + } + } }); for (const int c : constraint_indices_to_clean) { diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index 151394a72d..5108870619 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -1008,6 +1008,57 @@ void RegisterObjectiveBoundsImport( import_objective_bounds); } +// Registers a callback that will export non-problem clauses added during +// search. +void RegisterClausesExport(int id, SharedClausesManager* shared_clauses_manager, + Model* model) { + auto* mapping = model->GetOrCreate(); + auto* sat_solver = model->GetOrCreate(); + const auto& share_binary_clause = [mapping, id, shared_clauses_manager]( + Literal l1, Literal l2) { + const int var1 = + mapping->GetProtoVariableFromBooleanVariable(l1.Variable()); + if (var1 == -1) return; + const int var2 = + mapping->GetProtoVariableFromBooleanVariable(l2.Variable()); + if (var2 == -1) return; + const int lit1 = l1.IsPositive() ? var1 : NegatedRef(var1); + const int lit2 = l2.IsPositive() ? var2 : NegatedRef(var2); + shared_clauses_manager->AddBinaryClause(id, lit1, lit2); + }; + sat_solver->SetShareBinaryClauseCallback(share_binary_clause); +} + +// Registers a callback to import new clauses stored in the +// shared_clausess_manager. These clauses are imported at level 0 of the search +// in the linear scan minimize function. +// it returns the id of the worker in the shared clause manager. +// +// TODO(user): Can we import them in the core worker ? +int RegisterClausesLevelZeroImport(int id, + SharedClausesManager* shared_clauses_manager, + Model* model) { + CHECK(shared_clauses_manager != nullptr); + CpModelMapping* const mapping = model->GetOrCreate(); + SatSolver* sat_solver = model->GetOrCreate(); + const auto& import_level_zero_clauses = [shared_clauses_manager, id, mapping, + sat_solver]() { + std::vector> new_binary_clauses; + shared_clauses_manager->GetUnseenBinaryClauses(id, &new_binary_clauses); + for (const auto [ref1, ref2] : new_binary_clauses) { + const Literal l1 = mapping->Literal(ref1); + const Literal l2 = mapping->Literal(ref2); + if (!sat_solver->AddBinaryClause(l1, l2)) { + return false; + } + } + return true; + }; + model->GetOrCreate()->callbacks.push_back( + import_level_zero_clauses); + return id; +} + void LoadBaseModel(const CpModelProto& model_proto, Model* model) { auto* shared_response_manager = model->GetOrCreate(); CHECK(shared_response_manager != nullptr); @@ -1954,6 +2005,7 @@ struct SharedClasses { SharedRelaxationSolutionRepository* relaxation_solutions; SharedLPSolutionRepository* lp_solutions; SharedIncompleteSolutionManager* incomplete_solutions; + SharedClausesManager* clauses; Model* global_model; bool SearchIsDone() { @@ -1995,6 +2047,10 @@ class FullProblemSolver : public SubSolver { local_model_->Register( shared->incomplete_solutions); } + + if (shared->clauses != nullptr) { + local_model_->Register(shared->clauses); + } } bool TaskIsAvailable() override { @@ -2024,6 +2080,15 @@ class FullProblemSolver : public SubSolver { *shared_->model_proto, shared_->bounds, local_model_.get()); } + if (shared_->clauses != nullptr) { + const int id = shared_->clauses->RegisterNewId(); + shared_->clauses->SetWorkerNameForId(id, local_model_->Name()); + + RegisterClausesLevelZeroImport(id, shared_->clauses, + local_model_.get()); + RegisterClausesExport(id, shared_->clauses, local_model_.get()); + } + if (local_model_->GetOrCreate()->repair_hint()) { MinimizeL1DistanceWithHint(*shared_->model_proto, local_model_.get()); } else { @@ -2523,7 +2588,7 @@ class LnsSolver : public SubSolver { // TODO(user): This do not seem to work if they are symmetries loaded // into SAT. For now we just disable this if there is any symmetry. // See for instance spot5_1401.fzn. Be smarter about that: - // - If there are connected compo in the inital model, we should only + // - If there are connected compo in the initial model, we should only // compute generator that do not cross component? or are component // interchange useful? probably not. // - It should be fine if all our generator are fully or not at @@ -2658,6 +2723,11 @@ void SolveCpModelParallel(const CpModelProto& model_proto, shared_incomplete_solutions.get()); } + std::unique_ptr shared_clauses; + if (parameters.share_binary_clauses()) { + shared_clauses = absl::make_unique(); + } + SharedClasses shared; shared.model_proto = &model_proto; shared.wall_timer = global_model->GetOrCreate(); @@ -2667,6 +2737,7 @@ void SolveCpModelParallel(const CpModelProto& model_proto, shared.relaxation_solutions = shared_relaxation_solutions.get(); shared.lp_solutions = shared_lp_solutions.get(); shared.incomplete_solutions = shared_incomplete_solutions.get(); + shared.clauses = shared_clauses.get(); shared.global_model = global_model; // The list of all the SubSolver that will be used in this parallel search. @@ -2863,13 +2934,30 @@ void SolveCpModelParallel(const CpModelProto& model_proto, NonDeterministicLoop(subsolvers, num_search_workers); } - if (parameters.log_subsolver_statistics()) { + // Log statistics. + if (logger->LoggingIsEnabled()) { + if (parameters.log_subsolver_statistics()) { + SOLVER_LOG(logger, ""); + SOLVER_LOG(logger, "Sub-solver search statistics:"); + for (const auto& subsolver : subsolvers) { + const std::string stats = subsolver->StatisticsString(); + if (stats.empty()) continue; + SOLVER_LOG(logger, + absl::StrCat(" '", subsolver->name(), "':\n", stats)); + } + } + SOLVER_LOG(logger, ""); - SOLVER_LOG(logger, "Sub-solver search statistics:"); - for (const auto& subsolver : subsolvers) { - const std::string stats = subsolver->StatisticsString(); - if (stats.empty()) continue; - SOLVER_LOG(logger, absl::StrCat(" '", subsolver->name(), "':\n", stats)); + shared.response->DisplayImprovementStatistics(); + + if (shared.bounds) { + SOLVER_LOG(logger, ""); + shared.bounds->LogStatistics(logger); + } + + if (shared.clauses) { + SOLVER_LOG(logger, ""); + shared.clauses->LogStatistics(logger); } } } @@ -3454,10 +3542,9 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { QuickSolveWithHint(new_cp_model_proto, model); } SolveLoadedCpModel(new_cp_model_proto, model); - } - if (logger->LoggingIsEnabled()) { - if (params.num_search_workers() <= 1) { + // Sequential logging of LP statistics. + if (logger->LoggingIsEnabled()) { const auto& lps = *model->GetOrCreate(); if (!lps.empty()) { @@ -3467,11 +3554,6 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { } } } - - if (params.num_search_workers() > 1) { - SOLVER_LOG(logger, ""); - shared_response_manager->DisplayImprovementStatistics(); - } } return shared_response_manager->GetResponse(); diff --git a/ortools/sat/sat_parameters.proto b/ortools/sat/sat_parameters.proto index f8b852b7c5..a72ce6a7d3 100644 --- a/ortools/sat/sat_parameters.proto +++ b/ortools/sat/sat_parameters.proto @@ -24,7 +24,7 @@ option csharp_namespace = "Google.OrTools.Sat"; // Contains the definitions for all the sat algorithm parameters and their // default values. // -// NEXT TAG: 203 +// NEXT TAG: 204 message SatParameters { // In some context, like in a portfolio of search, it makes sense to name a // given parameters set for logging purpose. @@ -950,6 +950,9 @@ message SatParameters { // Allows sharing of the bounds of modified variables at level 0. optional bool share_level_zero_bounds = 114 [default = true]; + // Allows sharing of new learned binary clause between workers. + optional bool share_binary_clauses = 203 [default = true]; + // LNS parameters. optional bool use_lns_only = 101 [default = false]; diff --git a/ortools/sat/sat_solver.cc b/ortools/sat/sat_solver.cc index a3197f6e26..6bcffe12f1 100644 --- a/ortools/sat/sat_solver.cc +++ b/ortools/sat/sat_solver.cc @@ -375,6 +375,9 @@ int SatSolver::AddLearnedClauseAndEnqueueUnitPropagation( if (track_binary_clauses_) { CHECK(binary_clauses_.Add(BinaryClause(literals[0], literals[1]))); } + if (shared_binary_clauses_callback_ != nullptr) { + shared_binary_clauses_callback_(literals[0], literals[1]); + } CHECK(binary_implication_graph_->AddBinaryClauseDuringSearch(literals[0], literals[1])); // In case this is the first binary clauses. @@ -457,7 +460,7 @@ void SatSolver::AddBinaryClauseInternal(Literal a, Literal b) { } } -bool SatSolver::ClauseIsValidUnderDebugAssignement( +bool SatSolver::ClauseIsValidUnderDebugAssignment( const std::vector& clause) const { for (Literal l : clause) { if (l.Variable() >= debug_assignment_.NumberOfVariables() || @@ -583,7 +586,7 @@ bool SatSolver::PropagateAndStopAfterOneConflictResolution() { // An empty conflict means that the problem is UNSAT. if (learned_conflict_.empty()) return SetModelUnsat(); DCHECK(IsConflictValid(learned_conflict_)); - DCHECK(ClauseIsValidUnderDebugAssignement(learned_conflict_)); + DCHECK(ClauseIsValidUnderDebugAssignment(learned_conflict_)); // Update the activity of all the variables in the first UIP clause. // Also update the activity of the last level variables expanded (and @@ -726,7 +729,7 @@ bool SatSolver::PropagateAndStopAfterOneConflictResolution() { // this way. Second, more variables may be marked (in is_marked_) and // MinimizeConflict() can take advantage of that. Because of this, the // LBD of the learned conflict can change. - DCHECK(ClauseIsValidUnderDebugAssignement(learned_conflict_)); + DCHECK(ClauseIsValidUnderDebugAssignment(learned_conflict_)); if (!binary_implication_graph_->IsEmpty()) { if (parameters_->binary_minimization_algorithm() == SatParameters::BINARY_MINIMIZATION_FIRST) { @@ -782,7 +785,7 @@ bool SatSolver::PropagateAndStopAfterOneConflictResolution() { // Backtrack and add the reason to the set of learned clause. counters_.num_literals_learned += learned_conflict_.size(); Backtrack(ComputeBacktrackLevel(learned_conflict_)); - DCHECK(ClauseIsValidUnderDebugAssignement(learned_conflict_)); + DCHECK(ClauseIsValidUnderDebugAssignment(learned_conflict_)); // Note that we need to output the learned clause before cleaning the clause // database. This is because we already backtracked and some of the clauses diff --git a/ortools/sat/sat_solver.h b/ortools/sat/sat_solver.h index 51c84a7e47..db94d878a6 100644 --- a/ortools/sat/sat_solver.h +++ b/ortools/sat/sat_solver.h @@ -274,7 +274,7 @@ class SatSolver { void Backtrack(int target_level); // Advanced usage. This is meant to restore the solver to a "proper" state - // after a solve was interupted due to a limit reached. + // after a solve was interrupted due to a limit reached. // // Without assumption (i.e. if AssumptionLevel() is 0), this will revert all // decisions and make sure that all the fixed literals are propagated. In @@ -286,7 +286,7 @@ class SatSolver { // case it will return false. bool RestoreSolverToAssumptionLevel(); - // Advanced usage. Finish the progation if it was interupted. Note that this + // Advanced usage. Finish the progation if it was interrupted. Note that this // might run into conflict and will propagate again until a fixed point is // reached or the model was proven UNSAT. Returns IsModelUnsat(). bool FinishPropagation(); @@ -381,7 +381,7 @@ class SatSolver { // The idea is that if we know that a given assignment is satisfiable, then // all the learned clauses or PB constraints must be satisfiable by it. In // debug mode, and after this is called, all the learned clauses are tested to - // satisfy this saved assignement. + // satisfy this saved assignment. void SaveDebugAssignment(); // Returns true iff the loaded problem only contains clauses. @@ -418,6 +418,12 @@ class SatSolver { // use propagation to try to minimize some clauses from the database. void MinimizeSomeClauses(int decisions_budget); + // Sets the export function to the shared clauses manager. + void SetShareBinaryClauseCallback(const std::function& + shared_binary_clauses_callback) { + shared_binary_clauses_callback_ = shared_binary_clauses_callback; + } + // Advance the given time limit with all the deterministic time that was // elapsed since last call. void AdvanceDeterministicTime(TimeLimit* limit) { @@ -452,7 +458,7 @@ class SatSolver { // See SaveDebugAssignment(). Note that these functions only consider the // variables at the time the debug_assignment_ was saved. If new variables // were added since that time, they will be considered unassigned. - bool ClauseIsValidUnderDebugAssignement( + bool ClauseIsValidUnderDebugAssignment( const std::vector& clause) const; bool PBConstraintIsValidUnderDebugAssignment( const std::vector& cst, const Coefficient rhs) const; @@ -832,6 +838,10 @@ class SatSolver { DratProofHandler* drat_proof_handler_; mutable StatsGroup stats_; + + std::function shared_binary_clauses_callback_ = + nullptr; + DISALLOW_COPY_AND_ASSIGN(SatSolver); }; diff --git a/ortools/sat/synchronization.cc b/ortools/sat/synchronization.cc index 2c11be4143..7d9022ad8b 100644 --- a/ortools/sat/synchronization.cc +++ b/ortools/sat/synchronization.cc @@ -21,6 +21,7 @@ #include "ortools/sat/cp_model_mapping.h" #endif // __PORTABLE_PLATFORM__ +#include "absl/container/btree_map.h" #include "absl/container/flat_hash_set.h" #include "absl/random/random.h" #include "ortools/base/integral_types.h" @@ -750,11 +751,8 @@ void SharedBoundsManager::ReportPotentialNewBounds( changed_variables_since_last_synchronize_.Set(var); num_improvements++; } - // TODO(user): Display number of bound improvements cumulatively per - // workers at the end of the search. if (num_improvements > 0) { - VLOG(2) << worker_name << " exports " << num_improvements - << " modifications"; + bounds_exported_[worker_name] += num_improvements; } } @@ -840,5 +838,72 @@ void SharedBoundsManager::GetChangedBounds( id_to_changed_variables_[id].ClearAll(); } +void SharedBoundsManager::LogStatistics(SolverLogger* logger) { + absl::MutexLock mutex_lock(&mutex_); + if (!bounds_exported_.empty()) { + SOLVER_LOG(logger, "Improving variable bounds shared per subsolver:"); + for (const auto& entry : bounds_exported_) { + SOLVER_LOG(logger, " '", entry.first, "': ", entry.second); + } + } +} + +int SharedClausesManager::RegisterNewId() { + absl::MutexLock mutex_lock(&mutex_); + const int id = id_to_last_processed_binary_clause_.size(); + id_to_last_processed_binary_clause_.resize(id + 1, 0); + id_to_clauses_exported_.resize(id + 1, 0); + return id; +} + +void SharedClausesManager::SetWorkerNameForId(int id, + const std::string& worker_name) { + absl::MutexLock mutex_lock(&mutex_); + id_to_worker_name_[id] = worker_name; +} + +void SharedClausesManager::AddBinaryClause(int id, int lit1, int lit2) { + absl::MutexLock mutex_lock(&mutex_); + if (lit2 < lit1) std::swap(lit1, lit2); + + const auto p = std::make_pair(lit1, lit2); + const auto [unused_it, inserted] = added_binary_clauses_set_.insert(p); + if (inserted) { + added_binary_clauses_.push_back(p); + id_to_clauses_exported_[id]++; + // Small optim. If the worker is already up to date with clauses to import, + // we can mark this new clause as already seen. + if (id_to_last_processed_binary_clause_[id] == + added_binary_clauses_.size() - 1) { + id_to_last_processed_binary_clause_[id]++; + } + } +} + +void SharedClausesManager::GetUnseenBinaryClauses( + int id, std::vector>* new_clauses) { + new_clauses->clear(); + absl::MutexLock mutex_lock(&mutex_); + const int last_binary_clause_seen = id_to_last_processed_binary_clause_[id]; + new_clauses->assign(added_binary_clauses_.begin() + last_binary_clause_seen, + added_binary_clauses_.end()); + id_to_last_processed_binary_clause_[id] = added_binary_clauses_.size(); +} + +void SharedClausesManager::LogStatistics(SolverLogger* logger) { + absl::MutexLock mutex_lock(&mutex_); + absl::btree_map name_to_clauses; + for (int id = 0; id < id_to_clauses_exported_.size(); ++id) { + if (id_to_clauses_exported_[id] == 0) continue; + name_to_clauses[id_to_worker_name_[id]] = id_to_clauses_exported_[id]; + } + if (!name_to_clauses.empty()) { + SOLVER_LOG(logger, "Clauses shared per subsolver:"); + for (const auto& entry : name_to_clauses) { + SOLVER_LOG(logger, " '", entry.first, "': ", entry.second); + } + } +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/synchronization.h b/ortools/sat/synchronization.h index ddd76a7c00..c37926e285 100644 --- a/ortools/sat/synchronization.h +++ b/ortools/sat/synchronization.h @@ -447,6 +447,8 @@ class SharedBoundsManager { // state. void Synchronize(); + void LogStatistics(SolverLogger* logger); + private: const int num_variables_; const CpModelProto& model_proto_; @@ -464,6 +466,44 @@ class SharedBoundsManager { std::vector synchronized_upper_bounds_ ABSL_GUARDED_BY(mutex_); std::deque> id_to_changed_variables_ ABSL_GUARDED_BY(mutex_); + std::map bounds_exported_ ABSL_GUARDED_BY(mutex_); +}; + +// This class holds all the binary clauses that were found and shared by the +// workers. +// +// It is thread-safe. +// +// Note that this uses literal as encoded in a cp_model.proto. The literals can +// thus be negative numbers. +class SharedClausesManager { + public: + void AddBinaryClause(int id, int lit1, int lit2); + + // Fills flat_clauses with + // (lit1 of clause1, lit2 of clause1, lit1 of clause 2, lit2 of clause2 ...) + void GetUnseenBinaryClauses(int id, + std::vector>* new_clauses); + + int RegisterNewId(); + void SetWorkerNameForId(int id, const std::string& worker_name); + + // Search statistics. + void LogStatistics(SolverLogger* logger); + + private: + absl::Mutex mutex_; + // Cache to avoid adding the same clause twice. + absl::flat_hash_set> added_binary_clauses_set_ + ABSL_GUARDED_BY(mutex_); + std::vector> added_binary_clauses_ + ABSL_GUARDED_BY(mutex_); + std::vector id_to_last_processed_binary_clause_ + ABSL_GUARDED_BY(mutex_); + std::vector id_to_clauses_exported_; + + // Used for reporting statistics. + absl::flat_hash_map id_to_worker_name_; }; template