diff --git a/ortools/sat/integer.cc b/ortools/sat/integer.cc index ea7cabe5e5..3ff6d084b3 100644 --- a/ortools/sat/integer.cc +++ b/ortools/sat/integer.cc @@ -1591,6 +1591,8 @@ GenericLiteralWatcher::GenericLiteralWatcher(Model* model) // this one. model->GetOrCreate()->AddLastPropagator(this); + integer_trail_->RegisterReversibleClass( + &id_to_greatest_common_level_since_last_call_); integer_trail_->RegisterWatcher(&modified_vars_); queue_by_priority_.resize(2); // Because default priority is 1. } @@ -1769,11 +1771,6 @@ void GenericLiteralWatcher::Untrail(const Trail& trail, int trail_index) { propagation_trail_index_ = trail_index; modified_vars_.ClearAndResize(integer_trail_->NumIntegerVariables()); in_queue_.assign(watchers_.size(), false); - - const int level = trail.CurrentDecisionLevel(); - for (int& ref : id_to_greatest_common_level_since_last_call_) { - ref = std::min(ref, level); - } } // Registers a propagator and returns its unique ids. @@ -1781,7 +1778,7 @@ int GenericLiteralWatcher::Register(PropagatorInterface* propagator) { const int id = watchers_.size(); watchers_.push_back(propagator); id_to_level_at_last_call_.push_back(0); - id_to_greatest_common_level_since_last_call_.push_back(0); + id_to_greatest_common_level_since_last_call_.GrowByOne(); id_to_reversible_classes_.push_back(std::vector()); id_to_reversible_ints_.push_back(std::vector()); id_to_watch_indices_.push_back(std::vector()); diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index dada8e4841..75e5e2bca4 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -1111,7 +1111,7 @@ class GenericLiteralWatcher : public SatPropagator { // Data for each propagator. std::vector id_to_level_at_last_call_; - std::vector id_to_greatest_common_level_since_last_call_; + RevVector id_to_greatest_common_level_since_last_call_; std::vector> id_to_reversible_classes_; std::vector> id_to_reversible_ints_; std::vector> id_to_watch_indices_; diff --git a/ortools/util/rev.h b/ortools/util/rev.h index 24168444bc..489ad6461a 100644 --- a/ortools/util/rev.h +++ b/ortools/util/rev.h @@ -82,6 +82,49 @@ class RevRepository : public ReversibleInterface { std::vector> stack_; }; +// A basic reversible vector implementation. +template +class RevVector : public ReversibleInterface { + public: + const T& operator[](int index) const { return vector_[index]; } + T& operator[](int index) { + // Save on the stack first. + stack_.push_back({index, vector_[index]}); + return vector_[index]; + } + + int size() const { return vector_.size(); } + + void Grow(int new_size) { + CHECK_GE(new_size, vector_.size()); + vector_.resize(new_size); + } + + void GrowByOne() { vector_.resize(vector_.size() + 1); } + + int Level() const { return end_of_level_.size(); } + + void SetLevel(int level) final { + DCHECK_GE(level, 0); + if (level == Level()) return; + if (level < Level()) { + const int index = end_of_level_[level]; + end_of_level_.resize(level); // Shrinks. + for (int i = stack_.size() - 1; i >= index; --i) { + vector_[stack_[i].first] = stack_[i].second; + } + stack_.resize(index); + } else { + end_of_level_.resize(level, stack_.size()); // Grows. + } + } + + private: + std::vector end_of_level_; // In stack_. + std::vector> stack_; + std::vector vector_; +}; + template void RevRepository::SetLevel(int level) { DCHECK_GE(level, 0); @@ -90,11 +133,10 @@ void RevRepository::SetLevel(int level) { if (level < Level()) { const int index = end_of_level_[level]; end_of_level_.resize(level); // Shrinks. - while (stack_.size() > index) { - const auto& p = stack_.back(); - *p.first = p.second; - stack_.pop_back(); + for (int i = stack_.size() - 1; i >= index; --i) { + *stack_[i].first = stack_[i].second; } + stack_.resize(index); } else { end_of_level_.resize(level, stack_.size()); // Grows. }