[CP-SAT] remove bug in cuts; better glue clauses sharing

This commit is contained in:
Laurent Perron
2024-05-15 11:48:06 +02:00
parent f1a9881283
commit 1750d19b51
11 changed files with 327 additions and 63 deletions

View File

@@ -138,10 +138,14 @@ cc_library(
"//ortools/util:time_limit",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/random",
"@com_google_absl//absl/random:bit_gen_ref",
@@ -352,6 +356,7 @@ cc_library(
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
@@ -817,6 +822,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/random:bit_gen_ref",
"@com_google_absl//absl/random:distributions",

View File

@@ -217,20 +217,22 @@ SatClause* ClauseManager::ReasonClause(int trail_index) const {
}
bool ClauseManager::AddClause(absl::Span<const Literal> literals) {
return AddClause(literals, trail_);
return AddClause(literals, trail_, -1);
}
bool ClauseManager::AddClause(absl::Span<const Literal> literals,
Trail* trail) {
bool ClauseManager::AddClause(absl::Span<const Literal> literals, Trail* trail,
int lbd) {
SatClause* clause = SatClause::Create(literals);
clauses_.push_back(clause);
if (add_clause_callback_ != nullptr) add_clause_callback_(lbd, literals);
return AttachAndPropagate(clause, trail);
}
SatClause* ClauseManager::AddRemovableClause(
const std::vector<Literal>& literals, Trail* trail) {
const std::vector<Literal>& literals, Trail* trail, int lbd) {
SatClause* clause = SatClause::Create(literals);
clauses_.push_back(clause);
if (add_clause_callback_ != nullptr) add_clause_callback_(lbd, literals);
CHECK(AttachAndPropagate(clause, trail));
return clause;
}
@@ -558,7 +560,9 @@ bool BinaryImplicationGraph::AddBinaryClause(Literal a, Literal b) {
is_dag_ = false;
num_implications_ += 2;
if (enable_sharing_ && add_callback_ != nullptr) add_callback_(a, b);
if (enable_sharing_ && add_binary_callback_ != nullptr) {
add_binary_callback_(a, b);
}
const auto& assignment = trail_->Assignment();
if (trail_->CurrentDecisionLevel() == 0) {

View File

@@ -28,6 +28,7 @@
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/functional/any_invocable.h"
#include "absl/log/check.h"
#include "absl/random/bit_gen_ref.h"
#include "absl/types/span.h"
@@ -179,13 +180,13 @@ class ClauseManager : public SatPropagator {
SatClause* ReasonClause(int trail_index) const;
// Adds a new clause and perform initial propagation for this clause only.
bool AddClause(absl::Span<const Literal> literals, Trail* trail);
bool AddClause(absl::Span<const Literal> literals, Trail* trail, int lbd);
bool AddClause(absl::Span<const Literal> literals);
// Same as AddClause() for a removable clause. This is only called on learned
// conflict, so this should never have all its literal at false (CHECKED).
SatClause* AddRemovableClause(const std::vector<Literal>& literals,
Trail* trail);
Trail* trail, int lbd);
// Lazily detach the given clause. The deletion will actually occur when
// CleanUpWatchers() is called. The later needs to be called before any other
@@ -325,6 +326,12 @@ class ClauseManager : public SatPropagator {
return watchers_on_false_[false_literal];
}
void SetAddClauseCallback(
absl::AnyInvocable<void(int lbd, absl::Span<const Literal>)>
add_clause_callback) {
add_clause_callback_ = std::move(add_clause_callback);
}
private:
// Attaches the given clause. This eventually propagates a literal which is
// enqueued on the trail. Returns false if a contradiction was encountered.
@@ -379,6 +386,9 @@ class ClauseManager : public SatPropagator {
absl::flat_hash_map<SatClause*, ClauseInfo> clauses_info_;
DratProofHandler* drat_proof_handler_ = nullptr;
absl::AnyInvocable<void(int lbd, absl::Span<const Literal>)>
add_clause_callback_ = nullptr;
};
// A binary clause. This is used by BinaryClauseManager.
@@ -530,9 +540,8 @@ class BinaryImplicationGraph : public SatPropagator {
// were we keep new implication and add them in batches.
void EnableSharing(bool enable) { enable_sharing_ = enable; }
void SetAdditionCallback(std::function<void(Literal, Literal)> f) {
add_callback_ = f;
add_binary_callback_ = f;
}
// An at most one constraint of size n is a compact way to encode n * (n - 1)
// implications. This must only be called at level zero.
//
@@ -680,8 +689,8 @@ class BinaryImplicationGraph : public SatPropagator {
return num_redundant_implications_;
}
// Returns the number of current implications. Note that a => b and not(b) =>
// not(a) are counted separately since they appear separately in our
// Returns the number of current implications. Note that a => b and not(b)
// => not(a) are counted separately since they appear separately in our
// propagation lists. The number of size 2 clauses that represent the same
// thing is half this number.
int64_t num_implications() const { return num_implications_; }
@@ -933,7 +942,7 @@ class BinaryImplicationGraph : public SatPropagator {
int num_processed_fixed_variables_ = 0;
bool enable_sharing_ = true;
std::function<void(Literal, Literal)> add_callback_ = nullptr;
std::function<void(Literal, Literal)> add_binary_callback_ = nullptr;
};
extern template std::vector<Literal>

View File

@@ -14,6 +14,7 @@
#include "ortools/sat/cp_model_solver.h"
#include <algorithm>
#include <array>
#include <atomic>
#include <cstdint>
#include <cstdlib>
@@ -1414,8 +1415,7 @@ void RegisterObjectiveBoundsImport(
import_objective_bounds);
}
// Registers a callback that will export binary clauses discovered during
// search.
// Registers a callback that will export good clauses discovered during search.
void RegisterClausesExport(int id, SharedClausesManager* shared_clauses_manager,
Model* model) {
auto* mapping = model->GetOrCreate<CpModelMapping>();
@@ -1433,6 +1433,30 @@ void RegisterClausesExport(int id, SharedClausesManager* shared_clauses_manager,
};
model->GetOrCreate<BinaryImplicationGraph>()->SetAdditionCallback(
share_binary_clause);
if (!model->GetOrCreate<SatParameters>()->share_glue_clauses()) {
return;
}
auto* clause_stream = shared_clauses_manager->GetClauseStream(id);
// Note that this callback takes no locks, everything operates on this
// worker's own clause_stream, clauses are exported from there by SyncClauses
// at level zero.
auto share_clause = [mapping, clause_stream](
int lbd, absl::Span<const Literal> literals) {
if (lbd <= 0 || lbd > 2 || literals.size() <= 2 ||
literals.size() > UniqueClauseStream::kMaxClauseSize) {
return;
}
std::vector<int> clause;
for (const Literal& lit : literals) {
const int var =
mapping->GetProtoVariableFromBooleanVariable(lit.Variable());
if (var == -1) return;
clause.push_back(lit.IsPositive() ? var : NegatedRef(var));
}
clause_stream->Add(std::move(clause));
};
model->GetOrCreate<ClauseManager>()->SetAddClauseCallback(
std::move(share_clause));
}
// Registers a callback to import new clauses stored in the
@@ -1448,8 +1472,14 @@ int RegisterClausesLevelZeroImport(int id,
CpModelMapping* const mapping = model->GetOrCreate<CpModelMapping>();
auto* sat_solver = model->GetOrCreate<SatSolver>();
auto* implications = model->GetOrCreate<BinaryImplicationGraph>();
bool share_glue_clauses =
model->GetOrCreate<SatParameters>()->share_glue_clauses();
auto* clause_stream = share_glue_clauses
? shared_clauses_manager->GetClauseStream(id)
: nullptr;
const auto& import_level_zero_clauses = [shared_clauses_manager, id, mapping,
sat_solver, implications]() {
sat_solver, implications,
clause_stream]() {
std::vector<std::pair<int, int>> new_binary_clauses;
shared_clauses_manager->GetUnseenBinaryClauses(id, &new_binary_clauses);
implications->EnableSharing(false);
@@ -1461,6 +1491,20 @@ int RegisterClausesLevelZeroImport(int id,
}
}
implications->EnableSharing(true);
if (clause_stream == nullptr) return true;
std::array<Literal, UniqueClauseStream::kMaxClauseSize> local_clause;
for (const absl::Span<const int> shared_clause :
shared_clauses_manager->SyncClauses(id)) {
// Check this clause was not already learned by this worker.
if (!clause_stream->BlockClause(shared_clause)) continue;
for (int i = 0; i < shared_clause.size(); ++i) {
local_clause[i] = mapping->Literal(shared_clause[i]);
}
if (!sat_solver->AddProblemClause(
absl::MakeSpan(local_clause).subspan(0, shared_clause.size()))) {
return false;
}
}
return true;
};
model->GetOrCreate<LevelZeroCallbackHelper>()->callbacks.push_back(

View File

@@ -89,7 +89,7 @@ void CutTerm::Complement(absl::int128* rhs) {
expr_offset = bound_diff - expr_offset;
// Note that this is not involutive because of floating point error. Fix?
lp_value = ToDouble(bound_diff) - lp_value;
lp_value = static_cast<double>(bound_diff.value()) - lp_value;
coeff = -coeff;
}
@@ -152,13 +152,13 @@ bool CutData::AppendOneTerm(IntegerVariable var, IntegerValue coeff,
entry.expr_coeffs[0] = -IntegerValue(1);
entry.expr_offset = ub;
entry.coeff = -coeff;
entry.lp_value = ToDouble(ub) - lp_value;
entry.lp_value = static_cast<double>(ub.value()) - lp_value;
} else {
// C = (X - LB) + LB
entry.expr_coeffs[0] = IntegerValue(1);
entry.expr_offset = -lb;
entry.coeff = coeff;
entry.lp_value = lp_value - ToDouble(lb);
entry.lp_value = lp_value - static_cast<double>(lb.value());
}
terms.push_back(entry);
return true;
@@ -660,7 +660,7 @@ double IntegerRoundingCutHelper::GetScaledViolation(
// Even before we finish the adjust, we can have a lower bound on the
// activily loss using this divisor, and so we can abort early. This is
// similar to what is done below.
double max_violation = ToDouble(initial_rhs_remainder);
double max_violation = static_cast<double>(initial_rhs_remainder.value());
for (int i = 0; i < cut.num_relevant_entries; ++i) {
const CutTerm& entry = cut.terms[i];
const IntegerValue remainder = PositiveRemainder(entry.coeff, divisor);
@@ -668,7 +668,8 @@ double IntegerRoundingCutHelper::GetScaledViolation(
if (remainder <= initial_rhs_remainder) {
// We do not know exactly f() yet, but it will always round to the
// floor of the division by divisor in this case.
max_violation -= ToDouble(remainder) * entry.lp_value;
max_violation -=
static_cast<double>(remainder.value()) * entry.lp_value;
if (max_violation <= 1e-3) return 0.0;
continue;
}
@@ -1049,16 +1050,16 @@ struct LargeCoeffFirst {
struct SmallContribFirst {
bool operator()(const CutTerm& a, const CutTerm& b) const {
const double contrib_a = a.lp_value * AsDouble(a.coeff);
const double contrib_b = b.lp_value * AsDouble(b.coeff);
const double contrib_a = a.lp_value * static_cast<double>(a.coeff.value());
const double contrib_b = b.lp_value * static_cast<double>(b.coeff.value());
return contrib_a < contrib_b;
}
};
struct LargeContribFirst {
bool operator()(const CutTerm& a, const CutTerm& b) const {
const double contrib_a = a.lp_value * AsDouble(a.coeff);
const double contrib_b = b.lp_value * AsDouble(b.coeff);
const double contrib_a = a.lp_value * static_cast<double>(a.coeff.value());
const double contrib_b = b.lp_value * static_cast<double>(b.coeff.value());
return contrib_a > contrib_b;
}
};
@@ -1069,15 +1070,19 @@ struct LargeContribFirst {
// lead to the same formula as for Booleans.
struct KnapsackAdd {
bool operator()(const CutTerm& a, const CutTerm& b) const {
const double contrib_a = a.LpDistToMaxValue() / AsDouble(a.coeff);
const double contrib_b = b.LpDistToMaxValue() / AsDouble(b.coeff);
const double contrib_a =
a.LpDistToMaxValue() / static_cast<double>(a.coeff.value());
const double contrib_b =
b.LpDistToMaxValue() / static_cast<double>(b.coeff.value());
return contrib_a < contrib_b;
}
};
struct KnapsackRemove {
bool operator()(const CutTerm& a, const CutTerm& b) const {
const double contrib_a = a.LpDistToMaxValue() / AsDouble(a.coeff);
const double contrib_b = b.LpDistToMaxValue() / AsDouble(b.coeff);
const double contrib_a =
a.LpDistToMaxValue() / static_cast<double>(a.coeff.value());
const double contrib_b =
b.LpDistToMaxValue() / static_cast<double>(b.coeff.value());
return contrib_a > contrib_b;
}
};
@@ -1352,14 +1357,6 @@ bool CoverCutHelper::TrySingleNodeFlow(const CutData& input_ct,
ImpliedBoundsProcessor* ib_processor) {
InitializeCut(input_ct);
bool has_large_coeff = false;
for (const CutTerm& term : cut_.terms) {
if (IntTypeAbs(term.coeff) > 1'000'000) {
has_large_coeff = true;
break;
}
}
// TODO(user): Change the heuristic to depends on the lp_value of the implied
// bounds. This way we can exactly match what happen in FlowCoverCutHelper and
// remove the code there.
@@ -1393,6 +1390,14 @@ bool CoverCutHelper::TrySingleNodeFlow(const CutData& input_ct,
return false;
}
bool has_large_coeff = false;
for (const CutTerm& term : cut_.terms) {
if (IntTypeAbs(term.coeff) > 1'000'000) {
has_large_coeff = true;
break;
}
}
// TODO(user): Shouldn't we just use rounding f() with maximum coeff to allows
// lift of all other terms? but then except for the heuristic the cut is
// really similar to the cover cut.

View File

@@ -237,8 +237,9 @@ class ImpliedBoundsProcessor {
IntegerVariable bool_var = kNoIntegerVariable;
double SlackLpValue(IntegerValue lb) const {
const double bool_term = ToDouble(implied_bound - lb) * bool_lp_value;
return var_lp_value - ToDouble(lb) - bool_term;
const double bool_term =
static_cast<double>((implied_bound - lb).value()) * bool_lp_value;
return var_lp_value - static_cast<double>(lb.value()) - bool_term;
}
std::string DebugString() const {

View File

@@ -1849,9 +1849,15 @@ bool LinearProgrammingConstraint::ScalingCanOverflow(
const std::vector<std::pair<glop::RowIndex, double>>& multipliers,
int64_t overflow_cap) const {
int64_t bound = 0;
const int64_t factor = int64_t{1} << power;
const double factor_as_double = static_cast<double>(factor);
if (take_objective_into_account) {
bound = CapAdd(bound, CapProd(factor, objective_infinity_norm_.value()));
if (bound >= overflow_cap) return true;
}
for (const auto [row, double_coeff] : multipliers) {
const double magnitude =
std::abs(std::round(std::ldexp(double_coeff, power)));
std::abs(std::round(double_coeff * factor_as_double));
if (std::isnan(magnitude)) return true;
if (magnitude >= static_cast<double>(std::numeric_limits<int64_t>::max())) {
return true;
@@ -1860,11 +1866,6 @@ bool LinearProgrammingConstraint::ScalingCanOverflow(
infinity_norms_[row].value()));
if (bound >= overflow_cap) return true;
}
if (take_objective_into_account) {
bound = CapAdd(
bound, CapProd(int64_t{1} << power, objective_infinity_norm_.value()));
if (bound >= overflow_cap) return true;
}
return bound >= overflow_cap;
}
@@ -1936,8 +1937,9 @@ LinearProgrammingConstraint::ScaleLpMultiplier(
// Scale the multipliers by *scaling.
// Note that we use the exact same formula as in ScalingCanOverflow().
int64_t gcd = scaling->value();
const double scaling_as_double = static_cast<double>(scaling->value());
for (const auto [row, double_coeff] : tmp_cp_multipliers_) {
const IntegerValue coeff(std::round(std::ldexp(double_coeff, power)));
const IntegerValue coeff(std::round(double_coeff * scaling_as_double));
if (coeff != 0) {
gcd = std::gcd(gcd, std::abs(coeff.value()));
integer_multipliers.push_back({row, coeff});

View File

@@ -23,7 +23,7 @@ option csharp_namespace = "Google.OrTools.Sat";
// Contains the definitions for all the sat algorithm parameters and their
// default values.
//
// NEXT TAG: 285
// NEXT TAG: 286
message SatParameters {
// In some context, like in a portfolio of search, it makes sense to name a
// given parameters set for logging purpose.
@@ -647,6 +647,10 @@ message SatParameters {
// Allows sharing of new learned binary clause between workers.
optional bool share_binary_clauses = 203 [default = true];
// Allows sharing of short glue clauses between workers.
// Implicitly disabled if share_binary_clauses is false.
optional bool share_glue_clauses = 285 [default = false];
// ==========================================================================
// Debugging parameters
// ==========================================================================

View File

@@ -273,7 +273,7 @@ bool SatSolver::AddProblemClauseInternal(absl::Span<const Literal> literals) {
AddBinaryClauseInternal(literals[0], literals[1]);
}
} else {
if (!clauses_propagator_->AddClause(literals, trail_)) {
if (!clauses_propagator_->AddClause(literals, trail_, /*lbd=*/-1)) {
return SetModelUnsat();
}
}
@@ -432,14 +432,14 @@ int SatSolver::AddLearnedClauseAndEnqueueUnitPropagation(
--num_learned_clause_before_cleanup_;
SatClause* clause =
clauses_propagator_->AddRemovableClause(literals, trail_);
clauses_propagator_->AddRemovableClause(literals, trail_, lbd);
// BumpClauseActivity() must be called after clauses_info_[clause] has
// been created or it will have no effect.
(*clauses_propagator_->mutable_clauses_info())[clause].lbd = lbd;
BumpClauseActivity(clause);
} else {
CHECK(clauses_propagator_->AddClause(literals, trail_));
CHECK(clauses_propagator_->AddClause(literals, trail_, lbd));
}
return lbd;
}

View File

@@ -13,11 +13,15 @@
#include "ortools/sat/synchronization.h"
#include <sys/types.h>
#include <algorithm>
#include <atomic>
#include <cctype>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <ctime>
#include <deque>
#include <functional>
#include <limits>
@@ -25,6 +29,7 @@
#include <utility>
#include <vector>
#include "absl/hash/hash.h"
#include "ortools/base/logging.h"
#include "ortools/base/timer.h"
#if !defined(__PORTABLE_PLATFORM__)
@@ -1034,6 +1039,88 @@ int SharedBoundsManager::NumBoundsExported(const std::string& worker_name) {
return it->second;
}
bool UniqueClauseStream::Add(absl::Span<const int> clause) {
const int index = clause.size() - 3;
if (BlockClause(clause)) {
num_buffered_literals_ += clause.size();
clauses_by_size_[index].insert(clauses_by_size_[index].end(),
clause.begin(), clause.end());
return true;
}
return false;
}
bool UniqueClauseStream::BlockClause(absl::Span<const int> clause) {
if (clause.size() > kMaxClauseSize) return false;
if (clause.size() <= 2) return false;
bool is_new = false;
// We set 4 bits all in the same page to guarantee at most 1 page fault per
// insertion.
const size_t page_offset =
(HashClause(clause, -1) % kBloomFilterPages) * kBitsPerPage;
// We could use more bits per hash if this is ever too slow.
for (int i = 0; i < 4; ++i) {
const size_t bit = page_offset + HashClause(clause, i) % kBitsPerPage;
is_new = is_new || !filter_.test(bit);
filter_.set(bit, true);
}
return is_new;
}
CompactVectorVector<int> UniqueClauseStream::NextBatch(int num_literals) {
CompactVectorVector<int> buffer;
buffer.reserve(num_literals / 3, num_literals);
int total_literals = 0;
for (int size = 3; size < kMaxClauseSize; ++size) {
const int size_index = size - 3;
while (total_literals + size <= num_literals &&
!clauses_by_size_[size_index].empty()) {
buffer.Add(NextClause(size));
PopClause(size);
total_literals += size;
}
}
return buffer;
}
int UniqueClauseStream::MoveClausesTo(UniqueClauseStream& upstream,
int max_literals) {
int num_exported_clauses = 0;
for (int size = 3; size < kMaxClauseSize; ++size) {
const int size_index = size - 3;
while (!clauses_by_size_[size_index].empty() && max_literals >= size) {
max_literals -= size;
if (upstream.Add(NextClause(size))) {
++num_exported_clauses;
}
PopClause(size);
}
}
return num_exported_clauses;
}
size_t UniqueClauseStream::HashClause(absl::Span<const int> clause,
size_t hash_seed) {
size_t hash = absl::HashOf(hash_seed, clause.size());
for (int i = 0; i < clause.size(); ++i) {
hash ^= absl::HashOf(clause[i], hash_seed);
}
return hash;
}
absl::Span<const int> UniqueClauseStream::NextClause(int size) const {
const int index = size - 3;
return absl::MakeConstSpan(clauses_by_size_[index])
.subspan(clauses_by_size_[index].size() - size, size);
}
void UniqueClauseStream::PopClause(int size) {
const int index = size - 3;
num_buffered_literals_ -= size;
clauses_by_size_[index].erase(clauses_by_size_[index].end() - size,
clauses_by_size_[index].end());
}
SharedClausesManager::SharedClausesManager(bool always_synchronize)
: always_synchronize_(always_synchronize) {}
@@ -1041,7 +1128,9 @@ int SharedClausesManager::RegisterNewId() {
absl::MutexLock mutex_lock(&mutex_);
const int id = id_to_last_processed_binary_clause_.size();
id_to_last_processed_binary_clause_.resize(id + 1, 0);
id_to_last_processed_batch_.resize(id + 1, 0);
id_to_clauses_exported_.resize(id + 1, 0);
id_to_clause_stream_.emplace_back();
return id;
}
@@ -1059,7 +1148,7 @@ void SharedClausesManager::AddBinaryClause(int id, int lit1, int lit2) {
const auto [unused_it, inserted] = added_binary_clauses_set_.insert(p);
if (inserted) {
added_binary_clauses_.push_back(p);
if (always_synchronize_) ++last_visible_clause_;
if (always_synchronize_) ++last_visible_binary_clause_;
id_to_clauses_exported_[id]++;
// Small optim. If the worker is already up to date with clauses to import,
@@ -1071,16 +1160,35 @@ void SharedClausesManager::AddBinaryClause(int id, int lit1, int lit2) {
}
}
std::vector<absl::Span<const int>> SharedClausesManager::SyncClauses(int id) {
absl::MutexLock mutex_lock(&mutex_);
UniqueClauseStream& worker_clauses = id_to_clause_stream_[id];
id_to_clauses_exported_[id] +=
worker_clauses.MoveClausesTo(all_clauses_, kLiteralsPerBatch);
std::vector<absl::Span<const int>> result;
for (int i = id_to_last_processed_batch_[id]; i < batches_.size(); ++i) {
for (int j = 0; j < batches_[i].size(); ++j) {
result.push_back(batches_[i][j]);
}
}
// TODO: tobyodavies - We should delete old clauses that have been consumed by
// all workers. This will be subtle as the returned spans must remain valid
// until the *next* call to SyncClauses() after they are returned.
id_to_last_processed_batch_[id] = batches_.size();
return result;
}
void SharedClausesManager::GetUnseenBinaryClauses(
int id, std::vector<std::pair<int, int>>* new_clauses) {
new_clauses->clear();
absl::MutexLock mutex_lock(&mutex_);
const int last_binary_clause_seen = id_to_last_processed_binary_clause_[id];
if (last_binary_clause_seen >= last_visible_clause_) return;
if (last_binary_clause_seen >= last_visible_binary_clause_) return;
new_clauses->assign(added_binary_clauses_.begin() + last_binary_clause_seen,
added_binary_clauses_.begin() + last_visible_clause_);
id_to_last_processed_binary_clause_[id] = last_visible_clause_;
new_clauses->assign(
added_binary_clauses_.begin() + last_binary_clause_seen,
added_binary_clauses_.begin() + last_visible_binary_clause_);
id_to_last_processed_binary_clause_[id] = last_visible_binary_clause_;
}
void SharedClausesManager::LogStatistics(SolverLogger* logger) {
@@ -1102,8 +1210,11 @@ void SharedClausesManager::LogStatistics(SolverLogger* logger) {
void SharedClausesManager::Synchronize() {
absl::MutexLock mutex_lock(&mutex_);
last_visible_clause_ = added_binary_clauses_.size();
// TODO(user): We could cleanup added_binary_clauses_ periodically.
last_visible_binary_clause_ = added_binary_clauses_.size();
if (all_clauses_.num_buffered_literals() >= kLiteralsPerBatch) {
batches_.push_back(all_clauses_.NextBatch(kLiteralsPerBatch));
}
// TODO(user): We could cleanup clauses that have been consumed.
}
void SharedStatistics::AddStats(

View File

@@ -14,7 +14,10 @@
#ifndef OR_TOOLS_SAT_SYNCHRONIZATION_H_
#define OR_TOOLS_SAT_SYNCHRONIZATION_H_
#include <array>
#include <atomic>
#include <bitset>
#include <cstddef>
#include <cstdint>
#include <deque>
#include <functional>
@@ -31,7 +34,6 @@
#include "absl/random/random.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "ortools/base/logging.h"
#include "ortools/base/stl_util.h"
@@ -570,8 +572,60 @@ class SharedBoundsManager {
int export_counter_ = 0;
};
// This class holds all the binary clauses that were found and shared by the
// workers.
// Emit a stream of clauses in batches without duplicates.
//
// This class is thread-compatible, the idea is to have one per worker plus a
// global one to deduplicate between workers.
//
// Note that this uses literal as encoded in a cp_model.proto. Thus, the
// literals can be negative numbers.
class UniqueClauseStream {
public:
static constexpr int kMaxClauseSize = 8;
// The bloom filter is 1MiB.
static constexpr int kBloomFilterBits = 1024 * 1024 * 8;
static constexpr int kBitsPerPage = 4096 * 8;
static constexpr int kBloomFilterPages = kBloomFilterBits / kBitsPerPage;
UniqueClauseStream() = default;
// Move only - this is an expensive class to copy.
UniqueClauseStream(const UniqueClauseStream&) = delete;
UniqueClauseStream(UniqueClauseStream&&) = default;
// Adds the clause to a future batch and returns true if the clause was new.
// Otherwise returns false if the clause was previously added or blocked.
bool Add(absl::Span<const int> clause);
// Inserts the clause into the bloom filter without adding it to the buffer,
// Returns true if the clause is new.
bool BlockClause(absl::Span<const int> clause);
// Returns the sum of sizes of all buffered clauses.
int num_buffered_literals() const { return num_buffered_literals_; }
// Returns a set of clauses totalling up to num_literals and removes them from
// the internal buffer.
CompactVectorVector<int> NextBatch(int num_literals);
// Adds all clauses from this stream to upstream and removes them from the
// internal buffer.
int MoveClausesTo(UniqueClauseStream& upstream, int max_literals);
// Computes a hash that is independent of the order of literals in the clause.
static size_t HashClause(absl::Span<const int> clause, size_t hash_seed = 0);
private:
absl::Span<const int> NextClause(int size) const;
void PopClause(int size);
int NumClauses(int size) const;
std::bitset<kBloomFilterBits> filter_;
int num_buffered_literals_ = 0;
std::array<std::vector<int>, kMaxClauseSize - 2> clauses_by_size_;
};
// This class holds clauses found and shared by workers.
// It is exact for binary clauses, but approximate for longer ones.
//
// It is thread-safe.
//
@@ -582,6 +636,10 @@ class SharedClausesManager {
explicit SharedClausesManager(bool always_synchronize);
void AddBinaryClause(int id, int lit1, int lit2);
// Imports all clauses from the given id into the shared pool.
// Returns new clauses.
std::vector<absl::Span<const int>> SyncClauses(int id);
// Fills new_clauses with
// {{lit1 of clause1, lit2 of clause1},
// {lit1 of clause2, lit2 of clause2},
@@ -593,26 +651,46 @@ class SharedClausesManager {
int RegisterNewId();
void SetWorkerNameForId(int id, const std::string& worker_name);
// A worker can add or remove clauses from its own clause set.
// It must not do so concurrently with any call to SyncClauses.
// Retains ownership of the returned ClauseFilter.
UniqueClauseStream* GetClauseStream(int id) {
absl::ReaderMutexLock mutex_lock(&mutex_);
return &id_to_clause_stream_[id];
}
// Search statistics.
void LogStatistics(SolverLogger* logger);
// Unlocks waiting binary clauses for workers if always_synchronize is false.
// Periodically starts a new sharing round, making glue clauses visible.
void Synchronize();
private:
// 1024 literals is 4KiB, i.e. 1 page.
static constexpr int kLiteralsPerBatch = 1024;
absl::Mutex mutex_;
// Cache to avoid adding the same clause twice.
// Binary clauses:
// Cache to avoid adding the same binary clause twice.
absl::flat_hash_set<std::pair<int, int>> added_binary_clauses_set_
ABSL_GUARDED_BY(mutex_);
std::vector<std::pair<int, int>> added_binary_clauses_
ABSL_GUARDED_BY(mutex_);
std::vector<int> id_to_last_processed_binary_clause_ ABSL_GUARDED_BY(mutex_);
std::vector<int64_t> id_to_clauses_exported_;
int last_visible_clause_ ABSL_GUARDED_BY(mutex_) = 0;
int last_visible_binary_clause_ ABSL_GUARDED_BY(mutex_) = 0;
// Longer clauses:
UniqueClauseStream all_clauses_ ABSL_GUARDED_BY(mutex_);
std::vector<int> id_to_last_processed_batch_ ABSL_GUARDED_BY(mutex_);
std::deque<CompactVectorVector<int>> batches_ ABSL_GUARDED_BY(mutex_);
std::deque<UniqueClauseStream> id_to_clause_stream_ ABSL_GUARDED_BY(mutex_);
const bool always_synchronize_ = true;
// Used for reporting statistics.
// Stats:
std::vector<int64_t> id_to_clauses_exported_;
absl::flat_hash_map<int, std::string> id_to_worker_name_;
};