diff --git a/Makefile b/Makefile index 94947d0322..7405a97b48 100644 --- a/Makefile +++ b/Makefile @@ -323,8 +323,11 @@ nqueens: $(CPLIBS) $(BASE_LIBS) objs/nqueens.o objs/tricks.o: examples/tricks.cc $(CCC) $(CFLAGS) -c examples/tricks.cc -o objs/tricks.o -tricks: $(CPLIBS) $(BASE_LIBS) objs/tricks.o - $(CCC) $(CFLAGS) $(LDFLAGS) objs/tricks.o $(CPLIBS) $(BASE_LIBS) -o tricks +objs/global_arith.o: examples/global_arith.cc + $(CCC) $(CFLAGS) -c examples/global_arith.cc -o objs/global_arith.o + +tricks: $(CPLIBS) $(BASE_LIBS) objs/tricks.o objs/global_arith.o + $(CCC) $(CFLAGS) $(LDFLAGS) objs/tricks.o objs/global_arith.o $(CPLIBS) $(BASE_LIBS) -o tricks # Routing Examples diff --git a/examples/global_arith.cc b/examples/global_arith.cc new file mode 100644 index 0000000000..ca46742c10 --- /dev/null +++ b/examples/global_arith.cc @@ -0,0 +1,614 @@ +// Copyright 2010 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "base/stl_util-inl.h" +#include "constraint_solver/constraint_solveri.h" +#include "examples/global_arith.h" + +namespace operations_research { + +class ArithmeticPropagator; + +// ----- SubstitutionMap ----- + +class SubstitutionMap { + public: + struct Offset { // to_replace = var_index + offset + Offset() : var_index(-1), offset(0) {} + Offset(int v, int64 o) : var_index(v), offset(o) {} + int var_index; + int64 offset; + }; + + void AddSubstitution(int left_var, int right_var, int64 right_offset) { + // TODO(lperron) : Perform transitive closure. + substitutions_[left_var] = Offset(right_var, right_offset); + } + + void ProcessAllSubstitutions(Callback3* const hook) { + for (hash_map::const_iterator it = substitutions_.begin(); + it != substitutions_.end(); + ++it) { + hook->Run(it->first, it->second.var_index, it->second.offset); + } + } + private: + hash_map substitutions_; +}; + +// ----- Bounds ----- + +struct Bounds { + Bounds() : lb(kint64min), ub(kint64max) {} + Bounds(int64 l, int64 u) : lb(l), ub(u) {} + + void Intersect(int64 new_lb, int64 new_ub) { + lb = std::max(lb, new_lb); + ub = std::min(ub, new_ub); + } + + void Intersect(const Bounds& other) { + Intersect(other.lb, other.ub); + } + + void Union(int64 new_lb, int64 new_ub) { + lb = std::min(lb, new_lb); + ub = std::max(ub, new_ub); + } + + void Union(const Bounds& other) { + Union(other.lb, other.ub); + } + + bool IsEqual(const Bounds& other) { + return (ub == other.ub && lb == other.lb); + } + + bool IsIncluded(const Bounds& other) { + return (ub <= other.ub && lb >= other.lb); + } + + int64 lb; + int64 ub; +}; + +// ----- BoundsStore ----- + +class BoundsStore { + public: + BoundsStore(vector* initial_bounds) + : initial_bounds_(initial_bounds) {} + + void SetRange(int var_index, int64 lb, int64 ub) { + hash_map::iterator it = modified_bounds_.find(var_index); + if (it == modified_bounds_.end()) { + Bounds new_bounds(lb, ub); + const Bounds& initial = (*initial_bounds_)[var_index]; + new_bounds.Intersect(initial); + if (!new_bounds.IsEqual(initial)) { + modified_bounds_.insert(make_pair(var_index, new_bounds)); + } + } else { + it->second.Intersect(lb, ub); + } + } + + void Clear() { + modified_bounds_.clear(); + } + + const hash_map& modified_bounds() const { + return modified_bounds_; + } + + vector* initial_bounds() const { return initial_bounds_; } + + void Apply() { + for (hash_map::const_iterator it = modified_bounds_.begin(); + it != modified_bounds_.end(); + ++it) { + (*initial_bounds_)[it->first] = it->second; + } + } + + private: + vector* initial_bounds_; + hash_map modified_bounds_; +}; + +// ----- ArithmeticConstraint ----- + +class ArithmeticConstraint { + public: + virtual ~ArithmeticConstraint() {} + + const vector& vars() const { return vars_; } + + virtual bool Propagate(BoundsStore* const store) = 0; + virtual void Replace(int to_replace, int var, int64 offset) = 0; + virtual void Deduce(ArithmeticPropagator* const propagator) const = 0; + virtual string DebugString() const = 0; + private: + const vector vars_; +}; + +// ----- ArithmeticPropagator ----- + +class ArithmeticPropagator : PropagationBaseObject { + public: + ArithmeticPropagator(Solver* const solver, Demon* const demon) + : PropagationBaseObject(solver), demon_(demon) {} + + void ReduceProblem() { + for (int constraint_index = 0; + constraint_index < constraints_.size(); + ++constraint_index) { + constraints_[constraint_index]->Deduce(this); + } + scoped_ptr > hook( + NewPermanentCallback(this, + &ArithmeticPropagator::ProcessOneSubstitution)); + substitution_map_.ProcessAllSubstitutions(hook.get()); + } + + void Post() { + for (int constraint_index = 0; + constraint_index < constraints_.size(); + ++constraint_index) { + const vector& vars = constraints_[constraint_index]->vars(); + for (int var_index = 0; var_index < vars.size(); ++var_index) { + dependencies_[vars[var_index]].push_back(constraint_index); + } + } + } + + void InitialPropagate() { + + } + + void Update(int var_index) { + Enqueue(demon_); + } + + void AddConstraint(ArithmeticConstraint* const ct) { + constraints_.push_back(ct); + } + + void AddVariable(int64 lb, int64 ub) { + bounds_.push_back(Bounds(lb, ub)); + } + + const vector vars() const { return vars_; } + + int VarIndex(IntVar* const var) { + hash_map::const_iterator it = var_map_.find(var); + if (it == var_map_.end()) { + const int index = var_map_.size(); + var_map_[var] = index; + return index; + } else { + return it->second; + } + } + + void AddSubstitution(int left_var, int right_var, int64 right_offset) { + substitution_map_.AddSubstitution(left_var, right_var, right_offset); + } + + void AddNewBounds(int var_index, int64 lb, int64 ub) { + bounds_[var_index].Intersect(lb, ub); + } + + void ProcessOneSubstitution(int left_var, int right_var, int64 right_offset) { + for (int constraint_index = 0; + constraint_index < constraints_.size(); + ++constraint_index) { + ArithmeticConstraint* const constraint = constraints_[constraint_index]; + constraint->Replace(left_var, right_var, right_offset); + } + } + + void PrintModel() { + LOG(INFO) << "Vars:"; + for (int i = 0; i < bounds_.size(); ++i) { + LOG(INFO) << " var<" << i << "> = [" << bounds_[i].lb + << " .. " << bounds_[i].ub << "]"; + } + LOG(INFO) << "Constraints"; + for (int i = 0; i < constraints_.size(); ++i) { + LOG(INFO) << " " << constraints_[i]->DebugString(); + } + } + private: + Demon* const demon_; + vector vars_; + hash_map var_map_; + vector constraints_; + vector bounds_; + vector > dependencies_; // from var indices to constraints. + SubstitutionMap substitution_map_; +}; + +// ----- Custom Constraints ----- + +class VarEqualVarPlusOffset : public ArithmeticConstraint { + public: + VarEqualVarPlusOffset(int left_var, int right_var, int64 right_offset) + : left_var_(left_var), right_var_(right_var), + right_offset_(right_offset) {} + + virtual bool Propagate(BoundsStore* const store) { + return true; + } + + virtual void Replace(int to_replace, int var, int64 offset) { + if ((to_replace == left_var_ && + var == right_offset_ && + offset == right_offset_) || + (to_replace == right_var_ && + var == left_var_ && offset == -right_offset_)) { + return; + } + if (to_replace == left_var_) { + left_var_ = to_replace; + right_offset_ -= offset; + return; + } + if (to_replace == right_var_) { + right_var_ = var; + right_offset_ += offset; + return; + } + } + + virtual void Deduce(ArithmeticPropagator* const propagator) const { + propagator->AddSubstitution(left_var_, right_var_, right_offset_); + } + + virtual string DebugString() const { + if (right_offset_ == 0) { + return StringPrintf("var<%d> == var<%d>", left_var_, right_var_); + } else { + return StringPrintf("var<%d> == var<%d> + %lld", + left_var_, + right_var_, + right_offset_); + } + } + private: + int left_var_; + int right_var_; + int64 right_offset_; +}; + +class RowConstraint : public ArithmeticConstraint { + public: + RowConstraint(int64 lb, int64 ub) : lb_(lb), ub_(ub) {} + virtual ~RowConstraint() {} + + void AddTerm(int var_index, int64 coefficient) { + // TODO(lperron): Check not present. + coefficients_[var_index] = coefficient; + } + + virtual bool Propagate(BoundsStore* const store) { + return true; + } + + virtual void Replace(int to_replace, int var, int64 offset) { + hash_map::iterator find_other = coefficients_.find(to_replace); + if (find_other != coefficients_.end()) { + hash_map::iterator find_var = coefficients_.find(var); + const int64 other_coefficient = find_other->second; + if (lb_ != kint64min) { + lb_ += other_coefficient * offset; + } + if (ub_ != kint64max) { + ub_ += other_coefficient * offset; + } + coefficients_.erase(find_other); + if (find_var == coefficients_.end()) { + coefficients_[var] = other_coefficient; + } else { + find_var->second += other_coefficient; + if (find_var->second == 0) { + coefficients_.erase(find_var); + } + } + } + } + + virtual void Deduce(ArithmeticPropagator* const propagator) const {} + + virtual string DebugString() const { + string output = "("; + bool first = true; + for (hash_map::const_iterator it = coefficients_.begin(); + it != coefficients_.end(); + ++it) { + if (it->second != 0) { + if (first) { + first = false; + if (it->second == 1) { + output += StringPrintf("var<%d>", it->first); + } else if (it->second == -1) { + output += StringPrintf("-var<%d>", it->first); + } else { + output += StringPrintf("%lld*var<%d>", it->second, it->first); + } + } else if (it->second == 1) { + output += StringPrintf(" + var<%d>", it->first); + } else if (it->second == -1) { + output += StringPrintf(" - var<%d>", it->first); + } else if (it->second > 0) { + output += StringPrintf(" + %lld*var<%d>", it->second, it->first); + } else { + output += StringPrintf(" - %lld*var<%d>", -it->second, it->first); + } + } + } + if (lb_ == ub_) { + output += StringPrintf(" == %lld)", ub_); + } else if (lb_ == kint64min) { + output += StringPrintf(" <= %lld)", ub_); + } else if (ub_ == kint64max) { + output += StringPrintf(" >= %lld)", lb_); + } else { + output += StringPrintf(" in [%lld .. %lld])", lb_, ub_); + } + return output; + } + private: + hash_map coefficients_; + int64 lb_; + int64 ub_; +}; + +class OrConstraint : public ArithmeticConstraint { + public: + OrConstraint(ArithmeticConstraint* const left, + ArithmeticConstraint* const right) + : left_(left), right_(right) {} + + virtual ~OrConstraint() {} + + virtual bool Propagate(BoundsStore* const store) { + return true; + } + + virtual void Replace(int to_replace, int var, int64 offset) { + left_->Replace(to_replace, var, offset); + right_->Replace(to_replace, var, offset); + } + + virtual void Deduce(ArithmeticPropagator* const propagator) const {} + + virtual string DebugString() const { + return StringPrintf("Or(%s, %s)", + left_->DebugString().c_str(), + right_->DebugString().c_str()); + } + private: + ArithmeticConstraint* const left_; + ArithmeticConstraint* const right_; +}; + +// ----- GlobalArithmeticConstraint ----- + +GlobalArithmeticConstraint::GlobalArithmeticConstraint(Solver* const solver) + : Constraint(solver), + propagator_(NULL) { + propagator_.reset(new ArithmeticPropagator( + solver, + solver->MakeDelayedConstraintInitialPropagateCallback(this))); +} +GlobalArithmeticConstraint::~GlobalArithmeticConstraint() { + STLDeleteElements(&constraints_); +} + +void GlobalArithmeticConstraint::Post() { + const vector& vars = propagator_->vars(); + for (int var_index = 0; var_index < vars.size(); ++var_index) { + Demon* const demon = + MakeConstraintDemon1(solver(), + this, + &GlobalArithmeticConstraint::Update, + "Update", + var_index); + vars[var_index]->WhenRange(demon); + } + LOG(INFO) << "----- Before reduction -----"; + propagator_->PrintModel(); + LOG(INFO) << "----- After reduction -----"; + propagator_->ReduceProblem(); + propagator_->PrintModel(); + LOG(INFO) << "---------------------------"; + propagator_->Post(); +} + +void GlobalArithmeticConstraint::InitialPropagate() { + propagator_->InitialPropagate(); +} + +void GlobalArithmeticConstraint::Update(int var_index) { + propagator_->Update(var_index); +} + +int GlobalArithmeticConstraint::MakeVarEqualVarPlusOffset( + IntVar* const left_var, + IntVar* const right_var, + int64 right_offset) { + const int left_index = VarIndex(left_var); + const int right_index = VarIndex(right_var); + return Store(new VarEqualVarPlusOffset(left_index, + right_index, + right_offset)); +} + +int GlobalArithmeticConstraint::MakeScalProdGreaterOrEqualConstant( + const vector vars, + const vector coefficients, + int64 constant) { + RowConstraint* const constraint = new RowConstraint(constant, kint64max); + for (int index = 0; index < vars.size(); ++index) { + constraint->AddTerm(VarIndex(vars[index]), coefficients[index]); + } + return Store(constraint); +} + +int GlobalArithmeticConstraint::MakeScalProdLessOrEqualConstant( + const vector vars, + const vector coefficients, + int64 constant) { + RowConstraint* const constraint = new RowConstraint(kint64min, constant); + for (int index = 0; index < vars.size(); ++index) { + constraint->AddTerm(VarIndex(vars[index]), coefficients[index]); + } + return Store(constraint); +} + +int GlobalArithmeticConstraint::MakeScalProdEqualConstant( + const vector vars, + const vector coefficients, + int64 constant) { + RowConstraint* const constraint = new RowConstraint(constant, constant); + for (int index = 0; index < vars.size(); ++index) { + constraint->AddTerm(VarIndex(vars[index]), coefficients[index]); + } + return Store(constraint); +} + +int GlobalArithmeticConstraint::MakeSumGreaterOrEqualConstant( + const vector vars, + int64 constant) { + RowConstraint* const constraint = new RowConstraint(constant, kint64max); + for (int index = 0; index < vars.size(); ++index) { + constraint->AddTerm(VarIndex(vars[index]), 1); + } + return Store(constraint); +} + +int GlobalArithmeticConstraint::MakeSumLessOrEqualConstant( + const vector vars, int64 constant) { + RowConstraint* const constraint = new RowConstraint(kint64min, constant); + for (int index = 0; index < vars.size(); ++index) { + constraint->AddTerm(VarIndex(vars[index]), 1); + } + return Store(constraint); +} + +int GlobalArithmeticConstraint::MakeSumEqualConstant( + const vector vars, int64 constant) { + RowConstraint* const constraint = new RowConstraint(constant, constant); + for (int index = 0; index < vars.size(); ++index) { + constraint->AddTerm(VarIndex(vars[index]), 1); + } + return Store(constraint); +} + +int GlobalArithmeticConstraint::MakeRowConstraint( + int64 lb, + const vector vars, + const vector coefficients, + int64 ub) { + RowConstraint* const constraint = new RowConstraint(lb, ub); + for (int index = 0; index < vars.size(); ++index) { + constraint->AddTerm(VarIndex(vars[index]), coefficients[index]); + } + return Store(constraint); +} + +int GlobalArithmeticConstraint::MakeRowConstraint(int64 lb, + IntVar* const v1, + int64 coeff1, + int64 ub) { + RowConstraint* const constraint = new RowConstraint(lb, ub); + constraint->AddTerm(VarIndex(v1), coeff1); + return Store(constraint); +} + +int GlobalArithmeticConstraint::MakeRowConstraint(int64 lb, + IntVar* const v1, + int64 coeff1, + IntVar* const v2, + int64 coeff2, + int64 ub) { + RowConstraint* const constraint = new RowConstraint(lb, ub); + constraint->AddTerm(VarIndex(v1), coeff1); + constraint->AddTerm(VarIndex(v2), coeff2); + return Store(constraint); +} + +int GlobalArithmeticConstraint::MakeRowConstraint(int64 lb, + IntVar* const v1, + int64 coeff1, + IntVar* const v2, + int64 coeff2, + IntVar* const v3, + int64 coeff3, + int64 ub) { + RowConstraint* const constraint = new RowConstraint(lb, ub); + constraint->AddTerm(VarIndex(v1), coeff1); + constraint->AddTerm(VarIndex(v2), coeff2); + constraint->AddTerm(VarIndex(v3), coeff3); + return Store(constraint); +} + +int GlobalArithmeticConstraint::MakeRowConstraint(int64 lb, + IntVar* const v1, + int64 coeff1, + IntVar* const v2, + int64 coeff2, + IntVar* const v3, + int64 coeff3, + IntVar* const v4, + int64 coeff4, + int64 ub) { + RowConstraint* const constraint = new RowConstraint(lb, ub); + constraint->AddTerm(VarIndex(v1), coeff1); + constraint->AddTerm(VarIndex(v2), coeff2); + constraint->AddTerm(VarIndex(v3), coeff3); + constraint->AddTerm(VarIndex(v4), coeff4); + return Store(constraint); +} + +int GlobalArithmeticConstraint::MakeOrConstraint(int left_constraint_index, + int right_constraint_index) { + OrConstraint* const constraint = + new OrConstraint(constraints_[left_constraint_index], + constraints_[right_constraint_index]); + return Store(constraint); +} + +void GlobalArithmeticConstraint::Add(int constraint_index) { + propagator_->AddConstraint(constraints_[constraint_index]); +} + +int GlobalArithmeticConstraint::VarIndex(IntVar* const var) { + hash_map::const_iterator it = var_indices_.find(var); + if (it == var_indices_.end()) { + const int new_index = var_indices_.size(); + var_indices_.insert(make_pair(var, new_index)); + propagator_->AddVariable(var->Min(), var->Max()); + return new_index; + } else { + return it->second; + } +} + +int GlobalArithmeticConstraint::Store(ArithmeticConstraint* const constraint) { + const int constraint_index = constraints_.size(); + constraints_.push_back(constraint); + return constraint_index; +} +} // namespace operations_research diff --git a/examples/global_arith.h b/examples/global_arith.h new file mode 100644 index 0000000000..f653bf7acf --- /dev/null +++ b/examples/global_arith.h @@ -0,0 +1,89 @@ +// Copyright 2010 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef EXAMPLES_GLOBAL_ARITH_H_ +#define EXAMPLES_GLOBAL_ARITH_H_ +namespace operations_research { +class ArithmeticPropagator; +class ArithmeticConstraint; + +class GlobalArithmeticConstraint : public Constraint { + public: + GlobalArithmeticConstraint(Solver* const solver); + virtual ~GlobalArithmeticConstraint(); + + virtual void Post(); + virtual void InitialPropagate(); + void Update(int var_index); + + int MakeVarEqualVarPlusOffset(IntVar* const left_var, + IntVar* const right_var, + int64 right_offset); + int MakeScalProdGreaterOrEqualConstant(const vector vars, + const vector coefficients, + int64 constant); + int MakeScalProdLessOrEqualConstant(const vector vars, + const vector coefficients, + int64 constant); + int MakeScalProdEqualConstant(const vector vars, + const vector coefficients, + int64 constant); + int MakeSumGreaterOrEqualConstant(const vector vars, int64 constant); + int MakeSumLessOrEqualConstant(const vector vars, int64 constant); + int MakeSumEqualConstant(const vector vars, int64 constant); + int MakeRowConstraint(int64 lb, + const vector vars, + const vector coefficients, + int64 ub); + int MakeRowConstraint(int64 lb, + IntVar* const v1, + int64 coeff1, + int64 ub); + int MakeRowConstraint(int64 lb, + IntVar* const v1, + int64 coeff1, + IntVar* const v2, + int64 coeff2, + int64 ub); + int MakeRowConstraint(int64 lb, + IntVar* const v1, + int64 coeff1, + IntVar* const v2, + int64 coeff2, + IntVar* const v3, + int64 coeff3, + int64 ub); + int MakeRowConstraint(int64 lb, + IntVar* const v1, + int64 coeff1, + IntVar* const v2, + int64 coeff2, + IntVar* const v3, + int64 coeff3, + IntVar* const v4, + int64 coeff4, + int64 ub); + int MakeOrConstraint(int left_constraint_index, int right_constraint_index); + + void Add(int constraint_index); + private: + int VarIndex(IntVar* const var); + int Store(ArithmeticConstraint* const constraint); + + scoped_ptr propagator_; + hash_map var_indices_; + vector constraints_; +}; + +} // namespace operations_research +#endif // EXAMPLES_GLOBAL_ARITH_H_ diff --git a/examples/tricks.cc b/examples/tricks.cc index fc8ae42a59..9db336aa1b 100644 --- a/examples/tricks.cc +++ b/examples/tricks.cc @@ -1,615 +1,27 @@ -// Copyright 2010 Google Inc. All Rights Reserved. -// Author: lperron@google.com (Laurent Perron) +// Copyright 2010 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "base/commandlineflags.h" #include "base/integral_types.h" #include "base/logging.h" #include "base/scoped_ptr.h" -#include "base/stl_util-inl.h" #include "base/stringprintf.h" #include "constraint_solver/constraint_solveri.h" +#include "examples/global_arith.h" DEFINE_int32(size, 20, "Size of the problem"); namespace operations_research { - -// ---------- Arithmetic Propagator ---------- - -namespace { - -class ArithmeticPropagator; - -// ----- SubstitutionMap ----- - -class SubstitutionMap { - public: - struct Offset { // to_replace = var_index + offset - Offset() : var_index(-1), offset(0) {} - Offset(int v, int64 o) : var_index(v), offset(o) {} - int var_index; - int64 offset; - }; - - void AddSubstitution(int left_var, int right_var, int64 right_offset) { - // TODO(lperron) : Perform transitive closure. - substitutions_[left_var] = Offset(right_var, right_offset); - } - - void ProcessAllSubstitutions(Callback3* const hook) { - for (hash_map::const_iterator it = substitutions_.begin(); - it != substitutions_.end(); - ++it) { - hook->Run(it->first, it->second.var_index, it->second.offset); - } - } - private: - hash_map substitutions_; -}; - -// ----- Bounds ----- - -struct Bounds { - Bounds() : lb(kint64min), ub(kint64max) {} - Bounds(int64 l, int64 u) : lb(l), ub(u) {} - - void Intersect(int64 new_lb, int64 new_ub) { - lb = std::max(lb, new_lb); - ub = std::min(ub, new_ub); - } - - void Intersect(const Bounds& other) { - Intersect(other.lb, other.ub); - } - - void Union(int64 new_lb, int64 new_ub) { - lb = std::min(lb, new_lb); - ub = std::max(ub, new_ub); - } - - void Union(const Bounds& other) { - Union(other.lb, other.ub); - } - - bool IsEqual(const Bounds& other) { - return (ub == other.ub && lb == other.lb); - } - - bool IsIncluded(const Bounds& other) { - return (ub <= other.ub && lb >= other.lb); - } - - int64 lb; - int64 ub; -}; - -// ----- BoundsStore ----- - -class BoundsStore { - public: - BoundsStore(vector* initial_bounds) - : initial_bounds_(initial_bounds) {} - - void SetRange(int var_index, int64 lb, int64 ub) { - hash_map::iterator it = modified_bounds_.find(var_index); - if (it == modified_bounds_.end()) { - Bounds new_bounds(lb, ub); - const Bounds& initial = (*initial_bounds_)[var_index]; - new_bounds.Intersect(initial); - if (!new_bounds.IsEqual(initial)) { - modified_bounds_.insert(make_pair(var_index, new_bounds)); - } - } else { - it->second.Intersect(lb, ub); - } - } - - void Clear() { - modified_bounds_.clear(); - } - - const hash_map& modified_bounds() const { - return modified_bounds_; - } - - vector* initial_bounds() const { return initial_bounds_; } - - void Apply() { - for (hash_map::const_iterator it = modified_bounds_.begin(); - it != modified_bounds_.end(); - ++it) { - (*initial_bounds_)[it->first] = it->second; - } - } - - private: - vector* initial_bounds_; - hash_map modified_bounds_; -}; - -// ----- ArithmeticConstraint ----- - -class ArithmeticConstraint { - public: - virtual ~ArithmeticConstraint() {} - - const vector& vars() const { return vars_; } - - virtual bool Propagate(BoundsStore* const store) = 0; - virtual void Replace(int to_replace, int var, int64 offset) = 0; - virtual void Deduce(ArithmeticPropagator* const propagator) const = 0; - virtual string DebugString() const = 0; - private: - const vector vars_; -}; - -// ----- ArithmeticPropagator ----- - -class ArithmeticPropagator : PropagationBaseObject { - public: - ArithmeticPropagator(Solver* const solver, Demon* const demon) - : PropagationBaseObject(solver), demon_(demon) {} - - void ReduceProblem() { - for (int constraint_index = 0; - constraint_index < constraints_.size(); - ++constraint_index) { - constraints_[constraint_index]->Deduce(this); - } - scoped_ptr > hook( - NewPermanentCallback(this, - &ArithmeticPropagator::ProcessOneSubstitution)); - substitution_map_.ProcessAllSubstitutions(hook.get()); - } - - void Post() { - for (int constraint_index = 0; - constraint_index < constraints_.size(); - ++constraint_index) { - const vector& vars = constraints_[constraint_index]->vars(); - for (int var_index = 0; var_index < vars.size(); ++var_index) { - dependencies_[vars[var_index]].push_back(constraint_index); - } - } - } - - void InitialPropagate() { - - } - - void Update(int var_index) { - Enqueue(demon_); - } - - void AddConstraint(ArithmeticConstraint* const ct) { - constraints_.push_back(ct); - } - - void AddVariable(int64 lb, int64 ub) { - bounds_.push_back(Bounds(lb, ub)); - } - - const vector vars() const { return vars_; } - - int VarIndex(IntVar* const var) { - hash_map::const_iterator it = var_map_.find(var); - if (it == var_map_.end()) { - const int index = var_map_.size(); - var_map_[var] = index; - return index; - } else { - return it->second; - } - } - - void AddSubstitution(int left_var, int right_var, int64 right_offset) { - substitution_map_.AddSubstitution(left_var, right_var, right_offset); - } - - void AddNewBounds(int var_index, int64 lb, int64 ub) { - bounds_[var_index].Intersect(lb, ub); - } - - void ProcessOneSubstitution(int left_var, int right_var, int64 right_offset) { - for (int constraint_index = 0; - constraint_index < constraints_.size(); - ++constraint_index) { - ArithmeticConstraint* const constraint = constraints_[constraint_index]; - constraint->Replace(left_var, right_var, right_offset); - } - } - - void PrintModel() { - LOG(INFO) << "Vars:"; - for (int i = 0; i < bounds_.size(); ++i) { - LOG(INFO) << " var<" << i << "> = [" << bounds_[i].lb - << " .. " << bounds_[i].ub << "]"; - } - LOG(INFO) << "Constraints"; - for (int i = 0; i < constraints_.size(); ++i) { - LOG(INFO) << " " << constraints_[i]->DebugString(); - } - } - private: - Demon* const demon_; - vector vars_; - hash_map var_map_; - vector constraints_; - vector bounds_; - vector > dependencies_; // from var indices to constraints. - SubstitutionMap substitution_map_; -}; - -// ----- Custom Constraints ----- - -class VarEqualVarPlusOffset : public ArithmeticConstraint { - public: - VarEqualVarPlusOffset(int left_var, int right_var, int64 right_offset) - : left_var_(left_var), right_var_(right_var), - right_offset_(right_offset) {} - - virtual bool Propagate(BoundsStore* const store) { - return true; - } - - virtual void Replace(int to_replace, int var, int64 offset) { - if ((to_replace == left_var_ && - var == right_offset_ && - offset == right_offset_) || - (to_replace == right_var_ && - var == left_var_ && offset == -right_offset_)) { - return; - } - if (to_replace == left_var_) { - left_var_ = to_replace; - right_offset_ -= offset; - return; - } - if (to_replace == right_var_) { - right_var_ = var; - right_offset_ += offset; - return; - } - } - - virtual void Deduce(ArithmeticPropagator* const propagator) const { - propagator->AddSubstitution(left_var_, right_var_, right_offset_); - } - - virtual string DebugString() const { - if (right_offset_ == 0) { - return StringPrintf("var<%d> == var<%d>", left_var_, right_var_); - } else { - return StringPrintf("var<%d> == var<%d> + %lld", - left_var_, - right_var_, - right_offset_); - } - } - private: - int left_var_; - int right_var_; - int64 right_offset_; -}; - -class RowConstraint : public ArithmeticConstraint { - public: - RowConstraint(int64 lb, int64 ub) : lb_(lb), ub_(ub) {} - virtual ~RowConstraint() {} - - void AddTerm(int var_index, int64 coefficient) { - // TODO(lperron): Check not present. - coefficients_[var_index] = coefficient; - } - - virtual bool Propagate(BoundsStore* const store) { - return true; - } - - virtual void Replace(int to_replace, int var, int64 offset) { - hash_map::iterator find_other = coefficients_.find(to_replace); - if (find_other != coefficients_.end()) { - hash_map::iterator find_var = coefficients_.find(var); - const int64 other_coefficient = find_other->second; - if (lb_ != kint64min) { - lb_ += other_coefficient * offset; - } - if (ub_ != kint64max) { - ub_ += other_coefficient * offset; - } - coefficients_.erase(find_other); - if (find_var == coefficients_.end()) { - coefficients_[var] = other_coefficient; - } else { - find_var->second += other_coefficient; - if (find_var->second == 0) { - coefficients_.erase(find_var); - } - } - } - } - - virtual void Deduce(ArithmeticPropagator* const propagator) const {} - - virtual string DebugString() const { - string output = "("; - bool first = true; - for (hash_map::const_iterator it = coefficients_.begin(); - it != coefficients_.end(); - ++it) { - if (it->second != 0) { - if (first) { - first = false; - if (it->second == 1) { - output += StringPrintf("var<%d>", it->first); - } else if (it->second == -1) { - output += StringPrintf("-var<%d>", it->first); - } else { - output += StringPrintf("%lld*var<%d>", it->second, it->first); - } - } else if (it->second == 1) { - output += StringPrintf(" + var<%d>", it->first); - } else if (it->second == -1) { - output += StringPrintf(" - var<%d>", it->first); - } else if (it->second > 0) { - output += StringPrintf(" + %lld*var<%d>", it->second, it->first); - } else { - output += StringPrintf(" - %lld*var<%d>", -it->second, it->first); - } - } - } - if (lb_ == ub_) { - output += StringPrintf(" == %lld)", ub_); - } else if (lb_ == kint64min) { - output += StringPrintf(" <= %lld)", ub_); - } else if (ub_ == kint64max) { - output += StringPrintf(" >= %lld)", lb_); - } else { - output += StringPrintf(" in [%lld .. %lld])", lb_, ub_); - } - return output; - } - private: - hash_map coefficients_; - int64 lb_; - int64 ub_; -}; - -class OrConstraint : public ArithmeticConstraint { - public: - OrConstraint(ArithmeticConstraint* const left, - ArithmeticConstraint* const right) - : left_(left), right_(right) {} - - virtual ~OrConstraint() {} - - virtual bool Propagate(BoundsStore* const store) { - return true; - } - - virtual void Replace(int to_replace, int var, int64 offset) { - left_->Replace(to_replace, var, offset); - right_->Replace(to_replace, var, offset); - } - - virtual void Deduce(ArithmeticPropagator* const propagator) const {} - - virtual string DebugString() const { - return StringPrintf("Or(%s, %s)", - left_->DebugString().c_str(), - right_->DebugString().c_str()); - } - private: - ArithmeticConstraint* const left_; - ArithmeticConstraint* const right_; -}; - -// ----- GlobalArithmeticConstraint ----- - -class GlobalArithmeticConstraint : public Constraint { - public: - GlobalArithmeticConstraint(Solver* const solver) - : Constraint(solver), - propagator_( - solver, - solver->MakeDelayedConstraintInitialPropagateCallback(this)) {} - virtual ~GlobalArithmeticConstraint() { - STLDeleteElements(&constraints_); - } - - virtual void Post() { - const vector& vars = propagator_.vars(); - for (int var_index = 0; var_index < vars.size(); ++var_index) { - Demon* const demon = - MakeConstraintDemon1(solver(), - this, - &GlobalArithmeticConstraint::Update, - "Update", - var_index); - vars[var_index]->WhenRange(demon); - } - LOG(INFO) << "----- Before reduction -----"; - propagator_.PrintModel(); - LOG(INFO) << "----- After reduction -----"; - propagator_.ReduceProblem(); - propagator_.PrintModel(); - LOG(INFO) << "---------------------------"; - propagator_.Post(); - } - - virtual void InitialPropagate() { - propagator_.InitialPropagate(); - } - - void Update(int var_index) { - propagator_.Update(var_index); - } - - int MakeVarEqualVarPlusOffset(IntVar* const left_var, - IntVar* const right_var, - int64 right_offset) { - const int left_index = VarIndex(left_var); - const int right_index = VarIndex(right_var); - return Store(new VarEqualVarPlusOffset(left_index, - right_index, - right_offset)); - } - - int MakeScalProdGreaterOrEqualConstant(const vector vars, - const vector coefficients, - int64 constant) { - RowConstraint* const constraint = new RowConstraint(constant, kint64max); - for (int index = 0; index < vars.size(); ++index) { - constraint->AddTerm(VarIndex(vars[index]), coefficients[index]); - } - return Store(constraint); - } - - int MakeScalProdLessOrEqualConstant(const vector vars, - const vector coefficients, - int64 constant) { - RowConstraint* const constraint = new RowConstraint(kint64min, constant); - for (int index = 0; index < vars.size(); ++index) { - constraint->AddTerm(VarIndex(vars[index]), coefficients[index]); - } - return Store(constraint); - } - - int MakeScalProdEqualConstant(const vector vars, - const vector coefficients, - int64 constant) { - RowConstraint* const constraint = new RowConstraint(constant, constant); - for (int index = 0; index < vars.size(); ++index) { - constraint->AddTerm(VarIndex(vars[index]), coefficients[index]); - } - return Store(constraint); - } - - int MakeSumGreaterOrEqualConstant(const vector vars, - int64 constant) { - RowConstraint* const constraint = new RowConstraint(constant, kint64max); - for (int index = 0; index < vars.size(); ++index) { - constraint->AddTerm(VarIndex(vars[index]), 1); - } - return Store(constraint); - } - - int MakeSumLessOrEqualConstant(const vector vars, int64 constant) { - RowConstraint* const constraint = new RowConstraint(kint64min, constant); - for (int index = 0; index < vars.size(); ++index) { - constraint->AddTerm(VarIndex(vars[index]), 1); - } - return Store(constraint); - } - - int MakeSumEqualConstant(const vector vars, int64 constant) { - RowConstraint* const constraint = new RowConstraint(constant, constant); - for (int index = 0; index < vars.size(); ++index) { - constraint->AddTerm(VarIndex(vars[index]), 1); - } - return Store(constraint); - } - - int MakeRowConstraint(int64 lb, - const vector vars, - const vector coefficients, - int64 ub) { - RowConstraint* const constraint = new RowConstraint(lb, ub); - for (int index = 0; index < vars.size(); ++index) { - constraint->AddTerm(VarIndex(vars[index]), coefficients[index]); - } - return Store(constraint); - } - - int MakeRowConstraint(int64 lb, - IntVar* const v1, - int64 coeff1, - int64 ub) { - RowConstraint* const constraint = new RowConstraint(lb, ub); - constraint->AddTerm(VarIndex(v1), coeff1); - return Store(constraint); - } - - int MakeRowConstraint(int64 lb, - IntVar* const v1, - int64 coeff1, - IntVar* const v2, - int64 coeff2, - int64 ub) { - RowConstraint* const constraint = new RowConstraint(lb, ub); - constraint->AddTerm(VarIndex(v1), coeff1); - constraint->AddTerm(VarIndex(v2), coeff2); - return Store(constraint); - } - - int MakeRowConstraint(int64 lb, - IntVar* const v1, - int64 coeff1, - IntVar* const v2, - int64 coeff2, - IntVar* const v3, - int64 coeff3, - int64 ub) { - RowConstraint* const constraint = new RowConstraint(lb, ub); - constraint->AddTerm(VarIndex(v1), coeff1); - constraint->AddTerm(VarIndex(v2), coeff2); - constraint->AddTerm(VarIndex(v3), coeff3); - return Store(constraint); - } - - int MakeRowConstraint(int64 lb, - IntVar* const v1, - int64 coeff1, - IntVar* const v2, - int64 coeff2, - IntVar* const v3, - int64 coeff3, - IntVar* const v4, - int64 coeff4, - int64 ub) { - RowConstraint* const constraint = new RowConstraint(lb, ub); - constraint->AddTerm(VarIndex(v1), coeff1); - constraint->AddTerm(VarIndex(v2), coeff2); - constraint->AddTerm(VarIndex(v3), coeff3); - constraint->AddTerm(VarIndex(v4), coeff4); - return Store(constraint); - } - - int MakeOrConstraint(int left_constraint_index, int right_constraint_index) { - OrConstraint* const constraint = - new OrConstraint(constraints_[left_constraint_index], - constraints_[right_constraint_index]); - return Store(constraint); - } - - void Add(int constraint_index) { - propagator_.AddConstraint(constraints_[constraint_index]); - } - private: - int VarIndex(IntVar* const var) { - hash_map::const_iterator it = var_indices_.find(var); - if (it == var_indices_.end()) { - const int new_index = var_indices_.size(); - var_indices_.insert(make_pair(var, new_index)); - propagator_.AddVariable(var->Min(), var->Max()); - return new_index; - } else { - return it->second; - } - } - - int Store(ArithmeticConstraint* const constraint) { - const int constraint_index = constraints_.size(); - constraints_.push_back(constraint); - return constraint_index; - } - - ArithmeticPropagator propagator_; - hash_map var_indices_; - vector constraints_; -}; - -} // namespace - // ---------- Examples ---------- void DeepSearchTreeArith(int size) { @@ -623,13 +35,13 @@ void DeepSearchTreeArith(int size) { GlobalArithmeticConstraint* const global = solver.RevAlloc(new GlobalArithmeticConstraint(&solver)); - + global->Add(global->MakeVarEqualVarPlusOffset(v1, v2, 0)); global->Add(global->MakeVarEqualVarPlusOffset(v2, v3, 0)); - const int left = + const int left = global->MakeRowConstraint(0, v1, -1, v2, -1, v3, 1, kint64max); - const int right = + const int right = global->MakeRowConstraint(0, v1, -1, v2, 1, v3, -1, kint64max); global->Add(global->MakeOrConstraint(left, right)); @@ -647,7 +59,7 @@ void SlowPropagationArith(int size) { GlobalArithmeticConstraint* const global = solver.RevAlloc(new GlobalArithmeticConstraint(&solver)); - + global->Add(global->MakeRowConstraint(1, v1, 1, v2, -1, kint64max)); global->Add(global->MakeRowConstraint(0, v1, -1, v2, 1, kint64max));