[CP-SAT] remove bug in cuts; better glue clauses sharing
This commit is contained in:
@@ -138,10 +138,14 @@ cc_library(
|
||||
"//ortools/util:time_limit",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/cleanup",
|
||||
"@com_google_absl//absl/container:btree",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/hash",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/random",
|
||||
"@com_google_absl//absl/random:bit_gen_ref",
|
||||
@@ -352,6 +356,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:btree",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
@@ -817,6 +822,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/functional:any_invocable",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/random:bit_gen_ref",
|
||||
"@com_google_absl//absl/random:distributions",
|
||||
|
||||
@@ -217,20 +217,22 @@ SatClause* ClauseManager::ReasonClause(int trail_index) const {
|
||||
}
|
||||
|
||||
bool ClauseManager::AddClause(absl::Span<const Literal> literals) {
|
||||
return AddClause(literals, trail_);
|
||||
return AddClause(literals, trail_, -1);
|
||||
}
|
||||
|
||||
bool ClauseManager::AddClause(absl::Span<const Literal> literals,
|
||||
Trail* trail) {
|
||||
bool ClauseManager::AddClause(absl::Span<const Literal> literals, Trail* trail,
|
||||
int lbd) {
|
||||
SatClause* clause = SatClause::Create(literals);
|
||||
clauses_.push_back(clause);
|
||||
if (add_clause_callback_ != nullptr) add_clause_callback_(lbd, literals);
|
||||
return AttachAndPropagate(clause, trail);
|
||||
}
|
||||
|
||||
SatClause* ClauseManager::AddRemovableClause(
|
||||
const std::vector<Literal>& literals, Trail* trail) {
|
||||
const std::vector<Literal>& literals, Trail* trail, int lbd) {
|
||||
SatClause* clause = SatClause::Create(literals);
|
||||
clauses_.push_back(clause);
|
||||
if (add_clause_callback_ != nullptr) add_clause_callback_(lbd, literals);
|
||||
CHECK(AttachAndPropagate(clause, trail));
|
||||
return clause;
|
||||
}
|
||||
@@ -558,7 +560,9 @@ bool BinaryImplicationGraph::AddBinaryClause(Literal a, Literal b) {
|
||||
is_dag_ = false;
|
||||
num_implications_ += 2;
|
||||
|
||||
if (enable_sharing_ && add_callback_ != nullptr) add_callback_(a, b);
|
||||
if (enable_sharing_ && add_binary_callback_ != nullptr) {
|
||||
add_binary_callback_(a, b);
|
||||
}
|
||||
|
||||
const auto& assignment = trail_->Assignment();
|
||||
if (trail_->CurrentDecisionLevel() == 0) {
|
||||
|
||||
@@ -28,6 +28,7 @@
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/functional/any_invocable.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/random/bit_gen_ref.h"
|
||||
#include "absl/types/span.h"
|
||||
@@ -179,13 +180,13 @@ class ClauseManager : public SatPropagator {
|
||||
SatClause* ReasonClause(int trail_index) const;
|
||||
|
||||
// Adds a new clause and perform initial propagation for this clause only.
|
||||
bool AddClause(absl::Span<const Literal> literals, Trail* trail);
|
||||
bool AddClause(absl::Span<const Literal> literals, Trail* trail, int lbd);
|
||||
bool AddClause(absl::Span<const Literal> literals);
|
||||
|
||||
// Same as AddClause() for a removable clause. This is only called on learned
|
||||
// conflict, so this should never have all its literal at false (CHECKED).
|
||||
SatClause* AddRemovableClause(const std::vector<Literal>& literals,
|
||||
Trail* trail);
|
||||
Trail* trail, int lbd);
|
||||
|
||||
// Lazily detach the given clause. The deletion will actually occur when
|
||||
// CleanUpWatchers() is called. The later needs to be called before any other
|
||||
@@ -325,6 +326,12 @@ class ClauseManager : public SatPropagator {
|
||||
return watchers_on_false_[false_literal];
|
||||
}
|
||||
|
||||
void SetAddClauseCallback(
|
||||
absl::AnyInvocable<void(int lbd, absl::Span<const Literal>)>
|
||||
add_clause_callback) {
|
||||
add_clause_callback_ = std::move(add_clause_callback);
|
||||
}
|
||||
|
||||
private:
|
||||
// Attaches the given clause. This eventually propagates a literal which is
|
||||
// enqueued on the trail. Returns false if a contradiction was encountered.
|
||||
@@ -379,6 +386,9 @@ class ClauseManager : public SatPropagator {
|
||||
absl::flat_hash_map<SatClause*, ClauseInfo> clauses_info_;
|
||||
|
||||
DratProofHandler* drat_proof_handler_ = nullptr;
|
||||
|
||||
absl::AnyInvocable<void(int lbd, absl::Span<const Literal>)>
|
||||
add_clause_callback_ = nullptr;
|
||||
};
|
||||
|
||||
// A binary clause. This is used by BinaryClauseManager.
|
||||
@@ -530,9 +540,8 @@ class BinaryImplicationGraph : public SatPropagator {
|
||||
// were we keep new implication and add them in batches.
|
||||
void EnableSharing(bool enable) { enable_sharing_ = enable; }
|
||||
void SetAdditionCallback(std::function<void(Literal, Literal)> f) {
|
||||
add_callback_ = f;
|
||||
add_binary_callback_ = f;
|
||||
}
|
||||
|
||||
// An at most one constraint of size n is a compact way to encode n * (n - 1)
|
||||
// implications. This must only be called at level zero.
|
||||
//
|
||||
@@ -680,8 +689,8 @@ class BinaryImplicationGraph : public SatPropagator {
|
||||
return num_redundant_implications_;
|
||||
}
|
||||
|
||||
// Returns the number of current implications. Note that a => b and not(b) =>
|
||||
// not(a) are counted separately since they appear separately in our
|
||||
// Returns the number of current implications. Note that a => b and not(b)
|
||||
// => not(a) are counted separately since they appear separately in our
|
||||
// propagation lists. The number of size 2 clauses that represent the same
|
||||
// thing is half this number.
|
||||
int64_t num_implications() const { return num_implications_; }
|
||||
@@ -933,7 +942,7 @@ class BinaryImplicationGraph : public SatPropagator {
|
||||
int num_processed_fixed_variables_ = 0;
|
||||
|
||||
bool enable_sharing_ = true;
|
||||
std::function<void(Literal, Literal)> add_callback_ = nullptr;
|
||||
std::function<void(Literal, Literal)> add_binary_callback_ = nullptr;
|
||||
};
|
||||
|
||||
extern template std::vector<Literal>
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "ortools/sat/cp_model_solver.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
@@ -1414,8 +1415,7 @@ void RegisterObjectiveBoundsImport(
|
||||
import_objective_bounds);
|
||||
}
|
||||
|
||||
// Registers a callback that will export binary clauses discovered during
|
||||
// search.
|
||||
// Registers a callback that will export good clauses discovered during search.
|
||||
void RegisterClausesExport(int id, SharedClausesManager* shared_clauses_manager,
|
||||
Model* model) {
|
||||
auto* mapping = model->GetOrCreate<CpModelMapping>();
|
||||
@@ -1433,6 +1433,30 @@ void RegisterClausesExport(int id, SharedClausesManager* shared_clauses_manager,
|
||||
};
|
||||
model->GetOrCreate<BinaryImplicationGraph>()->SetAdditionCallback(
|
||||
share_binary_clause);
|
||||
if (!model->GetOrCreate<SatParameters>()->share_glue_clauses()) {
|
||||
return;
|
||||
}
|
||||
auto* clause_stream = shared_clauses_manager->GetClauseStream(id);
|
||||
// Note that this callback takes no locks, everything operates on this
|
||||
// worker's own clause_stream, clauses are exported from there by SyncClauses
|
||||
// at level zero.
|
||||
auto share_clause = [mapping, clause_stream](
|
||||
int lbd, absl::Span<const Literal> literals) {
|
||||
if (lbd <= 0 || lbd > 2 || literals.size() <= 2 ||
|
||||
literals.size() > UniqueClauseStream::kMaxClauseSize) {
|
||||
return;
|
||||
}
|
||||
std::vector<int> clause;
|
||||
for (const Literal& lit : literals) {
|
||||
const int var =
|
||||
mapping->GetProtoVariableFromBooleanVariable(lit.Variable());
|
||||
if (var == -1) return;
|
||||
clause.push_back(lit.IsPositive() ? var : NegatedRef(var));
|
||||
}
|
||||
clause_stream->Add(std::move(clause));
|
||||
};
|
||||
model->GetOrCreate<ClauseManager>()->SetAddClauseCallback(
|
||||
std::move(share_clause));
|
||||
}
|
||||
|
||||
// Registers a callback to import new clauses stored in the
|
||||
@@ -1448,8 +1472,14 @@ int RegisterClausesLevelZeroImport(int id,
|
||||
CpModelMapping* const mapping = model->GetOrCreate<CpModelMapping>();
|
||||
auto* sat_solver = model->GetOrCreate<SatSolver>();
|
||||
auto* implications = model->GetOrCreate<BinaryImplicationGraph>();
|
||||
bool share_glue_clauses =
|
||||
model->GetOrCreate<SatParameters>()->share_glue_clauses();
|
||||
auto* clause_stream = share_glue_clauses
|
||||
? shared_clauses_manager->GetClauseStream(id)
|
||||
: nullptr;
|
||||
const auto& import_level_zero_clauses = [shared_clauses_manager, id, mapping,
|
||||
sat_solver, implications]() {
|
||||
sat_solver, implications,
|
||||
clause_stream]() {
|
||||
std::vector<std::pair<int, int>> new_binary_clauses;
|
||||
shared_clauses_manager->GetUnseenBinaryClauses(id, &new_binary_clauses);
|
||||
implications->EnableSharing(false);
|
||||
@@ -1461,6 +1491,20 @@ int RegisterClausesLevelZeroImport(int id,
|
||||
}
|
||||
}
|
||||
implications->EnableSharing(true);
|
||||
if (clause_stream == nullptr) return true;
|
||||
std::array<Literal, UniqueClauseStream::kMaxClauseSize> local_clause;
|
||||
for (const absl::Span<const int> shared_clause :
|
||||
shared_clauses_manager->SyncClauses(id)) {
|
||||
// Check this clause was not already learned by this worker.
|
||||
if (!clause_stream->BlockClause(shared_clause)) continue;
|
||||
for (int i = 0; i < shared_clause.size(); ++i) {
|
||||
local_clause[i] = mapping->Literal(shared_clause[i]);
|
||||
}
|
||||
if (!sat_solver->AddProblemClause(
|
||||
absl::MakeSpan(local_clause).subspan(0, shared_clause.size()))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
model->GetOrCreate<LevelZeroCallbackHelper>()->callbacks.push_back(
|
||||
|
||||
@@ -89,7 +89,7 @@ void CutTerm::Complement(absl::int128* rhs) {
|
||||
expr_offset = bound_diff - expr_offset;
|
||||
|
||||
// Note that this is not involutive because of floating point error. Fix?
|
||||
lp_value = ToDouble(bound_diff) - lp_value;
|
||||
lp_value = static_cast<double>(bound_diff.value()) - lp_value;
|
||||
coeff = -coeff;
|
||||
}
|
||||
|
||||
@@ -152,13 +152,13 @@ bool CutData::AppendOneTerm(IntegerVariable var, IntegerValue coeff,
|
||||
entry.expr_coeffs[0] = -IntegerValue(1);
|
||||
entry.expr_offset = ub;
|
||||
entry.coeff = -coeff;
|
||||
entry.lp_value = ToDouble(ub) - lp_value;
|
||||
entry.lp_value = static_cast<double>(ub.value()) - lp_value;
|
||||
} else {
|
||||
// C = (X - LB) + LB
|
||||
entry.expr_coeffs[0] = IntegerValue(1);
|
||||
entry.expr_offset = -lb;
|
||||
entry.coeff = coeff;
|
||||
entry.lp_value = lp_value - ToDouble(lb);
|
||||
entry.lp_value = lp_value - static_cast<double>(lb.value());
|
||||
}
|
||||
terms.push_back(entry);
|
||||
return true;
|
||||
@@ -660,7 +660,7 @@ double IntegerRoundingCutHelper::GetScaledViolation(
|
||||
// Even before we finish the adjust, we can have a lower bound on the
|
||||
// activily loss using this divisor, and so we can abort early. This is
|
||||
// similar to what is done below.
|
||||
double max_violation = ToDouble(initial_rhs_remainder);
|
||||
double max_violation = static_cast<double>(initial_rhs_remainder.value());
|
||||
for (int i = 0; i < cut.num_relevant_entries; ++i) {
|
||||
const CutTerm& entry = cut.terms[i];
|
||||
const IntegerValue remainder = PositiveRemainder(entry.coeff, divisor);
|
||||
@@ -668,7 +668,8 @@ double IntegerRoundingCutHelper::GetScaledViolation(
|
||||
if (remainder <= initial_rhs_remainder) {
|
||||
// We do not know exactly f() yet, but it will always round to the
|
||||
// floor of the division by divisor in this case.
|
||||
max_violation -= ToDouble(remainder) * entry.lp_value;
|
||||
max_violation -=
|
||||
static_cast<double>(remainder.value()) * entry.lp_value;
|
||||
if (max_violation <= 1e-3) return 0.0;
|
||||
continue;
|
||||
}
|
||||
@@ -1049,16 +1050,16 @@ struct LargeCoeffFirst {
|
||||
|
||||
struct SmallContribFirst {
|
||||
bool operator()(const CutTerm& a, const CutTerm& b) const {
|
||||
const double contrib_a = a.lp_value * AsDouble(a.coeff);
|
||||
const double contrib_b = b.lp_value * AsDouble(b.coeff);
|
||||
const double contrib_a = a.lp_value * static_cast<double>(a.coeff.value());
|
||||
const double contrib_b = b.lp_value * static_cast<double>(b.coeff.value());
|
||||
return contrib_a < contrib_b;
|
||||
}
|
||||
};
|
||||
|
||||
struct LargeContribFirst {
|
||||
bool operator()(const CutTerm& a, const CutTerm& b) const {
|
||||
const double contrib_a = a.lp_value * AsDouble(a.coeff);
|
||||
const double contrib_b = b.lp_value * AsDouble(b.coeff);
|
||||
const double contrib_a = a.lp_value * static_cast<double>(a.coeff.value());
|
||||
const double contrib_b = b.lp_value * static_cast<double>(b.coeff.value());
|
||||
return contrib_a > contrib_b;
|
||||
}
|
||||
};
|
||||
@@ -1069,15 +1070,19 @@ struct LargeContribFirst {
|
||||
// lead to the same formula as for Booleans.
|
||||
struct KnapsackAdd {
|
||||
bool operator()(const CutTerm& a, const CutTerm& b) const {
|
||||
const double contrib_a = a.LpDistToMaxValue() / AsDouble(a.coeff);
|
||||
const double contrib_b = b.LpDistToMaxValue() / AsDouble(b.coeff);
|
||||
const double contrib_a =
|
||||
a.LpDistToMaxValue() / static_cast<double>(a.coeff.value());
|
||||
const double contrib_b =
|
||||
b.LpDistToMaxValue() / static_cast<double>(b.coeff.value());
|
||||
return contrib_a < contrib_b;
|
||||
}
|
||||
};
|
||||
struct KnapsackRemove {
|
||||
bool operator()(const CutTerm& a, const CutTerm& b) const {
|
||||
const double contrib_a = a.LpDistToMaxValue() / AsDouble(a.coeff);
|
||||
const double contrib_b = b.LpDistToMaxValue() / AsDouble(b.coeff);
|
||||
const double contrib_a =
|
||||
a.LpDistToMaxValue() / static_cast<double>(a.coeff.value());
|
||||
const double contrib_b =
|
||||
b.LpDistToMaxValue() / static_cast<double>(b.coeff.value());
|
||||
return contrib_a > contrib_b;
|
||||
}
|
||||
};
|
||||
@@ -1352,14 +1357,6 @@ bool CoverCutHelper::TrySingleNodeFlow(const CutData& input_ct,
|
||||
ImpliedBoundsProcessor* ib_processor) {
|
||||
InitializeCut(input_ct);
|
||||
|
||||
bool has_large_coeff = false;
|
||||
for (const CutTerm& term : cut_.terms) {
|
||||
if (IntTypeAbs(term.coeff) > 1'000'000) {
|
||||
has_large_coeff = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(user): Change the heuristic to depends on the lp_value of the implied
|
||||
// bounds. This way we can exactly match what happen in FlowCoverCutHelper and
|
||||
// remove the code there.
|
||||
@@ -1393,6 +1390,14 @@ bool CoverCutHelper::TrySingleNodeFlow(const CutData& input_ct,
|
||||
return false;
|
||||
}
|
||||
|
||||
bool has_large_coeff = false;
|
||||
for (const CutTerm& term : cut_.terms) {
|
||||
if (IntTypeAbs(term.coeff) > 1'000'000) {
|
||||
has_large_coeff = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(user): Shouldn't we just use rounding f() with maximum coeff to allows
|
||||
// lift of all other terms? but then except for the heuristic the cut is
|
||||
// really similar to the cover cut.
|
||||
|
||||
@@ -237,8 +237,9 @@ class ImpliedBoundsProcessor {
|
||||
IntegerVariable bool_var = kNoIntegerVariable;
|
||||
|
||||
double SlackLpValue(IntegerValue lb) const {
|
||||
const double bool_term = ToDouble(implied_bound - lb) * bool_lp_value;
|
||||
return var_lp_value - ToDouble(lb) - bool_term;
|
||||
const double bool_term =
|
||||
static_cast<double>((implied_bound - lb).value()) * bool_lp_value;
|
||||
return var_lp_value - static_cast<double>(lb.value()) - bool_term;
|
||||
}
|
||||
|
||||
std::string DebugString() const {
|
||||
|
||||
@@ -1849,9 +1849,15 @@ bool LinearProgrammingConstraint::ScalingCanOverflow(
|
||||
const std::vector<std::pair<glop::RowIndex, double>>& multipliers,
|
||||
int64_t overflow_cap) const {
|
||||
int64_t bound = 0;
|
||||
const int64_t factor = int64_t{1} << power;
|
||||
const double factor_as_double = static_cast<double>(factor);
|
||||
if (take_objective_into_account) {
|
||||
bound = CapAdd(bound, CapProd(factor, objective_infinity_norm_.value()));
|
||||
if (bound >= overflow_cap) return true;
|
||||
}
|
||||
for (const auto [row, double_coeff] : multipliers) {
|
||||
const double magnitude =
|
||||
std::abs(std::round(std::ldexp(double_coeff, power)));
|
||||
std::abs(std::round(double_coeff * factor_as_double));
|
||||
if (std::isnan(magnitude)) return true;
|
||||
if (magnitude >= static_cast<double>(std::numeric_limits<int64_t>::max())) {
|
||||
return true;
|
||||
@@ -1860,11 +1866,6 @@ bool LinearProgrammingConstraint::ScalingCanOverflow(
|
||||
infinity_norms_[row].value()));
|
||||
if (bound >= overflow_cap) return true;
|
||||
}
|
||||
if (take_objective_into_account) {
|
||||
bound = CapAdd(
|
||||
bound, CapProd(int64_t{1} << power, objective_infinity_norm_.value()));
|
||||
if (bound >= overflow_cap) return true;
|
||||
}
|
||||
return bound >= overflow_cap;
|
||||
}
|
||||
|
||||
@@ -1936,8 +1937,9 @@ LinearProgrammingConstraint::ScaleLpMultiplier(
|
||||
// Scale the multipliers by *scaling.
|
||||
// Note that we use the exact same formula as in ScalingCanOverflow().
|
||||
int64_t gcd = scaling->value();
|
||||
const double scaling_as_double = static_cast<double>(scaling->value());
|
||||
for (const auto [row, double_coeff] : tmp_cp_multipliers_) {
|
||||
const IntegerValue coeff(std::round(std::ldexp(double_coeff, power)));
|
||||
const IntegerValue coeff(std::round(double_coeff * scaling_as_double));
|
||||
if (coeff != 0) {
|
||||
gcd = std::gcd(gcd, std::abs(coeff.value()));
|
||||
integer_multipliers.push_back({row, coeff});
|
||||
|
||||
@@ -23,7 +23,7 @@ option csharp_namespace = "Google.OrTools.Sat";
|
||||
// Contains the definitions for all the sat algorithm parameters and their
|
||||
// default values.
|
||||
//
|
||||
// NEXT TAG: 285
|
||||
// NEXT TAG: 286
|
||||
message SatParameters {
|
||||
// In some context, like in a portfolio of search, it makes sense to name a
|
||||
// given parameters set for logging purpose.
|
||||
@@ -647,6 +647,10 @@ message SatParameters {
|
||||
// Allows sharing of new learned binary clause between workers.
|
||||
optional bool share_binary_clauses = 203 [default = true];
|
||||
|
||||
// Allows sharing of short glue clauses between workers.
|
||||
// Implicitly disabled if share_binary_clauses is false.
|
||||
optional bool share_glue_clauses = 285 [default = false];
|
||||
|
||||
// ==========================================================================
|
||||
// Debugging parameters
|
||||
// ==========================================================================
|
||||
|
||||
@@ -273,7 +273,7 @@ bool SatSolver::AddProblemClauseInternal(absl::Span<const Literal> literals) {
|
||||
AddBinaryClauseInternal(literals[0], literals[1]);
|
||||
}
|
||||
} else {
|
||||
if (!clauses_propagator_->AddClause(literals, trail_)) {
|
||||
if (!clauses_propagator_->AddClause(literals, trail_, /*lbd=*/-1)) {
|
||||
return SetModelUnsat();
|
||||
}
|
||||
}
|
||||
@@ -432,14 +432,14 @@ int SatSolver::AddLearnedClauseAndEnqueueUnitPropagation(
|
||||
--num_learned_clause_before_cleanup_;
|
||||
|
||||
SatClause* clause =
|
||||
clauses_propagator_->AddRemovableClause(literals, trail_);
|
||||
clauses_propagator_->AddRemovableClause(literals, trail_, lbd);
|
||||
|
||||
// BumpClauseActivity() must be called after clauses_info_[clause] has
|
||||
// been created or it will have no effect.
|
||||
(*clauses_propagator_->mutable_clauses_info())[clause].lbd = lbd;
|
||||
BumpClauseActivity(clause);
|
||||
} else {
|
||||
CHECK(clauses_propagator_->AddClause(literals, trail_));
|
||||
CHECK(clauses_propagator_->AddClause(literals, trail_, lbd));
|
||||
}
|
||||
return lbd;
|
||||
}
|
||||
|
||||
@@ -13,11 +13,15 @@
|
||||
|
||||
#include "ortools/sat/synchronization.h"
|
||||
|
||||
#include <sys/types.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cctype>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <ctime>
|
||||
#include <deque>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
@@ -25,6 +29,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/hash/hash.h"
|
||||
#include "ortools/base/logging.h"
|
||||
#include "ortools/base/timer.h"
|
||||
#if !defined(__PORTABLE_PLATFORM__)
|
||||
@@ -1034,6 +1039,88 @@ int SharedBoundsManager::NumBoundsExported(const std::string& worker_name) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
bool UniqueClauseStream::Add(absl::Span<const int> clause) {
|
||||
const int index = clause.size() - 3;
|
||||
if (BlockClause(clause)) {
|
||||
num_buffered_literals_ += clause.size();
|
||||
clauses_by_size_[index].insert(clauses_by_size_[index].end(),
|
||||
clause.begin(), clause.end());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool UniqueClauseStream::BlockClause(absl::Span<const int> clause) {
|
||||
if (clause.size() > kMaxClauseSize) return false;
|
||||
if (clause.size() <= 2) return false;
|
||||
bool is_new = false;
|
||||
// We set 4 bits all in the same page to guarantee at most 1 page fault per
|
||||
// insertion.
|
||||
const size_t page_offset =
|
||||
(HashClause(clause, -1) % kBloomFilterPages) * kBitsPerPage;
|
||||
// We could use more bits per hash if this is ever too slow.
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
const size_t bit = page_offset + HashClause(clause, i) % kBitsPerPage;
|
||||
is_new = is_new || !filter_.test(bit);
|
||||
filter_.set(bit, true);
|
||||
}
|
||||
return is_new;
|
||||
}
|
||||
|
||||
CompactVectorVector<int> UniqueClauseStream::NextBatch(int num_literals) {
|
||||
CompactVectorVector<int> buffer;
|
||||
buffer.reserve(num_literals / 3, num_literals);
|
||||
int total_literals = 0;
|
||||
for (int size = 3; size < kMaxClauseSize; ++size) {
|
||||
const int size_index = size - 3;
|
||||
while (total_literals + size <= num_literals &&
|
||||
!clauses_by_size_[size_index].empty()) {
|
||||
buffer.Add(NextClause(size));
|
||||
PopClause(size);
|
||||
total_literals += size;
|
||||
}
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
int UniqueClauseStream::MoveClausesTo(UniqueClauseStream& upstream,
|
||||
int max_literals) {
|
||||
int num_exported_clauses = 0;
|
||||
for (int size = 3; size < kMaxClauseSize; ++size) {
|
||||
const int size_index = size - 3;
|
||||
while (!clauses_by_size_[size_index].empty() && max_literals >= size) {
|
||||
max_literals -= size;
|
||||
if (upstream.Add(NextClause(size))) {
|
||||
++num_exported_clauses;
|
||||
}
|
||||
PopClause(size);
|
||||
}
|
||||
}
|
||||
return num_exported_clauses;
|
||||
}
|
||||
|
||||
size_t UniqueClauseStream::HashClause(absl::Span<const int> clause,
|
||||
size_t hash_seed) {
|
||||
size_t hash = absl::HashOf(hash_seed, clause.size());
|
||||
for (int i = 0; i < clause.size(); ++i) {
|
||||
hash ^= absl::HashOf(clause[i], hash_seed);
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
|
||||
absl::Span<const int> UniqueClauseStream::NextClause(int size) const {
|
||||
const int index = size - 3;
|
||||
return absl::MakeConstSpan(clauses_by_size_[index])
|
||||
.subspan(clauses_by_size_[index].size() - size, size);
|
||||
}
|
||||
|
||||
void UniqueClauseStream::PopClause(int size) {
|
||||
const int index = size - 3;
|
||||
num_buffered_literals_ -= size;
|
||||
clauses_by_size_[index].erase(clauses_by_size_[index].end() - size,
|
||||
clauses_by_size_[index].end());
|
||||
}
|
||||
|
||||
SharedClausesManager::SharedClausesManager(bool always_synchronize)
|
||||
: always_synchronize_(always_synchronize) {}
|
||||
|
||||
@@ -1041,7 +1128,9 @@ int SharedClausesManager::RegisterNewId() {
|
||||
absl::MutexLock mutex_lock(&mutex_);
|
||||
const int id = id_to_last_processed_binary_clause_.size();
|
||||
id_to_last_processed_binary_clause_.resize(id + 1, 0);
|
||||
id_to_last_processed_batch_.resize(id + 1, 0);
|
||||
id_to_clauses_exported_.resize(id + 1, 0);
|
||||
id_to_clause_stream_.emplace_back();
|
||||
return id;
|
||||
}
|
||||
|
||||
@@ -1059,7 +1148,7 @@ void SharedClausesManager::AddBinaryClause(int id, int lit1, int lit2) {
|
||||
const auto [unused_it, inserted] = added_binary_clauses_set_.insert(p);
|
||||
if (inserted) {
|
||||
added_binary_clauses_.push_back(p);
|
||||
if (always_synchronize_) ++last_visible_clause_;
|
||||
if (always_synchronize_) ++last_visible_binary_clause_;
|
||||
id_to_clauses_exported_[id]++;
|
||||
|
||||
// Small optim. If the worker is already up to date with clauses to import,
|
||||
@@ -1071,16 +1160,35 @@ void SharedClausesManager::AddBinaryClause(int id, int lit1, int lit2) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<absl::Span<const int>> SharedClausesManager::SyncClauses(int id) {
|
||||
absl::MutexLock mutex_lock(&mutex_);
|
||||
UniqueClauseStream& worker_clauses = id_to_clause_stream_[id];
|
||||
id_to_clauses_exported_[id] +=
|
||||
worker_clauses.MoveClausesTo(all_clauses_, kLiteralsPerBatch);
|
||||
std::vector<absl::Span<const int>> result;
|
||||
for (int i = id_to_last_processed_batch_[id]; i < batches_.size(); ++i) {
|
||||
for (int j = 0; j < batches_[i].size(); ++j) {
|
||||
result.push_back(batches_[i][j]);
|
||||
}
|
||||
}
|
||||
// TODO: tobyodavies - We should delete old clauses that have been consumed by
|
||||
// all workers. This will be subtle as the returned spans must remain valid
|
||||
// until the *next* call to SyncClauses() after they are returned.
|
||||
id_to_last_processed_batch_[id] = batches_.size();
|
||||
return result;
|
||||
}
|
||||
|
||||
void SharedClausesManager::GetUnseenBinaryClauses(
|
||||
int id, std::vector<std::pair<int, int>>* new_clauses) {
|
||||
new_clauses->clear();
|
||||
absl::MutexLock mutex_lock(&mutex_);
|
||||
const int last_binary_clause_seen = id_to_last_processed_binary_clause_[id];
|
||||
if (last_binary_clause_seen >= last_visible_clause_) return;
|
||||
if (last_binary_clause_seen >= last_visible_binary_clause_) return;
|
||||
|
||||
new_clauses->assign(added_binary_clauses_.begin() + last_binary_clause_seen,
|
||||
added_binary_clauses_.begin() + last_visible_clause_);
|
||||
id_to_last_processed_binary_clause_[id] = last_visible_clause_;
|
||||
new_clauses->assign(
|
||||
added_binary_clauses_.begin() + last_binary_clause_seen,
|
||||
added_binary_clauses_.begin() + last_visible_binary_clause_);
|
||||
id_to_last_processed_binary_clause_[id] = last_visible_binary_clause_;
|
||||
}
|
||||
|
||||
void SharedClausesManager::LogStatistics(SolverLogger* logger) {
|
||||
@@ -1102,8 +1210,11 @@ void SharedClausesManager::LogStatistics(SolverLogger* logger) {
|
||||
|
||||
void SharedClausesManager::Synchronize() {
|
||||
absl::MutexLock mutex_lock(&mutex_);
|
||||
last_visible_clause_ = added_binary_clauses_.size();
|
||||
// TODO(user): We could cleanup added_binary_clauses_ periodically.
|
||||
last_visible_binary_clause_ = added_binary_clauses_.size();
|
||||
if (all_clauses_.num_buffered_literals() >= kLiteralsPerBatch) {
|
||||
batches_.push_back(all_clauses_.NextBatch(kLiteralsPerBatch));
|
||||
}
|
||||
// TODO(user): We could cleanup clauses that have been consumed.
|
||||
}
|
||||
|
||||
void SharedStatistics::AddStats(
|
||||
|
||||
@@ -14,7 +14,10 @@
|
||||
#ifndef OR_TOOLS_SAT_SYNCHRONIZATION_H_
|
||||
#define OR_TOOLS_SAT_SYNCHRONIZATION_H_
|
||||
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <bitset>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <deque>
|
||||
#include <functional>
|
||||
@@ -31,7 +34,6 @@
|
||||
#include "absl/random/random.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "ortools/base/logging.h"
|
||||
#include "ortools/base/stl_util.h"
|
||||
@@ -570,8 +572,60 @@ class SharedBoundsManager {
|
||||
int export_counter_ = 0;
|
||||
};
|
||||
|
||||
// This class holds all the binary clauses that were found and shared by the
|
||||
// workers.
|
||||
// Emit a stream of clauses in batches without duplicates.
|
||||
//
|
||||
// This class is thread-compatible, the idea is to have one per worker plus a
|
||||
// global one to deduplicate between workers.
|
||||
//
|
||||
// Note that this uses literal as encoded in a cp_model.proto. Thus, the
|
||||
// literals can be negative numbers.
|
||||
class UniqueClauseStream {
|
||||
public:
|
||||
static constexpr int kMaxClauseSize = 8;
|
||||
// The bloom filter is 1MiB.
|
||||
static constexpr int kBloomFilterBits = 1024 * 1024 * 8;
|
||||
static constexpr int kBitsPerPage = 4096 * 8;
|
||||
static constexpr int kBloomFilterPages = kBloomFilterBits / kBitsPerPage;
|
||||
|
||||
UniqueClauseStream() = default;
|
||||
// Move only - this is an expensive class to copy.
|
||||
UniqueClauseStream(const UniqueClauseStream&) = delete;
|
||||
UniqueClauseStream(UniqueClauseStream&&) = default;
|
||||
|
||||
// Adds the clause to a future batch and returns true if the clause was new.
|
||||
// Otherwise returns false if the clause was previously added or blocked.
|
||||
bool Add(absl::Span<const int> clause);
|
||||
|
||||
// Inserts the clause into the bloom filter without adding it to the buffer,
|
||||
// Returns true if the clause is new.
|
||||
bool BlockClause(absl::Span<const int> clause);
|
||||
|
||||
// Returns the sum of sizes of all buffered clauses.
|
||||
int num_buffered_literals() const { return num_buffered_literals_; }
|
||||
|
||||
// Returns a set of clauses totalling up to num_literals and removes them from
|
||||
// the internal buffer.
|
||||
CompactVectorVector<int> NextBatch(int num_literals);
|
||||
|
||||
// Adds all clauses from this stream to upstream and removes them from the
|
||||
// internal buffer.
|
||||
int MoveClausesTo(UniqueClauseStream& upstream, int max_literals);
|
||||
|
||||
// Computes a hash that is independent of the order of literals in the clause.
|
||||
static size_t HashClause(absl::Span<const int> clause, size_t hash_seed = 0);
|
||||
|
||||
private:
|
||||
absl::Span<const int> NextClause(int size) const;
|
||||
void PopClause(int size);
|
||||
int NumClauses(int size) const;
|
||||
|
||||
std::bitset<kBloomFilterBits> filter_;
|
||||
int num_buffered_literals_ = 0;
|
||||
std::array<std::vector<int>, kMaxClauseSize - 2> clauses_by_size_;
|
||||
};
|
||||
|
||||
// This class holds clauses found and shared by workers.
|
||||
// It is exact for binary clauses, but approximate for longer ones.
|
||||
//
|
||||
// It is thread-safe.
|
||||
//
|
||||
@@ -582,6 +636,10 @@ class SharedClausesManager {
|
||||
explicit SharedClausesManager(bool always_synchronize);
|
||||
void AddBinaryClause(int id, int lit1, int lit2);
|
||||
|
||||
// Imports all clauses from the given id into the shared pool.
|
||||
// Returns new clauses.
|
||||
std::vector<absl::Span<const int>> SyncClauses(int id);
|
||||
|
||||
// Fills new_clauses with
|
||||
// {{lit1 of clause1, lit2 of clause1},
|
||||
// {lit1 of clause2, lit2 of clause2},
|
||||
@@ -593,26 +651,46 @@ class SharedClausesManager {
|
||||
int RegisterNewId();
|
||||
void SetWorkerNameForId(int id, const std::string& worker_name);
|
||||
|
||||
// A worker can add or remove clauses from its own clause set.
|
||||
// It must not do so concurrently with any call to SyncClauses.
|
||||
// Retains ownership of the returned ClauseFilter.
|
||||
UniqueClauseStream* GetClauseStream(int id) {
|
||||
absl::ReaderMutexLock mutex_lock(&mutex_);
|
||||
return &id_to_clause_stream_[id];
|
||||
}
|
||||
|
||||
// Search statistics.
|
||||
void LogStatistics(SolverLogger* logger);
|
||||
|
||||
// Unlocks waiting binary clauses for workers if always_synchronize is false.
|
||||
// Periodically starts a new sharing round, making glue clauses visible.
|
||||
void Synchronize();
|
||||
|
||||
private:
|
||||
// 1024 literals is 4KiB, i.e. 1 page.
|
||||
static constexpr int kLiteralsPerBatch = 1024;
|
||||
|
||||
absl::Mutex mutex_;
|
||||
|
||||
// Cache to avoid adding the same clause twice.
|
||||
// Binary clauses:
|
||||
// Cache to avoid adding the same binary clause twice.
|
||||
absl::flat_hash_set<std::pair<int, int>> added_binary_clauses_set_
|
||||
ABSL_GUARDED_BY(mutex_);
|
||||
std::vector<std::pair<int, int>> added_binary_clauses_
|
||||
ABSL_GUARDED_BY(mutex_);
|
||||
std::vector<int> id_to_last_processed_binary_clause_ ABSL_GUARDED_BY(mutex_);
|
||||
std::vector<int64_t> id_to_clauses_exported_;
|
||||
int last_visible_clause_ ABSL_GUARDED_BY(mutex_) = 0;
|
||||
int last_visible_binary_clause_ ABSL_GUARDED_BY(mutex_) = 0;
|
||||
|
||||
// Longer clauses:
|
||||
UniqueClauseStream all_clauses_ ABSL_GUARDED_BY(mutex_);
|
||||
std::vector<int> id_to_last_processed_batch_ ABSL_GUARDED_BY(mutex_);
|
||||
std::deque<CompactVectorVector<int>> batches_ ABSL_GUARDED_BY(mutex_);
|
||||
std::deque<UniqueClauseStream> id_to_clause_stream_ ABSL_GUARDED_BY(mutex_);
|
||||
|
||||
const bool always_synchronize_ = true;
|
||||
|
||||
// Used for reporting statistics.
|
||||
// Stats:
|
||||
std::vector<int64_t> id_to_clauses_exported_;
|
||||
absl::flat_hash_map<int, std::string> id_to_worker_name_;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user