[CP-SAT] print a solution after a SIGTERM; improve precedences

This commit is contained in:
Laurent Perron
2025-06-20 15:11:37 +02:00
parent 7a4222652e
commit c14e54cf82
19 changed files with 934 additions and 490 deletions

View File

@@ -13,6 +13,7 @@
#include "ortools/sat/2d_distances_propagator.h"
#include <algorithm>
#include <cstdint>
#include <string>
#include <utility>
@@ -69,10 +70,12 @@ void Precedences2DPropagator::UpdateVarLookups() {
void Precedences2DPropagator::CollectNewPairsOfBoxesWithNonTrivialDistance() {
const absl::Span<const LinearExpression2> exprs =
non_trivial_bounds_->GetLinear2WithPotentialNonTrivalBounds();
if (exprs.size() != num_known_linear2_) {
VLOG(2) << "CollectPairsOfBoxesWithNonTrivialDistance called, num_exprs: "
<< exprs.size();
if (exprs.size() == num_known_linear2_) {
return;
}
VLOG(2) << "CollectPairsOfBoxesWithNonTrivialDistance called, num_exprs: "
<< exprs.size();
const int previous_num_pairs = non_trivial_pairs_.size();
for (; num_known_linear2_ < exprs.size(); ++num_known_linear2_) {
const LinearExpression2& positive_expr = exprs[num_known_linear2_];
LinearExpression2 negated_expr = positive_expr;
@@ -111,7 +114,31 @@ void Precedences2DPropagator::CollectNewPairsOfBoxesWithNonTrivialDistance() {
}
}
gtl::STLSortAndRemoveDuplicates(&non_trivial_pairs_);
// Sort the new pairs.
std::sort(non_trivial_pairs_.begin() + previous_num_pairs,
non_trivial_pairs_.end());
// Remove duplicates from new pairs.
non_trivial_pairs_.erase(
std::unique(non_trivial_pairs_.begin() + previous_num_pairs,
non_trivial_pairs_.end()),
non_trivial_pairs_.end());
// Merge with the old pairs keeping sorted.
std::inplace_merge(non_trivial_pairs_.begin(),
non_trivial_pairs_.begin() + previous_num_pairs,
non_trivial_pairs_.end());
// Remove newly-added duplicates.
non_trivial_pairs_.erase(
std::unique(non_trivial_pairs_.begin(), non_trivial_pairs_.end()),
non_trivial_pairs_.end());
// Result should be sorted and without duplicates.
DCHECK(std::is_sorted(non_trivial_pairs_.begin(), non_trivial_pairs_.end()));
DCHECK(std::adjacent_find(non_trivial_pairs_.begin(),
non_trivial_pairs_.end()) ==
non_trivial_pairs_.end());
}
bool Precedences2DPropagator::Propagate() {

View File

@@ -815,7 +815,6 @@ cc_library(
deps = [
":cp_model_cc_proto",
":cp_model_utils",
":integer",
":integer_base",
":linear_constraint",
":model",
@@ -2056,6 +2055,7 @@ cc_library(
deps = [
":clause",
":cp_constraints",
":cp_model_mapping",
":integer",
":integer_base",
":model",
@@ -4023,13 +4023,17 @@ cc_binary(
"//ortools/base:path",
"//ortools/util:file_util",
"//ortools/util:logging",
"//ortools/util:sigint",
"//ortools/util:sorted_interval_list",
"@abseil-cpp//absl/base:core_headers",
"@abseil-cpp//absl/flags:flag",
"@abseil-cpp//absl/log",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/log:flags",
"@abseil-cpp//absl/strings",
"@abseil-cpp//absl/strings:str_format",
"@abseil-cpp//absl/synchronization",
"@abseil-cpp//absl/types:span",
"@protobuf",
],
)

View File

@@ -24,7 +24,6 @@
#include "ortools/base/strong_vector.h"
#include "ortools/sat/cp_model.pb.h"
#include "ortools/sat/cp_model_utils.h"
#include "ortools/sat/integer.h"
#include "ortools/sat/integer_base.h"
#include "ortools/sat/linear_constraint.h"
#include "ortools/sat/model.h"

View File

@@ -793,40 +793,6 @@ void LogSubsolverNames(absl::Span<const std::unique_ptr<SubSolver>> subsolvers,
SOLVER_LOG(logger, "");
}
void LogFinalStatistics(SharedClasses* shared) {
if (!shared->logger->LoggingIsEnabled()) return;
shared->logger->FlushPendingThrottledLogs(/*ignore_rates=*/true);
SOLVER_LOG(shared->logger, "");
shared->stat_tables->Display(shared->logger);
shared->response->DisplayImprovementStatistics();
std::vector<std::vector<std::string>> table;
table.push_back({"Solution repositories", "Added", "Queried", "Synchro"});
shared->response->SolutionPool().AddTableStats(&table);
table.push_back(shared->ls_hints->TableLineStats());
if (shared->lp_solutions != nullptr) {
table.push_back(shared->lp_solutions->TableLineStats());
}
if (shared->incomplete_solutions != nullptr) {
table.push_back(shared->incomplete_solutions->TableLineStats());
}
SOLVER_LOG(shared->logger, FormatTable(table));
if (shared->bounds) {
shared->bounds->LogStatistics(shared->logger);
}
if (shared->clauses) {
shared->clauses->LogStatistics(shared->logger);
}
// Extra logging if needed. Note that these are mainly activated on
// --vmodule *some_file*=1 and are here for development.
shared->stats->Log(shared->logger);
}
void LaunchSubsolvers(const SatParameters& params, SharedClasses* shared,
std::vector<std::unique_ptr<SubSolver>>& subsolvers,
absl::Span<const std::string> ignored) {
@@ -868,7 +834,7 @@ void LaunchSubsolvers(const SatParameters& params, SharedClasses* shared,
for (int i = 0; i < subsolvers.size(); ++i) {
subsolvers[i].reset();
}
LogFinalStatistics(shared);
shared->LogFinalStatistics();
}
bool VarIsFixed(const CpModelProto& model_proto, int i) {
@@ -1124,13 +1090,18 @@ class FullProblemSolver : public SubSolver {
shared_->model_proto, shared_->bounds.get(), &local_model_);
}
if (shared_->linear2_bounds != nullptr) {
RegisterLinear2BoundsImport(shared_->linear2_bounds.get(),
&local_model_);
}
// Note that this is done after the loading, so we will never export
// problem clauses.
if (shared_->clauses != nullptr) {
const int id = shared_->clauses->RegisterNewId(
local_model_.Name(),
/*may_terminate_early=*/stop_at_first_solution_ &&
local_model_.GetOrCreate<CpModelProto>()->has_objective());
shared_->clauses->SetWorkerNameForId(id, local_model_.Name());
local_model_.GetOrCreate<CpModelProto>()->has_objective());
RegisterClausesLevelZeroImport(id, shared_->clauses.get(),
&local_model_);

View File

@@ -847,6 +847,59 @@ void RegisterVariableBoundsLevelZeroImport(
import_level_zero_bounds);
}
void RegisterLinear2BoundsImport(SharedLinear2Bounds* shared_linear2_bounds,
Model* model) {
CHECK(shared_linear2_bounds != nullptr);
auto* cp_model_mapping = model->GetOrCreate<CpModelMapping>();
auto* root_linear2 = model->GetOrCreate<RootLevelLinear2Bounds>();
auto* sat_solver = model->GetOrCreate<SatSolver>();
const int import_id =
shared_linear2_bounds->RegisterNewImportId(model->Name());
const auto& import_function = [import_id, shared_linear2_bounds, root_linear2,
cp_model_mapping, sat_solver, model]() {
const auto new_bounds =
shared_linear2_bounds->NewlyUpdatedBounds(import_id);
int num_imported = 0;
for (const auto& [proto_expr, bounds] : new_bounds) {
// Lets create the corresponding LinearExpression2.
LinearExpression2 expr;
for (const int i : {0, 1}) {
expr.vars[i] = cp_model_mapping->Integer(proto_expr.vars[i]);
expr.coeffs[i] = proto_expr.coeffs[i];
}
const auto [lb, ub] = bounds;
const auto [lb_added, ub_added] = root_linear2->Add(expr, lb, ub);
if (!lb_added && !ub_added) continue;
++num_imported;
// TODO(user): Is it a good idea to add the linear constraint ?
// We might have many redundant linear2 relations that don't need
// propagation when we have chains of precedences. The root_linear2 should
// be up-to-date with transitive closure to avoid adding such relations
// (recompute it at level zero before this?).
//
// TODO(user): use IntegerValure directly in
// AddWeightedSumGreaterOrEqual() or use a lower-level API.
const std::vector<int64_t> coeffs = {expr.coeffs[0].value(),
expr.coeffs[1].value()};
if (lb_added) {
AddWeightedSumGreaterOrEqual({}, absl::MakeSpan(expr.vars, 2), coeffs,
lb.value(), model);
if (sat_solver->ModelIsUnsat()) return false;
}
if (ub_added) {
AddWeightedSumLowerOrEqual({}, absl::MakeSpan(expr.vars, 2), coeffs,
ub.value(), model);
if (sat_solver->ModelIsUnsat()) return false;
}
}
shared_linear2_bounds->NotifyNumImported(import_id, num_imported);
return true;
};
model->GetOrCreate<LevelZeroCallbackHelper>()->callbacks.push_back(
import_function);
}
// Registers a callback that will report improving objective best bound.
// It will be called each time new objective bound are propagated at level zero.
void RegisterObjectiveBestBoundExport(
@@ -2086,6 +2139,10 @@ SharedClasses::SharedClasses(const CpModelProto* proto, Model* global_model)
bounds->LoadDebugSolution(response->DebugSolution());
}
if (params.share_linear2_bounds()) {
linear2_bounds = std::make_unique<SharedLinear2Bounds>();
}
// Create extra shared classes if needed. Note that while these parameters
// are true by default, we disable them if we don't have enough workers for
// them in AdaptGlobalParameters().
@@ -2120,7 +2177,7 @@ void SharedClasses::RegisterSharedClassesInLocalModel(Model* local_model) {
local_model->Register<SharedStatTables>(stat_tables);
// TODO(user): Use parameters and not the presence/absence of these class
// to decide when to use them.
// to decide when to use them? this is not clear.
if (lp_solutions != nullptr) {
local_model->Register<SharedLPSolutionRepository>(lp_solutions.get());
}
@@ -2134,6 +2191,9 @@ void SharedClasses::RegisterSharedClassesInLocalModel(Model* local_model) {
if (clauses != nullptr) {
local_model->Register<SharedClausesManager>(clauses.get());
}
if (linear2_bounds != nullptr) {
local_model->Register<SharedLinear2Bounds>(linear2_bounds.get());
}
}
bool SharedClasses::SearchIsDone() {
@@ -2146,5 +2206,37 @@ bool SharedClasses::SearchIsDone() {
return false;
}
void SharedClasses::LogFinalStatistics() {
if (!logger->LoggingIsEnabled()) return;
logger->FlushPendingThrottledLogs(/*ignore_rates=*/true);
SOLVER_LOG(logger, "");
stat_tables->Display(logger);
response->DisplayImprovementStatistics();
std::vector<std::vector<std::string>> table;
table.push_back({"Solution repositories", "Added", "Queried", "Synchro"});
response->SolutionPool().AddTableStats(&table);
table.push_back(ls_hints->TableLineStats());
if (lp_solutions != nullptr) {
table.push_back(lp_solutions->TableLineStats());
}
if (incomplete_solutions != nullptr) {
table.push_back(incomplete_solutions->TableLineStats());
}
SOLVER_LOG(logger, FormatTable(table));
// TODO(user): we can combine the "bounds table" into one for shorter logs.
if (bounds != nullptr) bounds->LogStatistics(logger);
if (linear2_bounds != nullptr) linear2_bounds->LogStatistics(logger);
if (clauses != nullptr) clauses->LogStatistics(logger);
// Extra logging if needed. Note that these are mainly activated on
// --vmodule *some_file*=1 and are here for development.
stats->Log(logger);
}
} // namespace sat
} // namespace operations_research

View File

@@ -60,12 +60,15 @@ struct SharedClasses {
std::unique_ptr<SharedLPSolutionRepository> lp_solutions;
std::unique_ptr<SharedIncompleteSolutionManager> incomplete_solutions;
std::unique_ptr<SharedClausesManager> clauses;
std::unique_ptr<SharedLinear2Bounds> linear2_bounds;
// call local_model->Register() on most of the class here, this allow to
// more easily depends on one of the shared class deep within the solver.
void RegisterSharedClassesInLocalModel(Model* local_model);
bool SearchIsDone();
void LogFinalStatistics();
};
// Loads a CpModelProto inside the given model.
@@ -119,6 +122,11 @@ int RegisterClausesLevelZeroImport(int id,
SharedClausesManager* shared_clauses_manager,
Model* model);
// This will register a level zero callback to imports new linear2 from the
// SharedLinear2Bounds.
void RegisterLinear2BoundsImport(SharedLinear2Bounds* shared_linear2_bounds,
Model* model);
void PostsolveResponseWrapper(const SatParameters& params,
int num_variable_in_original_model,
const CpModelProto& mapping_proto,

View File

@@ -84,22 +84,18 @@ bool LinearExpression2::NegateForCanonicalization() {
}
bool LinearExpression2::CanonicalizeAndUpdateBounds(IntegerValue& lb,
IntegerValue& ub,
bool allow_negation) {
IntegerValue& ub) {
SimpleCanonicalization();
if (coeffs[0] == 0 || coeffs[1] == 0) return false; // abort.
bool negated = false;
if (allow_negation) {
negated = NegateForCanonicalization();
if (negated) {
// We need to be able to negate without overflow.
CHECK_GE(lb, kMinIntegerValue);
CHECK_LE(ub, kMaxIntegerValue);
std::swap(lb, ub);
lb = -lb;
ub = -ub;
}
const bool negated = NegateForCanonicalization();
if (negated) {
// We need to be able to negate without overflow.
CHECK_GE(lb, kMinIntegerValue);
CHECK_LE(ub, kMaxIntegerValue);
std::swap(lb, ub);
lb = -lb;
ub = -ub;
}
// Do gcd division.
@@ -144,8 +140,7 @@ std::pair<BestBinaryRelationBounds::AddResult,
BestBinaryRelationBounds::AddResult>
BestBinaryRelationBounds::Add(LinearExpression2 expr, IntegerValue lb,
IntegerValue ub) {
const bool negated =
expr.CanonicalizeAndUpdateBounds(lb, ub, /*allow_negation=*/true);
const bool negated = expr.CanonicalizeAndUpdateBounds(lb, ub);
// We only store proper linear2.
if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) {
@@ -184,7 +179,7 @@ BestBinaryRelationBounds::Add(LinearExpression2 expr, IntegerValue lb,
RelationStatus BestBinaryRelationBounds::GetStatus(LinearExpression2 expr,
IntegerValue lb,
IntegerValue ub) const {
expr.CanonicalizeAndUpdateBounds(lb, ub, /*allow_negation=*/true);
expr.CanonicalizeAndUpdateBounds(lb, ub);
if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) {
return RelationStatus::IS_UNKNOWN;
}
@@ -245,14 +240,4 @@ BestBinaryRelationBounds::GetSortedNonTrivialBounds() const {
return root_relations_sorted;
}
void BestBinaryRelationBounds::AppendAllExpressionContaining(
Bitset64<IntegerVariable>::ConstView var_set,
std::vector<LinearExpression2>* result) const {
for (const auto& [expr, unused] : best_bounds_) {
if (!var_set[PositiveVariable(expr.vars[0])]) continue;
if (!var_set[PositiveVariable(expr.vars[1])]) continue;
result->push_back(expr);
}
}
} // namespace operations_research::sat

View File

@@ -384,8 +384,7 @@ struct LinearExpression2 {
// accordingly. This is the same as SimpleCanonicalization(), DivideByGcd()
// and the NegateForCanonicalization() with a proper updates of the bounds.
// Returns whether the expression was negated.
bool CanonicalizeAndUpdateBounds(IntegerValue& lb, IntegerValue& ub,
bool allow_negation = false);
bool CanonicalizeAndUpdateBounds(IntegerValue& lb, IntegerValue& ub);
// Divides the expression by the gcd of both coefficients, and returns it.
// Note that we always return something >= 1 even if both coefficients are
@@ -493,7 +492,7 @@ class BestBinaryRelationBounds {
IntegerValue GetUpperBound(LinearExpression2 expr) const;
// Same as GetUpperBound() but assume the expression is already canonicalized.
// This is slighlty faster.
// This is slightly faster.
IntegerValue UpperBoundWhenCanonicalized(LinearExpression2 expr) const;
int64_t num_bounds() const { return best_bounds_.size(); }
@@ -504,11 +503,6 @@ class BestBinaryRelationBounds {
std::vector<std::tuple<LinearExpression2, IntegerValue, IntegerValue>>
GetSortedNonTrivialBounds() const;
// Note that this is non-deterministic and in O(num_relations).
void AppendAllExpressionContaining(
Bitset64<IntegerVariable>::ConstView var_set,
std::vector<LinearExpression2>* result) const;
private:
// The best bound on the given "canonicalized" expression.
absl::flat_hash_map<LinearExpression2, std::pair<IntegerValue, IntegerValue>>

View File

@@ -17,6 +17,7 @@
#include <algorithm>
#include <deque>
#include <limits>
#include <string>
#include <tuple>
#include <utility>
@@ -54,6 +55,53 @@
namespace operations_research {
namespace sat {
LinearExpression2Index Linear2WithPotentialNonTrivalBounds::AddOrGet(
LinearExpression2 original_expr) {
LinearExpression2 expr = original_expr;
DCHECK(expr.IsCanonicalized());
DCHECK_EQ(expr.DivideByGcd(), 1);
DCHECK_NE(expr.coeffs[0], 0);
DCHECK_NE(expr.coeffs[1], 0);
const bool negated = expr.NegateForCanonicalization();
auto [it, inserted] = expr_to_index_.insert({expr, exprs_.size()});
if (inserted) {
CHECK_LT(2 * exprs_.size() + 1,
std::numeric_limits<LinearExpression2Index>::max());
exprs_.push_back(expr);
}
const LinearExpression2Index result =
negated ? NegationOf(LinearExpression2Index(2 * it->second))
: LinearExpression2Index(2 * it->second);
if (!inserted) return result;
// Update our special coeff=1 lookup table.
if (expr.coeffs[0] == 1 && expr.coeffs[1] == 1) {
// +2 to handle possible negation.
const int new_size =
std::max(expr.vars[0].value(), expr.vars[1].value()) + 2;
if (new_size > coeff_one_var_lookup_.size()) {
coeff_one_var_lookup_.resize(new_size);
}
LinearExpression2 neg_expr = original_expr;
neg_expr.Negate();
coeff_one_var_lookup_[original_expr.vars[0]].push_back(result);
coeff_one_var_lookup_[original_expr.vars[1]].push_back(result);
coeff_one_var_lookup_[neg_expr.vars[1]].push_back(NegationOf(result));
coeff_one_var_lookup_[neg_expr.vars[0]].push_back(NegationOf(result));
}
// Update our per-variable and per-pair lookup tables.
IntegerVariable var1 = PositiveVariable(expr.vars[0]);
IntegerVariable var2 = PositiveVariable(expr.vars[1]);
if (var1 > var2) std::swap(var1, var2);
var_pair_to_bounds_[{var1, var2}].push_back(result);
var_to_bounds_[var1].push_back(result);
var_to_bounds_[var2].push_back(result);
return result;
}
void Linear2Watcher::NotifyBoundChanged(LinearExpression2 expr) {
DCHECK(expr.IsCanonicalized());
DCHECK_EQ(expr.DivideByGcd(), 1);
@@ -75,115 +123,51 @@ int64_t Linear2Watcher::VarTimestamp(IntegerVariable var) {
return var < var_timestamp_.size() ? var_timestamp_[var] : 0;
}
std::pair<bool, bool> RootLevelLinear2Bounds::Add(LinearExpression2 expr,
IntegerValue lb,
IntegerValue ub) {
using AddResult = BestBinaryRelationBounds::AddResult;
const IntegerValue zero_level_lb = integer_trail_->LevelZeroLowerBound(expr);
bool RootLevelLinear2Bounds::AddUpperBound(LinearExpression2Index index,
IntegerValue ub) {
const LinearExpression2 expr = non_trivial_bounds_->GetExpression(index);
const IntegerValue zero_level_ub = integer_trail_->LevelZeroUpperBound(expr);
if (lb <= zero_level_lb && ub >= zero_level_ub) {
return {false, false};
}
// Don't store one of the bounds if it is trivial.
if (lb <= zero_level_lb) {
lb = kMinIntegerValue;
}
if (ub >= zero_level_ub) {
ub = kMaxIntegerValue;
return false;
}
expr.CanonicalizeAndUpdateBounds(lb, ub);
const auto [status_lb, status_ub] = root_level_relations_.Add(expr, lb, ub);
if (best_upper_bounds_.size() <= index) {
best_upper_bounds_.resize(index.value() + 1, kMaxIntegerValue);
}
if (ub >= best_upper_bounds_[index]) {
return false;
}
best_upper_bounds_[index] = ub;
const bool lb_restricted =
status_lb == AddResult::ADDED || status_lb == AddResult::UPDATED;
const bool ub_restricted =
status_ub == AddResult::ADDED || status_ub == AddResult::UPDATED;
if (!lb_restricted && !ub_restricted) return {false, false};
non_trivial_bounds_->AddOrGet(expr);
++num_updates_;
linear2_watcher_->NotifyBoundChanged(expr);
// Update our special coeff=1 lookup table.
if (expr.coeffs[0] == 1 && expr.coeffs[1] == 1) {
// +2 to handle possible negation.
const int new_size =
std::max(expr.vars[0].value(), expr.vars[1].value()) + 2;
if (new_size > coeff_one_var_lookup_.size()) {
coeff_one_var_lookup_.resize(new_size);
}
if (status_lb == AddResult::ADDED) {
// First time added to root_level_relations_.
coeff_one_var_lookup_[NegationOf(expr.vars[0])].push_back(
NegationOf(expr.vars[1]));
coeff_one_var_lookup_[NegationOf(expr.vars[1])].push_back(
NegationOf(expr.vars[0]));
}
if (status_ub == AddResult::ADDED) {
coeff_one_var_lookup_[expr.vars[0]].push_back(expr.vars[1]);
coeff_one_var_lookup_[expr.vars[1]].push_back(expr.vars[0]);
// Share.
//
// TODO(user): It seems we could change the canonicalization to only use
// positive variable? that would simplify a bit the code here and not make it
// worse elsewhere?
if (shared_linear2_bounds_ != nullptr) {
const IntegerValue lb = -LevelZeroUpperBound(NegationOf(index));
const int proto_var0 =
cp_model_mapping_->GetProtoVariableFromIntegerVariable(
PositiveVariable(expr.vars[0]));
const int proto_var1 =
cp_model_mapping_->GetProtoVariableFromIntegerVariable(
PositiveVariable(expr.vars[1]));
if (proto_var0 >= 0 && proto_var1 >= 0) {
// This is also a relation between cp_model proto variable. Share it!
// Note that since expr is canonicalized, this one should too.
SharedLinear2Bounds::Key key;
key.vars[0] = proto_var0;
key.coeffs[0] =
VariableIsPositive(expr.vars[0]) ? expr.coeffs[0] : -expr.coeffs[0];
key.vars[1] = proto_var1;
key.coeffs[1] =
VariableIsPositive(expr.vars[1]) ? expr.coeffs[1] : -expr.coeffs[1];
shared_linear2_bounds_->Add(shared_linear2_bounds_id_, key, lb, ub);
}
}
// Update our per-variable and per-pair lookup tables.
IntegerVariable var1 = PositiveVariable(expr.vars[0]);
IntegerVariable var2 = PositiveVariable(expr.vars[1]);
if (var1 > var2) std::swap(var1, var2);
auto [it_var, inserted] = var_to_bounds_vector_index_.insert({expr, {0, 0}});
for (const IntegerVariable var : {var1, var2}) {
auto& var_bounds = var_to_bounds_[var];
if (inserted) {
if (var == var1) {
it_var->second.first = var_bounds.size();
} else {
it_var->second.second = var_bounds.size();
}
var_bounds.push_back({expr, lb, ub});
} else {
const int index =
(var == var1) ? it_var->second.first : it_var->second.second;
DCHECK_LT(index, var_bounds.size());
std::tuple<LinearExpression2, IntegerValue, IntegerValue>& var_bound =
var_bounds[index];
if (status_lb == AddResult::ADDED || status_lb == AddResult::UPDATED) {
std::get<1>(var_bound) = lb;
}
if (status_ub == AddResult::ADDED || status_ub == AddResult::UPDATED) {
std::get<2>(var_bound) = ub;
}
}
}
auto [it_pair, pair_inserted] =
var_pair_to_bounds_vector_index_.insert({expr, 0});
DCHECK_EQ(inserted, pair_inserted);
auto& pair_bounds = var_pair_to_bounds_[{var1, var2}];
if (pair_inserted) {
it_pair->second = pair_bounds.size();
pair_bounds.push_back({expr, lb, ub});
} else {
const int index = it_pair->second;
DCHECK_LT(index, pair_bounds.size());
std::tuple<LinearExpression2, IntegerValue, IntegerValue>& pair_bound =
pair_bounds[index];
if (status_lb == AddResult::ADDED || status_lb == AddResult::UPDATED) {
std::get<1>(pair_bound) = lb;
}
if (status_ub == AddResult::ADDED || status_ub == AddResult::UPDATED) {
std::get<2>(pair_bound) = ub;
}
}
return {lb_restricted, ub_restricted};
}
IntegerValue RootLevelLinear2Bounds::LevelZeroUpperBound(
LinearExpression2 expr) const {
// TODO(user): Remove the expression from the root_level_relations_ if the
// zero-level bound got more restrictive.
return std::min(integer_trail_->LevelZeroUpperBound(expr),
root_level_relations_.GetUpperBound(expr));
return true;
}
RootLevelLinear2Bounds::~RootLevelLinear2Bounds() {
@@ -209,38 +193,38 @@ RelationStatus RootLevelLinear2Bounds::GetLevelZeroStatus(
}
IntegerValue RootLevelLinear2Bounds::GetUpperBoundNoTrail(
LinearExpression2 expr) const {
DCHECK_EQ(expr.DivideByGcd(), 1);
DCHECK(expr.IsCanonicalized());
return root_level_relations_.UpperBoundWhenCanonicalized(expr);
LinearExpression2Index index) const {
if (best_upper_bounds_.size() <= index) {
return kMaxIntegerValue;
}
return best_upper_bounds_[index];
}
std::vector<std::pair<LinearExpression2, IntegerValue>>
RootLevelLinear2Bounds::GetSortedNonTrivialUpperBounds() const {
std::vector<std::pair<LinearExpression2, IntegerValue>> result =
root_level_relations_.GetSortedNonTrivialUpperBounds();
int new_size = 0;
for (int i = 0; i < result.size(); ++i) {
const auto& [expr, ub] = result[i];
std::vector<std::pair<LinearExpression2, IntegerValue>> result;
for (LinearExpression2Index index = LinearExpression2Index{0};
index < best_upper_bounds_.size(); ++index) {
const IntegerValue ub = best_upper_bounds_[index];
if (ub == kMaxIntegerValue) continue;
const LinearExpression2 expr = non_trivial_bounds_->GetExpression(index);
if (ub < integer_trail_->LevelZeroUpperBound(expr)) {
result[new_size] = {expr, ub};
++new_size;
result.push_back({expr, ub});
}
}
result.resize(new_size);
std::sort(result.begin(), result.end());
return result;
}
// Return a list of (lb <= expr <= ub), with expr.vars[0] = var, where at
// least one of the bounds is non-trivial and the potential other non-trivial
// bound is tight.
std::vector<std::tuple<LinearExpression2, IntegerValue, IntegerValue>>
RootLevelLinear2Bounds::GetAllBoundsContainingVariable(
IntegerVariable var) const {
std::vector<std::tuple<LinearExpression2, IntegerValue, IntegerValue>> result;
auto it = var_to_bounds_.find(PositiveVariable(var));
if (it == var_to_bounds_.end()) return {};
for (const auto& [expr, lb, ub] : it->second) {
for (const LinearExpression2Index index :
non_trivial_bounds_->GetAllLinear2ContainingVariable(var)) {
const IntegerValue lb = -GetUpperBoundNoTrail(NegationOf(index));
const IntegerValue ub = GetUpperBoundNoTrail(index);
const LinearExpression2 expr = non_trivial_bounds_->GetExpression(index);
const IntegerValue trail_lb = integer_trail_->LevelZeroLowerBound(expr);
const IntegerValue trail_ub = integer_trail_->LevelZeroUpperBound(expr);
if (lb <= trail_lb && ub >= trail_ub) continue;
@@ -271,12 +255,11 @@ std::vector<std::tuple<LinearExpression2, IntegerValue, IntegerValue>>
RootLevelLinear2Bounds::GetAllBoundsContainingVariables(
IntegerVariable var1, IntegerVariable var2) const {
std::vector<std::tuple<LinearExpression2, IntegerValue, IntegerValue>> result;
std::pair<IntegerVariable, IntegerVariable> key = {PositiveVariable(var1),
PositiveVariable(var2)};
if (key.first > key.second) std::swap(key.first, key.second);
auto it = var_pair_to_bounds_.find(key);
if (it == var_pair_to_bounds_.end()) return {};
for (const auto& [expr, lb, ub] : it->second) {
for (const LinearExpression2Index index :
non_trivial_bounds_->GetAllLinear2ContainingVariables(var1, var2)) {
const IntegerValue lb = -GetUpperBoundNoTrail(NegationOf(index));
const IntegerValue ub = GetUpperBoundNoTrail(index);
const LinearExpression2 expr = non_trivial_bounds_->GetExpression(index);
const IntegerValue trail_lb = integer_trail_->LevelZeroLowerBound(expr);
const IntegerValue trail_ub = integer_trail_->LevelZeroUpperBound(expr);
if (lb <= trail_lb && ub >= trail_ub) continue;
@@ -304,10 +287,25 @@ RootLevelLinear2Bounds::GetAllBoundsContainingVariables(
return result;
}
void RootLevelLinear2Bounds::AppendAllExpressionContaining(
Bitset64<IntegerVariable>::ConstView var_set,
std::vector<LinearExpression2>* result) const {
root_level_relations_.AppendAllExpressionContaining(var_set, result);
std::vector<IntegerVariable>
RootLevelLinear2Bounds::GetVariablesInSimpleRelation(
IntegerVariable var) const {
std::vector<IntegerVariable> result;
for (const LinearExpression2Index index :
non_trivial_bounds_->GetAllLinear2ContainingVariableWithCoeffOne(var)) {
const LinearExpression2 expr = non_trivial_bounds_->GetExpression(index);
const IntegerVariable other =
(expr.vars[0] == var ? expr.vars[1] : expr.vars[0]);
DCHECK_EQ(expr.coeffs[0], 1);
DCHECK_EQ(expr.coeffs[1], 1);
DCHECK((expr.vars[0] == var && expr.vars[1] == other) ||
(expr.vars[0] == other && expr.vars[1] == var));
if (GetUpperBoundNoTrail(index) <
integer_trail_->LevelZeroUpperBound(expr)) {
result.push_back(other);
}
}
return result;
}
EnforcedLinear2Bounds::~EnforcedLinear2Bounds() {
@@ -319,13 +317,8 @@ EnforcedLinear2Bounds::~EnforcedLinear2Bounds() {
}
void EnforcedLinear2Bounds::PushConditionalRelation(
absl::Span<const Literal> enforcements, LinearExpression2 expr,
absl::Span<const Literal> enforcements, LinearExpression2Index index,
IntegerValue rhs) {
expr.SimpleCanonicalization();
if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) {
return;
}
// This must be currently true.
if (DEBUG_MODE) {
for (const Literal l : enforcements) {
@@ -334,24 +327,25 @@ void EnforcedLinear2Bounds::PushConditionalRelation(
}
if (enforcements.empty() || trail_->CurrentDecisionLevel() == 0) {
root_level_bounds_->AddUpperBound(expr, rhs);
root_level_bounds_->AddUpperBound(index, rhs);
return;
}
const IntegerValue gcd = expr.DivideByGcd();
rhs = FloorRatio(rhs, gcd);
if (rhs >= root_level_bounds_->LevelZeroUpperBound(expr)) return;
if (rhs >= root_level_bounds_->LevelZeroUpperBound(index)) return;
const LinearExpression2 expr = non_trivial_bounds_->GetExpression(index);
linear2_watcher_->NotifyBoundChanged(expr);
++num_conditional_relation_updates_;
const int new_index = conditional_stack_.size();
const auto [it, inserted] = conditional_relations_.insert({expr, new_index});
if (inserted) {
non_trivial_bounds_->AddOrGet(expr);
if (conditional_relations_.size() <= index) {
conditional_relations_.resize(index.value() + 1, -1);
}
if (conditional_relations_[index] == -1) {
conditional_relations_[index] = new_index;
CreateLevelEntryIfNeeded();
conditional_stack_.emplace_back(/*prev_entry=*/-1, rhs, expr, enforcements);
conditional_stack_.emplace_back(/*prev_entry=*/-1, rhs, index,
enforcements);
if (expr.coeffs[0] == 1 && expr.coeffs[1] == 1) {
const int new_size =
@@ -363,13 +357,13 @@ void EnforcedLinear2Bounds::PushConditionalRelation(
conditional_var_lookup_[expr.vars[1]].push_back(expr.vars[0]);
}
} else {
const int prev_entry = it->second;
const int prev_entry = conditional_relations_[index];
if (rhs >= conditional_stack_[prev_entry].rhs) return;
// Update.
it->second = new_index;
conditional_relations_[index] = new_index;
CreateLevelEntryIfNeeded();
conditional_stack_.emplace_back(prev_entry, rhs, expr, enforcements);
conditional_stack_.emplace_back(prev_entry, rhs, index, enforcements);
}
}
@@ -392,15 +386,15 @@ void EnforcedLinear2Bounds::SetLevel(int level) {
if (back.prev_entry != -1) {
conditional_relations_[back.key] = back.prev_entry;
} else {
conditional_relations_.erase(back.key);
conditional_relations_[back.key] = -1;
const LinearExpression2 expr =
non_trivial_bounds_->GetExpression(back.key);
if (back.key.coeffs[0] == 1 && back.key.coeffs[1] == 1) {
DCHECK_EQ(conditional_var_lookup_[back.key.vars[0]].back(),
back.key.vars[1]);
DCHECK_EQ(conditional_var_lookup_[back.key.vars[1]].back(),
back.key.vars[0]);
conditional_var_lookup_[back.key.vars[0]].pop_back();
conditional_var_lookup_[back.key.vars[1]].pop_back();
if (expr.coeffs[0] == 1 && expr.coeffs[1] == 1) {
DCHECK_EQ(conditional_var_lookup_[expr.vars[0]].back(), expr.vars[1]);
DCHECK_EQ(conditional_var_lookup_[expr.vars[1]].back(), expr.vars[0]);
conditional_var_lookup_[expr.vars[0]].pop_back();
conditional_var_lookup_[expr.vars[1]].pop_back();
}
}
conditional_stack_.pop_back();
@@ -410,42 +404,42 @@ void EnforcedLinear2Bounds::SetLevel(int level) {
}
void EnforcedLinear2Bounds::AddReasonForUpperBoundLowerThan(
LinearExpression2 expr, IntegerValue ub,
LinearExpression2Index index, IntegerValue ub,
std::vector<Literal>* literal_reason,
std::vector<IntegerLiteral>* /*unused*/) const {
expr.SimpleCanonicalization();
if (ub >= root_level_bounds_->LevelZeroUpperBound(expr)) return;
const IntegerValue gcd = expr.DivideByGcd();
const auto it = conditional_relations_.find(expr);
DCHECK(it != conditional_relations_.end());
if (ub >= root_level_bounds_->LevelZeroUpperBound(index)) return;
DCHECK_LT(index, conditional_relations_.size());
const int entry_index = conditional_relations_[index];
DCHECK_NE(entry_index, -1);
const ConditionalEntry& entry = conditional_stack_[it->second];
const ConditionalEntry& entry = conditional_stack_[entry_index];
if (DEBUG_MODE) {
for (const Literal l : entry.enforcements) {
CHECK(trail_->Assignment().LiteralIsTrue(l));
}
}
DCHECK_LE(CapProdI(gcd, entry.rhs), ub);
DCHECK_LE(entry.rhs, ub);
for (const Literal l : entry.enforcements) {
literal_reason->push_back(l.Negated());
}
}
IntegerValue EnforcedLinear2Bounds::GetUpperBoundFromEnforced(
LinearExpression2 expr) const {
DCHECK_EQ(expr.DivideByGcd(), 1);
DCHECK(expr.IsCanonicalized());
const auto it = conditional_relations_.find(expr);
if (it == conditional_relations_.end()) {
LinearExpression2Index index) const {
if (index >= conditional_relations_.size()) {
return kMaxIntegerValue;
}
const int entry_index = conditional_relations_[index];
if (entry_index == -1) {
return kMaxIntegerValue;
} else {
const ConditionalEntry& entry = conditional_stack_[it->second];
const ConditionalEntry& entry = conditional_stack_[entry_index];
if (DEBUG_MODE) {
for (const Literal l : entry.enforcements) {
CHECK(trail_->Assignment().LiteralIsTrue(l));
}
}
DCHECK_LT(entry.rhs, root_level_bounds_->LevelZeroUpperBound(expr));
DCHECK_LT(entry.rhs, root_level_bounds_->LevelZeroUpperBound(index));
return entry.rhs;
}
}
@@ -569,7 +563,7 @@ void TransitivePrecedencesEvaluator::Build() {
}
VLOG(2) << "Full precedences. Work=" << work
<< " Relations=" << root_level_bounds_->num_bounds();
<< " Relations=" << root_relations_sorted.size();
}
void TransitivePrecedencesEvaluator::ComputeFullPrecedences(
@@ -738,16 +732,6 @@ void EnforcedLinear2Bounds::CollectPrecedences(
}
}
void EnforcedLinear2Bounds::AppendAllExpressionContaining(
Bitset64<IntegerVariable>::ConstView var_set,
std::vector<LinearExpression2>* result) const {
for (const auto& entry : conditional_stack_) {
if (!var_set[PositiveVariable(entry.key.vars[0])]) continue;
if (!var_set[PositiveVariable(entry.key.vars[1])]) continue;
result->push_back(entry.key);
}
}
namespace {
void AppendLowerBoundReasonIfValid(IntegerVariable var,
@@ -1828,6 +1812,7 @@ Linear2BoundsFromLinear3::Linear2BoundsFromLinear3(Model* model)
bool Linear2BoundsFromLinear3::AddAffineUpperBound(LinearExpression2 expr,
AffineExpression affine_ub) {
expr.SimpleCanonicalization();
if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) return false;
// At level zero, just add it to root_level_bounds_.
if (trail_->CurrentDecisionLevel() == 0) {
@@ -1900,16 +1885,6 @@ void Linear2BoundsFromLinear3::AddReasonForUpperBoundLowerThan(
integer_reason->push_back(affine.LowerOrEqual(CapProdI(ub + 1, divisor) - 1));
}
void Linear2BoundsFromLinear3::AppendAllExpressionContaining(
Bitset64<IntegerVariable>::ConstView var_set,
std::vector<LinearExpression2>* result) const {
for (const auto& [expr, unused] : best_affine_ub_) {
if (!var_set[PositiveVariable(expr.vars[0])]) continue;
if (!var_set[PositiveVariable(expr.vars[1])]) continue;
result->push_back(expr);
}
}
IntegerValue Linear2Bounds::UpperBound(LinearExpression2 expr) const {
expr.SimpleCanonicalization();
if (expr.coeffs[0] == 0) {
@@ -1918,8 +1893,11 @@ IntegerValue Linear2Bounds::UpperBound(LinearExpression2 expr) const {
DCHECK_NE(expr.coeffs[1], 0);
const IntegerValue gcd = expr.DivideByGcd();
IntegerValue ub = integer_trail_->UpperBound(expr);
ub = std::min(ub, root_level_bounds_->GetUpperBoundNoTrail(expr));
ub = std::min(ub, enforced_bounds_->GetUpperBoundFromEnforced(expr));
const LinearExpression2Index index = non_trivial_bounds_->GetIndex(expr);
if (index != kNoLinearExpression2Index) {
ub = std::min(ub, root_level_bounds_->GetUpperBoundNoTrail(index));
ub = std::min(ub, enforced_bounds_->GetUpperBoundFromEnforced(index));
}
ub = std::min(ub, linear3_bounds_->GetUpperBoundFromLinear3(expr));
return CapProdI(gcd, ub);
}
@@ -1932,8 +1910,12 @@ IntegerValue Linear2Bounds::NonTrivialUpperBoundForGcd1(
}
DCHECK_NE(expr.coeffs[1], 0);
DCHECK_EQ(1, expr.DivideByGcd());
IntegerValue ub = root_level_bounds_->GetUpperBoundNoTrail(expr);
ub = std::min(ub, enforced_bounds_->GetUpperBoundFromEnforced(expr));
IntegerValue ub = kMaxIntegerValue;
const LinearExpression2Index index = non_trivial_bounds_->GetIndex(expr);
if (index != kNoLinearExpression2Index) {
ub = std::min(ub, root_level_bounds_->GetUpperBoundNoTrail(index));
ub = std::min(ub, enforced_bounds_->GetUpperBoundFromEnforced(index));
}
ub = std::min(ub, linear3_bounds_->GetUpperBoundFromLinear3(expr));
return ub;
}
@@ -1942,20 +1924,25 @@ void Linear2Bounds::AddReasonForUpperBoundLowerThan(
LinearExpression2 expr, IntegerValue ub,
std::vector<Literal>* literal_reason,
std::vector<IntegerLiteral>* integer_reason) const {
expr.SimpleCanonicalization();
const IntegerValue gcd = expr.DivideByGcd();
ub = FloorRatio(ub, gcd);
DCHECK_LE(UpperBound(expr), ub);
// Explanation are by order of preference, with no reason needed first.
if (root_level_bounds_->LevelZeroUpperBound(expr) <= ub) {
if (integer_trail_->LevelZeroUpperBound(expr) <= ub) {
return;
}
expr.SimpleCanonicalization();
const IntegerValue gcd = expr.DivideByGcd();
ub = FloorRatio(ub, gcd);
const LinearExpression2Index index = non_trivial_bounds_->GetIndex(expr);
// This one is a single literal.
if (enforced_bounds_->GetUpperBoundFromEnforced(expr) <= ub) {
return enforced_bounds_->AddReasonForUpperBoundLowerThan(
expr, ub, literal_reason, integer_reason);
if (index != kNoLinearExpression2Index) {
if (root_level_bounds_->GetUpperBoundNoTrail(index) <= ub) {
return;
}
if (enforced_bounds_->GetUpperBoundFromEnforced(index) <= ub) {
return enforced_bounds_->AddReasonForUpperBoundLowerThan(
index, ub, literal_reason, integer_reason);
}
}
// This one is a single var upper bound.
@@ -1975,16 +1962,5 @@ void Linear2Bounds::AddReasonForUpperBoundLowerThan(
integer_reason);
}
absl::Span<const LinearExpression2>
Linear2Bounds::GetAllExpressionsWithPotentialNonTrivialBounds(
Bitset64<IntegerVariable>::ConstView var_set) const {
tmp_expressions_.clear();
root_level_bounds_->AppendAllExpressionContaining(var_set, &tmp_expressions_);
enforced_bounds_->AppendAllExpressionContaining(var_set, &tmp_expressions_);
linear3_bounds_->AppendAllExpressionContaining(var_set, &tmp_expressions_);
gtl::STLSortAndRemoveDuplicates(&tmp_expressions_);
return tmp_expressions_;
}
} // namespace sat
} // namespace operations_research

View File

@@ -14,10 +14,10 @@
#ifndef OR_TOOLS_SAT_PRECEDENCES_H_
#define OR_TOOLS_SAT_PRECEDENCES_H_
#include <algorithm>
#include <cstdint>
#include <deque>
#include <functional>
#include <limits>
#include <tuple>
#include <utility>
#include <vector>
@@ -31,6 +31,7 @@
#include "absl/types/span.h"
#include "ortools/base/strong_vector.h"
#include "ortools/graph/graph.h"
#include "ortools/sat/cp_model_mapping.h"
#include "ortools/sat/integer.h"
#include "ortools/sat/integer_base.h"
#include "ortools/sat/model.h"
@@ -70,23 +71,14 @@ class Linear2WithPotentialNonTrivalBounds {
// Returns a never-changing index for the given linear expression.
// The expression must already be canonicalized and divided by its GCD.
LinearExpression2Index AddOrGet(LinearExpression2 expr) {
DCHECK(expr.IsCanonicalized());
DCHECK_EQ(expr.DivideByGcd(), 1);
const bool negated = expr.NegateForCanonicalization();
auto [it, inserted] = expr_to_index_.insert({expr, exprs_.size()});
if (inserted) {
CHECK_LT(2 * exprs_.size() + 1,
std::numeric_limits<LinearExpression2Index>::max());
exprs_.push_back(expr);
}
const LinearExpression2Index positive_index(2 * it->second);
if (negated) {
return NegationOf(positive_index);
} else {
return positive_index;
}
}
LinearExpression2Index AddOrGet(LinearExpression2 expr);
// Returns a never-changing index for the given linear expression if it is
// potentially non-trivial, otherwise returns kNoLinearExpression2Index. The
// expression must already be canonicalized and divided by its GCD.
LinearExpression2Index GetIndex(LinearExpression2 expr) const;
LinearExpression2 GetExpression(LinearExpression2Index index) const;
// Return all positive linear2 expressions that have a potentially non-trivial
// bound. When calling this code it is often a good idea to check both the
@@ -97,9 +89,45 @@ class Linear2WithPotentialNonTrivalBounds {
return exprs_;
}
// Return a list of all potentially non-trivial LinearExpression2Indexes
// containing a given variable.
absl::Span<const LinearExpression2Index> GetAllLinear2ContainingVariable(
IntegerVariable var) const;
// Return a list of all potentially non-trivial LinearExpression2Indexes
// containing a given pair of variables.
absl::Span<const LinearExpression2Index> GetAllLinear2ContainingVariables(
IntegerVariable var1, IntegerVariable var2) const;
// For a given variable `var`, return all linear expressions with both
// coefficients 1 that have a potentially non trivial upper bound. For
// convenience it also returns the other variable to cheaply build the
// linear2. Note that using negation one can also recover x + y >= lb and x -
// y <= ub.
absl::Span<const LinearExpression2Index>
GetAllLinear2ContainingVariableWithCoeffOne(IntegerVariable var) const {
if (var >= coeff_one_var_lookup_.size()) return {};
return coeff_one_var_lookup_[var];
}
private:
util_intops::StrongVector<LinearExpression2Index, LinearExpression2> exprs_;
std::vector<LinearExpression2> exprs_;
absl::flat_hash_map<LinearExpression2, int> expr_to_index_;
// Lookup table to find all the LinearExpression2 with a given variable and
// having both coefficient 1.
util_intops::StrongVector<IntegerVariable,
std::vector<LinearExpression2Index>>
coeff_one_var_lookup_;
// Map to implement GetAllBoundsContainingVariable().
absl::flat_hash_map<IntegerVariable,
absl::InlinedVector<LinearExpression2Index, 2>>
var_to_bounds_;
// Map to implement GetAllBoundsContainingVariables().
absl::flat_hash_map<std::pair<IntegerVariable, IntegerVariable>,
absl::InlinedVector<LinearExpression2Index, 1>>
var_pair_to_bounds_;
};
// Simple "watcher" class that will be notified if a linear2 bound changed. It
@@ -138,7 +166,13 @@ class RootLevelLinear2Bounds {
linear2_watcher_(model->GetOrCreate<Linear2Watcher>()),
shared_stats_(model->GetOrCreate<SharedStatistics>()),
non_trivial_bounds_(
model->GetOrCreate<Linear2WithPotentialNonTrivalBounds>()) {}
model->GetOrCreate<Linear2WithPotentialNonTrivalBounds>()),
cp_model_mapping_(model->GetOrCreate<CpModelMapping>()),
shared_linear2_bounds_(model->Mutable<SharedLinear2Bounds>()),
shared_linear2_bounds_id_(
shared_linear2_bounds_ == nullptr
? 0
: shared_linear2_bounds_->RegisterNewId(model->Name())) {}
~RootLevelLinear2Bounds();
@@ -147,16 +181,49 @@ class RootLevelLinear2Bounds {
// Returns a pair saying whether the lower/upper bounds for this expr became
// more restricted than what was currently stored.
std::pair<bool, bool> Add(LinearExpression2 expr, IntegerValue lb,
IntegerValue ub);
IntegerValue ub) {
const bool negated = expr.CanonicalizeAndUpdateBounds(lb, ub);
if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) return {false, false};
const LinearExpression2Index index = non_trivial_bounds_->AddOrGet(expr);
bool ub_changed = AddUpperBound(index, ub);
bool lb_changed = AddUpperBound(NegationOf(index), -lb);
if (negated) {
std::swap(lb_changed, ub_changed);
}
return {lb_changed, ub_changed};
}
bool AddUpperBound(LinearExpression2Index index, IntegerValue ub);
// Same as above, but only update the upper bound.
bool AddUpperBound(LinearExpression2 expr, IntegerValue ub) {
return Add(expr, kMinIntegerValue, ub).second;
expr.SimpleCanonicalization();
if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) return false;
const IntegerValue gcd = expr.DivideByGcd();
ub = FloorRatio(ub, gcd);
return AddUpperBound(non_trivial_bounds_->AddOrGet(expr), ub);
}
IntegerValue LevelZeroUpperBound(LinearExpression2 expr) const;
IntegerValue LevelZeroUpperBound(LinearExpression2 expr) const {
expr.SimpleCanonicalization();
if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) {
return integer_trail_->LevelZeroUpperBound(expr);
}
const IntegerValue gcd = expr.DivideByGcd();
const LinearExpression2Index index = non_trivial_bounds_->GetIndex(expr);
if (index == kNoLinearExpression2Index) {
return integer_trail_->LevelZeroUpperBound(expr);
}
return CapProdI(gcd, LevelZeroUpperBound(index));
}
int64_t num_bounds() const { return root_level_relations_.num_bounds(); }
IntegerValue LevelZeroUpperBound(LinearExpression2Index index) const {
const LinearExpression2 expr = non_trivial_bounds_->GetExpression(index);
// TODO(user): Remove the expression from the root_level_relations_ if
// the zero-level bound got more restrictive.
return std::min(integer_trail_->LevelZeroUpperBound(expr),
GetUpperBoundNoTrail(index));
}
// Return a list of (expr <= ub) sorted by expr. They are guaranteed to be
// better than the trivial upper bound.
@@ -183,11 +250,8 @@ class RootLevelLinear2Bounds {
// For a given variable `var`, return all variables `other` so that
// LinearExpression2(var, other, 1, 1) has a non trivial upper bound.
// Note that using negation one can also recover x + y >= lb and x - y <= ub.
absl::Span<const IntegerVariable> GetVariablesInSimpleRelation(
IntegerVariable var) const {
if (var >= coeff_one_var_lookup_.size()) return {};
return coeff_one_var_lookup_[var];
}
std::vector<IntegerVariable> GetVariablesInSimpleRelation(
IntegerVariable var) const;
RelationStatus GetLevelZeroStatus(LinearExpression2 expr, IntegerValue lb,
IntegerValue ub) const;
@@ -197,47 +261,21 @@ class RootLevelLinear2Bounds {
// behavior from LevelZeroUpperBound() that would return the implied
// zero-level bound from the trail for trivial ones. `expr` must be
// canonicalized and gcd-reduced.
IntegerValue GetUpperBoundNoTrail(LinearExpression2 expr) const;
void AppendAllExpressionContaining(
Bitset64<IntegerVariable>::ConstView var_set,
std::vector<LinearExpression2>* result) const;
IntegerValue GetUpperBoundNoTrail(LinearExpression2Index index) const;
private:
IntegerTrail* integer_trail_;
Linear2Watcher* linear2_watcher_;
SharedStatistics* shared_stats_;
Linear2WithPotentialNonTrivalBounds* non_trivial_bounds_;
CpModelMapping* cp_model_mapping_;
SharedLinear2Bounds* shared_linear2_bounds_; // Might be nullptr.
// Lookup table to find all the LinearExpression2 with a given variable and
// having both coefficient 1.
util_intops::StrongVector<IntegerVariable, std::vector<IntegerVariable>>
coeff_one_var_lookup_;
const int shared_linear2_bounds_id_;
// TODO(user): use data structures that consume less memory. A single
// std::vector<LinearExpression2> and hash maps having the index as value
// could be enough.
absl::flat_hash_map<
IntegerVariable,
absl::InlinedVector<
std::tuple<LinearExpression2, IntegerValue, IntegerValue>, 2>>
var_to_bounds_;
// Map to implement GetAllBoundsContainingVariables().
absl::flat_hash_map<
std::pair<IntegerVariable, IntegerVariable>,
absl::InlinedVector<
std::tuple<LinearExpression2, IntegerValue, IntegerValue>, 1>>
var_pair_to_bounds_;
// Data structure to quickly update var_to_bounds_. Return the index where
// this linear expression appear in the vector for the first and second
// variable.
absl::flat_hash_map<LinearExpression2, std::pair<int, int>>
var_to_bounds_vector_index_;
absl::flat_hash_map<LinearExpression2, int> var_pair_to_bounds_vector_index_;
util_intops::StrongVector<LinearExpression2Index, IntegerValue>
best_upper_bounds_;
// TODO(user): Also push them to a global shared repository after
// remapping IntegerVariable to proto indices.
BestBinaryRelationBounds root_level_relations_;
int64_t num_updates_ = 0;
};
@@ -338,7 +376,17 @@ class EnforcedLinear2Bounds : public ReversibleInterface {
// If expr is not a proper linear2 expression (e.g. 0*x + y, y + y, y - y) it
// will be ignored.
void PushConditionalRelation(absl::Span<const Literal> enforcements,
LinearExpression2 expr, IntegerValue rhs);
LinearExpression2Index index, IntegerValue rhs);
void PushConditionalRelation(absl::Span<const Literal> enforcements,
LinearExpression2 expr, IntegerValue rhs) {
expr.SimpleCanonicalization();
if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) return;
const IntegerValue gcd = expr.DivideByGcd();
rhs = FloorRatio(rhs, gcd);
return PushConditionalRelation(enforcements,
non_trivial_bounds_->AddOrGet(expr), rhs);
}
// Called each time we change decision level.
void SetLevel(int level) final;
@@ -365,18 +413,13 @@ class EnforcedLinear2Bounds : public ReversibleInterface {
// Low-level function that returns the upper bound if there is some enforced
// relations only. Otherwise always returns kMaxIntegerValue.
// `expr` must be canonicalized and gcd-reduced.
IntegerValue GetUpperBoundFromEnforced(LinearExpression2 expr) const;
IntegerValue GetUpperBoundFromEnforced(LinearExpression2Index index) const;
void AddReasonForUpperBoundLowerThan(
LinearExpression2 expr, IntegerValue ub,
LinearExpression2Index index, IntegerValue ub,
std::vector<Literal>* literal_reason,
std::vector<IntegerLiteral>* integer_reason) const;
// Note: might contain duplicate expressions.
void AppendAllExpressionContaining(
Bitset64<IntegerVariable>::ConstView var_set,
std::vector<LinearExpression2>* result) const;
private:
void CreateLevelEntryIfNeeded();
@@ -395,13 +438,13 @@ class EnforcedLinear2Bounds : public ReversibleInterface {
// TODO(user): this kind of reversible hash_map is already implemented in
// other part of the code. Consolidate.
struct ConditionalEntry {
ConditionalEntry(int p, IntegerValue r, LinearExpression2 k,
ConditionalEntry(int p, IntegerValue r, LinearExpression2Index k,
absl::Span<const Literal> e)
: prev_entry(p), rhs(r), key(k), enforcements(e.begin(), e.end()) {}
int prev_entry;
IntegerValue rhs;
LinearExpression2 key;
LinearExpression2Index key;
absl::InlinedVector<Literal, 4> enforcements;
};
std::vector<ConditionalEntry> conditional_stack_;
@@ -409,7 +452,7 @@ class EnforcedLinear2Bounds : public ReversibleInterface {
// This is always stored in the form (expr <= rhs).
// The conditional relations contains indices in the conditional_stack_.
absl::flat_hash_map<LinearExpression2, int> conditional_relations_;
util_intops::StrongVector<LinearExpression2Index, int> conditional_relations_;
// Store for each variable x, the variables y that appears alongside it in
// lit => x + y <= ub. Note that conditional_var_lookup_ is updated on
@@ -510,11 +553,6 @@ class Linear2BoundsFromLinear3 {
// will replace it and returns true, otherwise it returns false.
bool AddAffineUpperBound(LinearExpression2 expr, AffineExpression affine_ub);
// Warning, the order will not be deterministic.
void AppendAllExpressionContaining(
Bitset64<IntegerVariable>::ConstView var_set,
std::vector<LinearExpression2>* result) const;
// Most users should just use Linear2Bounds::UpperBound() instead.
//
// Returns the upper bound only if there is some relations coming from a
@@ -601,7 +639,9 @@ class Linear2Bounds {
: integer_trail_(model->GetOrCreate<IntegerTrail>()),
root_level_bounds_(model->GetOrCreate<RootLevelLinear2Bounds>()),
enforced_bounds_(model->GetOrCreate<EnforcedLinear2Bounds>()),
linear3_bounds_(model->GetOrCreate<Linear2BoundsFromLinear3>()) {}
linear3_bounds_(model->GetOrCreate<Linear2BoundsFromLinear3>()),
non_trivial_bounds_(
model->GetOrCreate<Linear2WithPotentialNonTrivalBounds>()) {}
// Returns the best known upper-bound of the given LinearExpression2 at the
// current decision level. If its explanation is needed, it can be queried
@@ -616,31 +656,12 @@ class Linear2Bounds {
// don't want the trivial bounds.
IntegerValue NonTrivialUpperBoundForGcd1(LinearExpression2 expr) const;
// Returns all known expressions with potentially non-trivial bounds that
// involves two variable whose positive version is marked in 'vars'.
absl::Span<const LinearExpression2>
GetAllExpressionsWithPotentialNonTrivialBounds(
Bitset64<IntegerVariable>::ConstView var_set) const;
// Returns a temporary bitset, cleared, and resized for all existing
// variables.
//
// If we have many class calling
// GetAllExpressionsWithPotentialNonTrivialBounds() it is important that not
// all of them have a O(num_variables) vector when the same one can be used.
SparseBitset<IntegerVariable>* GetTemporyClearedAndResizedBitset() {
tmp_bitset_.ClearAndResize(integer_trail_->NumIntegerVariables());
return &tmp_bitset_;
}
private:
IntegerTrail* integer_trail_;
RootLevelLinear2Bounds* root_level_bounds_;
EnforcedLinear2Bounds* enforced_bounds_;
Linear2BoundsFromLinear3* linear3_bounds_;
mutable std::vector<LinearExpression2> tmp_expressions_;
SparseBitset<IntegerVariable> tmp_bitset_;
Linear2WithPotentialNonTrivalBounds* non_trivial_bounds_;
};
// Detects if at least one of a subset of linear of size 2 or 1, touching the
@@ -1000,6 +1021,58 @@ inline std::function<void(Model*)> ConditionalLowerOrEqualWithOffset(
};
}
inline LinearExpression2Index Linear2WithPotentialNonTrivalBounds::GetIndex(
LinearExpression2 expr) const {
DCHECK(expr.IsCanonicalized());
DCHECK_EQ(expr.DivideByGcd(), 1);
const bool negated = expr.NegateForCanonicalization();
auto it = expr_to_index_.find(expr);
if (it == expr_to_index_.end()) return kNoLinearExpression2Index;
const LinearExpression2Index positive_index(2 * it->second);
if (negated) {
return NegationOf(positive_index);
} else {
return positive_index;
}
}
inline LinearExpression2 Linear2WithPotentialNonTrivalBounds::GetExpression(
LinearExpression2Index index) const {
DCHECK_NE(index, kNoLinearExpression2Index);
const int lookup_index = index.value() / 2;
DCHECK_LT(lookup_index, exprs_.size());
if (Linear2IsPositive(index)) {
return exprs_[lookup_index];
} else {
LinearExpression2 result = exprs_[lookup_index];
result.Negate();
return result;
}
}
inline absl::Span<const LinearExpression2Index>
Linear2WithPotentialNonTrivalBounds::GetAllLinear2ContainingVariable(
IntegerVariable var) const {
const IntegerVariable positive_var = PositiveVariable(var);
auto it = var_to_bounds_.find(positive_var);
if (it == var_to_bounds_.end()) return {};
return it->second;
}
inline absl::Span<const LinearExpression2Index>
Linear2WithPotentialNonTrivalBounds::GetAllLinear2ContainingVariables(
IntegerVariable var1, IntegerVariable var2) const {
IntegerVariable positive_var1 = PositiveVariable(var1);
IntegerVariable positive_var2 = PositiveVariable(var2);
if (positive_var1 > positive_var2) {
std::swap(positive_var1, positive_var2);
}
auto it = var_pair_to_bounds_.find({positive_var1, positive_var2});
if (it == var_pair_to_bounds_.end()) return {};
return it->second;
}
} // namespace sat
} // namespace operations_research

View File

@@ -190,6 +190,8 @@ TEST(EnforcedLinear2BoundsTest, ConditionalRelations) {
auto* lin2_bounds = model.GetOrCreate<Linear2Bounds>();
auto* integer_trail = model.GetOrCreate<IntegerTrail>();
auto* precedences = model.GetOrCreate<EnforcedLinear2Bounds>();
auto* non_trivial_bounds =
model.GetOrCreate<Linear2WithPotentialNonTrivalBounds>();
const std::vector<IntegerVariable> vars = AddVariables(integer_trail);
const Literal l(model.Add(NewBooleanVariable()), true);
@@ -200,26 +202,25 @@ TEST(EnforcedLinear2BoundsTest, ConditionalRelations) {
precedences->PushConditionalRelation({l}, LinearExpression2(a, b, 1, 1), 15);
precedences->PushConditionalRelation({l}, LinearExpression2(a, b, 1, 1), 20);
LinearExpression2 expr_a_plus_b =
LinearExpression2::Difference(a, NegationOf(b));
expr_a_plus_b.SimpleCanonicalization();
// We only keep the best one.
EXPECT_EQ(
lin2_bounds->UpperBound(LinearExpression2::Difference(a, NegationOf(b))),
15);
EXPECT_EQ(lin2_bounds->UpperBound(expr_a_plus_b), 15);
std::vector<Literal> literal_reason;
std::vector<IntegerLiteral> integer_reason;
precedences->AddReasonForUpperBoundLowerThan(
LinearExpression2::Difference(a, NegationOf(b)), 15, &literal_reason,
non_trivial_bounds->AddOrGet(expr_a_plus_b), 15, &literal_reason,
&integer_reason);
EXPECT_THAT(literal_reason, ElementsAre(l.Negated()));
// Backtrack works.
EXPECT_TRUE(sat_solver->ResetToLevelZero());
EXPECT_EQ(
lin2_bounds->UpperBound(LinearExpression2::Difference(a, NegationOf(b))),
200);
EXPECT_EQ(lin2_bounds->UpperBound(expr_a_plus_b), 200);
literal_reason.clear();
integer_reason.clear();
precedences->AddReasonForUpperBoundLowerThan(
LinearExpression2::Difference(a, NegationOf(b)), kMaxIntegerValue,
non_trivial_bounds->AddOrGet(expr_a_plus_b), kMaxIntegerValue,
&literal_reason, &integer_reason);
EXPECT_THAT(literal_reason, IsEmpty());
}

View File

@@ -24,7 +24,7 @@ option java_multiple_files = true;
// Contains the definitions for all the sat algorithm parameters and their
// default values.
//
// NEXT TAG: 326
// NEXT TAG: 327
message SatParameters {
// In some context, like in a portfolio of search, it makes sense to name a
// given parameters set for logging purpose.
@@ -703,6 +703,13 @@ message SatParameters {
// Allows sharing of the bounds of modified variables at level 0.
optional bool share_level_zero_bounds = 114 [default = true];
// Allows sharing of the bounds on linear2 discovered at level 0. This is
// mainly interesting on scheduling type of problems when we branch on
// precedences.
//
// Warning: This currently non-deterministic.
optional bool share_linear2_bounds = 326 [default = false];
// Allows sharing of new learned binary clause between workers.
optional bool share_binary_clauses = 203 [default = true];

View File

@@ -16,9 +16,11 @@
#include <functional>
#include <iostream>
#include <limits>
#include <memory>
#include <string>
#include <vector>
#include "absl/base/thread_annotations.h"
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "absl/flags/usage.h"
@@ -30,6 +32,8 @@
#include "absl/strings/str_format.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/text_format.h"
#include "ortools/base/helpers.h"
@@ -45,6 +49,7 @@
#include "ortools/sat/synchronization.h"
#include "ortools/util/file_util.h"
#include "ortools/util/logging.h"
#include "ortools/util/sigint.h"
#include "ortools/util/sorted_interval_list.h"
ABSL_FLAG(
@@ -102,8 +107,69 @@ std::string ExtractName(absl::string_view full_filename) {
return filename;
}
void LogInPbCompetitionFormat(int num_variables, bool has_objective,
Model* model, SatParameters* parameters) {
class LastSolutionPrinter {
public:
// Note that is prints the solution in the PB competition format.
void MaybePrintLastSolution() {
absl::MutexLock lock(&mutex_);
if (last_solution_printed_) return;
last_solution_printed_ = true;
if (last_solution_.empty()) {
std::cout << "s UNKNOWN" << std::endl;
} else {
std::cout << "s SATISFIABLE" << std::endl;
std::string line;
for (int i = 0; i < num_variables_; ++i) {
if (last_solution_[i]) {
absl::StrAppend(&line, "x", i + 1, " ");
} else {
absl::StrAppend(&line, "-x", i + 1, " ");
}
if (line.size() >= 75) {
std::cout << "v " << line << std::endl;
line.clear();
}
}
if (!line.empty()) {
std::cout << "v " << line << std::endl;
}
}
}
void set_num_variables(int num_variables) { num_variables_ = num_variables; }
void set_last_solution(absl::Span<const int64_t> solution) {
absl::MutexLock lock(&mutex_);
if (last_solution_printed_) return;
last_solution_.assign(solution.begin(), solution.end());
}
// Returns false if the solution has already been printed, else mark it as
// printed by caller code.
bool mark_last_solution_printed() {
const absl::MutexLock lock(&mutex_);
if (last_solution_printed_) {
return false;
}
last_solution_printed_ = true;
return true;
}
private:
int num_variables_ = 0;
std::vector<int64_t> last_solution_ ABSL_GUARDED_BY(mutex_);
bool last_solution_printed_ ABSL_GUARDED_BY(mutex_) = false;
absl::Mutex mutex_;
};
void LogInPbCompetitionFormat(
int num_variables, bool has_objective, Model* model,
SatParameters* parameters,
std::shared_ptr<LastSolutionPrinter> last_solution_printer) {
CHECK(last_solution_printer != nullptr);
last_solution_printer->set_num_variables(num_variables);
const auto log_callback = [](const std::string& multi_line_input) {
if (multi_line_input.empty()) {
std::cout << "c" << std::endl;
@@ -118,55 +184,60 @@ void LogInPbCompetitionFormat(int num_variables, bool has_objective,
model->GetOrCreate<SolverLogger>()->AddInfoLoggingCallback(log_callback);
parameters->set_log_to_stdout(false);
const auto response_callback = [](const CpSolverResponse& r) {
const auto response_callback = [last_solution_printer](
const CpSolverResponse& r) {
std::cout << "o " << static_cast<int64_t>(r.objective_value()) << std::endl;
last_solution_printer->set_last_solution(r.solution());
};
model->Add(NewFeasibleSolutionObserver(response_callback));
const auto final_response_callback = [num_variables,
has_objective](CpSolverResponse* r) {
switch (r->status()) {
case CpSolverStatus::OPTIMAL:
if (has_objective) {
std::cout << "s OPTIMUM FOUND " << std::endl;
} else {
std::cout << "s SATISFIABLE" << std::endl;
const auto final_response_callback =
[num_variables, has_objective,
last_solution_printer](CpSolverResponse* r) {
if (!last_solution_printer->mark_last_solution_printed()) return;
switch (r->status()) {
case CpSolverStatus::OPTIMAL:
if (has_objective) {
std::cout << "s OPTIMUM FOUND " << std::endl;
} else {
std::cout << "s SATISFIABLE" << std::endl;
}
break;
case CpSolverStatus::FEASIBLE:
std::cout << "s SATISFIABLE" << std::endl;
break;
case CpSolverStatus::INFEASIBLE:
std::cout << "s UNSATISFIABLE" << std::endl;
break;
case CpSolverStatus::MODEL_INVALID:
std::cout << "s UNSUPPORTED" << std::endl;
break;
case CpSolverStatus::UNKNOWN:
std::cout << "s UNKNOWN" << std::endl;
break;
default:
break;
}
break;
case CpSolverStatus::FEASIBLE:
std::cout << "s SATISFIABLE" << std::endl;
break;
case CpSolverStatus::INFEASIBLE:
std::cout << "s UNSATISFIABLE" << std::endl;
break;
case CpSolverStatus::MODEL_INVALID:
std::cout << "s UNSUPPORTED" << std::endl;
break;
case CpSolverStatus::UNKNOWN:
std::cout << "s UNKNOWN" << std::endl;
break;
default:
break;
}
if (r->status() == CpSolverStatus::OPTIMAL ||
r->status() == CpSolverStatus::FEASIBLE) {
std::string line;
for (int i = 0; i < num_variables; ++i) {
if (r->solution(i)) {
absl::StrAppend(&line, "x", i + 1, " ");
} else {
absl::StrAppend(&line, "-x", i + 1, " ");
if (r->status() == CpSolverStatus::OPTIMAL ||
r->status() == CpSolverStatus::FEASIBLE) {
std::string line;
for (int i = 0; i < num_variables; ++i) {
if (r->solution(i)) {
absl::StrAppend(&line, "x", i + 1, " ");
} else {
absl::StrAppend(&line, "-x", i + 1, " ");
}
if (line.size() >= 75) {
std::cout << "v " << line << std::endl;
line.clear();
}
}
if (!line.empty()) {
std::cout << "v " << line << std::endl;
}
}
if (line.size() >= 75) {
std::cout << "v " << line << std::endl;
line.clear();
}
}
if (!line.empty()) {
std::cout << "v " << line << std::endl;
}
}
};
};
model->GetOrCreate<SharedResponseManager>()->AddFinalResponsePostprocessor(
final_response_callback);
}
@@ -186,7 +257,8 @@ void SetInterleavedWorkers(SatParameters* parameters) {
bool LoadProblem(const std::string& filename, absl::string_view hint_file,
absl::string_view domain_file, CpModelProto* cp_model,
Model* model, SatParameters* parameters) {
Model* model, SatParameters* parameters,
std::shared_ptr<LastSolutionPrinter> last_solution_printer) {
if (absl::EndsWith(filename, ".opb") ||
absl::EndsWith(filename, ".opb.bz2") ||
absl::EndsWith(filename, ".opb.gz") || absl::EndsWith(filename, ".wbo") ||
@@ -217,7 +289,7 @@ bool LoadProblem(const std::string& filename, absl::string_view hint_file,
const int num_variables =
reader.model_is_supported() ? reader.num_variables() : 1;
LogInPbCompetitionFormat(num_variables, cp_model->has_objective(), model,
parameters);
parameters, last_solution_printer);
}
if (absl::GetFlag(FLAGS_force_interleave_search)) {
SetInterleavedWorkers(parameters);
@@ -310,9 +382,13 @@ int Run() {
google::protobuf::Arena arena;
CpModelProto* cp_model =
google::protobuf::Arena::Create<CpModelProto>(&arena);
std::shared_ptr<LastSolutionPrinter> last_solution_printer;
if (absl::GetFlag(FLAGS_competition_mode)) {
last_solution_printer = std::make_shared<LastSolutionPrinter>();
}
if (!LoadProblem(absl::GetFlag(FLAGS_input), absl::GetFlag(FLAGS_hint_file),
absl::GetFlag(FLAGS_domain_file), cp_model, &model,
&parameters)) {
&parameters, last_solution_printer)) {
if (!absl::GetFlag(FLAGS_competition_mode)) {
LOG(FATAL) << "Cannot load file '" << absl::GetFlag(FLAGS_input) << "'.";
}
@@ -329,6 +405,14 @@ int Run() {
FingerprintRepeatedField(r.solution(), kDefaultFingerprintSeed));
}));
}
if (absl::GetFlag(FLAGS_competition_mode)) {
model.GetOrCreate<SigtermHandler>()->Register([last_solution_printer]() {
last_solution_printer->MaybePrintLastSolution();
exit(EXIT_SUCCESS);
});
}
const CpSolverResponse response = SolveCpModel(*cp_model, &model);
if (!absl::GetFlag(FLAGS_output).empty()) {

View File

@@ -1386,14 +1386,27 @@ int UniqueClauseStream::NumLiteralsOfSize(int size) const {
SharedClausesManager::SharedClausesManager(bool always_synchronize)
: always_synchronize_(always_synchronize) {}
int SharedClausesManager::RegisterNewId(bool may_terminate_early) {
int SharedClausesManager::RegisterNewId(absl::string_view worker_name,
bool may_terminate_early) {
absl::MutexLock mutex_lock(&mutex_);
num_full_workers_ += may_terminate_early ? 0 : 1;
const int id = id_to_last_processed_binary_clause_.size();
id_to_last_processed_binary_clause_.resize(id + 1, 0);
id_to_last_returned_batch_.resize(id + 1, -1);
id_to_last_finished_batch_.resize(id + 1, -1);
id_to_clauses_exported_.resize(id + 1, 0);
id_to_num_exported_.resize(id + 1, 0);
id_to_worker_name_.resize(id + 1);
id_to_worker_name_[id] = worker_name;
return id;
}
int SharedLinear2Bounds::RegisterNewId(std::string worker_name) {
absl::MutexLock mutex_lock(&mutex_);
const int id = id_to_worker_name_.size();
id_to_stats_.resize(id + 1);
id_to_worker_name_.resize(id + 1);
id_to_worker_name_[id] = worker_name;
return id;
}
@@ -1401,12 +1414,6 @@ bool SharedClausesManager::ShouldReadBatch(int reader_id, int writer_id) {
return reader_id != writer_id;
}
void SharedClausesManager::SetWorkerNameForId(int id,
absl::string_view worker_name) {
absl::MutexLock mutex_lock(&mutex_);
id_to_worker_name_[id] = worker_name;
}
void SharedClausesManager::AddBinaryClause(int id, int lit1, int lit2) {
if (lit2 < lit1) std::swap(lit1, lit2);
const auto p = std::make_pair(lit1, lit2);
@@ -1416,7 +1423,7 @@ void SharedClausesManager::AddBinaryClause(int id, int lit1, int lit2) {
if (inserted) {
added_binary_clauses_.push_back(p);
if (always_synchronize_) ++last_visible_binary_clause_;
id_to_clauses_exported_[id]++;
id_to_num_exported_[id]++;
// Small optim. If the worker is already up to date with clauses to import,
// we can mark this new clause as already seen.
@@ -1429,7 +1436,7 @@ void SharedClausesManager::AddBinaryClause(int id, int lit1, int lit2) {
void SharedClausesManager::AddBatch(int id, CompactVectorVector<int> batch) {
absl::MutexLock mutex_lock(&mutex_);
id_to_clauses_exported_[id] += batch.size();
id_to_num_exported_[id] += batch.size();
pending_batches_.push_back(std::move(batch));
}
@@ -1463,16 +1470,44 @@ void SharedClausesManager::GetUnseenBinaryClauses(
void SharedClausesManager::LogStatistics(SolverLogger* logger) {
absl::MutexLock mutex_lock(&mutex_);
absl::btree_map<std::string, int64_t> name_to_clauses;
for (int id = 0; id < id_to_clauses_exported_.size(); ++id) {
if (id_to_clauses_exported_[id] == 0) continue;
name_to_clauses[id_to_worker_name_[id]] = id_to_clauses_exported_[id];
absl::btree_map<std::string, int64_t> name_to_table_line;
for (int id = 0; id < id_to_num_exported_.size(); ++id) {
if (id_to_num_exported_[id] == 0) continue;
name_to_table_line[id_to_worker_name_[id]] = id_to_num_exported_[id];
}
if (!name_to_clauses.empty()) {
if (!name_to_table_line.empty()) {
std::vector<std::vector<std::string>> table;
table.push_back({"Clauses shared", "Num"});
for (const auto& entry : name_to_clauses) {
table.push_back({FormatName(entry.first), FormatCounter(entry.second)});
for (const auto& [name, count] : name_to_table_line) {
table.push_back({FormatName(name), FormatCounter(count)});
}
SOLVER_LOG(logger, FormatTable(table));
}
}
// TODO(user): Add some library to simplify this "transposition". Ideally we
// could merge small table with few columns. I am thinking list (row_name,
// col_name, count) + function that create table?
void SharedLinear2Bounds::LogStatistics(SolverLogger* logger) {
absl::MutexLock mutex_lock(&mutex_);
absl::btree_map<std::string, Stats> name_to_table_line;
for (int id = 0; id < id_to_stats_.size(); ++id) {
const Stats stats = id_to_stats_[id];
if (!stats.empty()) {
name_to_table_line[id_to_worker_name_[id]] = stats;
}
}
for (int import_id = 0; import_id < import_id_to_index_.size(); ++import_id) {
name_to_table_line[import_id_to_name_[import_id]].num_imported =
import_id_to_num_imported_[import_id];
}
if (!name_to_table_line.empty()) {
std::vector<std::vector<std::string>> table;
table.push_back({"Linear2 shared", "New", "Updated", "Imported"});
for (const auto& [name, stats] : name_to_table_line) {
table.push_back({FormatName(name), FormatCounter(stats.num_new),
FormatCounter(stats.num_update),
FormatCounter(stats.num_imported)});
}
SOLVER_LOG(logger, FormatTable(table));
}
@@ -1522,6 +1557,69 @@ void SharedClausesManager::Synchronize() {
}
}
void SharedLinear2Bounds::Add(int id, Key expr, IntegerValue lb,
IntegerValue ub) {
DCHECK(expr.IsCanonicalized());
absl::MutexLock mutex_lock(&mutex_);
auto [it, inserted] = shared_bounds_.insert({expr, {lb, ub}});
if (inserted) {
// It is new.
id_to_stats_[id].num_new++;
newly_updated_keys_.push_back(expr);
} else {
// Update the individual bounds if the new ones are better.
auto& bounds = it->second;
const bool update_lb = lb > bounds.first;
if (update_lb) bounds.first = lb;
const bool update_ub = ub < bounds.second;
if (update_ub) bounds.second = ub;
if (update_lb || update_ub) {
id_to_stats_[id].num_update++;
newly_updated_keys_.push_back(expr);
}
}
}
int SharedLinear2Bounds::RegisterNewImportId(std::string name) {
absl::MutexLock mutex_lock(&mutex_);
const int import_id = import_id_to_index_.size();
import_id_to_name_.push_back(name);
import_id_to_index_.push_back(0);
import_id_to_num_imported_.push_back(0);
return import_id;
}
std::vector<
std::pair<SharedLinear2Bounds::Key, std::pair<IntegerValue, IntegerValue>>>
SharedLinear2Bounds::NewlyUpdatedBounds(int import_id) {
std::vector<std::pair<Key, std::pair<IntegerValue, IntegerValue>>> result;
absl::MutexLock mutex_lock(&mutex_);
MaybeCompressNewlyUpdateKeys();
const int size = newly_updated_keys_.size();
for (int i = import_id_to_index_[import_id]; i < size; ++i) {
const auto& key = newly_updated_keys_[i];
result.push_back({key, shared_bounds_[key]});
}
import_id_to_index_[import_id] = size;
return result;
}
void SharedLinear2Bounds::MaybeCompressNewlyUpdateKeys() {
int min_index = 0;
for (const int index : import_id_to_index_) {
min_index = std::min(index, min_index);
}
if (min_index == 0) return;
newly_updated_keys_.erase(newly_updated_keys_.begin(),
newly_updated_keys_.begin() + min_index);
for (int& index_ref : import_id_to_index_) {
index_ref -= min_index;
}
}
void SharedStatistics::AddStats(
absl::Span<const std::pair<std::string, int64_t>> stats) {
absl::MutexLock mutex_lock(&mutex_);

View File

@@ -848,8 +848,7 @@ class SharedClausesManager {
std::vector<std::pair<int, int>>* new_clauses);
// Ids are used to identify which worker is exporting/importing clauses.
int RegisterNewId(bool may_terminate_early);
void SetWorkerNameForId(int id, absl::string_view worker_name);
int RegisterNewId(absl::string_view worker_name, bool may_terminate_early);
// Search statistics.
void LogStatistics(SolverLogger* logger);
@@ -893,8 +892,100 @@ class SharedClausesManager {
const bool always_synchronize_ = true;
// Stats:
std::vector<int64_t> id_to_clauses_exported_;
absl::flat_hash_map<int, std::string> id_to_worker_name_;
std::vector<int64_t> id_to_num_exported_ ABSL_GUARDED_BY(mutex_);
std::vector<int64_t> id_to_num_updated_ ABSL_GUARDED_BY(mutex_);
std::vector<std::string> id_to_worker_name_ ABSL_GUARDED_BY(mutex_);
};
// A class that allows to exchange root level bounds on linear2.
//
// TODO(user): Add Synchronize() support and only publish new bounds when this
// is called.
class SharedLinear2Bounds {
public:
int RegisterNewId(std::string worker_name);
void LogStatistics(SolverLogger* logger);
// This should only contain canonicalized expression.
// See the code for IsCanonicalized() for the definition.
struct Key {
int vars[2];
IntegerValue coeffs[2];
bool IsCanonicalized() {
return coeffs[0] > 0 && coeffs[1] != 0 && vars[0] < vars[1] &&
std::gcd(coeffs[0].value(), coeffs[1].value()) == 1;
}
bool operator==(const Key& o) const {
return vars[0] == o.vars[0] && vars[1] == o.vars[1] &&
coeffs[0] == o.coeffs[0] && coeffs[1] == o.coeffs[1];
}
template <typename H>
friend H AbslHashValue(H h, const Key& k) {
return H::combine(std::move(h), k.vars[0], k.vars[1], k.coeffs[0],
k.coeffs[1]);
}
};
// Exports new bounds on the given expr (should be canonicalized).
void Add(int id, Key expr, IntegerValue lb, IntegerValue ub);
// This is called less often, and maybe not every-worker that exports want to
// export, so we use a separate id space. Because we rely on hash map to
// check if a bound is new, it is not such a big deal that a worker re-read
// once the bounds it exported.
int RegisterNewImportId(std::string name);
// Returns the linear2 and their bounds.
// We only return changes since the last call with the same id.
std::vector<std::pair<Key, std::pair<IntegerValue, IntegerValue>>>
NewlyUpdatedBounds(int import_id);
// This is not filled by NewlyUpdatedBounds() because we want to track the
// bounds that were not already known by the worker at the time of the import,
// and we don't have this information here.
void NotifyNumImported(int import_id, int num) {
absl::MutexLock mutex_lock(&mutex_);
import_id_to_num_imported_[import_id] += num;
}
private:
void MaybeCompressNewlyUpdateKeys() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
absl::Mutex mutex_;
// The best known bounds for each key.
absl::flat_hash_map<Key, std::pair<IntegerValue, IntegerValue>> shared_bounds_
ABSL_GUARDED_BY(mutex_);
// Ever growing list of updated position in shared_bounds_.
// Note that we do reduce it in MaybeCompressNewlyUpdateKeys(), but that
// requires all registered workers to have at least imported some bounds.
//
// TODO(user): use indirect addressing so that newly_updated_keys_ can just
// deal with indices, and it is a bit tighter memory wise? We also avoid
// hash-lookups on NewlyUpdatedBounds(). But since this is only called at
// level zero on new bounds, I don't think we care.
std::vector<Key> newly_updated_keys_;
// For import.
std::vector<std::string> import_id_to_name_ ABSL_GUARDED_BY(mutex_);
std::vector<int> import_id_to_index_ ABSL_GUARDED_BY(mutex_);
std::vector<int> import_id_to_num_imported_ ABSL_GUARDED_BY(mutex_);
// Just for reporting at the end of the solve.
struct Stats {
int64_t num_new = 0;
int64_t num_update = 0;
int64_t num_imported = 0; // Copy of import_id_to_num_imported_.
bool empty() const {
return num_new == 0 && num_update == 0 && num_imported == 0;
}
};
std::vector<Stats> id_to_stats_ ABSL_GUARDED_BY(mutex_);
std::vector<std::string> id_to_worker_name_ ABSL_GUARDED_BY(mutex_);
};
// Simple class to add statistics by name and print them at the end.

View File

@@ -834,8 +834,8 @@ TEST(SharedResponseManagerTest, Callback) {
TEST(SharedClausesManagerTest, SyncApi) {
SharedClausesManager manager(/*always_synchronize=*/true);
EXPECT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false));
EXPECT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false));
EXPECT_EQ(0, manager.RegisterNewId("", /*may_terminate_early=*/false));
EXPECT_EQ(1, manager.RegisterNewId("", /*may_terminate_early=*/false));
manager.AddBinaryClause(/*id=*/0, 1, 2);
std::vector<std::pair<int, int>> new_clauses;
@@ -922,8 +922,8 @@ TEST(UniqueClauseStreamTest, DropsClauses) {
TEST(SharedClausesManagerTest, NonSyncApi) {
SharedClausesManager manager(/*always_synchronize=*/false);
EXPECT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false));
EXPECT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false));
EXPECT_EQ(0, manager.RegisterNewId("", /*may_terminate_early=*/false));
EXPECT_EQ(1, manager.RegisterNewId("", /*may_terminate_early=*/false));
manager.AddBinaryClause(/*id=*/0, 1, 2);
std::vector<std::pair<int, int>> new_clauses;
@@ -971,8 +971,8 @@ TEST(SharedClausesManagerTest, NonSyncApi) {
TEST(SharedClausesManagerTest, ShareGlueClauses) {
SharedClausesManager manager(/*always_synchronize=*/true);
ASSERT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false));
ASSERT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false));
ASSERT_EQ(0, manager.RegisterNewId("", /*may_terminate_early=*/false));
ASSERT_EQ(1, manager.RegisterNewId("", /*may_terminate_early=*/false));
UniqueClauseStream stream0;
UniqueClauseStream stream1;
// Add a bunch of clauses that will be skipped batch.
@@ -999,8 +999,8 @@ TEST(SharedClausesManagerTest, ShareGlueClauses) {
TEST(SharedClausesManagerTest, LbdThresholdIncrease) {
SharedClausesManager manager(/*always_synchronize=*/true);
ASSERT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false));
ASSERT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false));
ASSERT_EQ(0, manager.RegisterNewId("", /*may_terminate_early=*/false));
ASSERT_EQ(1, manager.RegisterNewId("", /*may_terminate_early=*/false));
UniqueClauseStream stream0;
UniqueClauseStream stream1;
const int kExpectedClauses = UniqueClauseStream::kMaxLiteralsPerBatch / 5;
@@ -1027,8 +1027,8 @@ TEST(SharedClausesManagerTest, LbdThresholdIncrease) {
TEST(SharedClausesManagerTest, LbdThresholdDecrease) {
SharedClausesManager manager(/*always_synchronize=*/true);
ASSERT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false));
ASSERT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false));
ASSERT_EQ(0, manager.RegisterNewId("", /*may_terminate_early=*/false));
ASSERT_EQ(1, manager.RegisterNewId("", /*may_terminate_early=*/false));
UniqueClauseStream stream0;
UniqueClauseStream stream1;

View File

@@ -23,29 +23,47 @@ namespace operations_research {
void SigintHandler::Register(const std::function<void()>& f) {
handler_ = [this, f]() -> void {
const int num_sigint_calls = ++num_sigint_calls_;
if (num_sigint_calls < 3) {
const int num_calls = ++num_calls_;
if (num_calls < 3) {
LOG(INFO)
<< "^C pressed " << num_sigint_calls << " times. "
<< "^C pressed " << num_calls << " times. "
<< "Interrupting the solver. Press 3 times to force termination.";
if (num_sigint_calls == 1) f();
} else if (num_sigint_calls == 3) {
if (num_calls == 1) f();
} else if (num_calls == 3) {
LOG(INFO) << "^C pressed 3 times. Forcing termination.";
exit(EXIT_FAILURE);
} else {
// Another thread is already running exit(), do nothing.
}
};
signal(SIGINT, &ControlCHandler);
signal(SIGINT, &SigHandler);
}
// This method will be called by the system after the SIGINT signal.
// The parameter is the signal received.
void SigintHandler::ControlCHandler(int sig) { handler_(); }
void SigintHandler::SigHandler(int) { handler_(); }
// Unregister the SIGINT handler.
SigintHandler::~SigintHandler() { signal(SIGINT, SIG_DFL); }
// Unregister the signal handlers.
SigintHandler::~SigintHandler() {
if (handler_ != nullptr) signal(SIGINT, SIG_DFL);
}
thread_local std::function<void()> SigintHandler::handler_;
void SigtermHandler::Register(const std::function<void()>& f) {
handler_ = [f]() -> void { f(); };
signal(SIGTERM, &SigHandler);
}
// This method will be called by the system after the SIGTERM signal.
// The parameter is the signal received.
void SigtermHandler::SigHandler(int) { handler_(); }
// Unregister the signal handlers.
SigtermHandler::~SigtermHandler() {
if (handler_ != nullptr) signal(SIGTERM, SIG_DFL);
}
thread_local std::function<void()> SigtermHandler::handler_;
} // namespace operations_research

View File

@@ -21,7 +21,7 @@ namespace operations_research {
class SigintHandler {
public:
SigintHandler() {}
SigintHandler() = default;
~SigintHandler();
// Catches ^C and call f() the first time this happen. If ^C is pressed 3
@@ -29,9 +29,23 @@ class SigintHandler {
void Register(const std::function<void()>& f);
private:
static void ControlCHandler(int s);
std::atomic<int> num_calls_ = 0;
std::atomic<int> num_sigint_calls_ = 0;
static void SigHandler(int s);
thread_local static std::function<void()> handler_;
};
class SigtermHandler {
public:
SigtermHandler() = default;
~SigtermHandler();
// Catches SIGTERM and call f(). It is recommended that f() calls exit() to
// terminate the program.
void Register(const std::function<void()>& f);
private:
static void SigHandler(int s);
thread_local static std::function<void()> handler_;
};

View File

@@ -724,7 +724,9 @@ class ClosedInterval::Iterator {
// arithmetic.
uint64_t current_;
};
#if __cplusplus >= 202002L
static_assert(std::input_iterator<ClosedInterval::Iterator>);
#endif
// begin()/end() are required for iteration over ClosedInterval in a range for
// loop.
inline ClosedInterval::Iterator begin(ClosedInterval interval) {