extract size 2 var element expressions into IfThenElseExpr

This commit is contained in:
lperron@google.com
2014-07-30 16:32:00 +00:00
parent 12bc709f05
commit 07c93b6ff3
3 changed files with 213 additions and 29 deletions

View File

@@ -1279,6 +1279,12 @@ class Solver {
// It assumes that vars are all different.
IntExpr* MakeIndexExpression(const std::vector<IntVar*>& vars, int64 value);
// Special cases with arrays of size two.
IntExpr* MakeIfThenElse(IntVar* const condition, int64 then_value,
int64 else_value);
IntExpr* MakeIfThenElse(IntVar* const condition, IntExpr* const then_expr,
IntExpr* const else_expr);
// std::min(vars)
IntExpr* MakeMin(const std::vector<IntVar*>& vars);
// min (left, right)

View File

@@ -965,6 +965,136 @@ IntExpr* Solver::MakeElement(ResultCallback2<int64, int64, int64>* values,
// ---------- Generalized element ----------
// ----- IfThenElseExpr -----
class IfThenElseExpr : public BaseIntExpr {
public:
IfThenElseExpr(Solver* const solver, IntVar* const index, IntExpr* zero,
IntExpr* const one)
: BaseIntExpr(solver), index_(index), zero_(zero), one_(one) {}
virtual ~IfThenElseExpr() {}
virtual int64 Min() const {
if (index_->Max() == 0) {
return zero_->Min();
} else if (index_->Min() == 1) {
return one_->Min();
} else {
return std::min(zero_->Min(), one_->Min());
}
}
virtual void SetMin(int64 m) {
if (index_->Max() == 0) {
zero_->SetMin(m);
} else if (index_->Min() == 1) {
one_->SetMin(m);
} else {
if (m > zero_->Max()) {
index_->SetValue(1);
one_->SetMin(m);
} else if (m > one_->Max()) {
index_->SetValue(0);
zero_->SetMin(m);
}
}
}
virtual int64 Max() const {
if (index_->Max() == 0) {
return zero_->Max();
} else if (index_->Min() == 1) {
return one_->Max();
} else {
return std::max(zero_->Max(), one_->Max());
}
}
virtual void SetMax(int64 m) {
if (index_->Max() == 0) {
zero_->SetMax(m);
} else if (index_->Min() == 1) {
one_->SetMax(m);
} else {
if (m < zero_->Min()) {
index_->SetValue(1);
one_->SetMax(m);
} else if (m < one_->Min()) {
index_->SetValue(0);
zero_->SetMax(m);
}
}
}
virtual void Range(int64* l, int64* u) {
if (index_->Max() == 0) {
zero_->Range(l, u);
} else if (index_->Min() == 1) {
one_->Range(l, u);
} else {
int64 zl = 0;
int64 zu = 0;
int64 ol = 0;
int64 ou = 0;
zero_->Range(&zl, &zu);
one_->Range(&ol, &ou);
*l = std::min(zl, ol);
*u = std::max(zu, ou);
}
}
virtual void SetRange(int64 mi, int64 ma) {
if (index_->Max() == 0) {
zero_->SetRange(mi, ma);
} else if (index_->Min() == 1) {
one_->SetRange(mi, ma);
} else {
if (ma < zero_->Min() || mi > zero_->Max()) {
index_->SetValue(1);
one_->SetRange(mi, ma);
} else if (ma < one_->Min() || mi > one_->Max()) {
index_->SetValue(0);
zero_->SetRange(mi, ma);
}
}
}
virtual bool Bound() const {
if (index_->Max() == 0) {
return zero_->Bound();
} else if (index_->Min() == 1) {
return one_->Bound();
} else {
return zero_->Bound() && one_->Bound() && zero_->Min() == one_->Min();
}
}
virtual void WhenRange(Demon* d) {
if (index_->Min() == 0) {
return zero_->WhenRange(d);
}
if (index_->Max() == 1) {
one_->WhenRange(d);
}
if (!index_->Bound()) {
index_->WhenRange(d);
}
}
virtual std::string DebugString() const {
return StringPrintf("IfThenElseExpr(%s, [%s, %s])",
index_->DebugString().c_str(),
zero_->DebugString().c_str(),
one_->DebugString().c_str());
}
private:
IntVar* const index_;
IntExpr* const zero_;
IntExpr* const one_;
};
// ----- IntExprArrayElementCt -----
// This constraint implements vars[index] == var. It is delayed such
@@ -1298,6 +1428,17 @@ Constraint* MakeElementEqualityFunc(Solver* const solver,
}
} // namespace
IntExpr* Solver::MakeIfThenElse(IntVar* const condition, int64 then_value,
int64 else_value) {
return MakeSum(MakeProd(condition, then_value - else_value), else_value);
}
IntExpr* Solver::MakeIfThenElse(IntVar* const condition,
IntExpr* const then_expr,
IntExpr* const else_expr) {
return RevAlloc(new IfThenElseExpr(this, condition, else_expr, then_expr));
}
IntExpr* Solver::MakeElement(const std::vector<IntVar*>& vars, IntVar* const index) {
if (index->Bound()) {
return vars[index->Min()];
@@ -1310,6 +1451,14 @@ IntExpr* Solver::MakeElement(const std::vector<IntVar*>& vars, IntVar* const ind
}
return MakeElement(values, index);
}
if (index->Size() == 2 && index->Min() + 1 == index->Max() &&
index->Min() >= 0 && index->Max() < vars.size()) {
// Let's get the index between 0 and 1.
IntVar* const scaled_index = MakeSum(index, -index->Min())->Var();
IntVar* const zero = vars[index->Min()];
IntVar* const one = vars[index->Max()];
return RevAlloc(new IfThenElseExpr(this, scaled_index, zero, one));
}
int64 emin = kint64max;
int64 emax = kint64min;
std::unique_ptr<IntVarIterator> iterator(index->MakeDomainIterator(false));

View File

@@ -257,38 +257,67 @@ void ExtractArrayIntElement(FzSolver* fzsolver, FzConstraint* ct) {
void ExtractArrayVarIntElement(FzSolver* fzsolver, FzConstraint* ct) {
Solver* const solver = fzsolver->solver();
IntExpr* const index = fzsolver->GetExpression(ct->Arg(0));
const std::vector<IntVar*> vars = fzsolver->GetVariableArray(ct->Arg(1));
const int64 array_size = ct->Arg(1).variables.size();
const int64 imin = std::max(index->Min(), 1LL);
const int64 imax = std::min(index->Max(), static_cast<int64>(vars.size()));
const int64 imax = std::min(index->Max(), array_size);
IntVar* const shifted_index = solver->MakeSum(index, -imin)->Var();
const int64 size = imax - imin + 1;
std::vector<IntVar*> var_array(size);
for (int i = 0; i < size; ++i) {
var_array[i] = vars[i + imin - 1];
}
if (ct->target_variable != nullptr) {
DCHECK_EQ(ct->Arg(2).Var(), ct->target_variable);
IntExpr* const target = solver->MakeElement(var_array, shifted_index);
FZVLOG << " - creating " << ct->target_variable->DebugString()
<< " := " << target->DebugString() << FZENDL;
fzsolver->SetExtracted(ct->target_variable, target);
} else {
Constraint* constraint = nullptr;
if (ct->Arg(2).HasOneValue()) {
const int64 target = ct->Arg(2).Value();
if (fzsolver->IsAllDifferent(ct->Arg(1).variables)) {
constraint =
solver->MakeIndexOfConstraint(var_array, shifted_index, target);
if (array_size == 2 && imin == 1 && imax == 2) {
IntExpr* const zero = fzsolver->Extract(ct->Arg(1).variables[0]);
IntExpr* const one = fzsolver->Extract(ct->Arg(1).variables[1]);
if (ct->target_variable != nullptr) {
DCHECK_EQ(ct->Arg(2).Var(), ct->target_variable);
IntExpr* const zero = fzsolver->Extract(ct->Arg(1).variables[0]);
IntExpr* const one = fzsolver->Extract(ct->Arg(1).variables[1]);
IntExpr* const target =
solver->MakeIfThenElse(shifted_index->Var(), one, zero);
FZVLOG << " - creating " << ct->target_variable->DebugString()
<< " := " << target->DebugString() << FZENDL;
fzsolver->SetExtracted(ct->target_variable, target);
} else {
Constraint* constraint = nullptr;
if (ct->Arg(2).HasOneValue()) {
const int64 target = ct->Arg(2).Value();
constraint = solver->MakeEquality(
solver->MakeIfThenElse(shifted_index->Var(), one, zero), target);
} else {
IntVar* const target = fzsolver->GetExpression(ct->Arg(2))->Var();
constraint = solver->MakeEquality(
solver->MakeIfThenElse(shifted_index->Var(), one, zero), target);
}
AddConstraint(solver, ct, constraint);
}
} else {
const std::vector<IntVar*> vars = fzsolver->GetVariableArray(ct->Arg(1));
const int64 size = imax - imin + 1;
std::vector<IntVar*> var_array(size);
for (int i = 0; i < size; ++i) {
var_array[i] = vars[i + imin - 1];
}
if (ct->target_variable != nullptr) {
DCHECK_EQ(ct->Arg(2).Var(), ct->target_variable);
IntExpr* const target = solver->MakeElement(var_array, shifted_index);
FZVLOG << " - creating " << ct->target_variable->DebugString()
<< " := " << target->DebugString() << FZENDL;
fzsolver->SetExtracted(ct->target_variable, target);
} else {
Constraint* constraint = nullptr;
if (ct->Arg(2).HasOneValue()) {
const int64 target = ct->Arg(2).Value();
if (fzsolver->IsAllDifferent(ct->Arg(1).variables)) {
constraint =
solver->MakeIndexOfConstraint(var_array, shifted_index, target);
} else {
constraint =
solver->MakeElementEquality(var_array, shifted_index, target);
}
} else {
IntVar* const target = fzsolver->GetExpression(ct->Arg(2))->Var();
constraint =
solver->MakeElementEquality(var_array, shifted_index, target);
}
} else {
IntVar* const target = fzsolver->GetExpression(ct->Arg(2))->Var();
constraint =
solver->MakeElementEquality(var_array, shifted_index, target);
AddConstraint(solver, ct, constraint);
}
AddConstraint(solver, ct, constraint);
}
}
@@ -929,8 +958,8 @@ void ExtractIntEqReif(FzSolver* fzsolver, FzConstraint* ct) {
if (ct->target_variable != nullptr) {
CHECK_EQ(ct->target_variable, ct->Arg(2).Var());
if (ct->Arg(1).HasOneValue()) {
IntVar* const boolvar =
solver->MakeIsEqualCstVar(left, ct->Arg(1).Value());
const int64 value = ct->Arg(1).Value();
IntVar* const boolvar = solver->MakeIsEqualCstVar(left, value);
FZVLOG << " - creating " << ct->target_variable->DebugString()
<< " := " << boolvar->DebugString() << FZENDL;
fzsolver->SetExtracted(ct->target_variable, boolvar);
@@ -1355,10 +1384,10 @@ void ExtractIntLinEqReif(FzSolver* fzsolver, FzConstraint* ct) {
if (ct->target_variable != nullptr) {
if (AreAllBooleans(vars) && AreAllOnes(coeffs)) {
IntVar* const boolvar = solver->MakeBoolVar();
PostIsBooleanSumInRange(fzsolver->Sat(), solver, vars, rhs, rhs,
boolvar);
FZVLOG << " - creating " << ct->target_variable->DebugString()
<< " := " << boolvar->DebugString() << FZENDL;
PostIsBooleanSumInRange(fzsolver->Sat(), solver, vars, rhs, rhs,
boolvar);
fzsolver->SetExtracted(ct->target_variable, boolvar);
} else {
IntVar* const boolvar =