speedup local search

This commit is contained in:
lperron@google.com
2010-11-03 00:08:52 +00:00
parent 469e9e1640
commit 41e12a1010
2 changed files with 97 additions and 55 deletions

View File

@@ -667,15 +667,23 @@ class IntVarLocalSearchFilter : public LocalSearchFilter {
IntVarLocalSearchFilter(const IntVar* const* vars, int size);
~IntVarLocalSearchFilter();
protected:
// Add variables to "track" to the filter.
void AddVars(const IntVar* const* vars, int size);
// This method should not be overridden. Override OnSynchronize() instead
// which is called before exiting this method.
virtual void Synchronize(const Assignment* assignment);
virtual void OnSynchronize() {}
bool FindIndex(const IntVar* const var, int64* index) const {
DCHECK(index != NULL);
return FindCopy(var_to_index_, var, index);
}
int size() const { return size_; }
IntVar* Vars(int index) const { return vars_[index]; }
int Size() const { return size_; }
IntVar* Var(int index) const { return vars_[index]; }
int64 Value(int index) const { return values_[index]; }
private:
scoped_array<IntVar*> vars_;
const int size_;
scoped_array<int64> values_;
int size_;
hash_map<const IntVar*, int64> var_to_index_;
};

View File

@@ -1948,20 +1948,51 @@ LocalSearchFilter* Solver::MakeVariableDomainFilter() {
IntVarLocalSearchFilter::IntVarLocalSearchFilter(const IntVar* const* vars,
int size)
: vars_(NULL),
size_(size) {
: vars_(NULL), values_(NULL), size_(0) {
AddVars(vars, size);
CHECK_GE(size_, 0);
if (size_ > 0) {
vars_.reset(new IntVar*[size_]);
memcpy(vars_.get(), vars, size_ * sizeof(*vars));
for (int i = 0; i < size_; ++i) {
var_to_index_[vars_[i]] = i;
}
void IntVarLocalSearchFilter::AddVars(const IntVar* const* vars, int size) {
if (size > 0) {
for (int i = 0; i < size; ++i) {
var_to_index_[vars[i]] = i + size_;
}
const int new_size = size_ + size;
IntVar** new_vars = new IntVar*[new_size];
if (size_ > 0) {
memcpy(new_vars, vars_.get(), size_ * sizeof(*new_vars));
}
memcpy(new_vars + size_, vars, size * sizeof(*vars));
vars_.reset(new_vars);
values_.reset(new int64[new_size]);
memset(values_.get(), 0, sizeof(values_.get()));
size_ = new_size;
}
}
IntVarLocalSearchFilter::~IntVarLocalSearchFilter() {}
void IntVarLocalSearchFilter::Synchronize(const Assignment* assignment) {
const Assignment::IntContainer& container = assignment->IntVarContainer();
const int size = container.Size();
typedef hash_map<const IntVar*, int64>::const_iterator IndexMapIterator;
IndexMapIterator indices_end = var_to_index_.end();
for (int i = 0; i < size; ++i) {
const IntVarElement& element = container.Element(i);
const IntVar* var = element.Var();
if (i < size_ && vars_[i] == var) {
values_[i] = element.Value();
} else {
IndexMapIterator iterator = var_to_index_.find(var);
if (iterator != indices_end) {
values_[iterator->second] = element.Value();
}
}
}
OnSynchronize();
}
// ----- Objective filter ------
// Assignment is accepted if it improves the best objective value found so far.
// 'Values' callback takes an index of a variable and its value and returns the
@@ -1979,9 +2010,8 @@ class ObjectiveFilter : public IntVarLocalSearchFilter {
LSOperation* op);
virtual ~ObjectiveFilter();
virtual bool Accept(const Assignment* delta, const Assignment* deltadelta);
virtual void Synchronize(const Assignment* assignment);
virtual int64 SynchronizedElementValue(const Assignment* assignment,
int64 index) = 0;
virtual void OnSynchronize();
virtual int64 SynchronizedElementValue(int64 index) = 0;
virtual bool EvaluateElementValue(const Assignment::IntContainer& container,
int index,
int* container_index,
@@ -1993,6 +2023,7 @@ class ObjectiveFilter : public IntVarLocalSearchFilter {
const int64* const out_values,
bool cache_delta_values);
const int primary_vars_size_;
int64* const cache_;
int64* const delta_cache_;
const IntVar* const objective_;
@@ -2009,6 +2040,7 @@ ObjectiveFilter::ObjectiveFilter(const IntVar* const* vars,
Solver::LocalSearchFilterBound filter_enum,
LSOperation* op)
: IntVarLocalSearchFilter(vars, var_size),
primary_vars_size_(var_size),
cache_(new int64[var_size]),
delta_cache_(new int64[var_size]),
objective_(objective),
@@ -2018,7 +2050,7 @@ ObjectiveFilter::ObjectiveFilter(const IntVar* const* vars,
old_delta_value_(0),
incremental_(false) {
CHECK(op_ != NULL);
for (int i = 0; i < size(); ++i) {
for (int i = 0; i < Size(); ++i) {
cache_[i] = 0;
delta_cache_[i] = 0;
}
@@ -2046,7 +2078,7 @@ bool ObjectiveFilter::Accept(const Assignment* delta,
incremental_ = true;
} else {
if (incremental_) {
for (int i = 0; i < size(); ++i) {
for (int i = 0; i < primary_vars_size_; ++i) {
delta_cache_[i] = cache_[i];
}
old_delta_value_ = old_value_;
@@ -2078,10 +2110,10 @@ bool ObjectiveFilter::Accept(const Assignment* delta,
}
}
void ObjectiveFilter::Synchronize(const Assignment* assignment) {
void ObjectiveFilter::OnSynchronize() {
op_->Init();
for (int i = 0; i < size(); ++i) {
const int64 obj_value = SynchronizedElementValue(assignment, i);
for (int i = 0; i < primary_vars_size_; ++i) {
const int64 obj_value = SynchronizedElementValue(i);
cache_[i] = obj_value;
delta_cache_[i] = obj_value;
op_->Update(obj_value);
@@ -2103,9 +2135,9 @@ int64 ObjectiveFilter::Evaluate(const Assignment* delta,
const IntVarElement& new_element = container.Element(i);
const IntVar* var = new_element.Var();
int64 index = -1;
if (FindIndex(var, &index)) {
if (FindIndex(var, &index) && index < primary_vars_size_) {
op_->Remove(out_values[index]);
int64 obj_value;
int64 obj_value = 0LL;
if (EvaluateElementValue(container, index, &i, &obj_value)) {
op_->Update(obj_value);
if (cache_delta_values) {
@@ -2126,30 +2158,29 @@ class BinaryObjectiveFilter : public ObjectiveFilter {
Solver::LocalSearchFilterBound filter_enum,
LSOperation* op);
virtual ~BinaryObjectiveFilter() {}
virtual int64 SynchronizedElementValue(const Assignment* assignment,
int64 index);
virtual int64 SynchronizedElementValue(int64 index);
virtual bool EvaluateElementValue(const Assignment::IntContainer& container,
int index,
int* container_index,
int64* obj_value);
private:
scoped_ptr<Solver::IndexEvaluator2> values_;
scoped_ptr<Solver::IndexEvaluator2> value_evaluator_;
};
BinaryObjectiveFilter::BinaryObjectiveFilter(
const IntVar* const* vars,
int size,
Solver::IndexEvaluator2* values,
Solver::IndexEvaluator2* value_evaluator,
const IntVar* const objective,
Solver::LocalSearchFilterBound filter_enum,
LSOperation* op)
: ObjectiveFilter(vars, size, objective, filter_enum, op), values_(values) {
values_->CheckIsRepeatable();
: ObjectiveFilter(vars, size, objective, filter_enum, op),
value_evaluator_(value_evaluator) {
value_evaluator_->CheckIsRepeatable();
}
int64 BinaryObjectiveFilter::SynchronizedElementValue(
const Assignment* assignment, int64 index) {
return values_->Run(index, assignment->Value(Vars(index)));
int64 BinaryObjectiveFilter::SynchronizedElementValue(int64 index) {
return value_evaluator_->Run(index, Value(index));
}
bool BinaryObjectiveFilter::EvaluateElementValue(
@@ -2159,12 +2190,12 @@ bool BinaryObjectiveFilter::EvaluateElementValue(
int64* obj_value) {
const IntVarElement& element = container.Element(*container_index);
if (element.Activated()) {
*obj_value = values_->Run(index, element.Value());
*obj_value = value_evaluator_->Run(index, element.Value());
return true;
} else {
const IntVar* var = element.Var();
if (var->Bound()) {
*obj_value = values_->Run(index, var->Min());
*obj_value = value_evaluator_->Run(index, var->Min());
return true;
}
}
@@ -2176,47 +2207,42 @@ class TernaryObjectiveFilter : public ObjectiveFilter {
TernaryObjectiveFilter(const IntVar* const* vars,
const IntVar* const* secondary_vars,
int size,
Solver::IndexEvaluator3* values,
Solver::IndexEvaluator3* value_evaluator,
const IntVar* const objective,
Solver::LocalSearchFilterBound filter_enum,
LSOperation* op);
virtual ~TernaryObjectiveFilter() {}
virtual int64 SynchronizedElementValue(const Assignment* assignment,
int64 index);
virtual int64 SynchronizedElementValue(int64 index);
bool EvaluateElementValue(const Assignment::IntContainer& container,
int index,
int* container_index,
int64* obj_value);
private:
scoped_array<IntVar*> secondary_vars_;
scoped_ptr<Solver::IndexEvaluator3> values_;
int secondary_vars_offset_;
scoped_ptr<Solver::IndexEvaluator3> value_evaluator_;
};
TernaryObjectiveFilter::TernaryObjectiveFilter(
const IntVar* const* vars,
const IntVar* const* secondary_vars,
int var_size,
Solver::IndexEvaluator3* values,
Solver::IndexEvaluator3* value_evaluator,
const IntVar* const objective,
Solver::LocalSearchFilterBound filter_enum,
LSOperation* op)
: ObjectiveFilter(vars, var_size, objective, filter_enum, op),
values_(values) {
values_->CheckIsRepeatable();
CHECK_GE(size(), 0);
if (size() > 0) {
secondary_vars_.reset(new IntVar*[size()]);
memcpy(secondary_vars_.get(),
secondary_vars, size() * sizeof(*secondary_vars));
}
secondary_vars_offset_(var_size),
value_evaluator_(value_evaluator) {
value_evaluator_->CheckIsRepeatable();
AddVars(secondary_vars, var_size);
CHECK_GE(Size(), 0);
}
int64 TernaryObjectiveFilter::SynchronizedElementValue(
const Assignment* assignment,
int64 index) {
return values_->Run(index,
assignment->Value(Vars(index)),
assignment->Value(secondary_vars_[index]));
int64 TernaryObjectiveFilter::SynchronizedElementValue(int64 index) {
DCHECK_LT(index, secondary_vars_offset_);
return value_evaluator_->Run(index,
Value(index),
Value(index + secondary_vars_offset_));
}
bool TernaryObjectiveFilter::EvaluateElementValue(
@@ -2224,25 +2250,33 @@ bool TernaryObjectiveFilter::EvaluateElementValue(
int index,
int* container_index,
int64* obj_value) {
DCHECK_LT(index, secondary_vars_offset_);
*obj_value = 0LL;
const IntVarElement& element = container.Element(*container_index);
const IntVar* secondary_var = secondary_vars_[index];
const IntVar* secondary_var = Var(index + secondary_vars_offset_);
if (element.Activated()) {
const int64 value = element.Value();
int hint_index = *container_index + 1;
if (hint_index < container.Size()
&& secondary_var == container.Element(hint_index).Var()) {
*obj_value =
values_->Run(index, value, container.Element(hint_index).Value());
value_evaluator_->Run(index,
value,
container.Element(hint_index).Value());
*container_index = hint_index;
} else {
*obj_value =
values_->Run(index, value, container.Element(secondary_var).Value());
value_evaluator_->Run(index,
value,
container.Element(secondary_var).Value());
}
return true;
} else {
const IntVar* var = element.Var();
if (var->Bound() && secondary_var->Bound()) {
*obj_value = values_->Run(index, var->Min(), secondary_var->Min());
*obj_value = value_evaluator_->Run(index,
var->Min(),
secondary_var->Min());
return true;
}
}