From 4efd2c4a6c40fe42bee1cc32c2ddc5cc40adeeae Mon Sep 17 00:00:00 2001 From: "lperron@google.com" Date: Sat, 30 Jun 2012 10:38:29 +0000 Subject: [PATCH] nested expression simplifications (3 - (x - 2)) -> 1 - x... --- src/constraint_solver/count_cst.cc | 79 ++++-- src/constraint_solver/expressions.cc | 379 +++++++++++++++------------ 2 files changed, 279 insertions(+), 179 deletions(-) diff --git a/src/constraint_solver/count_cst.cc b/src/constraint_solver/count_cst.cc index ecf9dacbf5..6b9b161911 100644 --- a/src/constraint_solver/count_cst.cc +++ b/src/constraint_solver/count_cst.cc @@ -93,19 +93,21 @@ string CountValueEqCst::DebugString() const { void CountValueEqCst::Post() { for (int i = 0; i < size_; ++i) { IntVar* const var = vars_[i]; - if (!var->Bound()) { + if (!var->Bound() && var->Contains(value_)) { Demon* d = MakeConstraintDemon1(solver(), this, &CountValueEqCst::OneBound, "OneBound", i); var->WhenBound(d); - if (var->Contains(value_)) { - d = MakeConstraintDemon1(solver(), - this, - &CountValueEqCst::OneDomain, - "OneDomain", - i); + d = MakeConstraintDemon1(solver(), + this, + &CountValueEqCst::OneDomain, + "OneDomain", + i); + if (var->Min() == value_ || var->Max() == value_) { + var->WhenRange(d); + } else { var->WhenDomain(d); } } @@ -196,6 +198,47 @@ Constraint* Solver::MakeCount(const std::vector& vars, int64 v, int64 c for (ConstIter > it(vars); !it.at_end(); ++it) { CHECK_EQ(this, (*it)->solver()); } + int count_min = 0; + int count_max = 0; + int count_active = 0; + int count_forced = 0; + for (int i = 0; i < vars.size(); ++i) { + if (vars[i]->Contains(v)) { + count_active++; + if (vars[i]->Bound()) { + count_forced++; + } + if (vars[i]->Max() - vars[i]->Min() <= 1) { + if (vars[i]->Min() == v) { + count_min++; + } else if (vars[i]->Max() == v) { + count_max++; + } + } + } + } + if (count_active == 0) { + return (c == 0 ? MakeTrueConstraint() : MakeFalseConstraint()); + } else if (count_forced == count_active) { + return (c == count_forced ? MakeTrueConstraint() : MakeFalseConstraint()); + } + if (count_min == count_active) { + std::vector terms; + for (int i = 0; i < vars.size(); ++i) { + if (vars[i]->Contains(v)) { + terms.push_back(MakeSum(vars[i], -v)->Var()); + } + } + return MakeSumEquality(terms, count_active - c); + } else if (count_max == count_active) { + std::vector terms; + for (int i = 0; i < vars.size(); ++i) { + if (vars[i]->Contains(v)) { + terms.push_back(MakeSum(vars[i], -v + 1)->Var()); + } + } + return MakeSumEquality(terms, c); + } return RevAlloc(new CountValueEqCst(this, vars.data(), vars.size(), v, c)); } @@ -262,19 +305,21 @@ string CountValueEq::DebugString() const { void CountValueEq::Post() { for (int i = 0; i < size_; ++i) { IntVar* const var = vars_[i]; - if (!var->Bound()) { + if (!var->Bound() && var->Contains(value_)) { Demon* d = MakeConstraintDemon1(solver(), this, &CountValueEq::OneBound, "OneBound", i); var->WhenBound(d); - if (var->Contains(value_)) { - d = MakeConstraintDemon1(solver(), - this, - &CountValueEq::OneDomain, - "OneDomain", - i); + d = MakeConstraintDemon1(solver(), + this, + &CountValueEq::OneDomain, + "OneDomain", + i); + if (var->Min() == value_ || var->Max() == value_) { + var->WhenRange(d); + } else { var->WhenDomain(d); } } @@ -389,7 +434,11 @@ Constraint* Solver::MakeCount(const std::vector& vars, int64 v, IntVar* CHECK_EQ(this, (*it)->solver()); } CHECK_EQ(this, c->solver()); - return RevAlloc(new CountValueEq(this, vars.data(), vars.size(), v, c)); + if (c->Bound()) { + return MakeCount(vars, v, c->Min()); + } else { + return RevAlloc(new CountValueEq(this, vars.data(), vars.size(), v, c)); + } } // ---------- Distribute ---------- diff --git a/src/constraint_solver/expressions.cc b/src/constraint_solver/expressions.cc index 7e73d49389..caccf99790 100644 --- a/src/constraint_solver/expressions.cc +++ b/src/constraint_solver/expressions.cc @@ -1832,58 +1832,43 @@ string IntConst::DebugString() const { // ----- x + c variable, optimized case ----- -class PlusCstIntVar : public IntVar { +class PlusCstVar : public IntVar { public: - class PlusCstIntVarIterator : public UnaryIterator { - public: - PlusCstIntVarIterator(const IntVar* const v, int64 c, bool hole, bool rev) - : UnaryIterator(v, hole, rev), cst_(c) {} + PlusCstVar(Solver* const s, IntVar* v, int64 c) + : IntVar(s), var_(v), cst_(c) {} - virtual ~PlusCstIntVarIterator() {} + virtual ~PlusCstVar() {} - virtual int64 Value() const { - return iterator_->Value() + cst_; - } - - private: - const int64 cst_; - }; - PlusCstIntVar(Solver* const s, IntVar* v, int64 c); - virtual ~PlusCstIntVar(); - - virtual int64 Min() const; - virtual void SetMin(int64 m); - virtual int64 Max() const; - virtual void SetMax(int64 m); - virtual void SetRange(int64 l, int64 u); - virtual void SetValue(int64 v); - virtual bool Bound() const; - virtual int64 Value() const; - virtual void RemoveValue(int64 v); - virtual void RemoveInterval(int64 l, int64 u); - virtual uint64 Size() const; - virtual bool Contains(int64 v) const; - virtual void WhenRange(Demon* d); - virtual void WhenBound(Demon* d); - virtual void WhenDomain(Demon* d); - virtual IntVarIterator* MakeHoleIterator(bool reversible) const { - return COND_REV_ALLOC(reversible, - new PlusCstIntVarIterator(var_, cst_, - true, reversible)); + virtual void WhenRange(Demon* d) { + var_->WhenRange(d); } - virtual IntVarIterator* MakeDomainIterator(bool reversible) const { - return COND_REV_ALLOC(reversible, - new PlusCstIntVarIterator(var_, cst_, - false, reversible)); + + virtual void WhenBound(Demon* d) { + var_->WhenBound(d); } + + virtual void WhenDomain(Demon* d) { + var_->WhenDomain(d); + } + virtual int64 OldMin() const { return var_->OldMin() + cst_; } + virtual int64 OldMax() const { return var_->OldMax() + cst_; } - virtual string DebugString() const; + virtual string DebugString() const { + if (HasName()) { + return StringPrintf("%s(%s + %" GG_LL_FORMAT "d)", + name().c_str(), var_->DebugString().c_str(), cst_); + } else { + return StringPrintf("(%s + %" GG_LL_FORMAT "d)", + var_->DebugString().c_str(), cst_); + } + } + virtual int VarType() const { return VAR_ADD_CST; } virtual void Accept(ModelVisitor* const visitor) const { @@ -1901,80 +1886,97 @@ class PlusCstIntVar : public IntVar { return var_->IsDifferent(constant - cst_); } - private: + IntVar* SubVar() const { return var_; } + + int64 Constant() const { return cst_; } + + protected: IntVar* const var_; const int64 cst_; }; -PlusCstIntVar::PlusCstIntVar(Solver* const s, IntVar* v, int64 c) - : IntVar(s), var_(v), cst_(c) {} -PlusCstIntVar::~PlusCstIntVar() {} -int64 PlusCstIntVar::Min() const { - return var_->Min() + cst_; -} +class PlusCstIntVar : public PlusCstVar { + public: + class PlusCstIntVarIterator : public UnaryIterator { + public: + PlusCstIntVarIterator(const IntVar* const v, int64 c, bool hole, bool rev) + : UnaryIterator(v, hole, rev), cst_(c) {} -void PlusCstIntVar::SetMin(int64 m) { - var_->SetMin(m - cst_); -} + virtual ~PlusCstIntVarIterator() {} -int64 PlusCstIntVar::Max() const { - return var_->Max() + cst_; -} + virtual int64 Value() const { + return iterator_->Value() + cst_; + } -void PlusCstIntVar::SetMax(int64 m) { - var_->SetMax(m - cst_); -} + private: + const int64 cst_; + }; -void PlusCstIntVar::SetRange(int64 l, int64 u) { - var_->SetRange(l - cst_, u - cst_); -} + PlusCstIntVar(Solver* const s, IntVar* v, int64 c) : PlusCstVar(s, v, c) {} -void PlusCstIntVar::SetValue(int64 v) { - var_->SetValue(v - cst_); + virtual ~PlusCstIntVar() {} + + virtual int64 Min() const { + return var_->Min() + cst_; } -bool PlusCstIntVar::Bound() const { - return var_->Bound(); -} -void PlusCstIntVar::WhenRange(Demon* d) { - var_->WhenRange(d); -} + virtual void SetMin(int64 m) { + var_->SetMin(m - cst_); + } -int64 PlusCstIntVar::Value() const { - return var_->Value() + cst_; -} + virtual int64 Max() const { + return var_->Max() + cst_; + } -void PlusCstIntVar::RemoveValue(int64 v) { - var_->RemoveValue(v - cst_); -} + virtual void SetMax(int64 m) { + var_->SetMax(m - cst_); + } -void PlusCstIntVar::RemoveInterval(int64 l, int64 u) { - var_->RemoveInterval(l - cst_, u - cst_); -} + virtual void SetRange(int64 l, int64 u) { + var_->SetRange(l - cst_, u - cst_); + } -void PlusCstIntVar::WhenBound(Demon* d) { - var_->WhenBound(d); -} + virtual void SetValue(int64 v) { + var_->SetValue(v - cst_); + } -void PlusCstIntVar::WhenDomain(Demon* d) { - var_->WhenDomain(d); -} + virtual int64 Value() const { + return var_->Value() + cst_; + } -uint64 PlusCstIntVar::Size() const { - return var_->Size(); -} + virtual bool Bound() const { + return var_->Bound(); + } -bool PlusCstIntVar::Contains(int64 v) const { - return var_->Contains(v - cst_); -} + virtual void RemoveValue(int64 v) { + var_->RemoveValue(v - cst_); + } -string PlusCstIntVar::DebugString() const { - return StringPrintf("(%s + %" GG_LL_FORMAT "d)", - var_->DebugString().c_str(), cst_); -} + virtual void RemoveInterval(int64 l, int64 u) { + var_->RemoveInterval(l - cst_, u - cst_); + } -class PlusCstDomainIntVar : public IntVar { + virtual uint64 Size() const { + return var_->Size(); + } + + virtual bool Contains(int64 v) const { + return var_->Contains(v - cst_); + } + virtual IntVarIterator* MakeHoleIterator(bool reversible) const { + return COND_REV_ALLOC(reversible, + new PlusCstIntVarIterator(var_, cst_, + true, reversible)); + } + virtual IntVarIterator* MakeDomainIterator(bool reversible) const { + return COND_REV_ALLOC(reversible, + new PlusCstIntVarIterator(var_, cst_, + false, reversible)); + } +}; + +class PlusCstDomainIntVar : public PlusCstVar { public: class PlusCstDomainIntVarIterator : public UnaryIterator { public: @@ -1993,8 +1995,11 @@ class PlusCstDomainIntVar : public IntVar { private: const int64 cst_; }; - PlusCstDomainIntVar(Solver* const s, DomainIntVar* v, int64 c); - virtual ~PlusCstDomainIntVar(); + + PlusCstDomainIntVar(Solver* const s, DomainIntVar* v, int64 c) + : PlusCstVar(s, v, c) {} + + virtual ~PlusCstDomainIntVar() {} virtual int64 Min() const; virtual void SetMin(int64 m); @@ -2008,9 +2013,11 @@ class PlusCstDomainIntVar : public IntVar { virtual void RemoveInterval(int64 l, int64 u); virtual uint64 Size() const; virtual bool Contains(int64 v) const; - virtual void WhenRange(Demon* d); - virtual void WhenBound(Demon* d); - virtual void WhenDomain(Demon* d); + + DomainIntVar* domain_int_var() const { + return reinterpret_cast(var_); + } + virtual IntVarIterator* MakeHoleIterator(bool reversible) const { return COND_REV_ALLOC(reversible, new PlusCstDomainIntVarIterator(var_, cst_, @@ -2021,106 +2028,56 @@ class PlusCstDomainIntVar : public IntVar { new PlusCstDomainIntVarIterator(var_, cst_, false, reversible)); } - virtual int64 OldMin() const { - return var_->OldMin() + cst_; - } - virtual int64 OldMax() const { - return var_->OldMax() + cst_; - } - - virtual string DebugString() const; - virtual int VarType() const { return VAR_ADD_CST; } - - virtual void Accept(ModelVisitor* const visitor) const { - visitor->VisitIntegerVariable(this, - ModelVisitor::kSumOperation, - cst_, - var_); - } - - virtual IntVar* IsEqual(int64 constant) { - return var_->IsEqual(constant - cst_); - } - - virtual IntVar* IsDifferent(int64 constant) { - return var_->IsDifferent(constant - cst_); - } - - private: - DomainIntVar* const var_; - const int64 cst_; }; -PlusCstDomainIntVar::PlusCstDomainIntVar(Solver* const s, - DomainIntVar* v, - int64 c) - : IntVar(s), var_(v), cst_(c) {} -PlusCstDomainIntVar::~PlusCstDomainIntVar() {} - int64 PlusCstDomainIntVar::Min() const { - return var_->min_.Value() + cst_; + return domain_int_var()->min_.Value() + cst_; } void PlusCstDomainIntVar::SetMin(int64 m) { - var_->DomainIntVar::SetMin(m - cst_); + domain_int_var()->DomainIntVar::SetMin(m - cst_); } int64 PlusCstDomainIntVar::Max() const { - return var_->max_.Value() + cst_; + return domain_int_var()->max_.Value() + cst_; } void PlusCstDomainIntVar::SetMax(int64 m) { - var_->DomainIntVar::SetMax(m - cst_); + domain_int_var()->DomainIntVar::SetMax(m - cst_); } void PlusCstDomainIntVar::SetRange(int64 l, int64 u) { - var_->DomainIntVar::SetRange(l - cst_, u - cst_); + domain_int_var()->DomainIntVar::SetRange(l - cst_, u - cst_); } void PlusCstDomainIntVar::SetValue(int64 v) { - var_->DomainIntVar::SetValue(v - cst_); + domain_int_var()->DomainIntVar::SetValue(v - cst_); } bool PlusCstDomainIntVar::Bound() const { - return var_->min_.Value() == var_->max_.Value(); -} - -void PlusCstDomainIntVar::WhenRange(Demon* d) { - var_->WhenRange(d); + return domain_int_var()->min_.Value() == domain_int_var()->max_.Value(); } int64 PlusCstDomainIntVar::Value() const { - CHECK_EQ(var_->min_.Value(), var_->max_.Value()) << "variable is not bound"; - return var_->min_.Value() + cst_; + CHECK_EQ(domain_int_var()->min_.Value(), + domain_int_var()->max_.Value()) << "variable is not bound"; + return domain_int_var()->min_.Value() + cst_; } void PlusCstDomainIntVar::RemoveValue(int64 v) { - var_->DomainIntVar::RemoveValue(v - cst_); + domain_int_var()->DomainIntVar::RemoveValue(v - cst_); } void PlusCstDomainIntVar::RemoveInterval(int64 l, int64 u) { - var_->DomainIntVar::RemoveInterval(l - cst_, u - cst_); -} - -void PlusCstDomainIntVar::WhenBound(Demon* d) { - var_->WhenBound(d); -} - -void PlusCstDomainIntVar::WhenDomain(Demon* d) { - var_->WhenDomain(d); + domain_int_var()->DomainIntVar::RemoveInterval(l - cst_, u - cst_); } uint64 PlusCstDomainIntVar::Size() const { - return var_->DomainIntVar::Size(); + return domain_int_var()->DomainIntVar::Size(); } bool PlusCstDomainIntVar::Contains(int64 v) const { - return var_->DomainIntVar::Contains(v - cst_); -} - -string PlusCstDomainIntVar::DebugString() const { - return StringPrintf("(%s + %" GG_LL_FORMAT "d)", - var_->DebugString().c_str(), cst_); + return domain_int_var()->DomainIntVar::Contains(v - cst_); } // c - x variable, optimized case @@ -2193,6 +2150,9 @@ class SubCstIntVar : public IntVar { return var_->IsDifferent(cst_ - constant); } + IntVar* SubVar() const { return var_; } + int64 Constant() const { return cst_; } + private: IntVar* const var_; const int64 cst_; @@ -2332,6 +2292,8 @@ class OppIntVar : public IntVar { return var_->IsDifferent(-constant); } + IntVar* SubVar() const { return var_; } + private: IntVar* const var_; }; @@ -5782,6 +5744,11 @@ IntVar* Solver::MakeIntVar(int64 min, int64 max, const string& name) { } if (min == 0 && max == 1) { return RegisterIntVar(RevAlloc(new BooleanVar(this, name))); + } else if (max - min == 1) { + const string inner_name = "inner_" + name; + return RegisterIntVar( + MakeSum(RevAlloc(new BooleanVar(this, inner_name)), + min)->VarWithName(name)); } else { return RegisterIntVar(RevAlloc(new DomainIntVar(this, min, max, name))); } @@ -5966,7 +5933,58 @@ IntExpr* Solver::MakeSum(IntExpr* const e, int64 v) { IntExpr* result = Cache()->FindExprConstantExpression( e, v, ModelCache::EXPR_CONSTANT_SUM); if (result == NULL) { - result = RegisterIntExpr(RevAlloc(new PlusIntCstExpr(this, e, v))); + if (e->IsVar()) { + IntVar* const var = e->Var(); + switch (var->VarType()) { + case DOMAIN_INT_VAR: { + result = RegisterIntExpr(RevAlloc( + new PlusCstDomainIntVar( + this, reinterpret_cast(var), v))); + break; + } + case CONST_VAR: { + result = RegisterIntExpr(MakeIntConst(var->Min() + v)); + break; + } + case VAR_ADD_CST: { + PlusCstVar* const add_var = reinterpret_cast(var); + IntVar* const sub_var = add_var->SubVar(); + const int64 new_constant = v + add_var->Constant(); + if (new_constant == 0) { + result = sub_var; + } else { + if (sub_var->VarType() == DOMAIN_INT_VAR) { + DomainIntVar* const dvar = + reinterpret_cast(sub_var); + result = RegisterIntExpr( + RevAlloc(new PlusCstDomainIntVar(this, dvar, new_constant))); + } else { + result = RegisterIntExpr( + RevAlloc(new PlusCstIntVar(this, sub_var, new_constant))); + } + } + break; + } + case CST_SUB_VAR: { + SubCstIntVar* const add_var = reinterpret_cast(var); + IntVar* const sub_var = add_var->SubVar(); + const int64 new_constant = v - add_var->Constant(); + result = + RegisterIntExpr(new SubCstIntVar(this, sub_var, new_constant)); + break; + } + case OPP_VAR: { + OppIntVar* const add_var = reinterpret_cast(var); + IntVar* const sub_var = add_var->SubVar(); + result = RegisterIntExpr(new SubCstIntVar(this, sub_var, v)); + break; + } + default: + result = RegisterIntExpr(RevAlloc(new PlusCstIntVar(this, var, v))); + } + } else { + result = RegisterIntExpr(RevAlloc(new PlusIntCstExpr(this, e, v))); + } Cache()->InsertExprConstantExpression( result, e, v, ModelCache::EXPR_CONSTANT_SUM); } @@ -6230,7 +6248,40 @@ IntExpr* Solver::MakeDifference(int64 v, IntExpr* const e) { IntExpr* result = Cache()->FindExprConstantExpression( e, v, ModelCache::EXPR_CONSTANT_DIFFERENCE); if (result == NULL) { - result = RegisterIntExpr(RevAlloc(new SubIntCstExpr(this, e, v))); + if (e->IsVar()) { + IntVar* const var = e->Var(); + switch (var->VarType()) { + case VAR_ADD_CST: { + PlusCstVar* const add_var = reinterpret_cast(var); + IntVar* const sub_var = add_var->SubVar(); + const int64 new_constant = v - add_var->Constant(); + if (new_constant == 0) { + result = sub_var; + } else { + result = RegisterIntExpr( + RevAlloc(new SubCstIntVar(this, sub_var, new_constant))); + } + break; + } + case CST_SUB_VAR: { + SubCstIntVar* const add_var = reinterpret_cast(var); + IntVar* const sub_var = add_var->SubVar(); + const int64 new_constant = v - add_var->Constant(); + result = MakeSum(sub_var, new_constant); + break; + } + case OPP_VAR: { + OppIntVar* const add_var = reinterpret_cast(var); + IntVar* const sub_var = add_var->SubVar(); + result = MakeSum(sub_var, v); + break; + } + default: + result = RegisterIntExpr(RevAlloc(new SubCstIntVar(this, var, v))); + } + } else { + result = RegisterIntExpr(RevAlloc(new SubIntCstExpr(this, e, v))); + } Cache()->InsertExprConstantExpression( result, e, v, ModelCache::EXPR_CONSTANT_DIFFERENCE); }