[CP-SAT] tweak and improve code
This commit is contained in:
@@ -983,7 +983,8 @@ int IntegerTrail::FindTrailIndexOfVarBefore(IntegerVariable var,
|
||||
int IntegerTrail::FindLowestTrailIndexThatExplainBound(
|
||||
IntegerLiteral i_lit) const {
|
||||
DCHECK_LE(i_lit.bound, var_lbs_[i_lit.var]);
|
||||
if (i_lit.bound <= LevelZeroLowerBound(i_lit.var)) return -1;
|
||||
DCHECK(!IsTrueAtLevelZero(i_lit));
|
||||
|
||||
int trail_index = var_trail_index_[i_lit.var];
|
||||
|
||||
// Check the validity of the cached index and use it if possible. This caching
|
||||
@@ -1003,6 +1004,7 @@ int IntegerTrail::FindLowestTrailIndexThatExplainBound(
|
||||
|
||||
int prev_trail_index = trail_index;
|
||||
while (true) {
|
||||
++work_done_in_explain_lower_than_;
|
||||
if (trail_index >= var_trail_index_cache_threshold_) {
|
||||
var_trail_index_cache_[i_lit.var] = trail_index;
|
||||
}
|
||||
@@ -1171,10 +1173,9 @@ std::vector<Literal>* IntegerTrail::InitializeConflict(
|
||||
lazy_reasons_.back().Explain(conflict, &tmp_queue_);
|
||||
} else {
|
||||
conflict->assign(literals_reason.begin(), literals_reason.end());
|
||||
const int num_vars = var_lbs_.size();
|
||||
for (const IntegerLiteral& literal : bounds_reason) {
|
||||
const int trail_index = FindLowestTrailIndexThatExplainBound(literal);
|
||||
if (trail_index >= num_vars) tmp_queue_.push_back(trail_index);
|
||||
if (IsTrueAtLevelZero(literal)) continue;
|
||||
tmp_queue_.push_back(FindLowestTrailIndexThatExplainBound(literal));
|
||||
}
|
||||
}
|
||||
return conflict;
|
||||
@@ -1553,9 +1554,8 @@ bool IntegerTrail::EnqueueInternal(
|
||||
// efficiency and a potential smaller reason.
|
||||
auto* conflict = InitializeConflict(i_lit, use_lazy_reason, literal_reason,
|
||||
integer_reason);
|
||||
{
|
||||
const int trail_index = FindLowestTrailIndexThatExplainBound(ub_reason);
|
||||
if (trail_index >= 0) tmp_queue_.push_back(trail_index);
|
||||
if (!IsTrueAtLevelZero(ub_reason)) {
|
||||
tmp_queue_.push_back(FindLowestTrailIndexThatExplainBound(ub_reason));
|
||||
}
|
||||
MergeReasonIntoInternal(conflict, NextConflictId());
|
||||
return false;
|
||||
@@ -1771,12 +1771,10 @@ absl::Span<const int> IntegerTrail::Dependencies(int reason_index) const {
|
||||
|
||||
int new_size = 0;
|
||||
int* data = trail_index_reason_buffer_.data() + start;
|
||||
const int num_vars = var_lbs_.size();
|
||||
for (int i = start; i < end; ++i) {
|
||||
const int dep =
|
||||
FindLowestTrailIndexThatExplainBound(bounds_reason_buffer_[i]);
|
||||
if (dep >= num_vars) {
|
||||
data[new_size++] = dep;
|
||||
const IntegerLiteral to_explain = bounds_reason_buffer_[i];
|
||||
if (!IsTrueAtLevelZero(to_explain)) {
|
||||
data[new_size++] = FindLowestTrailIndexThatExplainBound(to_explain);
|
||||
}
|
||||
}
|
||||
cached_sizes_[reason_index] = new_size;
|
||||
@@ -1818,14 +1816,10 @@ std::vector<Literal> IntegerTrail::ReasonFor(IntegerLiteral literal) const {
|
||||
void IntegerTrail::MergeReasonInto(absl::Span<const IntegerLiteral> literals,
|
||||
std::vector<Literal>* output) const {
|
||||
DCHECK(tmp_queue_.empty());
|
||||
const int num_vars = var_lbs_.size();
|
||||
for (const IntegerLiteral& literal : literals) {
|
||||
if (literal.IsAlwaysTrue()) continue;
|
||||
const int trail_index = FindLowestTrailIndexThatExplainBound(literal);
|
||||
|
||||
// Any indices lower than that means that there is no reason needed.
|
||||
// Note that it is important for size to be signed because of -1 indices.
|
||||
if (trail_index >= num_vars) tmp_queue_.push_back(trail_index);
|
||||
if (IsTrueAtLevelZero(literal)) continue;
|
||||
tmp_queue_.push_back(FindLowestTrailIndexThatExplainBound(literal));
|
||||
}
|
||||
return MergeReasonIntoInternal(output, -1);
|
||||
}
|
||||
|
||||
@@ -523,6 +523,7 @@ class IntegerTrail final : public SatPropagator {
|
||||
// Returns the current value (if known) of an IntegerLiteral.
|
||||
bool IntegerLiteralIsTrue(IntegerLiteral l) const;
|
||||
bool IntegerLiteralIsFalse(IntegerLiteral l) const;
|
||||
bool IsTrueAtLevelZero(IntegerLiteral l) const;
|
||||
|
||||
// Returns globally valid lower/upper bound on the given integer variable.
|
||||
IntegerValue LevelZeroLowerBound(IntegerVariable var) const;
|
||||
@@ -796,39 +797,38 @@ class IntegerTrail final : public SatPropagator {
|
||||
void AddAllGreaterThanConstantReason(absl::Span<AffineExpression> exprs,
|
||||
IntegerValue target_min,
|
||||
std::vector<int>* indices) const {
|
||||
int64_t num_processed = 0;
|
||||
constexpr int64_t check_period = 1e6;
|
||||
int64_t limit_check = work_done_in_explain_lower_than_ + check_period;
|
||||
for (const AffineExpression& expr : exprs) {
|
||||
if (expr.IsConstant()) {
|
||||
DCHECK_GE(expr.constant, target_min);
|
||||
continue;
|
||||
}
|
||||
DCHECK_NE(expr.var, kNoIntegerVariable);
|
||||
const IntegerLiteral to_explain = expr.GreaterOrEqual(target_min);
|
||||
if (IsTrueAtLevelZero(to_explain)) continue;
|
||||
|
||||
// On large routing problems, we can spend a lot of time in this loop.
|
||||
// We check the time limit every 5 processed expressions.
|
||||
if (++num_processed % 5 == 0 && time_limit_->LimitReached()) return;
|
||||
if (work_done_in_explain_lower_than_ > limit_check) {
|
||||
limit_check = work_done_in_explain_lower_than_ + check_period;
|
||||
if (time_limit_->LimitReached()) return;
|
||||
}
|
||||
|
||||
// Skip if we already have an explanation for expr >= target_min. Note
|
||||
// that we already do that while processing the returned indices, so this
|
||||
// mainly save a FindLowestTrailIndexThatExplainBound() call per skipped
|
||||
// indices, which can still be costly.
|
||||
{
|
||||
const int index = tmp_var_to_trail_index_in_queue_[expr.var];
|
||||
const int index = tmp_var_to_trail_index_in_queue_[to_explain.var];
|
||||
if (index == std::numeric_limits<int>::max()) continue;
|
||||
if (index > 0 &&
|
||||
expr.ValueAt(integer_trail_[index].bound) >= target_min) {
|
||||
if (index > 0 && integer_trail_[index].bound >= to_explain.bound) {
|
||||
has_dependency_ = true;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// We need to find the index that explain the bound.
|
||||
// Note that this will skip if the condition is true at level zero.
|
||||
const int index =
|
||||
FindLowestTrailIndexThatExplainBound(expr.GreaterOrEqual(target_min));
|
||||
if (index >= 0) {
|
||||
indices->push_back(index);
|
||||
}
|
||||
indices->push_back(FindLowestTrailIndexThatExplainBound(to_explain));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -885,8 +885,8 @@ class IntegerTrail final : public SatPropagator {
|
||||
int64_t conflict_id) const;
|
||||
|
||||
// Returns the lowest trail index of a TrailEntry that can be used to explain
|
||||
// the given IntegerLiteral. The literal must be currently true (CHECKed).
|
||||
// Returns -1 if the explanation is trivial.
|
||||
// the given IntegerLiteral. The literal must be currently true but not true
|
||||
// at level zero (DCHECKed).
|
||||
int FindLowestTrailIndexThatExplainBound(IntegerLiteral i_lit) const;
|
||||
|
||||
// This must be called before Dependencies() or AppendLiteralsReason().
|
||||
@@ -1033,6 +1033,8 @@ class IntegerTrail final : public SatPropagator {
|
||||
std::vector<SparseBitset<IntegerVariable>*> watchers_;
|
||||
std::vector<ReversibleInterface*> reversible_classes_;
|
||||
|
||||
mutable int64_t work_done_in_explain_lower_than_ = 0;
|
||||
|
||||
mutable Domain temp_domain_;
|
||||
DelayedRootLevelDeduction* delayed_to_fix_;
|
||||
IntegerDomains* domains_;
|
||||
@@ -1417,6 +1419,10 @@ inline bool IntegerTrail::IntegerLiteralIsFalse(IntegerLiteral l) const {
|
||||
return l.bound > UpperBound(l.var);
|
||||
}
|
||||
|
||||
inline bool IntegerTrail::IsTrueAtLevelZero(IntegerLiteral l) const {
|
||||
return l.bound <= LevelZeroLowerBound(l.var);
|
||||
}
|
||||
|
||||
// The level zero bounds are stored at the beginning of the trail and they also
|
||||
// serves as sentinels. Their index match the variables index.
|
||||
inline IntegerValue IntegerTrail::LevelZeroLowerBound(
|
||||
|
||||
@@ -214,26 +214,6 @@ IntegerValue BestBinaryRelationBounds::GetUpperBound(
|
||||
return kMaxIntegerValue;
|
||||
}
|
||||
|
||||
// TODO(user): Maybe introduce a CanonicalizedLinear2 class so we automatically
|
||||
// get the better function, and it documents when we have canonicalized
|
||||
// expression.
|
||||
IntegerValue BestBinaryRelationBounds::UpperBoundWhenCanonicalized(
|
||||
LinearExpression2 expr) const {
|
||||
DCHECK_EQ(expr.DivideByGcd(), 1);
|
||||
DCHECK(expr.IsCanonicalized());
|
||||
const bool negated = expr.NegateForCanonicalization();
|
||||
const auto it = best_bounds_.find(expr);
|
||||
if (it != best_bounds_.end()) {
|
||||
const auto [known_lb, known_ub] = it->second;
|
||||
if (negated) {
|
||||
return -known_lb;
|
||||
} else {
|
||||
return known_ub;
|
||||
}
|
||||
}
|
||||
return kMaxIntegerValue;
|
||||
}
|
||||
|
||||
std::vector<std::pair<LinearExpression2, IntegerValue>>
|
||||
BestBinaryRelationBounds::GetSortedNonTrivialUpperBounds() const {
|
||||
std::vector<std::pair<LinearExpression2, IntegerValue>> root_relations_sorted;
|
||||
|
||||
@@ -559,6 +559,28 @@ std::ostream& operator<<(std::ostream& os, const ValueLiteralPair& p);
|
||||
DEFINE_STRONG_INDEX_TYPE(IntervalVariable);
|
||||
const IntervalVariable kNoIntervalVariable(-1);
|
||||
|
||||
// This functions appears in hot spot, and so it is important to inline it.
|
||||
//
|
||||
// TODO(user): Maybe introduce a CanonicalizedLinear2 class so we automatically
|
||||
// get the better function, and it documents when we have canonicalized
|
||||
// expression.
|
||||
inline IntegerValue BestBinaryRelationBounds::UpperBoundWhenCanonicalized(
|
||||
LinearExpression2 expr) const {
|
||||
DCHECK_EQ(expr.DivideByGcd(), 1);
|
||||
DCHECK(expr.IsCanonicalized());
|
||||
const bool negated = expr.NegateForCanonicalization();
|
||||
const auto it = best_bounds_.find(expr);
|
||||
if (it != best_bounds_.end()) {
|
||||
const auto [known_lb, known_ub] = it->second;
|
||||
if (negated) {
|
||||
return -known_lb;
|
||||
} else {
|
||||
return known_ub;
|
||||
}
|
||||
}
|
||||
return kMaxIntegerValue;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Implementation.
|
||||
// ============================================================================
|
||||
@@ -599,8 +621,8 @@ inline IntegerLiteral AffineExpression::GreaterOrEqual(
|
||||
: IntegerLiteral::FalseLiteral();
|
||||
}
|
||||
DCHECK_GT(coeff, 0);
|
||||
return IntegerLiteral::GreaterOrEqual(var,
|
||||
CeilRatio(bound - constant, coeff));
|
||||
return IntegerLiteral::GreaterOrEqual(
|
||||
var, coeff == 1 ? bound - constant : CeilRatio(bound - constant, coeff));
|
||||
}
|
||||
|
||||
// var * coeff + constant <= bound.
|
||||
@@ -610,7 +632,8 @@ inline IntegerLiteral AffineExpression::LowerOrEqual(IntegerValue bound) const {
|
||||
: IntegerLiteral::FalseLiteral();
|
||||
}
|
||||
DCHECK_GT(coeff, 0);
|
||||
return IntegerLiteral::LowerOrEqual(var, FloorRatio(bound - constant, coeff));
|
||||
return IntegerLiteral::LowerOrEqual(
|
||||
var, coeff == 1 ? bound - constant : FloorRatio(bound - constant, coeff));
|
||||
}
|
||||
|
||||
} // namespace sat
|
||||
|
||||
@@ -1943,8 +1943,7 @@ IntegerValue Linear2Bounds::NonTrivialUpperBoundForGcd1(
|
||||
}
|
||||
DCHECK_NE(expr.coeffs[1], 0);
|
||||
DCHECK_EQ(1, expr.DivideByGcd());
|
||||
IntegerValue ub = kMaxIntegerValue;
|
||||
ub = std::min(ub, root_level_bounds_->GetUpperBoundNoTrail(expr));
|
||||
IntegerValue ub = root_level_bounds_->GetUpperBoundNoTrail(expr);
|
||||
ub = std::min(ub, enforced_bounds_->GetUpperBoundFromEnforced(expr));
|
||||
ub = std::min(ub, linear3_bounds_->GetUpperBoundFromLinear3(expr));
|
||||
return ub;
|
||||
|
||||
Reference in New Issue
Block a user