diff --git a/ortools/sat/integer.cc b/ortools/sat/integer.cc index 3adcb4d8ba..5ecc0455a8 100644 --- a/ortools/sat/integer.cc +++ b/ortools/sat/integer.cc @@ -983,7 +983,8 @@ int IntegerTrail::FindTrailIndexOfVarBefore(IntegerVariable var, int IntegerTrail::FindLowestTrailIndexThatExplainBound( IntegerLiteral i_lit) const { DCHECK_LE(i_lit.bound, var_lbs_[i_lit.var]); - if (i_lit.bound <= LevelZeroLowerBound(i_lit.var)) return -1; + DCHECK(!IsTrueAtLevelZero(i_lit)); + int trail_index = var_trail_index_[i_lit.var]; // Check the validity of the cached index and use it if possible. This caching @@ -1003,6 +1004,7 @@ int IntegerTrail::FindLowestTrailIndexThatExplainBound( int prev_trail_index = trail_index; while (true) { + ++work_done_in_explain_lower_than_; if (trail_index >= var_trail_index_cache_threshold_) { var_trail_index_cache_[i_lit.var] = trail_index; } @@ -1171,10 +1173,9 @@ std::vector* IntegerTrail::InitializeConflict( lazy_reasons_.back().Explain(conflict, &tmp_queue_); } else { conflict->assign(literals_reason.begin(), literals_reason.end()); - const int num_vars = var_lbs_.size(); for (const IntegerLiteral& literal : bounds_reason) { - const int trail_index = FindLowestTrailIndexThatExplainBound(literal); - if (trail_index >= num_vars) tmp_queue_.push_back(trail_index); + if (IsTrueAtLevelZero(literal)) continue; + tmp_queue_.push_back(FindLowestTrailIndexThatExplainBound(literal)); } } return conflict; @@ -1553,9 +1554,8 @@ bool IntegerTrail::EnqueueInternal( // efficiency and a potential smaller reason. auto* conflict = InitializeConflict(i_lit, use_lazy_reason, literal_reason, integer_reason); - { - const int trail_index = FindLowestTrailIndexThatExplainBound(ub_reason); - if (trail_index >= 0) tmp_queue_.push_back(trail_index); + if (!IsTrueAtLevelZero(ub_reason)) { + tmp_queue_.push_back(FindLowestTrailIndexThatExplainBound(ub_reason)); } MergeReasonIntoInternal(conflict, NextConflictId()); return false; @@ -1771,12 +1771,10 @@ absl::Span IntegerTrail::Dependencies(int reason_index) const { int new_size = 0; int* data = trail_index_reason_buffer_.data() + start; - const int num_vars = var_lbs_.size(); for (int i = start; i < end; ++i) { - const int dep = - FindLowestTrailIndexThatExplainBound(bounds_reason_buffer_[i]); - if (dep >= num_vars) { - data[new_size++] = dep; + const IntegerLiteral to_explain = bounds_reason_buffer_[i]; + if (!IsTrueAtLevelZero(to_explain)) { + data[new_size++] = FindLowestTrailIndexThatExplainBound(to_explain); } } cached_sizes_[reason_index] = new_size; @@ -1818,14 +1816,10 @@ std::vector IntegerTrail::ReasonFor(IntegerLiteral literal) const { void IntegerTrail::MergeReasonInto(absl::Span literals, std::vector* output) const { DCHECK(tmp_queue_.empty()); - const int num_vars = var_lbs_.size(); for (const IntegerLiteral& literal : literals) { if (literal.IsAlwaysTrue()) continue; - const int trail_index = FindLowestTrailIndexThatExplainBound(literal); - - // Any indices lower than that means that there is no reason needed. - // Note that it is important for size to be signed because of -1 indices. - if (trail_index >= num_vars) tmp_queue_.push_back(trail_index); + if (IsTrueAtLevelZero(literal)) continue; + tmp_queue_.push_back(FindLowestTrailIndexThatExplainBound(literal)); } return MergeReasonIntoInternal(output, -1); } diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index 9802f74a75..14e485fdad 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -523,6 +523,7 @@ class IntegerTrail final : public SatPropagator { // Returns the current value (if known) of an IntegerLiteral. bool IntegerLiteralIsTrue(IntegerLiteral l) const; bool IntegerLiteralIsFalse(IntegerLiteral l) const; + bool IsTrueAtLevelZero(IntegerLiteral l) const; // Returns globally valid lower/upper bound on the given integer variable. IntegerValue LevelZeroLowerBound(IntegerVariable var) const; @@ -796,39 +797,38 @@ class IntegerTrail final : public SatPropagator { void AddAllGreaterThanConstantReason(absl::Span exprs, IntegerValue target_min, std::vector* indices) const { - int64_t num_processed = 0; + constexpr int64_t check_period = 1e6; + int64_t limit_check = work_done_in_explain_lower_than_ + check_period; for (const AffineExpression& expr : exprs) { if (expr.IsConstant()) { DCHECK_GE(expr.constant, target_min); continue; } DCHECK_NE(expr.var, kNoIntegerVariable); + const IntegerLiteral to_explain = expr.GreaterOrEqual(target_min); + if (IsTrueAtLevelZero(to_explain)) continue; // On large routing problems, we can spend a lot of time in this loop. - // We check the time limit every 5 processed expressions. - if (++num_processed % 5 == 0 && time_limit_->LimitReached()) return; + if (work_done_in_explain_lower_than_ > limit_check) { + limit_check = work_done_in_explain_lower_than_ + check_period; + if (time_limit_->LimitReached()) return; + } // Skip if we already have an explanation for expr >= target_min. Note // that we already do that while processing the returned indices, so this // mainly save a FindLowestTrailIndexThatExplainBound() call per skipped // indices, which can still be costly. { - const int index = tmp_var_to_trail_index_in_queue_[expr.var]; + const int index = tmp_var_to_trail_index_in_queue_[to_explain.var]; if (index == std::numeric_limits::max()) continue; - if (index > 0 && - expr.ValueAt(integer_trail_[index].bound) >= target_min) { + if (index > 0 && integer_trail_[index].bound >= to_explain.bound) { has_dependency_ = true; continue; } } // We need to find the index that explain the bound. - // Note that this will skip if the condition is true at level zero. - const int index = - FindLowestTrailIndexThatExplainBound(expr.GreaterOrEqual(target_min)); - if (index >= 0) { - indices->push_back(index); - } + indices->push_back(FindLowestTrailIndexThatExplainBound(to_explain)); } } @@ -885,8 +885,8 @@ class IntegerTrail final : public SatPropagator { int64_t conflict_id) const; // Returns the lowest trail index of a TrailEntry that can be used to explain - // the given IntegerLiteral. The literal must be currently true (CHECKed). - // Returns -1 if the explanation is trivial. + // the given IntegerLiteral. The literal must be currently true but not true + // at level zero (DCHECKed). int FindLowestTrailIndexThatExplainBound(IntegerLiteral i_lit) const; // This must be called before Dependencies() or AppendLiteralsReason(). @@ -1033,6 +1033,8 @@ class IntegerTrail final : public SatPropagator { std::vector*> watchers_; std::vector reversible_classes_; + mutable int64_t work_done_in_explain_lower_than_ = 0; + mutable Domain temp_domain_; DelayedRootLevelDeduction* delayed_to_fix_; IntegerDomains* domains_; @@ -1417,6 +1419,10 @@ inline bool IntegerTrail::IntegerLiteralIsFalse(IntegerLiteral l) const { return l.bound > UpperBound(l.var); } +inline bool IntegerTrail::IsTrueAtLevelZero(IntegerLiteral l) const { + return l.bound <= LevelZeroLowerBound(l.var); +} + // The level zero bounds are stored at the beginning of the trail and they also // serves as sentinels. Their index match the variables index. inline IntegerValue IntegerTrail::LevelZeroLowerBound( diff --git a/ortools/sat/integer_base.cc b/ortools/sat/integer_base.cc index d514001c31..f39463353f 100644 --- a/ortools/sat/integer_base.cc +++ b/ortools/sat/integer_base.cc @@ -214,26 +214,6 @@ IntegerValue BestBinaryRelationBounds::GetUpperBound( return kMaxIntegerValue; } -// TODO(user): Maybe introduce a CanonicalizedLinear2 class so we automatically -// get the better function, and it documents when we have canonicalized -// expression. -IntegerValue BestBinaryRelationBounds::UpperBoundWhenCanonicalized( - LinearExpression2 expr) const { - DCHECK_EQ(expr.DivideByGcd(), 1); - DCHECK(expr.IsCanonicalized()); - const bool negated = expr.NegateForCanonicalization(); - const auto it = best_bounds_.find(expr); - if (it != best_bounds_.end()) { - const auto [known_lb, known_ub] = it->second; - if (negated) { - return -known_lb; - } else { - return known_ub; - } - } - return kMaxIntegerValue; -} - std::vector> BestBinaryRelationBounds::GetSortedNonTrivialUpperBounds() const { std::vector> root_relations_sorted; diff --git a/ortools/sat/integer_base.h b/ortools/sat/integer_base.h index 572f62a906..ad4331e5a5 100644 --- a/ortools/sat/integer_base.h +++ b/ortools/sat/integer_base.h @@ -559,6 +559,28 @@ std::ostream& operator<<(std::ostream& os, const ValueLiteralPair& p); DEFINE_STRONG_INDEX_TYPE(IntervalVariable); const IntervalVariable kNoIntervalVariable(-1); +// This functions appears in hot spot, and so it is important to inline it. +// +// TODO(user): Maybe introduce a CanonicalizedLinear2 class so we automatically +// get the better function, and it documents when we have canonicalized +// expression. +inline IntegerValue BestBinaryRelationBounds::UpperBoundWhenCanonicalized( + LinearExpression2 expr) const { + DCHECK_EQ(expr.DivideByGcd(), 1); + DCHECK(expr.IsCanonicalized()); + const bool negated = expr.NegateForCanonicalization(); + const auto it = best_bounds_.find(expr); + if (it != best_bounds_.end()) { + const auto [known_lb, known_ub] = it->second; + if (negated) { + return -known_lb; + } else { + return known_ub; + } + } + return kMaxIntegerValue; +} + // ============================================================================ // Implementation. // ============================================================================ @@ -599,8 +621,8 @@ inline IntegerLiteral AffineExpression::GreaterOrEqual( : IntegerLiteral::FalseLiteral(); } DCHECK_GT(coeff, 0); - return IntegerLiteral::GreaterOrEqual(var, - CeilRatio(bound - constant, coeff)); + return IntegerLiteral::GreaterOrEqual( + var, coeff == 1 ? bound - constant : CeilRatio(bound - constant, coeff)); } // var * coeff + constant <= bound. @@ -610,7 +632,8 @@ inline IntegerLiteral AffineExpression::LowerOrEqual(IntegerValue bound) const { : IntegerLiteral::FalseLiteral(); } DCHECK_GT(coeff, 0); - return IntegerLiteral::LowerOrEqual(var, FloorRatio(bound - constant, coeff)); + return IntegerLiteral::LowerOrEqual( + var, coeff == 1 ? bound - constant : FloorRatio(bound - constant, coeff)); } } // namespace sat diff --git a/ortools/sat/precedences.cc b/ortools/sat/precedences.cc index 2e2a5bcf47..b82d97b8fb 100644 --- a/ortools/sat/precedences.cc +++ b/ortools/sat/precedences.cc @@ -1943,8 +1943,7 @@ IntegerValue Linear2Bounds::NonTrivialUpperBoundForGcd1( } DCHECK_NE(expr.coeffs[1], 0); DCHECK_EQ(1, expr.DivideByGcd()); - IntegerValue ub = kMaxIntegerValue; - ub = std::min(ub, root_level_bounds_->GetUpperBoundNoTrail(expr)); + IntegerValue ub = root_level_bounds_->GetUpperBoundNoTrail(expr); ub = std::min(ub, enforced_bounds_->GetUpperBoundFromEnforced(expr)); ub = std::min(ub, linear3_bounds_->GetUpperBoundFromLinear3(expr)); return ub;