nested expression simplifications (3 - (x - 2)) -> 1 - x...

This commit is contained in:
lperron@google.com
2012-06-30 10:38:29 +00:00
parent 8ed0b3b014
commit 4efd2c4a6c
2 changed files with 279 additions and 179 deletions

View File

@@ -93,19 +93,21 @@ string CountValueEqCst::DebugString() const {
void CountValueEqCst::Post() {
for (int i = 0; i < size_; ++i) {
IntVar* const var = vars_[i];
if (!var->Bound()) {
if (!var->Bound() && var->Contains(value_)) {
Demon* d = MakeConstraintDemon1(solver(),
this,
&CountValueEqCst::OneBound,
"OneBound",
i);
var->WhenBound(d);
if (var->Contains(value_)) {
d = MakeConstraintDemon1(solver(),
this,
&CountValueEqCst::OneDomain,
"OneDomain",
i);
d = MakeConstraintDemon1(solver(),
this,
&CountValueEqCst::OneDomain,
"OneDomain",
i);
if (var->Min() == value_ || var->Max() == value_) {
var->WhenRange(d);
} else {
var->WhenDomain(d);
}
}
@@ -196,6 +198,47 @@ Constraint* Solver::MakeCount(const std::vector<IntVar*>& vars, int64 v, int64 c
for (ConstIter<std::vector<IntVar*> > it(vars); !it.at_end(); ++it) {
CHECK_EQ(this, (*it)->solver());
}
int count_min = 0;
int count_max = 0;
int count_active = 0;
int count_forced = 0;
for (int i = 0; i < vars.size(); ++i) {
if (vars[i]->Contains(v)) {
count_active++;
if (vars[i]->Bound()) {
count_forced++;
}
if (vars[i]->Max() - vars[i]->Min() <= 1) {
if (vars[i]->Min() == v) {
count_min++;
} else if (vars[i]->Max() == v) {
count_max++;
}
}
}
}
if (count_active == 0) {
return (c == 0 ? MakeTrueConstraint() : MakeFalseConstraint());
} else if (count_forced == count_active) {
return (c == count_forced ? MakeTrueConstraint() : MakeFalseConstraint());
}
if (count_min == count_active) {
std::vector<IntVar*> terms;
for (int i = 0; i < vars.size(); ++i) {
if (vars[i]->Contains(v)) {
terms.push_back(MakeSum(vars[i], -v)->Var());
}
}
return MakeSumEquality(terms, count_active - c);
} else if (count_max == count_active) {
std::vector<IntVar*> terms;
for (int i = 0; i < vars.size(); ++i) {
if (vars[i]->Contains(v)) {
terms.push_back(MakeSum(vars[i], -v + 1)->Var());
}
}
return MakeSumEquality(terms, c);
}
return RevAlloc(new CountValueEqCst(this, vars.data(), vars.size(), v, c));
}
@@ -262,19 +305,21 @@ string CountValueEq::DebugString() const {
void CountValueEq::Post() {
for (int i = 0; i < size_; ++i) {
IntVar* const var = vars_[i];
if (!var->Bound()) {
if (!var->Bound() && var->Contains(value_)) {
Demon* d = MakeConstraintDemon1(solver(),
this,
&CountValueEq::OneBound,
"OneBound",
i);
var->WhenBound(d);
if (var->Contains(value_)) {
d = MakeConstraintDemon1(solver(),
this,
&CountValueEq::OneDomain,
"OneDomain",
i);
d = MakeConstraintDemon1(solver(),
this,
&CountValueEq::OneDomain,
"OneDomain",
i);
if (var->Min() == value_ || var->Max() == value_) {
var->WhenRange(d);
} else {
var->WhenDomain(d);
}
}
@@ -389,7 +434,11 @@ Constraint* Solver::MakeCount(const std::vector<IntVar*>& vars, int64 v, IntVar*
CHECK_EQ(this, (*it)->solver());
}
CHECK_EQ(this, c->solver());
return RevAlloc(new CountValueEq(this, vars.data(), vars.size(), v, c));
if (c->Bound()) {
return MakeCount(vars, v, c->Min());
} else {
return RevAlloc(new CountValueEq(this, vars.data(), vars.size(), v, c));
}
}
// ---------- Distribute ----------

View File

@@ -1832,58 +1832,43 @@ string IntConst::DebugString() const {
// ----- x + c variable, optimized case -----
class PlusCstIntVar : public IntVar {
class PlusCstVar : public IntVar {
public:
class PlusCstIntVarIterator : public UnaryIterator {
public:
PlusCstIntVarIterator(const IntVar* const v, int64 c, bool hole, bool rev)
: UnaryIterator(v, hole, rev), cst_(c) {}
PlusCstVar(Solver* const s, IntVar* v, int64 c)
: IntVar(s), var_(v), cst_(c) {}
virtual ~PlusCstIntVarIterator() {}
virtual ~PlusCstVar() {}
virtual int64 Value() const {
return iterator_->Value() + cst_;
}
private:
const int64 cst_;
};
PlusCstIntVar(Solver* const s, IntVar* v, int64 c);
virtual ~PlusCstIntVar();
virtual int64 Min() const;
virtual void SetMin(int64 m);
virtual int64 Max() const;
virtual void SetMax(int64 m);
virtual void SetRange(int64 l, int64 u);
virtual void SetValue(int64 v);
virtual bool Bound() const;
virtual int64 Value() const;
virtual void RemoveValue(int64 v);
virtual void RemoveInterval(int64 l, int64 u);
virtual uint64 Size() const;
virtual bool Contains(int64 v) const;
virtual void WhenRange(Demon* d);
virtual void WhenBound(Demon* d);
virtual void WhenDomain(Demon* d);
virtual IntVarIterator* MakeHoleIterator(bool reversible) const {
return COND_REV_ALLOC(reversible,
new PlusCstIntVarIterator(var_, cst_,
true, reversible));
virtual void WhenRange(Demon* d) {
var_->WhenRange(d);
}
virtual IntVarIterator* MakeDomainIterator(bool reversible) const {
return COND_REV_ALLOC(reversible,
new PlusCstIntVarIterator(var_, cst_,
false, reversible));
virtual void WhenBound(Demon* d) {
var_->WhenBound(d);
}
virtual void WhenDomain(Demon* d) {
var_->WhenDomain(d);
}
virtual int64 OldMin() const {
return var_->OldMin() + cst_;
}
virtual int64 OldMax() const {
return var_->OldMax() + cst_;
}
virtual string DebugString() const;
virtual string DebugString() const {
if (HasName()) {
return StringPrintf("%s(%s + %" GG_LL_FORMAT "d)",
name().c_str(), var_->DebugString().c_str(), cst_);
} else {
return StringPrintf("(%s + %" GG_LL_FORMAT "d)",
var_->DebugString().c_str(), cst_);
}
}
virtual int VarType() const { return VAR_ADD_CST; }
virtual void Accept(ModelVisitor* const visitor) const {
@@ -1901,80 +1886,97 @@ class PlusCstIntVar : public IntVar {
return var_->IsDifferent(constant - cst_);
}
private:
IntVar* SubVar() const { return var_; }
int64 Constant() const { return cst_; }
protected:
IntVar* const var_;
const int64 cst_;
};
PlusCstIntVar::PlusCstIntVar(Solver* const s, IntVar* v, int64 c)
: IntVar(s), var_(v), cst_(c) {}
PlusCstIntVar::~PlusCstIntVar() {}
int64 PlusCstIntVar::Min() const {
return var_->Min() + cst_;
}
class PlusCstIntVar : public PlusCstVar {
public:
class PlusCstIntVarIterator : public UnaryIterator {
public:
PlusCstIntVarIterator(const IntVar* const v, int64 c, bool hole, bool rev)
: UnaryIterator(v, hole, rev), cst_(c) {}
void PlusCstIntVar::SetMin(int64 m) {
var_->SetMin(m - cst_);
}
virtual ~PlusCstIntVarIterator() {}
int64 PlusCstIntVar::Max() const {
return var_->Max() + cst_;
}
virtual int64 Value() const {
return iterator_->Value() + cst_;
}
void PlusCstIntVar::SetMax(int64 m) {
var_->SetMax(m - cst_);
}
private:
const int64 cst_;
};
void PlusCstIntVar::SetRange(int64 l, int64 u) {
var_->SetRange(l - cst_, u - cst_);
}
PlusCstIntVar(Solver* const s, IntVar* v, int64 c) : PlusCstVar(s, v, c) {}
void PlusCstIntVar::SetValue(int64 v) {
var_->SetValue(v - cst_);
virtual ~PlusCstIntVar() {}
virtual int64 Min() const {
return var_->Min() + cst_;
}
bool PlusCstIntVar::Bound() const {
return var_->Bound();
}
void PlusCstIntVar::WhenRange(Demon* d) {
var_->WhenRange(d);
}
virtual void SetMin(int64 m) {
var_->SetMin(m - cst_);
}
int64 PlusCstIntVar::Value() const {
return var_->Value() + cst_;
}
virtual int64 Max() const {
return var_->Max() + cst_;
}
void PlusCstIntVar::RemoveValue(int64 v) {
var_->RemoveValue(v - cst_);
}
virtual void SetMax(int64 m) {
var_->SetMax(m - cst_);
}
void PlusCstIntVar::RemoveInterval(int64 l, int64 u) {
var_->RemoveInterval(l - cst_, u - cst_);
}
virtual void SetRange(int64 l, int64 u) {
var_->SetRange(l - cst_, u - cst_);
}
void PlusCstIntVar::WhenBound(Demon* d) {
var_->WhenBound(d);
}
virtual void SetValue(int64 v) {
var_->SetValue(v - cst_);
}
void PlusCstIntVar::WhenDomain(Demon* d) {
var_->WhenDomain(d);
}
virtual int64 Value() const {
return var_->Value() + cst_;
}
uint64 PlusCstIntVar::Size() const {
return var_->Size();
}
virtual bool Bound() const {
return var_->Bound();
}
bool PlusCstIntVar::Contains(int64 v) const {
return var_->Contains(v - cst_);
}
virtual void RemoveValue(int64 v) {
var_->RemoveValue(v - cst_);
}
string PlusCstIntVar::DebugString() const {
return StringPrintf("(%s + %" GG_LL_FORMAT "d)",
var_->DebugString().c_str(), cst_);
}
virtual void RemoveInterval(int64 l, int64 u) {
var_->RemoveInterval(l - cst_, u - cst_);
}
class PlusCstDomainIntVar : public IntVar {
virtual uint64 Size() const {
return var_->Size();
}
virtual bool Contains(int64 v) const {
return var_->Contains(v - cst_);
}
virtual IntVarIterator* MakeHoleIterator(bool reversible) const {
return COND_REV_ALLOC(reversible,
new PlusCstIntVarIterator(var_, cst_,
true, reversible));
}
virtual IntVarIterator* MakeDomainIterator(bool reversible) const {
return COND_REV_ALLOC(reversible,
new PlusCstIntVarIterator(var_, cst_,
false, reversible));
}
};
class PlusCstDomainIntVar : public PlusCstVar {
public:
class PlusCstDomainIntVarIterator : public UnaryIterator {
public:
@@ -1993,8 +1995,11 @@ class PlusCstDomainIntVar : public IntVar {
private:
const int64 cst_;
};
PlusCstDomainIntVar(Solver* const s, DomainIntVar* v, int64 c);
virtual ~PlusCstDomainIntVar();
PlusCstDomainIntVar(Solver* const s, DomainIntVar* v, int64 c)
: PlusCstVar(s, v, c) {}
virtual ~PlusCstDomainIntVar() {}
virtual int64 Min() const;
virtual void SetMin(int64 m);
@@ -2008,9 +2013,11 @@ class PlusCstDomainIntVar : public IntVar {
virtual void RemoveInterval(int64 l, int64 u);
virtual uint64 Size() const;
virtual bool Contains(int64 v) const;
virtual void WhenRange(Demon* d);
virtual void WhenBound(Demon* d);
virtual void WhenDomain(Demon* d);
DomainIntVar* domain_int_var() const {
return reinterpret_cast<DomainIntVar*>(var_);
}
virtual IntVarIterator* MakeHoleIterator(bool reversible) const {
return COND_REV_ALLOC(reversible,
new PlusCstDomainIntVarIterator(var_, cst_,
@@ -2021,106 +2028,56 @@ class PlusCstDomainIntVar : public IntVar {
new PlusCstDomainIntVarIterator(var_, cst_,
false, reversible));
}
virtual int64 OldMin() const {
return var_->OldMin() + cst_;
}
virtual int64 OldMax() const {
return var_->OldMax() + cst_;
}
virtual string DebugString() const;
virtual int VarType() const { return VAR_ADD_CST; }
virtual void Accept(ModelVisitor* const visitor) const {
visitor->VisitIntegerVariable(this,
ModelVisitor::kSumOperation,
cst_,
var_);
}
virtual IntVar* IsEqual(int64 constant) {
return var_->IsEqual(constant - cst_);
}
virtual IntVar* IsDifferent(int64 constant) {
return var_->IsDifferent(constant - cst_);
}
private:
DomainIntVar* const var_;
const int64 cst_;
};
PlusCstDomainIntVar::PlusCstDomainIntVar(Solver* const s,
DomainIntVar* v,
int64 c)
: IntVar(s), var_(v), cst_(c) {}
PlusCstDomainIntVar::~PlusCstDomainIntVar() {}
int64 PlusCstDomainIntVar::Min() const {
return var_->min_.Value() + cst_;
return domain_int_var()->min_.Value() + cst_;
}
void PlusCstDomainIntVar::SetMin(int64 m) {
var_->DomainIntVar::SetMin(m - cst_);
domain_int_var()->DomainIntVar::SetMin(m - cst_);
}
int64 PlusCstDomainIntVar::Max() const {
return var_->max_.Value() + cst_;
return domain_int_var()->max_.Value() + cst_;
}
void PlusCstDomainIntVar::SetMax(int64 m) {
var_->DomainIntVar::SetMax(m - cst_);
domain_int_var()->DomainIntVar::SetMax(m - cst_);
}
void PlusCstDomainIntVar::SetRange(int64 l, int64 u) {
var_->DomainIntVar::SetRange(l - cst_, u - cst_);
domain_int_var()->DomainIntVar::SetRange(l - cst_, u - cst_);
}
void PlusCstDomainIntVar::SetValue(int64 v) {
var_->DomainIntVar::SetValue(v - cst_);
domain_int_var()->DomainIntVar::SetValue(v - cst_);
}
bool PlusCstDomainIntVar::Bound() const {
return var_->min_.Value() == var_->max_.Value();
}
void PlusCstDomainIntVar::WhenRange(Demon* d) {
var_->WhenRange(d);
return domain_int_var()->min_.Value() == domain_int_var()->max_.Value();
}
int64 PlusCstDomainIntVar::Value() const {
CHECK_EQ(var_->min_.Value(), var_->max_.Value()) << "variable is not bound";
return var_->min_.Value() + cst_;
CHECK_EQ(domain_int_var()->min_.Value(),
domain_int_var()->max_.Value()) << "variable is not bound";
return domain_int_var()->min_.Value() + cst_;
}
void PlusCstDomainIntVar::RemoveValue(int64 v) {
var_->DomainIntVar::RemoveValue(v - cst_);
domain_int_var()->DomainIntVar::RemoveValue(v - cst_);
}
void PlusCstDomainIntVar::RemoveInterval(int64 l, int64 u) {
var_->DomainIntVar::RemoveInterval(l - cst_, u - cst_);
}
void PlusCstDomainIntVar::WhenBound(Demon* d) {
var_->WhenBound(d);
}
void PlusCstDomainIntVar::WhenDomain(Demon* d) {
var_->WhenDomain(d);
domain_int_var()->DomainIntVar::RemoveInterval(l - cst_, u - cst_);
}
uint64 PlusCstDomainIntVar::Size() const {
return var_->DomainIntVar::Size();
return domain_int_var()->DomainIntVar::Size();
}
bool PlusCstDomainIntVar::Contains(int64 v) const {
return var_->DomainIntVar::Contains(v - cst_);
}
string PlusCstDomainIntVar::DebugString() const {
return StringPrintf("(%s + %" GG_LL_FORMAT "d)",
var_->DebugString().c_str(), cst_);
return domain_int_var()->DomainIntVar::Contains(v - cst_);
}
// c - x variable, optimized case
@@ -2193,6 +2150,9 @@ class SubCstIntVar : public IntVar {
return var_->IsDifferent(cst_ - constant);
}
IntVar* SubVar() const { return var_; }
int64 Constant() const { return cst_; }
private:
IntVar* const var_;
const int64 cst_;
@@ -2332,6 +2292,8 @@ class OppIntVar : public IntVar {
return var_->IsDifferent(-constant);
}
IntVar* SubVar() const { return var_; }
private:
IntVar* const var_;
};
@@ -5782,6 +5744,11 @@ IntVar* Solver::MakeIntVar(int64 min, int64 max, const string& name) {
}
if (min == 0 && max == 1) {
return RegisterIntVar(RevAlloc(new BooleanVar(this, name)));
} else if (max - min == 1) {
const string inner_name = "inner_" + name;
return RegisterIntVar(
MakeSum(RevAlloc(new BooleanVar(this, inner_name)),
min)->VarWithName(name));
} else {
return RegisterIntVar(RevAlloc(new DomainIntVar(this, min, max, name)));
}
@@ -5966,7 +5933,58 @@ IntExpr* Solver::MakeSum(IntExpr* const e, int64 v) {
IntExpr* result = Cache()->FindExprConstantExpression(
e, v, ModelCache::EXPR_CONSTANT_SUM);
if (result == NULL) {
result = RegisterIntExpr(RevAlloc(new PlusIntCstExpr(this, e, v)));
if (e->IsVar()) {
IntVar* const var = e->Var();
switch (var->VarType()) {
case DOMAIN_INT_VAR: {
result = RegisterIntExpr(RevAlloc(
new PlusCstDomainIntVar(
this, reinterpret_cast<DomainIntVar*>(var), v)));
break;
}
case CONST_VAR: {
result = RegisterIntExpr(MakeIntConst(var->Min() + v));
break;
}
case VAR_ADD_CST: {
PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
IntVar* const sub_var = add_var->SubVar();
const int64 new_constant = v + add_var->Constant();
if (new_constant == 0) {
result = sub_var;
} else {
if (sub_var->VarType() == DOMAIN_INT_VAR) {
DomainIntVar* const dvar =
reinterpret_cast<DomainIntVar*>(sub_var);
result = RegisterIntExpr(
RevAlloc(new PlusCstDomainIntVar(this, dvar, new_constant)));
} else {
result = RegisterIntExpr(
RevAlloc(new PlusCstIntVar(this, sub_var, new_constant)));
}
}
break;
}
case CST_SUB_VAR: {
SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
IntVar* const sub_var = add_var->SubVar();
const int64 new_constant = v - add_var->Constant();
result =
RegisterIntExpr(new SubCstIntVar(this, sub_var, new_constant));
break;
}
case OPP_VAR: {
OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
IntVar* const sub_var = add_var->SubVar();
result = RegisterIntExpr(new SubCstIntVar(this, sub_var, v));
break;
}
default:
result = RegisterIntExpr(RevAlloc(new PlusCstIntVar(this, var, v)));
}
} else {
result = RegisterIntExpr(RevAlloc(new PlusIntCstExpr(this, e, v)));
}
Cache()->InsertExprConstantExpression(
result, e, v, ModelCache::EXPR_CONSTANT_SUM);
}
@@ -6230,7 +6248,40 @@ IntExpr* Solver::MakeDifference(int64 v, IntExpr* const e) {
IntExpr* result = Cache()->FindExprConstantExpression(
e, v, ModelCache::EXPR_CONSTANT_DIFFERENCE);
if (result == NULL) {
result = RegisterIntExpr(RevAlloc(new SubIntCstExpr(this, e, v)));
if (e->IsVar()) {
IntVar* const var = e->Var();
switch (var->VarType()) {
case VAR_ADD_CST: {
PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
IntVar* const sub_var = add_var->SubVar();
const int64 new_constant = v - add_var->Constant();
if (new_constant == 0) {
result = sub_var;
} else {
result = RegisterIntExpr(
RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
}
break;
}
case CST_SUB_VAR: {
SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
IntVar* const sub_var = add_var->SubVar();
const int64 new_constant = v - add_var->Constant();
result = MakeSum(sub_var, new_constant);
break;
}
case OPP_VAR: {
OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
IntVar* const sub_var = add_var->SubVar();
result = MakeSum(sub_var, v);
break;
}
default:
result = RegisterIntExpr(RevAlloc(new SubCstIntVar(this, var, v)));
}
} else {
result = RegisterIntExpr(RevAlloc(new SubIntCstExpr(this, e, v)));
}
Cache()->InsertExprConstantExpression(
result, e, v, ModelCache::EXPR_CONSTANT_DIFFERENCE);
}