diff --git a/constraint_solver/constraint_solveri.h b/constraint_solver/constraint_solveri.h index 3899505db9..1afb39848d 100644 --- a/constraint_solver/constraint_solveri.h +++ b/constraint_solver/constraint_solveri.h @@ -667,15 +667,23 @@ class IntVarLocalSearchFilter : public LocalSearchFilter { IntVarLocalSearchFilter(const IntVar* const* vars, int size); ~IntVarLocalSearchFilter(); protected: + // Add variables to "track" to the filter. + void AddVars(const IntVar* const* vars, int size); + // This method should not be overridden. Override OnSynchronize() instead + // which is called before exiting this method. + virtual void Synchronize(const Assignment* assignment); + virtual void OnSynchronize() {} bool FindIndex(const IntVar* const var, int64* index) const { DCHECK(index != NULL); return FindCopy(var_to_index_, var, index); } - int size() const { return size_; } - IntVar* Vars(int index) const { return vars_[index]; } + int Size() const { return size_; } + IntVar* Var(int index) const { return vars_[index]; } + int64 Value(int index) const { return values_[index]; } private: scoped_array vars_; - const int size_; + scoped_array values_; + int size_; hash_map var_to_index_; }; diff --git a/constraint_solver/local_search.cc b/constraint_solver/local_search.cc index 7b7dbb0e32..a8ac9acae2 100644 --- a/constraint_solver/local_search.cc +++ b/constraint_solver/local_search.cc @@ -1948,20 +1948,51 @@ LocalSearchFilter* Solver::MakeVariableDomainFilter() { IntVarLocalSearchFilter::IntVarLocalSearchFilter(const IntVar* const* vars, int size) - : vars_(NULL), - size_(size) { + : vars_(NULL), values_(NULL), size_(0) { + AddVars(vars, size); CHECK_GE(size_, 0); - if (size_ > 0) { - vars_.reset(new IntVar*[size_]); - memcpy(vars_.get(), vars, size_ * sizeof(*vars)); - for (int i = 0; i < size_; ++i) { - var_to_index_[vars_[i]] = i; +} + +void IntVarLocalSearchFilter::AddVars(const IntVar* const* vars, int size) { + if (size > 0) { + for (int i = 0; i < size; ++i) { + var_to_index_[vars[i]] = i + size_; } + const int new_size = size_ + size; + IntVar** new_vars = new IntVar*[new_size]; + if (size_ > 0) { + memcpy(new_vars, vars_.get(), size_ * sizeof(*new_vars)); + } + memcpy(new_vars + size_, vars, size * sizeof(*vars)); + vars_.reset(new_vars); + values_.reset(new int64[new_size]); + memset(values_.get(), 0, sizeof(values_.get())); + size_ = new_size; } } IntVarLocalSearchFilter::~IntVarLocalSearchFilter() {} +void IntVarLocalSearchFilter::Synchronize(const Assignment* assignment) { + const Assignment::IntContainer& container = assignment->IntVarContainer(); + const int size = container.Size(); + typedef hash_map::const_iterator IndexMapIterator; + IndexMapIterator indices_end = var_to_index_.end(); + for (int i = 0; i < size; ++i) { + const IntVarElement& element = container.Element(i); + const IntVar* var = element.Var(); + if (i < size_ && vars_[i] == var) { + values_[i] = element.Value(); + } else { + IndexMapIterator iterator = var_to_index_.find(var); + if (iterator != indices_end) { + values_[iterator->second] = element.Value(); + } + } + } + OnSynchronize(); +} + // ----- Objective filter ------ // Assignment is accepted if it improves the best objective value found so far. // 'Values' callback takes an index of a variable and its value and returns the @@ -1979,9 +2010,8 @@ class ObjectiveFilter : public IntVarLocalSearchFilter { LSOperation* op); virtual ~ObjectiveFilter(); virtual bool Accept(const Assignment* delta, const Assignment* deltadelta); - virtual void Synchronize(const Assignment* assignment); - virtual int64 SynchronizedElementValue(const Assignment* assignment, - int64 index) = 0; + virtual void OnSynchronize(); + virtual int64 SynchronizedElementValue(int64 index) = 0; virtual bool EvaluateElementValue(const Assignment::IntContainer& container, int index, int* container_index, @@ -1993,6 +2023,7 @@ class ObjectiveFilter : public IntVarLocalSearchFilter { const int64* const out_values, bool cache_delta_values); + const int primary_vars_size_; int64* const cache_; int64* const delta_cache_; const IntVar* const objective_; @@ -2009,6 +2040,7 @@ ObjectiveFilter::ObjectiveFilter(const IntVar* const* vars, Solver::LocalSearchFilterBound filter_enum, LSOperation* op) : IntVarLocalSearchFilter(vars, var_size), + primary_vars_size_(var_size), cache_(new int64[var_size]), delta_cache_(new int64[var_size]), objective_(objective), @@ -2018,7 +2050,7 @@ ObjectiveFilter::ObjectiveFilter(const IntVar* const* vars, old_delta_value_(0), incremental_(false) { CHECK(op_ != NULL); - for (int i = 0; i < size(); ++i) { + for (int i = 0; i < Size(); ++i) { cache_[i] = 0; delta_cache_[i] = 0; } @@ -2046,7 +2078,7 @@ bool ObjectiveFilter::Accept(const Assignment* delta, incremental_ = true; } else { if (incremental_) { - for (int i = 0; i < size(); ++i) { + for (int i = 0; i < primary_vars_size_; ++i) { delta_cache_[i] = cache_[i]; } old_delta_value_ = old_value_; @@ -2078,10 +2110,10 @@ bool ObjectiveFilter::Accept(const Assignment* delta, } } -void ObjectiveFilter::Synchronize(const Assignment* assignment) { +void ObjectiveFilter::OnSynchronize() { op_->Init(); - for (int i = 0; i < size(); ++i) { - const int64 obj_value = SynchronizedElementValue(assignment, i); + for (int i = 0; i < primary_vars_size_; ++i) { + const int64 obj_value = SynchronizedElementValue(i); cache_[i] = obj_value; delta_cache_[i] = obj_value; op_->Update(obj_value); @@ -2103,9 +2135,9 @@ int64 ObjectiveFilter::Evaluate(const Assignment* delta, const IntVarElement& new_element = container.Element(i); const IntVar* var = new_element.Var(); int64 index = -1; - if (FindIndex(var, &index)) { + if (FindIndex(var, &index) && index < primary_vars_size_) { op_->Remove(out_values[index]); - int64 obj_value; + int64 obj_value = 0LL; if (EvaluateElementValue(container, index, &i, &obj_value)) { op_->Update(obj_value); if (cache_delta_values) { @@ -2126,30 +2158,29 @@ class BinaryObjectiveFilter : public ObjectiveFilter { Solver::LocalSearchFilterBound filter_enum, LSOperation* op); virtual ~BinaryObjectiveFilter() {} - virtual int64 SynchronizedElementValue(const Assignment* assignment, - int64 index); + virtual int64 SynchronizedElementValue(int64 index); virtual bool EvaluateElementValue(const Assignment::IntContainer& container, int index, int* container_index, int64* obj_value); private: - scoped_ptr values_; + scoped_ptr value_evaluator_; }; BinaryObjectiveFilter::BinaryObjectiveFilter( const IntVar* const* vars, int size, - Solver::IndexEvaluator2* values, + Solver::IndexEvaluator2* value_evaluator, const IntVar* const objective, Solver::LocalSearchFilterBound filter_enum, LSOperation* op) - : ObjectiveFilter(vars, size, objective, filter_enum, op), values_(values) { - values_->CheckIsRepeatable(); + : ObjectiveFilter(vars, size, objective, filter_enum, op), + value_evaluator_(value_evaluator) { + value_evaluator_->CheckIsRepeatable(); } -int64 BinaryObjectiveFilter::SynchronizedElementValue( - const Assignment* assignment, int64 index) { - return values_->Run(index, assignment->Value(Vars(index))); +int64 BinaryObjectiveFilter::SynchronizedElementValue(int64 index) { + return value_evaluator_->Run(index, Value(index)); } bool BinaryObjectiveFilter::EvaluateElementValue( @@ -2159,12 +2190,12 @@ bool BinaryObjectiveFilter::EvaluateElementValue( int64* obj_value) { const IntVarElement& element = container.Element(*container_index); if (element.Activated()) { - *obj_value = values_->Run(index, element.Value()); + *obj_value = value_evaluator_->Run(index, element.Value()); return true; } else { const IntVar* var = element.Var(); if (var->Bound()) { - *obj_value = values_->Run(index, var->Min()); + *obj_value = value_evaluator_->Run(index, var->Min()); return true; } } @@ -2176,47 +2207,42 @@ class TernaryObjectiveFilter : public ObjectiveFilter { TernaryObjectiveFilter(const IntVar* const* vars, const IntVar* const* secondary_vars, int size, - Solver::IndexEvaluator3* values, + Solver::IndexEvaluator3* value_evaluator, const IntVar* const objective, Solver::LocalSearchFilterBound filter_enum, LSOperation* op); virtual ~TernaryObjectiveFilter() {} - virtual int64 SynchronizedElementValue(const Assignment* assignment, - int64 index); + virtual int64 SynchronizedElementValue(int64 index); bool EvaluateElementValue(const Assignment::IntContainer& container, int index, int* container_index, int64* obj_value); private: - scoped_array secondary_vars_; - scoped_ptr values_; + int secondary_vars_offset_; + scoped_ptr value_evaluator_; }; TernaryObjectiveFilter::TernaryObjectiveFilter( const IntVar* const* vars, const IntVar* const* secondary_vars, int var_size, - Solver::IndexEvaluator3* values, + Solver::IndexEvaluator3* value_evaluator, const IntVar* const objective, Solver::LocalSearchFilterBound filter_enum, LSOperation* op) : ObjectiveFilter(vars, var_size, objective, filter_enum, op), - values_(values) { - values_->CheckIsRepeatable(); - CHECK_GE(size(), 0); - if (size() > 0) { - secondary_vars_.reset(new IntVar*[size()]); - memcpy(secondary_vars_.get(), - secondary_vars, size() * sizeof(*secondary_vars)); - } + secondary_vars_offset_(var_size), + value_evaluator_(value_evaluator) { + value_evaluator_->CheckIsRepeatable(); + AddVars(secondary_vars, var_size); + CHECK_GE(Size(), 0); } -int64 TernaryObjectiveFilter::SynchronizedElementValue( - const Assignment* assignment, - int64 index) { - return values_->Run(index, - assignment->Value(Vars(index)), - assignment->Value(secondary_vars_[index])); +int64 TernaryObjectiveFilter::SynchronizedElementValue(int64 index) { + DCHECK_LT(index, secondary_vars_offset_); + return value_evaluator_->Run(index, + Value(index), + Value(index + secondary_vars_offset_)); } bool TernaryObjectiveFilter::EvaluateElementValue( @@ -2224,25 +2250,33 @@ bool TernaryObjectiveFilter::EvaluateElementValue( int index, int* container_index, int64* obj_value) { + DCHECK_LT(index, secondary_vars_offset_); + *obj_value = 0LL; const IntVarElement& element = container.Element(*container_index); - const IntVar* secondary_var = secondary_vars_[index]; + const IntVar* secondary_var = Var(index + secondary_vars_offset_); if (element.Activated()) { const int64 value = element.Value(); int hint_index = *container_index + 1; if (hint_index < container.Size() && secondary_var == container.Element(hint_index).Var()) { *obj_value = - values_->Run(index, value, container.Element(hint_index).Value()); + value_evaluator_->Run(index, + value, + container.Element(hint_index).Value()); *container_index = hint_index; } else { *obj_value = - values_->Run(index, value, container.Element(secondary_var).Value()); + value_evaluator_->Run(index, + value, + container.Element(secondary_var).Value()); } return true; } else { const IntVar* var = element.Var(); if (var->Bound() && secondary_var->Bound()) { - *obj_value = values_->Run(index, var->Min(), secondary_var->Min()); + *obj_value = value_evaluator_->Run(index, + var->Min(), + secondary_var->Min()); return true; } }