[CP-SAT] more presolve on small linear constraints; share clauses of size 2 across workers

This commit is contained in:
Laurent Perron
2022-01-26 10:07:13 +01:00
parent a191bcc51c
commit 22e1f47e0a
8 changed files with 305 additions and 45 deletions

View File

@@ -99,6 +99,7 @@ cc_library(
"//ortools/base",
"//ortools/util:bitset",
"//ortools/util:logging",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/synchronization",
],

View File

@@ -6325,17 +6325,17 @@ void CpModelPresolver::DetectDominatedLinearConstraints() {
bool abort = false;
for (int i = 0; i < lin.vars().size(); ++i) {
const int var = lin.vars(i);
if (!RefIsPositive(var)) {
const int64_t coeff = lin.coeffs(i);
if (!RefIsPositive(var) || coeff == 0) {
// This shouldn't happen except in potential corner cases were the
// constraints were not canonicalized before this point. We just skip
// such constraint.
abort = true;
break;
}
implied = implied
.AdditionWith(
context_->DomainOf(var).MultiplicationBy(lin.coeffs(i)))
.RelaxIfTooComplex();
implied =
implied.AdditionWith(context_->DomainOf(var).MultiplicationBy(coeff))
.RelaxIfTooComplex();
}
if (abort) continue;
@@ -6358,6 +6358,12 @@ void CpModelPresolver::DetectDominatedLinearConstraints() {
coeff_map[subset_lin.vars(i)] = subset_lin.coeffs(i);
}
// We have a perfect match if 'factor_a * subset == factor_b * superset' on
// the common positions. Note that assuming subset has been gcd reduced,
// there is not point considering factor_b != 1.
bool perfect_match = true;
int64_t factor = 0;
// Lets compute the implied domain of the linear expression
// "superset - subset". Note that we actually do not need exact inclusion
// for this algorithm to work, but it is an heuristic to not try it with
@@ -6371,7 +6377,23 @@ void CpModelPresolver::DetectDominatedLinearConstraints() {
int64_t coeff = superset_lin.coeffs(i);
const auto it = coeff_map.find(var);
if (it != coeff_map.end()) {
coeff -= it->second;
const int64_t subset_coeff = it->second;
if (perfect_match) {
if (coeff % subset_coeff == 0) {
const int64_t div = coeff / subset_coeff;
if (factor == 0) {
// Note that factor can be negative.
factor = div;
} else if (factor != div) {
perfect_match = false;
}
} else {
perfect_match = false;
}
}
// TODO(user): compute the factor first in case it is != 1 ?
coeff -= subset_coeff;
}
if (coeff == 0) continue;
@@ -6384,17 +6406,6 @@ void CpModelPresolver::DetectDominatedLinearConstraints() {
const Domain subset_ct_domain = ReadDomainFromProto(subset_lin);
const Domain superset_ct_domain = ReadDomainFromProto(superset_lin);
// TODO(user): when we have equality constraint, we might try substitution.
if (subset_ct_domain.IsFixed()) {
// This should be easy since we can just simplify the superset.
// Especially if we have a true inclusion.
context_->UpdateRuleStats("TODO linear inclusion: subset is equality");
}
if (superset_ct_domain.IsFixed()) {
// This one could make sense if subset is large.
context_->UpdateRuleStats("TODO linear inclusion: superset is equality");
}
// Case 1: superset is redundant.
// We process this one first as it let us remove the longest constraint.
const Domain implied_superset_domain =
@@ -6429,6 +6440,51 @@ void CpModelPresolver::DetectDominatedLinearConstraints() {
detector.StopProcessingCurrentSubset();
return;
}
// When we have equality constraint, we might try substitution. For now we
// only try that when we have a perfect inclusion with the same coefficients
// after multiplication by factor.
if (perfect_match) {
CHECK_NE(factor, 0);
if (subset_ct_domain.IsFixed()) {
// Rewrite the constraint by removing subset from it and updating
// the domain to domain - factor * subset_domain.
//
// This seems always beneficial, although we might miss some
// oportunities for constraint included in the superset if we do that
// too early.
context_->UpdateRuleStats("linear inclusion: subset is equality");
int new_size = 0;
auto* mutable_linear =
context_->working_model->mutable_constraints(superset_c)
->mutable_linear();
for (int i = 0; i < mutable_linear->vars().size(); ++i) {
const int var = mutable_linear->vars(i);
const int64_t coeff = mutable_linear->coeffs(i);
const auto it = coeff_map.find(var);
if (it != coeff_map.end()) {
CHECK_EQ(factor * it->second, coeff);
continue;
}
mutable_linear->set_vars(new_size, var);
mutable_linear->set_coeffs(new_size, coeff);
++new_size;
}
mutable_linear->mutable_vars()->Truncate(new_size);
mutable_linear->mutable_coeffs()->Truncate(new_size);
FillDomainInProto(superset_ct_domain.AdditionWith(
subset_ct_domain.MultiplicationBy(-factor)),
mutable_linear);
constraint_indices_to_clean.push_back(superset_c);
detector.StopProcessingCurrentSuperset();
return;
}
if (superset_ct_domain.IsFixed()) {
// This one could make sense if subset is large vs superset.
context_->UpdateRuleStats(
"TODO linear inclusion: superset is equality");
}
}
});
for (const int c : constraint_indices_to_clean) {

View File

@@ -1008,6 +1008,57 @@ void RegisterObjectiveBoundsImport(
import_objective_bounds);
}
// Registers a callback that will export non-problem clauses added during
// search.
void RegisterClausesExport(int id, SharedClausesManager* shared_clauses_manager,
Model* model) {
auto* mapping = model->GetOrCreate<CpModelMapping>();
auto* sat_solver = model->GetOrCreate<SatSolver>();
const auto& share_binary_clause = [mapping, id, shared_clauses_manager](
Literal l1, Literal l2) {
const int var1 =
mapping->GetProtoVariableFromBooleanVariable(l1.Variable());
if (var1 == -1) return;
const int var2 =
mapping->GetProtoVariableFromBooleanVariable(l2.Variable());
if (var2 == -1) return;
const int lit1 = l1.IsPositive() ? var1 : NegatedRef(var1);
const int lit2 = l2.IsPositive() ? var2 : NegatedRef(var2);
shared_clauses_manager->AddBinaryClause(id, lit1, lit2);
};
sat_solver->SetShareBinaryClauseCallback(share_binary_clause);
}
// Registers a callback to import new clauses stored in the
// shared_clausess_manager. These clauses are imported at level 0 of the search
// in the linear scan minimize function.
// it returns the id of the worker in the shared clause manager.
//
// TODO(user): Can we import them in the core worker ?
int RegisterClausesLevelZeroImport(int id,
SharedClausesManager* shared_clauses_manager,
Model* model) {
CHECK(shared_clauses_manager != nullptr);
CpModelMapping* const mapping = model->GetOrCreate<CpModelMapping>();
SatSolver* sat_solver = model->GetOrCreate<SatSolver>();
const auto& import_level_zero_clauses = [shared_clauses_manager, id, mapping,
sat_solver]() {
std::vector<std::pair<int, int>> new_binary_clauses;
shared_clauses_manager->GetUnseenBinaryClauses(id, &new_binary_clauses);
for (const auto [ref1, ref2] : new_binary_clauses) {
const Literal l1 = mapping->Literal(ref1);
const Literal l2 = mapping->Literal(ref2);
if (!sat_solver->AddBinaryClause(l1, l2)) {
return false;
}
}
return true;
};
model->GetOrCreate<LevelZeroCallbackHelper>()->callbacks.push_back(
import_level_zero_clauses);
return id;
}
void LoadBaseModel(const CpModelProto& model_proto, Model* model) {
auto* shared_response_manager = model->GetOrCreate<SharedResponseManager>();
CHECK(shared_response_manager != nullptr);
@@ -1954,6 +2005,7 @@ struct SharedClasses {
SharedRelaxationSolutionRepository* relaxation_solutions;
SharedLPSolutionRepository* lp_solutions;
SharedIncompleteSolutionManager* incomplete_solutions;
SharedClausesManager* clauses;
Model* global_model;
bool SearchIsDone() {
@@ -1995,6 +2047,10 @@ class FullProblemSolver : public SubSolver {
local_model_->Register<SharedIncompleteSolutionManager>(
shared->incomplete_solutions);
}
if (shared->clauses != nullptr) {
local_model_->Register<SharedClausesManager>(shared->clauses);
}
}
bool TaskIsAvailable() override {
@@ -2024,6 +2080,15 @@ class FullProblemSolver : public SubSolver {
*shared_->model_proto, shared_->bounds, local_model_.get());
}
if (shared_->clauses != nullptr) {
const int id = shared_->clauses->RegisterNewId();
shared_->clauses->SetWorkerNameForId(id, local_model_->Name());
RegisterClausesLevelZeroImport(id, shared_->clauses,
local_model_.get());
RegisterClausesExport(id, shared_->clauses, local_model_.get());
}
if (local_model_->GetOrCreate<SatParameters>()->repair_hint()) {
MinimizeL1DistanceWithHint(*shared_->model_proto, local_model_.get());
} else {
@@ -2523,7 +2588,7 @@ class LnsSolver : public SubSolver {
// TODO(user): This do not seem to work if they are symmetries loaded
// into SAT. For now we just disable this if there is any symmetry.
// See for instance spot5_1401.fzn. Be smarter about that:
// - If there are connected compo in the inital model, we should only
// - If there are connected compo in the initial model, we should only
// compute generator that do not cross component? or are component
// interchange useful? probably not.
// - It should be fine if all our generator are fully or not at
@@ -2658,6 +2723,11 @@ void SolveCpModelParallel(const CpModelProto& model_proto,
shared_incomplete_solutions.get());
}
std::unique_ptr<SharedClausesManager> shared_clauses;
if (parameters.share_binary_clauses()) {
shared_clauses = absl::make_unique<SharedClausesManager>();
}
SharedClasses shared;
shared.model_proto = &model_proto;
shared.wall_timer = global_model->GetOrCreate<WallTimer>();
@@ -2667,6 +2737,7 @@ void SolveCpModelParallel(const CpModelProto& model_proto,
shared.relaxation_solutions = shared_relaxation_solutions.get();
shared.lp_solutions = shared_lp_solutions.get();
shared.incomplete_solutions = shared_incomplete_solutions.get();
shared.clauses = shared_clauses.get();
shared.global_model = global_model;
// The list of all the SubSolver that will be used in this parallel search.
@@ -2863,13 +2934,30 @@ void SolveCpModelParallel(const CpModelProto& model_proto,
NonDeterministicLoop(subsolvers, num_search_workers);
}
if (parameters.log_subsolver_statistics()) {
// Log statistics.
if (logger->LoggingIsEnabled()) {
if (parameters.log_subsolver_statistics()) {
SOLVER_LOG(logger, "");
SOLVER_LOG(logger, "Sub-solver search statistics:");
for (const auto& subsolver : subsolvers) {
const std::string stats = subsolver->StatisticsString();
if (stats.empty()) continue;
SOLVER_LOG(logger,
absl::StrCat(" '", subsolver->name(), "':\n", stats));
}
}
SOLVER_LOG(logger, "");
SOLVER_LOG(logger, "Sub-solver search statistics:");
for (const auto& subsolver : subsolvers) {
const std::string stats = subsolver->StatisticsString();
if (stats.empty()) continue;
SOLVER_LOG(logger, absl::StrCat(" '", subsolver->name(), "':\n", stats));
shared.response->DisplayImprovementStatistics();
if (shared.bounds) {
SOLVER_LOG(logger, "");
shared.bounds->LogStatistics(logger);
}
if (shared.clauses) {
SOLVER_LOG(logger, "");
shared.clauses->LogStatistics(logger);
}
}
}
@@ -3454,10 +3542,9 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) {
QuickSolveWithHint(new_cp_model_proto, model);
}
SolveLoadedCpModel(new_cp_model_proto, model);
}
if (logger->LoggingIsEnabled()) {
if (params.num_search_workers() <= 1) {
// Sequential logging of LP statistics.
if (logger->LoggingIsEnabled()) {
const auto& lps =
*model->GetOrCreate<LinearProgrammingConstraintCollection>();
if (!lps.empty()) {
@@ -3467,11 +3554,6 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) {
}
}
}
if (params.num_search_workers() > 1) {
SOLVER_LOG(logger, "");
shared_response_manager->DisplayImprovementStatistics();
}
}
return shared_response_manager->GetResponse();

View File

@@ -24,7 +24,7 @@ option csharp_namespace = "Google.OrTools.Sat";
// Contains the definitions for all the sat algorithm parameters and their
// default values.
//
// NEXT TAG: 203
// NEXT TAG: 204
message SatParameters {
// In some context, like in a portfolio of search, it makes sense to name a
// given parameters set for logging purpose.
@@ -950,6 +950,9 @@ message SatParameters {
// Allows sharing of the bounds of modified variables at level 0.
optional bool share_level_zero_bounds = 114 [default = true];
// Allows sharing of new learned binary clause between workers.
optional bool share_binary_clauses = 203 [default = true];
// LNS parameters.
optional bool use_lns_only = 101 [default = false];

View File

@@ -375,6 +375,9 @@ int SatSolver::AddLearnedClauseAndEnqueueUnitPropagation(
if (track_binary_clauses_) {
CHECK(binary_clauses_.Add(BinaryClause(literals[0], literals[1])));
}
if (shared_binary_clauses_callback_ != nullptr) {
shared_binary_clauses_callback_(literals[0], literals[1]);
}
CHECK(binary_implication_graph_->AddBinaryClauseDuringSearch(literals[0],
literals[1]));
// In case this is the first binary clauses.
@@ -457,7 +460,7 @@ void SatSolver::AddBinaryClauseInternal(Literal a, Literal b) {
}
}
bool SatSolver::ClauseIsValidUnderDebugAssignement(
bool SatSolver::ClauseIsValidUnderDebugAssignment(
const std::vector<Literal>& clause) const {
for (Literal l : clause) {
if (l.Variable() >= debug_assignment_.NumberOfVariables() ||
@@ -583,7 +586,7 @@ bool SatSolver::PropagateAndStopAfterOneConflictResolution() {
// An empty conflict means that the problem is UNSAT.
if (learned_conflict_.empty()) return SetModelUnsat();
DCHECK(IsConflictValid(learned_conflict_));
DCHECK(ClauseIsValidUnderDebugAssignement(learned_conflict_));
DCHECK(ClauseIsValidUnderDebugAssignment(learned_conflict_));
// Update the activity of all the variables in the first UIP clause.
// Also update the activity of the last level variables expanded (and
@@ -726,7 +729,7 @@ bool SatSolver::PropagateAndStopAfterOneConflictResolution() {
// this way. Second, more variables may be marked (in is_marked_) and
// MinimizeConflict() can take advantage of that. Because of this, the
// LBD of the learned conflict can change.
DCHECK(ClauseIsValidUnderDebugAssignement(learned_conflict_));
DCHECK(ClauseIsValidUnderDebugAssignment(learned_conflict_));
if (!binary_implication_graph_->IsEmpty()) {
if (parameters_->binary_minimization_algorithm() ==
SatParameters::BINARY_MINIMIZATION_FIRST) {
@@ -782,7 +785,7 @@ bool SatSolver::PropagateAndStopAfterOneConflictResolution() {
// Backtrack and add the reason to the set of learned clause.
counters_.num_literals_learned += learned_conflict_.size();
Backtrack(ComputeBacktrackLevel(learned_conflict_));
DCHECK(ClauseIsValidUnderDebugAssignement(learned_conflict_));
DCHECK(ClauseIsValidUnderDebugAssignment(learned_conflict_));
// Note that we need to output the learned clause before cleaning the clause
// database. This is because we already backtracked and some of the clauses

View File

@@ -274,7 +274,7 @@ class SatSolver {
void Backtrack(int target_level);
// Advanced usage. This is meant to restore the solver to a "proper" state
// after a solve was interupted due to a limit reached.
// after a solve was interrupted due to a limit reached.
//
// Without assumption (i.e. if AssumptionLevel() is 0), this will revert all
// decisions and make sure that all the fixed literals are propagated. In
@@ -286,7 +286,7 @@ class SatSolver {
// case it will return false.
bool RestoreSolverToAssumptionLevel();
// Advanced usage. Finish the progation if it was interupted. Note that this
// Advanced usage. Finish the progation if it was interrupted. Note that this
// might run into conflict and will propagate again until a fixed point is
// reached or the model was proven UNSAT. Returns IsModelUnsat().
bool FinishPropagation();
@@ -381,7 +381,7 @@ class SatSolver {
// The idea is that if we know that a given assignment is satisfiable, then
// all the learned clauses or PB constraints must be satisfiable by it. In
// debug mode, and after this is called, all the learned clauses are tested to
// satisfy this saved assignement.
// satisfy this saved assignment.
void SaveDebugAssignment();
// Returns true iff the loaded problem only contains clauses.
@@ -418,6 +418,12 @@ class SatSolver {
// use propagation to try to minimize some clauses from the database.
void MinimizeSomeClauses(int decisions_budget);
// Sets the export function to the shared clauses manager.
void SetShareBinaryClauseCallback(const std::function<void(Literal, Literal)>&
shared_binary_clauses_callback) {
shared_binary_clauses_callback_ = shared_binary_clauses_callback;
}
// Advance the given time limit with all the deterministic time that was
// elapsed since last call.
void AdvanceDeterministicTime(TimeLimit* limit) {
@@ -452,7 +458,7 @@ class SatSolver {
// See SaveDebugAssignment(). Note that these functions only consider the
// variables at the time the debug_assignment_ was saved. If new variables
// were added since that time, they will be considered unassigned.
bool ClauseIsValidUnderDebugAssignement(
bool ClauseIsValidUnderDebugAssignment(
const std::vector<Literal>& clause) const;
bool PBConstraintIsValidUnderDebugAssignment(
const std::vector<LiteralWithCoeff>& cst, const Coefficient rhs) const;
@@ -832,6 +838,10 @@ class SatSolver {
DratProofHandler* drat_proof_handler_;
mutable StatsGroup stats_;
std::function<void(Literal, Literal)> shared_binary_clauses_callback_ =
nullptr;
DISALLOW_COPY_AND_ASSIGN(SatSolver);
};

View File

@@ -21,6 +21,7 @@
#include "ortools/sat/cp_model_mapping.h"
#endif // __PORTABLE_PLATFORM__
#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/random/random.h"
#include "ortools/base/integral_types.h"
@@ -750,11 +751,8 @@ void SharedBoundsManager::ReportPotentialNewBounds(
changed_variables_since_last_synchronize_.Set(var);
num_improvements++;
}
// TODO(user): Display number of bound improvements cumulatively per
// workers at the end of the search.
if (num_improvements > 0) {
VLOG(2) << worker_name << " exports " << num_improvements
<< " modifications";
bounds_exported_[worker_name] += num_improvements;
}
}
@@ -840,5 +838,72 @@ void SharedBoundsManager::GetChangedBounds(
id_to_changed_variables_[id].ClearAll();
}
void SharedBoundsManager::LogStatistics(SolverLogger* logger) {
absl::MutexLock mutex_lock(&mutex_);
if (!bounds_exported_.empty()) {
SOLVER_LOG(logger, "Improving variable bounds shared per subsolver:");
for (const auto& entry : bounds_exported_) {
SOLVER_LOG(logger, " '", entry.first, "': ", entry.second);
}
}
}
int SharedClausesManager::RegisterNewId() {
absl::MutexLock mutex_lock(&mutex_);
const int id = id_to_last_processed_binary_clause_.size();
id_to_last_processed_binary_clause_.resize(id + 1, 0);
id_to_clauses_exported_.resize(id + 1, 0);
return id;
}
void SharedClausesManager::SetWorkerNameForId(int id,
const std::string& worker_name) {
absl::MutexLock mutex_lock(&mutex_);
id_to_worker_name_[id] = worker_name;
}
void SharedClausesManager::AddBinaryClause(int id, int lit1, int lit2) {
absl::MutexLock mutex_lock(&mutex_);
if (lit2 < lit1) std::swap(lit1, lit2);
const auto p = std::make_pair(lit1, lit2);
const auto [unused_it, inserted] = added_binary_clauses_set_.insert(p);
if (inserted) {
added_binary_clauses_.push_back(p);
id_to_clauses_exported_[id]++;
// Small optim. If the worker is already up to date with clauses to import,
// we can mark this new clause as already seen.
if (id_to_last_processed_binary_clause_[id] ==
added_binary_clauses_.size() - 1) {
id_to_last_processed_binary_clause_[id]++;
}
}
}
void SharedClausesManager::GetUnseenBinaryClauses(
int id, std::vector<std::pair<int, int>>* new_clauses) {
new_clauses->clear();
absl::MutexLock mutex_lock(&mutex_);
const int last_binary_clause_seen = id_to_last_processed_binary_clause_[id];
new_clauses->assign(added_binary_clauses_.begin() + last_binary_clause_seen,
added_binary_clauses_.end());
id_to_last_processed_binary_clause_[id] = added_binary_clauses_.size();
}
void SharedClausesManager::LogStatistics(SolverLogger* logger) {
absl::MutexLock mutex_lock(&mutex_);
absl::btree_map<std::string, int64_t> name_to_clauses;
for (int id = 0; id < id_to_clauses_exported_.size(); ++id) {
if (id_to_clauses_exported_[id] == 0) continue;
name_to_clauses[id_to_worker_name_[id]] = id_to_clauses_exported_[id];
}
if (!name_to_clauses.empty()) {
SOLVER_LOG(logger, "Clauses shared per subsolver:");
for (const auto& entry : name_to_clauses) {
SOLVER_LOG(logger, " '", entry.first, "': ", entry.second);
}
}
}
} // namespace sat
} // namespace operations_research

View File

@@ -447,6 +447,8 @@ class SharedBoundsManager {
// state.
void Synchronize();
void LogStatistics(SolverLogger* logger);
private:
const int num_variables_;
const CpModelProto& model_proto_;
@@ -464,6 +466,44 @@ class SharedBoundsManager {
std::vector<int64_t> synchronized_upper_bounds_ ABSL_GUARDED_BY(mutex_);
std::deque<SparseBitset<int64_t>> id_to_changed_variables_
ABSL_GUARDED_BY(mutex_);
std::map<std::string, int> bounds_exported_ ABSL_GUARDED_BY(mutex_);
};
// This class holds all the binary clauses that were found and shared by the
// workers.
//
// It is thread-safe.
//
// Note that this uses literal as encoded in a cp_model.proto. The literals can
// thus be negative numbers.
class SharedClausesManager {
public:
void AddBinaryClause(int id, int lit1, int lit2);
// Fills flat_clauses with
// (lit1 of clause1, lit2 of clause1, lit1 of clause 2, lit2 of clause2 ...)
void GetUnseenBinaryClauses(int id,
std::vector<std::pair<int, int>>* new_clauses);
int RegisterNewId();
void SetWorkerNameForId(int id, const std::string& worker_name);
// Search statistics.
void LogStatistics(SolverLogger* logger);
private:
absl::Mutex mutex_;
// Cache to avoid adding the same clause twice.
absl::flat_hash_set<std::pair<int, int>> added_binary_clauses_set_
ABSL_GUARDED_BY(mutex_);
std::vector<std::pair<int, int>> added_binary_clauses_
ABSL_GUARDED_BY(mutex_);
std::vector<int64_t> id_to_last_processed_binary_clause_
ABSL_GUARDED_BY(mutex_);
std::vector<int64_t> id_to_clauses_exported_;
// Used for reporting statistics.
absl::flat_hash_map<int, std::string> id_to_worker_name_;
};
template <typename ValueType>