diff --git a/ortools/sat/cp_model_loader.cc b/ortools/sat/cp_model_loader.cc index afeb1caa85..9466d5aa80 100644 --- a/ortools/sat/cp_model_loader.cc +++ b/ortools/sat/cp_model_loader.cc @@ -1508,8 +1508,8 @@ void LoadIntProdConstraint(const ConstraintProto& ct, Model* m) { case 0: { auto* integer_trail = m->GetOrCreate(); auto* sat_solver = m->GetOrCreate(); - if (!integer_trail->Enqueue(prod.LowerOrEqual(1), {}) || - !integer_trail->Enqueue(prod.GreaterOrEqual(1), {})) { + if (!integer_trail->Enqueue(prod.LowerOrEqual(1)) || + !integer_trail->Enqueue(prod.GreaterOrEqual(1))) { sat_solver->NotifyThatModelIsUnsat(); } break; diff --git a/ortools/sat/disjunctive.cc b/ortools/sat/disjunctive.cc index 60fe83d05c..16635f0d7c 100644 --- a/ortools/sat/disjunctive.cc +++ b/ortools/sat/disjunctive.cc @@ -181,7 +181,10 @@ void TaskSet::NotifyEntryIsNowLastIfPresent(const Entry& e) { for (int i = 0;; ++i) { if (i == size) return; if (sorted_tasks_[i].task == e.task) { - sorted_tasks_.erase(sorted_tasks_.begin() + i); + for (int j = i; j + 1 < size; ++j) { + sorted_tasks_[j] = sorted_tasks_[j + 1]; + } + sorted_tasks_.pop_back(); break; } } @@ -413,7 +416,7 @@ bool CombinedDisjunctive::Propagate() { // TODO(user): Maybe factor out the code? It does require a function with a // lot of arguments though. helper_->ClearReason(); - const std::vector& sorted_tasks = + const absl::Span sorted_tasks = task_sets_[best_d_index].SortedTasks(); const IntegerValue window_start = sorted_tasks[best_critical_index].start_min; @@ -652,49 +655,50 @@ bool DisjunctiveDetectablePrecedences::Propagate() { // start_max >= end_min, so wouldn't be in detectable precedence. task_by_increasing_end_min_.clear(); IntegerValue window_end = kMinIntegerValue; + IntegerValue max_end_min = kMinIntegerValue; for (const TaskTime task_time : helper_->TaskByIncreasingStartMin()) { const int task = task_time.task_index; if (helper_->IsAbsent(task)) continue; // Note that the helper returns value assuming the task is present. const IntegerValue start_min = helper_->StartMin(task); - const IntegerValue size_min = helper_->SizeMin(task); const IntegerValue end_min = helper_->EndMin(task); - DCHECK_GE(end_min, start_min + size_min); if (start_min < window_end) { + const IntegerValue size_min = helper_->SizeMin(task); + DCHECK_GE(end_min, start_min + size_min); + task_by_increasing_end_min_.push_back({task, end_min}); + max_end_min = std::max(max_end_min, end_min); window_end = std::max(window_end, start_min) + size_min; continue; } // Process current window. - if (task_by_increasing_end_min_.size() > 1 && !PropagateSubwindow()) { + if (task_by_increasing_end_min_.size() > 1 && + !PropagateSubwindow(max_end_min)) { return false; } // Start of the next window. task_by_increasing_end_min_.clear(); task_by_increasing_end_min_.push_back({task, end_min}); + max_end_min = end_min; window_end = end_min; } - if (task_by_increasing_end_min_.size() > 1 && !PropagateSubwindow()) { + if (task_by_increasing_end_min_.size() > 1 && + !PropagateSubwindow(max_end_min)) { return false; } return true; } -bool DisjunctiveDetectablePrecedences::PropagateSubwindow() { +bool DisjunctiveDetectablePrecedences::PropagateSubwindow( + const IntegerValue max_end_min) { DCHECK(!task_by_increasing_end_min_.empty()); - // The vector is already sorted by shifted_start_min, so there is likely a - // good correlation, hence the incremental sort. - IncrementalSort(task_by_increasing_end_min_.begin(), - task_by_increasing_end_min_.end()); - const IntegerValue max_end_min = task_by_increasing_end_min_.back().time; - // Fill and sort task_by_increasing_start_max_. // // TODO(user): we should use start max if present, but more generally, all @@ -708,9 +712,16 @@ bool DisjunctiveDetectablePrecedences::PropagateSubwindow() { } } if (task_by_increasing_start_max_.empty()) return true; + std::sort(task_by_increasing_start_max_.begin(), task_by_increasing_start_max_.end()); + // The vector is already sorted by shifted_start_min, so there is likely a + // good correlation, hence the incremental sort. + IncrementalSort(task_by_increasing_end_min_.begin(), + task_by_increasing_end_min_.end()); + DCHECK_EQ(max_end_min, task_by_increasing_end_min_.back().time); + // Invariant: need_update is false implies that task_set_end_min is equal to // task_set_.ComputeEndMin(). // @@ -802,7 +813,7 @@ bool DisjunctiveDetectablePrecedences::PropagateSubwindow() { // Note that this works as well when IsPresent(t) is false. if (task_set_end_min > helper_->StartMin(t)) { const int critical_index = task_set_.GetCriticalIndex(); - const std::vector& sorted_tasks = + const absl::Span sorted_tasks = task_set_.SortedTasks(); helper_->ClearReason(); @@ -1251,7 +1262,8 @@ bool DisjunctiveNotLast::PropagateSubwindow() { // Find the largest start-max of the critical tasks (excluding t). The // end-max for t need to be smaller than or equal to this. IntegerValue largest_ct_start_max = kMinIntegerValue; - const std::vector& sorted_tasks = task_set_.SortedTasks(); + const absl::Span sorted_tasks = + task_set_.SortedTasks(); const int sorted_tasks_size = sorted_tasks.size(); for (int i = critical_index; i < sorted_tasks_size; ++i) { const int ct = sorted_tasks[i].task; diff --git a/ortools/sat/disjunctive.h b/ortools/sat/disjunctive.h index 6b3f50e011..d61a158820 100644 --- a/ortools/sat/disjunctive.h +++ b/ortools/sat/disjunctive.h @@ -53,7 +53,7 @@ void AddDisjunctiveWithBooleanPrecedencesOnly( // for most of the function here, not a O(log(n)) one. class TaskSet { public: - explicit TaskSet(int num_tasks) { sorted_tasks_.reserve(num_tasks); } + explicit TaskSet(int num_tasks) { sorted_tasks_.ClearAndReserve(num_tasks); } struct Entry { int task; @@ -113,10 +113,10 @@ class TaskSet { // another unneeded loop. int GetCriticalIndex() const { return optimized_restart_; } - const std::vector& SortedTasks() const { return sorted_tasks_; } + absl::Span SortedTasks() const { return sorted_tasks_; } private: - std::vector sorted_tasks_; + FixedCapacityVector sorted_tasks_; mutable int optimized_restart_ = 0; }; @@ -160,18 +160,22 @@ class DisjunctiveDetectablePrecedences : public PropagatorInterface { SchedulingConstraintHelper* helper) : time_direction_(time_direction), helper_(helper), - task_set_(helper->NumTasks()) {} + task_set_(helper->NumTasks()) { + task_by_increasing_end_min_.ClearAndReserve(helper->NumTasks()); + task_by_increasing_start_max_.ClearAndReserve(helper->NumTasks()); + to_propagate_.ClearAndReserve(helper->NumTasks()); + } bool Propagate() final; int RegisterWith(GenericLiteralWatcher* watcher); private: - bool PropagateSubwindow(); + bool PropagateSubwindow(IntegerValue max_end_min); - std::vector task_by_increasing_end_min_; - std::vector task_by_increasing_start_max_; + FixedCapacityVector task_by_increasing_end_min_; + FixedCapacityVector task_by_increasing_start_max_; std::vector processed_; - std::vector to_propagate_; + FixedCapacityVector to_propagate_; const bool time_direction_; SchedulingConstraintHelper* helper_; @@ -214,15 +218,18 @@ class DisjunctiveNotLast : public PropagatorInterface { DisjunctiveNotLast(bool time_direction, SchedulingConstraintHelper* helper) : time_direction_(time_direction), helper_(helper), - task_set_(helper->NumTasks()) {} + task_set_(helper->NumTasks()) { + start_min_window_.ClearAndReserve(helper->NumTasks()); + start_max_window_.ClearAndReserve(helper->NumTasks()); + } bool Propagate() final; int RegisterWith(GenericLiteralWatcher* watcher); private: bool PropagateSubwindow(); - std::vector start_min_window_; - std::vector start_max_window_; + FixedCapacityVector start_min_window_; + FixedCapacityVector start_max_window_; const bool time_direction_; SchedulingConstraintHelper* helper_; @@ -233,7 +240,11 @@ class DisjunctiveEdgeFinding : public PropagatorInterface { public: DisjunctiveEdgeFinding(bool time_direction, SchedulingConstraintHelper* helper) - : time_direction_(time_direction), helper_(helper) {} + : time_direction_(time_direction), helper_(helper) { + task_by_increasing_end_max_.ClearAndReserve(helper->NumTasks()); + window_.ClearAndReserve(helper->NumTasks()); + event_size_.ClearAndReserve(helper->NumTasks()); + } bool Propagate() final; int RegisterWith(GenericLiteralWatcher* watcher); @@ -244,12 +255,12 @@ class DisjunctiveEdgeFinding : public PropagatorInterface { SchedulingConstraintHelper* helper_; // This only contains non-gray tasks. - std::vector task_by_increasing_end_max_; + FixedCapacityVector task_by_increasing_end_max_; // All these member are indexed in the same way. - std::vector window_; + FixedCapacityVector window_; ThetaLambdaTree theta_tree_; - std::vector event_size_; + FixedCapacityVector event_size_; // Task indexed. std::vector non_gray_task_to_event_; @@ -267,7 +278,10 @@ class DisjunctivePrecedences : public PropagatorInterface { helper_(helper), integer_trail_(model->GetOrCreate()), precedence_relations_(model->GetOrCreate()), - shared_stats_(model->GetOrCreate()) {} + shared_stats_(model->GetOrCreate()) { + window_.ClearAndReserve(helper->NumTasks()); + index_to_end_vars_.ClearAndReserve(helper->NumTasks()); + } ~DisjunctivePrecedences() override; bool Propagate() final; @@ -284,8 +298,8 @@ class DisjunctivePrecedences : public PropagatorInterface { int64_t num_propagations_ = 0; - std::vector window_; - std::vector index_to_end_vars_; + FixedCapacityVector window_; + FixedCapacityVector index_to_end_vars_; std::vector indices_before_; std::vector skip_; diff --git a/ortools/sat/integer.cc b/ortools/sat/integer.cc index b33c75974a..534cdf1878 100644 --- a/ortools/sat/integer.cc +++ b/ortools/sat/integer.cc @@ -1149,22 +1149,24 @@ void IntegerTrail::RemoveLevelZeroBounds( } std::vector* IntegerTrail::InitializeConflict( - IntegerLiteral integer_literal, const LazyReasonFunction& lazy_reason, + IntegerLiteral integer_literal, bool use_lazy_reason, absl::Span literals_reason, absl::Span bounds_reason) { DCHECK(tmp_queue_.empty()); std::vector* conflict = trail_->MutableConflict(); - if (lazy_reason == nullptr) { + if (use_lazy_reason) { + // We use the current trail index here. + conflict->clear(); + const int trail_index = integer_trail_.size(); + lazy_reasons_[trail_index].Explain(integer_literal, trail_index, conflict, + &tmp_queue_); + } else { conflict->assign(literals_reason.begin(), literals_reason.end()); const int num_vars = var_lbs_.size(); for (const IntegerLiteral& literal : bounds_reason) { const int trail_index = FindLowestTrailIndexThatExplainBound(literal); if (trail_index >= num_vars) tmp_queue_.push_back(trail_index); } - } else { - // We use the current trail index here. - conflict->clear(); - lazy_reason(integer_literal, integer_trail_.size(), conflict, &tmp_queue_); } return conflict; } @@ -1248,13 +1250,6 @@ bool IntegerTrail::SafeEnqueue( return Enqueue(i_lit, {}, tmp_cleaned_reason_); } -bool IntegerTrail::Enqueue(IntegerLiteral i_lit, - absl::Span literal_reason, - absl::Span integer_reason) { - return EnqueueInternal(i_lit, nullptr, literal_reason, integer_reason, - integer_trail_.size()); -} - bool IntegerTrail::ConditionalEnqueue( Literal lit, IntegerLiteral i_lit, std::vector* literal_reason, std::vector* integer_reason) { @@ -1292,19 +1287,6 @@ bool IntegerTrail::ConditionalEnqueue( return true; } -bool IntegerTrail::Enqueue(IntegerLiteral i_lit, - absl::Span literal_reason, - absl::Span integer_reason, - int trail_index_with_same_reason) { - return EnqueueInternal(i_lit, nullptr, literal_reason, integer_reason, - trail_index_with_same_reason); -} - -bool IntegerTrail::Enqueue(IntegerLiteral i_lit, - LazyReasonFunction lazy_reason) { - return EnqueueInternal(i_lit, lazy_reason, {}, {}, integer_trail_.size()); -} - bool IntegerTrail::ReasonIsValid( absl::Span literal_reason, absl::Span integer_reason) { @@ -1394,15 +1376,15 @@ bool IntegerTrail::ReasonIsValid( void IntegerTrail::EnqueueLiteral( Literal literal, absl::Span literal_reason, absl::Span integer_reason) { - EnqueueLiteralInternal(literal, nullptr, literal_reason, integer_reason); + EnqueueLiteralInternal(literal, false, literal_reason, integer_reason); } void IntegerTrail::EnqueueLiteralInternal( - Literal literal, LazyReasonFunction lazy_reason, + Literal literal, bool use_lazy_reason, absl::Span literal_reason, absl::Span integer_reason) { DCHECK(!trail_->Assignment().LiteralIsAssigned(literal)); - DCHECK(lazy_reason != nullptr || + DCHECK(!use_lazy_reason || ReasonIsValid(literal, literal_reason, integer_reason)); if (integer_search_levels_.empty()) { // Level zero. We don't keep any reason. @@ -1412,7 +1394,7 @@ void IntegerTrail::EnqueueLiteralInternal( // If we are fixing something at a positive level, remember it. if (!integer_search_levels_.empty() && integer_reason.empty() && - literal_reason.empty() && lazy_reason == nullptr) { + literal_reason.empty() && !use_lazy_reason) { delayed_to_fix_->literal_to_fix.push_back(literal); } @@ -1422,23 +1404,10 @@ void IntegerTrail::EnqueueLiteralInternal( } boolean_trail_index_to_integer_one_[trail_index] = integer_trail_.size(); - int reason_index = literals_reason_starts_.size(); - if (lazy_reason != nullptr) { - if (integer_trail_.size() >= lazy_reasons_.size()) { - lazy_reasons_.resize(integer_trail_.size() + 1, nullptr); - } - lazy_reasons_[integer_trail_.size()] = lazy_reason; - reason_index = -1; - } else { - // Copy the reason. - literals_reason_starts_.push_back(literals_reason_buffer_.size()); - literals_reason_buffer_.insert(literals_reason_buffer_.end(), - literal_reason.begin(), - literal_reason.end()); - bounds_reason_starts_.push_back(bounds_reason_buffer_.size()); - bounds_reason_buffer_.insert(bounds_reason_buffer_.end(), - integer_reason.begin(), integer_reason.end()); - } + const int reason_index = + use_lazy_reason + ? -1 + : AppendReasonToInternalBuffers(literal_reason, integer_reason); integer_trail_.push_back({/*bound=*/IntegerValue(0), /*var=*/kNoIntegerVariable, @@ -1520,12 +1489,34 @@ void IntegerTrail::CanonicalizeLiteralIfNeeded(IntegerLiteral* i_lit) { } } +int IntegerTrail::AppendReasonToInternalBuffers( + absl::Span literal_reason, + absl::Span integer_reason) { + const int reason_index = literals_reason_starts_.size(); + DCHECK_EQ(reason_index, bounds_reason_starts_.size()); + + literals_reason_starts_.push_back(literals_reason_buffer_.size()); + if (!literal_reason.empty()) { + literals_reason_buffer_.insert(literals_reason_buffer_.end(), + literal_reason.begin(), + literal_reason.end()); + } + + bounds_reason_starts_.push_back(bounds_reason_buffer_.size()); + if (!integer_reason.empty()) { + bounds_reason_buffer_.insert(bounds_reason_buffer_.end(), + integer_reason.begin(), integer_reason.end()); + } + + return reason_index; +} + bool IntegerTrail::EnqueueInternal( - IntegerLiteral i_lit, LazyReasonFunction lazy_reason, + IntegerLiteral i_lit, bool use_lazy_reason, absl::Span literal_reason, absl::Span integer_reason, int trail_index_with_same_reason) { - DCHECK(lazy_reason != nullptr || + DCHECK(use_lazy_reason || ReasonIsValid(i_lit, literal_reason, integer_reason)); const IntegerVariable var(i_lit.var); @@ -1550,8 +1541,8 @@ bool IntegerTrail::EnqueueInternal( // Note that we want only one call to MergeReasonIntoInternal() for // efficiency and a potential smaller reason. - auto* conflict = - InitializeConflict(i_lit, lazy_reason, literal_reason, integer_reason); + auto* conflict = InitializeConflict(i_lit, use_lazy_reason, literal_reason, + integer_reason); { const int trail_index = FindLowestTrailIndexThatExplainBound(ub_reason); const int num_vars = var_lbs_.size(); // must be signed. @@ -1606,8 +1597,8 @@ bool IntegerTrail::EnqueueInternal( if (literal_index != kNoLiteralIndex) { const Literal to_enqueue = Literal(literal_index); if (trail_->Assignment().LiteralIsFalse(to_enqueue)) { - auto* conflict = InitializeConflict(i_lit, lazy_reason, literal_reason, - integer_reason); + auto* conflict = InitializeConflict(i_lit, use_lazy_reason, + literal_reason, integer_reason); conflict->push_back(to_enqueue); MergeReasonIntoInternal(conflict); return false; @@ -1620,7 +1611,7 @@ bool IntegerTrail::EnqueueInternal( if (bound >= i_lit.bound) { DCHECK_EQ(bound, i_lit.bound); if (!trail_->Assignment().LiteralIsTrue(to_enqueue)) { - EnqueueLiteralInternal(to_enqueue, lazy_reason, literal_reason, + EnqueueLiteralInternal(to_enqueue, use_lazy_reason, literal_reason, integer_reason); } return EnqueueAssociatedIntegerLiteral(i_lit, to_enqueue); @@ -1638,7 +1629,7 @@ bool IntegerTrail::EnqueueInternal( boolean_trail_index_to_integer_one_.resize(trail_index + 1); } boolean_trail_index_to_integer_one_[trail_index] = - trail_index_with_same_reason; + integer_trail_.size(); trail_->Enqueue(to_enqueue, propagator_id_); } } @@ -1662,32 +1653,16 @@ bool IntegerTrail::EnqueueInternal( // If we are not at level zero but there is not reason, we have a root level // deduction. Remember it so that we don't forget on the next restart. if (!integer_search_levels_.empty() && integer_reason.empty() && - literal_reason.empty() && lazy_reason == nullptr && - trail_index_with_same_reason >= integer_trail_.size()) { + literal_reason.empty() && !use_lazy_reason) { if (!RootLevelEnqueue(i_lit)) return false; } - int reason_index = literals_reason_starts_.size(); - if (lazy_reason != nullptr) { - if (integer_trail_.size() >= lazy_reasons_.size()) { - lazy_reasons_.resize(integer_trail_.size() + 1, nullptr); - } - lazy_reasons_[integer_trail_.size()] = lazy_reason; + int reason_index; + if (use_lazy_reason) { reason_index = -1; } else if (trail_index_with_same_reason >= integer_trail_.size()) { - // Save the reason into our internal buffers. - literals_reason_starts_.push_back(literals_reason_buffer_.size()); - if (!literal_reason.empty()) { - literals_reason_buffer_.insert(literals_reason_buffer_.end(), - literal_reason.begin(), - literal_reason.end()); - } - bounds_reason_starts_.push_back(bounds_reason_buffer_.size()); - if (!integer_reason.empty()) { - bounds_reason_buffer_.insert(bounds_reason_buffer_.end(), - integer_reason.begin(), - integer_reason.end()); - } + reason_index = + AppendReasonToInternalBuffers(literal_reason, integer_reason); } else { reason_index = integer_trail_[trail_index_with_same_reason].reason_index; } @@ -1741,12 +1716,8 @@ bool IntegerTrail::EnqueueAssociatedIntegerLiteral(IntegerLiteral i_lit, } DCHECK_GT(trail_->CurrentDecisionLevel(), 0); - const int reason_index = literals_reason_starts_.size(); - CHECK_EQ(reason_index, bounds_reason_starts_.size()); - literals_reason_starts_.push_back(literals_reason_buffer_.size()); - bounds_reason_starts_.push_back(bounds_reason_buffer_.size()); - literals_reason_buffer_.push_back(literal_reason.Negated()); - + const int reason_index = + AppendReasonToInternalBuffers({literal_reason.Negated()}, {}); const int prev_trail_index = var_trail_index_[i_lit.var]; integer_trail_.push_back({/*bound=*/i_lit.bound, /*var=*/i_lit.var, @@ -1763,8 +1734,9 @@ void IntegerTrail::ComputeLazyReasonIfNeeded(int trail_index) const { if (reason_index == -1) { const TrailEntry& entry = integer_trail_[trail_index]; const IntegerLiteral literal(entry.var, entry.bound); - lazy_reasons_[trail_index](literal, trail_index, &lazy_reason_literals_, - &lazy_reason_trail_indices_); + lazy_reasons_[trail_index].Explain(literal, trail_index, + &lazy_reason_literals_, + &lazy_reason_trail_indices_); } } diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index 472522028a..59f27b4df6 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -747,6 +747,31 @@ class IntegerEncoder { mutable std::vector partial_encoding_; }; +class LazyReasonInterface { + public: + LazyReasonInterface() = default; + virtual ~LazyReasonInterface() = default; + + // The function is provided with the IntegerLiteral to explain and its index + // in the integer trail. It must fill the two vectors so that literals + // contains any Literal part of the reason and dependencies contains the trail + // index of any IntegerLiteral that is also part of the reason. + // + // Remark: sometimes this is called to fill the conflict while the literal to + // explain is propagated. In this case, trail_index will be the current trail + // index, and we cannot assume that there is anything filled yet in + // integer_literal[trail_index]. + // + // TODO(user): Right now this is only used by "linear" propagator, if we need + // more we could replace {id, propagation_slack} by a generic payload so that + // each implementation can cast it to its need. Then the memory will just be + // the max size of this payload data (16 bytes should be fine). + virtual void Explain(int id, IntegerValue propagation_slack, + IntegerLiteral literal_to_explain, int trail_index, + std::vector* literals_reason, + std::vector* trail_indices_reason) = 0; +}; + // This class maintains a set of integer variables with their current bounds. // Bounds can be propagated from an external "source" and this class helps // to maintain the reason for each propagation. @@ -947,9 +972,15 @@ class IntegerTrail final : public SatPropagator { // TODO(user): If the given bound is equal to the current bound, maybe the new // reason is better? how to decide and what to do in this case? to think about // it. Currently we simply don't do anything. + ABSL_MUST_USE_RESULT bool Enqueue(IntegerLiteral i_lit) { + return EnqueueInternal(i_lit, false, {}, {}, integer_trail_.size()); + } ABSL_MUST_USE_RESULT bool Enqueue( IntegerLiteral i_lit, absl::Span literal_reason, - absl::Span integer_reason); + absl::Span integer_reason) { + return EnqueueInternal(i_lit, false, literal_reason, integer_reason, + integer_trail_.size()); + } // Enqueue new information about a variable bound. It has the same behavior // as the Enqueue() method, except that it accepts true and false integer @@ -983,24 +1014,22 @@ class IntegerTrail final : public SatPropagator { ABSL_MUST_USE_RESULT bool Enqueue( IntegerLiteral i_lit, absl::Span literal_reason, absl::Span integer_reason, - int trail_index_with_same_reason); + int trail_index_with_same_reason) { + return EnqueueInternal(i_lit, false, literal_reason, integer_reason, + trail_index_with_same_reason); + } // Lazy reason API. - // - // The function is provided with the IntegerLiteral to explain and its index - // in the integer trail. It must fill the two vectors so that literals - // contains any Literal part of the reason and dependencies contains the trail - // index of any IntegerLiteral that is also part of the reason. - // - // Remark: sometimes this is called to fill the conflict while the literal - // to explain is propagated. In this case, trail_index_of_literal will be - // the current trail index, and we cannot assume that there is anything filled - // yet in integer_literal[trail_index_of_literal]. - using LazyReasonFunction = std::function* literals, std::vector* dependencies)>; - ABSL_MUST_USE_RESULT bool Enqueue(IntegerLiteral i_lit, - LazyReasonFunction lazy_reason); + ABSL_MUST_USE_RESULT bool EnqueueWithLazyReason( + IntegerLiteral i_lit, int id, IntegerValue propagation_slack, + LazyReasonInterface* explainer) { + const int trail_index = integer_trail_.size(); + if (trail_index >= lazy_reasons_.size()) { + lazy_reasons_.resize(trail_index + 1); + } + lazy_reasons_[trail_index] = {explainer, propagation_slack, id}; + return EnqueueInternal(i_lit, true, {}, {}, 0); + } // Sometimes we infer some root level bounds but we are not at the root level. // In this case, we will update the level-zero bounds right away, but will @@ -1145,19 +1174,24 @@ class IntegerTrail final : public SatPropagator { // common conflict initialization that must terminate by a call to // MergeReasonIntoInternal(conflict) where conflict is the returned vector. std::vector* InitializeConflict( - IntegerLiteral integer_literal, const LazyReasonFunction& lazy_reason, + IntegerLiteral integer_literal, bool use_lazy_reason, absl::Span literals_reason, absl::Span bounds_reason); + // Saves the given reason and return its index. + int AppendReasonToInternalBuffers( + absl::Span literal_reason, + absl::Span integer_reason); + // Internal implementation of the different public Enqueue() functions. ABSL_MUST_USE_RESULT bool EnqueueInternal( - IntegerLiteral i_lit, LazyReasonFunction lazy_reason, + IntegerLiteral i_lit, bool use_lazy_reason, absl::Span literal_reason, absl::Span integer_reason, int trail_index_with_same_reason); // Internal implementation of the EnqueueLiteral() functions. - void EnqueueLiteralInternal(Literal literal, LazyReasonFunction lazy_reason, + void EnqueueLiteralInternal(Literal literal, bool use_lazy_reason, absl::Span literal_reason, absl::Span integer_reason); @@ -1229,7 +1263,20 @@ class IntegerTrail final : public SatPropagator { int32_t reason_index; }; std::vector integer_trail_; - std::vector lazy_reasons_; + + struct LazyReasonEntry { + LazyReasonInterface* explainer; + IntegerValue propagation_slack; + int id; + + void Explain(IntegerLiteral literal_to_explain, int trail_index_of_literal, + std::vector* literals, + std::vector* dependencies) const { + explainer->Explain(id, propagation_slack, literal_to_explain, + trail_index_of_literal, literals, dependencies); + } + }; + std::vector lazy_reasons_; // Start of each decision levels in integer_trail_. // TODO(user): use more general reversible mechanism? diff --git a/ortools/sat/integer_expr.cc b/ortools/sat/integer_expr.cc index 607a7a6ee0..25f8092195 100644 --- a/ortools/sat/integer_expr.cc +++ b/ortools/sat/integer_expr.cc @@ -227,6 +227,35 @@ LinearConstraintPropagator::ConditionalLb( } } +template +void LinearConstraintPropagator::Explain( + int /*id*/, IntegerValue propagation_slack, + IntegerLiteral literal_to_explain, int trail_index, + std::vector* literals_reason, + std::vector* trail_indices_reason) { + *literals_reason = literal_reason_; + trail_indices_reason->clear(); + shared_->reason_coeffs.clear(); + for (int i = 0; i < size_; ++i) { + const IntegerVariable var = vars_[i]; + if (PositiveVariable(var) == PositiveVariable(literal_to_explain.var)) { + continue; + } + const int index = + shared_->integer_trail->FindTrailIndexOfVarBefore(var, trail_index); + if (index >= 0) { + trail_indices_reason->push_back(index); + if (propagation_slack > 0) { + shared_->reason_coeffs.push_back(coeffs_[i]); + } + } + } + if (propagation_slack > 0) { + shared_->integer_trail->RelaxLinearReason( + propagation_slack, shared_->reason_coeffs, trail_indices_reason); + } +} + template bool LinearConstraintPropagator::Propagate() { // Reified case: If any of the enforcement_literals are false, we ignore the @@ -353,36 +382,9 @@ bool LinearConstraintPropagator::Propagate() { new_ub = lb + div; propagation_slack = (div + 1) * coeff - slack - 1; } - if (!shared_->integer_trail->Enqueue( - IntegerLiteral::LowerOrEqual(var, new_ub), - /*lazy_reason=*/[this, propagation_slack]( - IntegerLiteral i_lit, int trail_index, - std::vector* literal_reason, - std::vector* trail_indices_reason) { - *literal_reason = literal_reason_; - trail_indices_reason->clear(); - shared_->reason_coeffs.clear(); - for (int i = 0; i < size_; ++i) { - const IntegerVariable var = vars_[i]; - if (PositiveVariable(var) == PositiveVariable(i_lit.var)) { - continue; - } - const int index = - shared_->integer_trail->FindTrailIndexOfVarBefore( - var, trail_index); - if (index >= 0) { - trail_indices_reason->push_back(index); - if (propagation_slack > 0) { - shared_->reason_coeffs.push_back(coeffs_[i]); - } - } - } - if (propagation_slack > 0) { - shared_->integer_trail->RelaxLinearReason( - propagation_slack, shared_->reason_coeffs, - trail_indices_reason); - } - })) { + if (!shared_->integer_trail->EnqueueWithLazyReason( + IntegerLiteral::LowerOrEqual(var, new_ub), 0, propagation_slack, + this)) { // TODO(user): this is never supposed to happen since if we didn't have a // conflict above, we should be able to reduce the upper bound. It might // indicate an issue with our Boolean <-> integer encoding. @@ -650,9 +652,48 @@ LinMinPropagator::LinMinPropagator(const std::vector& exprs, model_(model), integer_trail_(model_->GetOrCreate()) {} +void LinMinPropagator::Explain(int id, IntegerValue propagation_slack, + IntegerLiteral literal_to_explain, + int trail_index, + std::vector* literals_reason, + std::vector* trail_indices_reason) { + const auto& vars = exprs_[id].vars; + const auto& coeffs = exprs_[id].coeffs; + literals_reason->clear(); + trail_indices_reason->clear(); + std::vector reason_coeffs; + const int size = vars.size(); + for (int i = 0; i < size; ++i) { + const IntegerVariable var = vars[i]; + if (PositiveVariable(var) == PositiveVariable(literal_to_explain.var)) { + continue; + } + const int index = + integer_trail_->FindTrailIndexOfVarBefore(var, trail_index); + if (index >= 0) { + trail_indices_reason->push_back(index); + if (propagation_slack > 0) { + reason_coeffs.push_back(coeffs[i]); + } + } + } + if (propagation_slack > 0) { + integer_trail_->RelaxLinearReason(propagation_slack, reason_coeffs, + trail_indices_reason); + } + // Now add the old integer_reason that triggered this propagation. + for (IntegerLiteral reason_lit : integer_reason_for_unique_candidate_) { + const int index = + integer_trail_->FindTrailIndexOfVarBefore(reason_lit.var, trail_index); + if (index >= 0) { + trail_indices_reason->push_back(index); + } + } +} + bool LinMinPropagator::PropagateLinearUpperBound( - const std::vector& vars, - const std::vector& coeffs, const IntegerValue upper_bound) { + int id, absl::Span vars, + absl::Span coeffs, const IntegerValue upper_bound) { IntegerValue sum_lb = IntegerValue(0); const int num_vars = vars.size(); max_variations_.resize(num_vars); @@ -699,46 +740,10 @@ bool LinMinPropagator::PropagateLinearUpperBound( const IntegerValue coeff = coeffs[i]; const IntegerValue div = slack / coeff; const IntegerValue new_ub = integer_trail_->LowerBound(var) + div; - const IntegerValue propagation_slack = (div + 1) * coeff - slack - 1; - if (!integer_trail_->Enqueue( - IntegerLiteral::LowerOrEqual(var, new_ub), - /*lazy_reason=*/[this, &vars, &coeffs, propagation_slack]( - IntegerLiteral i_lit, int trail_index, - std::vector* literal_reason, - std::vector* trail_indices_reason) { - literal_reason->clear(); - trail_indices_reason->clear(); - std::vector reason_coeffs; - const int size = vars.size(); - for (int i = 0; i < size; ++i) { - const IntegerVariable var = vars[i]; - if (PositiveVariable(var) == PositiveVariable(i_lit.var)) { - continue; - } - const int index = - integer_trail_->FindTrailIndexOfVarBefore(var, trail_index); - if (index >= 0) { - trail_indices_reason->push_back(index); - if (propagation_slack > 0) { - reason_coeffs.push_back(coeffs[i]); - } - } - } - if (propagation_slack > 0) { - integer_trail_->RelaxLinearReason( - propagation_slack, reason_coeffs, trail_indices_reason); - } - // Now add the old integer_reason that triggered this propagation. - for (IntegerLiteral reason_lit : - integer_reason_for_unique_candidate_) { - const int index = integer_trail_->FindTrailIndexOfVarBefore( - reason_lit.var, trail_index); - if (index >= 0) { - trail_indices_reason->push_back(index); - } - } - })) { + if (!integer_trail_->EnqueueWithLazyReason( + IntegerLiteral::LowerOrEqual(var, new_ub), id, propagation_slack, + this)) { return false; } } @@ -815,7 +820,7 @@ bool LinMinPropagator::Propagate() { } return PropagateLinearUpperBound( - exprs_[last_possible_min_interval].vars, + last_possible_min_interval, exprs_[last_possible_min_interval].vars, exprs_[last_possible_min_interval].coeffs, current_min_ub - exprs_[last_possible_min_interval].offset); } diff --git a/ortools/sat/integer_expr.h b/ortools/sat/integer_expr.h index 7f17b88e6c..104cfeeb21 100644 --- a/ortools/sat/integer_expr.h +++ b/ortools/sat/integer_expr.h @@ -64,7 +64,8 @@ namespace sat { // constraint implementation. But we do need support for enforcement literals // there. template -class LinearConstraintPropagator : public PropagatorInterface { +class LinearConstraintPropagator : public PropagatorInterface, + LazyReasonInterface { public: // If refied_literal is kNoLiteralIndex then this is a normal constraint, // otherwise we enforce the implication refied_literal => constraint is true. @@ -99,6 +100,12 @@ class LinearConstraintPropagator : public PropagatorInterface { std::pair ConditionalLb( IntegerLiteral integer_literal, IntegerVariable target_var) const; + // For LazyReasonInterface. + void Explain(int id, IntegerValue propagation_slack, + IntegerLiteral literal_to_explain, int trail_index, + std::vector* literals_reason, + std::vector* trail_indices_reason) final; + private: // Fills integer_reason_ with all the current lower_bounds. The real // explanation may require removing one of them, but as an optimization, we @@ -233,7 +240,7 @@ class MinPropagator : public PropagatorInterface { // Same as MinPropagator except this works on min = MIN(exprs) where exprs are // linear expressions. It uses IntegerSumLE to propagate bounds on the exprs. // Assumes Canonical expressions (all positive coefficients). -class LinMinPropagator : public PropagatorInterface { +class LinMinPropagator : public PropagatorInterface, LazyReasonInterface { public: LinMinPropagator(const std::vector& exprs, IntegerVariable min_var, Model* model); @@ -243,12 +250,18 @@ class LinMinPropagator : public PropagatorInterface { bool Propagate() final; void RegisterWith(GenericLiteralWatcher* watcher); + // For LazyReasonInterface. + void Explain(int id, IntegerValue propagation_slack, + IntegerLiteral literal_to_explain, int trail_index, + std::vector* literals_reason, + std::vector* trail_indices_reason) final; + private: // Lighter version of IntegerSumLE. This uses the current value of // integer_reason_ in addition to the reason for propagating the linear // constraint. The coeffs are assumed to be positive here. - bool PropagateLinearUpperBound(const std::vector& vars, - const std::vector& coeffs, + bool PropagateLinearUpperBound(int id, absl::Span vars, + absl::Span coeffs, IntegerValue upper_bound); const std::vector exprs_; diff --git a/ortools/sat/intervals.cc b/ortools/sat/intervals.cc index 5f34a48951..5aad6ea516 100644 --- a/ortools/sat/intervals.cc +++ b/ortools/sat/intervals.cc @@ -293,7 +293,7 @@ bool SchedulingConstraintHelper::Propagate() { bool SchedulingConstraintHelper::IncrementalPropagate( const std::vector& watch_indices) { - for (const int t : watch_indices) recompute_cache_[t] = true; + for (const int t : watch_indices) recompute_cache_.Set(t); return true; } @@ -326,7 +326,6 @@ void SchedulingConstraintHelper::RegisterWith(GenericLiteralWatcher* watcher) { } bool SchedulingConstraintHelper::UpdateCachedValues(int t) { - recompute_cache_[t] = false; if (IsAbsent(t)) return true; IntegerValue smin = integer_trail_->LowerBound(starts_[t]); @@ -432,7 +431,10 @@ void SchedulingConstraintHelper::InitSortedVectors() { const int num_tasks = starts_.size(); recompute_all_cache_ = true; - recompute_cache_.resize(num_tasks, true); + recompute_cache_.Resize(num_tasks); + for (int t = 0; t < num_tasks; ++t) { + recompute_cache_.Set(t); + } // Make sure all the cached_* arrays can hold enough data. CHECK_LE(num_tasks, capacity_); @@ -485,12 +487,11 @@ bool SchedulingConstraintHelper::SynchronizeAndSetTimeDirection( if (!UpdateCachedValues(t)) return false; } } else { - for (int t = 0; t < recompute_cache_.size(); ++t) { - if (recompute_cache_[t]) { - if (!UpdateCachedValues(t)) return false; - } + for (const int t : recompute_cache_) { + if (!UpdateCachedValues(t)) return false; } } + recompute_cache_.ClearAll(); recompute_all_cache_ = false; return true; } @@ -506,13 +507,17 @@ IntegerValue SchedulingConstraintHelper::GetCurrentMinDistanceBetweenTasks( return kMinIntegerValue; } - const IntegerValue offset = + // We take the max of the level zero offset and the one coming from a + // conditional precedence at true. + const IntegerValue conditional_offset = precedence_relations_->GetConditionalOffset(before.var, after.var); - if (offset == kMinIntegerValue) return kMinIntegerValue; + const IntegerValue known = integer_trail_->LevelZeroLowerBound(after.var) - + integer_trail_->LevelZeroUpperBound(before.var); + const IntegerValue offset = std::max(conditional_offset, known); const IntegerValue needed_offset = before.constant - after.constant; const IntegerValue distance = offset - needed_offset; - if (add_reason_if_after && distance >= 0) { + if (add_reason_if_after && distance >= 0 && known < conditional_offset) { for (const Literal l : precedence_relations_->GetConditionalEnforcements( before.var, after.var)) { literal_reason_.push_back(l.Negated()); @@ -722,6 +727,7 @@ bool SchedulingConstraintHelper::PushIntervalBound(int t, IntegerLiteral lit) { if (!PushIntegerLiteralIfTaskPresent(t, lit)) return false; if (IsAbsent(t)) return true; if (!UpdateCachedValues(t)) return false; + recompute_cache_.Clear(t); return true; } diff --git a/ortools/sat/intervals.h b/ortools/sat/intervals.h index 97ee3fecfa..9fddea6d22 100644 --- a/ortools/sat/intervals.h +++ b/ortools/sat/intervals.h @@ -582,7 +582,7 @@ class SchedulingConstraintHelper : public PropagatorInterface, // If recompute_cache_[t] is true, then we need to update all the cached // value for the task t in SynchronizeAndSetTimeDirection(). bool recompute_all_cache_ = true; - std::vector recompute_cache_; + Bitset64 recompute_cache_; // Reason vectors. std::vector literal_reason_; diff --git a/ortools/sat/linear_propagation.cc b/ortools/sat/linear_propagation.cc index 95dd45336c..db1b60f13d 100644 --- a/ortools/sat/linear_propagation.cc +++ b/ortools/sat/linear_propagation.cc @@ -603,6 +603,7 @@ bool LinearPropagator::AddConstraint( info.rev_rhs = upper_bound; info.rev_size = vars.size(); infos_.push_back(std::move(info)); + initial_rhs_.push_back(upper_bound); } id_to_propagation_count_.push_back(0); @@ -643,20 +644,17 @@ bool LinearPropagator::AddConstraint( watcher_->CallOnNextPropagate(watcher_id_); } - // When a conditional precedence becomes enforced, add it. Note that - // we cannot just use rev_size == 2 since we might miss some - // explanation if a longer constraint only have 2 non-fixed variable - // now.. It is however okay not to push precedence involving a fixed - // variable, since these should be reflected in the variable domain - // anyway. + // When a conditional precedence becomes enforced, add it. + // Note that we only look at relation that were a "precedence" from + // the start, note the one currently of size 2 if we ignore fixed + // variables. if (status == EnforcementStatus::IS_ENFORCED) { const auto info = infos_[id]; - if (info.initial_size == 2 && info.rev_size == 2 && - info.all_coeffs_are_one) { + if (info.initial_size == 2 && info.all_coeffs_are_one) { const auto vars = GetVariables(info); precedences_->PushConditionalRelation( enforcement_propagator_->GetEnforcementLiterals(enf_id), - vars[0], vars[1], info.rev_rhs); + vars[0], vars[1], initial_rhs_[id]); } } }); @@ -887,6 +885,39 @@ bool LinearPropagator::PropagateInfeasibleConstraint(int id, integer_reason_); } +void LinearPropagator::Explain(int id, IntegerValue propagation_slack, + IntegerLiteral literal_to_explain, + int trail_index, + std::vector* literals_reason, + std::vector* trail_indices_reason) { + literals_reason->clear(); + trail_indices_reason->clear(); + const ConstraintInfo& info = infos_[id]; + enforcement_propagator_->AddEnforcementReason(info.enf_id, literals_reason); + reason_coeffs_.clear(); + + const auto coeffs = GetCoeffs(info); + const auto vars = GetVariables(info); + for (int i = 0; i < info.initial_size; ++i) { + const IntegerVariable var = vars[i]; + if (PositiveVariable(var) == PositiveVariable(literal_to_explain.var)) { + continue; + } + const int index = + integer_trail_->FindTrailIndexOfVarBefore(var, trail_index); + if (index >= 0) { + trail_indices_reason->push_back(index); + if (propagation_slack > 0) { + reason_coeffs_.push_back(coeffs[i]); + } + } + } + if (propagation_slack > 0) { + integer_trail_->RelaxLinearReason(propagation_slack, reason_coeffs_, + trail_indices_reason); + } +} + bool LinearPropagator::PropagateOneConstraint(int id) { const auto [slack, num_to_push] = AnalyzeConstraint(id); if (slack < 0) return PropagateInfeasibleConstraint(id, slack); @@ -927,39 +958,9 @@ bool LinearPropagator::PropagateOneConstraint(int id) { const IntegerValue div = slack / coeff; const IntegerValue new_ub = integer_trail_->LowerBound(var) + div; const IntegerValue propagation_slack = (div + 1) * coeff - slack - 1; - if (!integer_trail_->Enqueue( - IntegerLiteral::LowerOrEqual(var, new_ub), - /*lazy_reason=*/[this, info, propagation_slack]( - IntegerLiteral i_lit, int trail_index, - std::vector* literal_reason, - std::vector* trail_indices_reason) { - literal_reason->clear(); - trail_indices_reason->clear(); - enforcement_propagator_->AddEnforcementReason(info.enf_id, - literal_reason); - reason_coeffs_.clear(); - - const auto coeffs = GetCoeffs(info); - const auto vars = GetVariables(info); - for (int i = 0; i < info.initial_size; ++i) { - const IntegerVariable var = vars[i]; - if (PositiveVariable(var) == PositiveVariable(i_lit.var)) { - continue; - } - const int index = - integer_trail_->FindTrailIndexOfVarBefore(var, trail_index); - if (index >= 0) { - trail_indices_reason->push_back(index); - if (propagation_slack > 0) { - reason_coeffs_.push_back(coeffs[i]); - } - } - } - if (propagation_slack > 0) { - integer_trail_->RelaxLinearReason( - propagation_slack, reason_coeffs_, trail_indices_reason); - } - })) { + if (!integer_trail_->EnqueueWithLazyReason( + IntegerLiteral::LowerOrEqual(var, new_ub), id, propagation_slack, + this)) { return false; } diff --git a/ortools/sat/linear_propagation.h b/ortools/sat/linear_propagation.h index 1b2bd5214a..20e470935d 100644 --- a/ortools/sat/linear_propagation.h +++ b/ortools/sat/linear_propagation.h @@ -297,7 +297,9 @@ class ConstraintPropagationOrder { // - Lack detection and propagation of at least one of these linear is true // which can be used to propagate more bound if a variable appear in all these // constraint. -class LinearPropagator : public PropagatorInterface, ReversibleInterface { +class LinearPropagator : public PropagatorInterface, + ReversibleInterface, + LazyReasonInterface { public: explicit LinearPropagator(Model* model); ~LinearPropagator() override; @@ -313,6 +315,12 @@ class LinearPropagator : public PropagatorInterface, ReversibleInterface { absl::Span coeffs, IntegerValue upper_bound); + // For LazyReasonInterface. + void Explain(int id, IntegerValue propagation_slack, + IntegerLiteral literal_to_explain, int trail_index, + std::vector* literals_reason, + std::vector* trail_indices_reason) final; + private: // We try to pack the struct as much as possible. Using a maximum size of // 1 << 29 should be okay since we split long constraint anyway. Technically @@ -393,6 +401,7 @@ class LinearPropagator : public PropagatorInterface, ReversibleInterface { // Per constraint info used during propagation. Note that we keep pointer for // the rev_size/rhs there, so we do need a deque. std::deque infos_; + std::vector initial_rhs_; // Buffer of the constraints data. std::vector variables_buffer_; diff --git a/ortools/sat/theta_tree.cc b/ortools/sat/theta_tree.cc index 3a236d6d34..18a983f33c 100644 --- a/ortools/sat/theta_tree.cc +++ b/ortools/sat/theta_tree.cc @@ -27,7 +27,8 @@ ThetaLambdaTree::ThetaLambdaTree() = default; template typename ThetaLambdaTree::TreeNode -ThetaLambdaTree::ComposeTreeNodes(TreeNode left, TreeNode right) { +ThetaLambdaTree::ComposeTreeNodes(const TreeNode& left, + const TreeNode& right) { return {std::max(right.envelope, left.envelope + right.sum_of_energy_min), std::max(right.envelope_opt, right.sum_of_energy_min + @@ -213,11 +214,12 @@ IntegerType ThetaLambdaTree::GetEnvelopeOf(int event) const { template void ThetaLambdaTree::RefreshNode(int node) { + TreeNode* tree = tree_.data(); do { const int right = node | 1; const int left = right ^ 1; node >>= 1; - tree_[node] = ComposeTreeNodes(tree_[left], tree_[right]); + tree[node] = ComposeTreeNodes(tree[left], tree[right]); } while (node > 1); } diff --git a/ortools/sat/theta_tree.h b/ortools/sat/theta_tree.h index bd833b6b5d..d817a328b0 100644 --- a/ortools/sat/theta_tree.h +++ b/ortools/sat/theta_tree.h @@ -207,7 +207,7 @@ class ThetaLambdaTree { IntegerType max_of_energy_delta; }; - TreeNode ComposeTreeNodes(TreeNode left, TreeNode right); + TreeNode ComposeTreeNodes(const TreeNode& left, const TreeNode& right); int GetLeafFromEvent(int event) const; int GetEventFromLeaf(int leaf) const; diff --git a/ortools/sat/util.h b/ortools/sat/util.h index 697d031fe1..fb5f337ee1 100644 --- a/ortools/sat/util.h +++ b/ortools/sat/util.h @@ -178,6 +178,7 @@ class FixedCapacityVector { T& back() { return data_[size_ - 1]; } void clear() { size_ = 0; } + void resize(size_t size) { size_ = size; } void pop_back() { --size_; } void push_back(T t) { data_[size_++] = t; }