extract size 2 var element expressions into IfThenElseExpr
This commit is contained in:
@@ -1279,6 +1279,12 @@ class Solver {
|
||||
// It assumes that vars are all different.
|
||||
IntExpr* MakeIndexExpression(const std::vector<IntVar*>& vars, int64 value);
|
||||
|
||||
// Special cases with arrays of size two.
|
||||
IntExpr* MakeIfThenElse(IntVar* const condition, int64 then_value,
|
||||
int64 else_value);
|
||||
IntExpr* MakeIfThenElse(IntVar* const condition, IntExpr* const then_expr,
|
||||
IntExpr* const else_expr);
|
||||
|
||||
// std::min(vars)
|
||||
IntExpr* MakeMin(const std::vector<IntVar*>& vars);
|
||||
// min (left, right)
|
||||
|
||||
@@ -965,6 +965,136 @@ IntExpr* Solver::MakeElement(ResultCallback2<int64, int64, int64>* values,
|
||||
|
||||
// ---------- Generalized element ----------
|
||||
|
||||
// ----- IfThenElseExpr -----
|
||||
|
||||
class IfThenElseExpr : public BaseIntExpr {
|
||||
public:
|
||||
IfThenElseExpr(Solver* const solver, IntVar* const index, IntExpr* zero,
|
||||
IntExpr* const one)
|
||||
: BaseIntExpr(solver), index_(index), zero_(zero), one_(one) {}
|
||||
|
||||
virtual ~IfThenElseExpr() {}
|
||||
|
||||
virtual int64 Min() const {
|
||||
if (index_->Max() == 0) {
|
||||
return zero_->Min();
|
||||
} else if (index_->Min() == 1) {
|
||||
return one_->Min();
|
||||
} else {
|
||||
return std::min(zero_->Min(), one_->Min());
|
||||
}
|
||||
}
|
||||
|
||||
virtual void SetMin(int64 m) {
|
||||
if (index_->Max() == 0) {
|
||||
zero_->SetMin(m);
|
||||
} else if (index_->Min() == 1) {
|
||||
one_->SetMin(m);
|
||||
} else {
|
||||
if (m > zero_->Max()) {
|
||||
index_->SetValue(1);
|
||||
one_->SetMin(m);
|
||||
} else if (m > one_->Max()) {
|
||||
index_->SetValue(0);
|
||||
zero_->SetMin(m);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
virtual int64 Max() const {
|
||||
if (index_->Max() == 0) {
|
||||
return zero_->Max();
|
||||
} else if (index_->Min() == 1) {
|
||||
return one_->Max();
|
||||
} else {
|
||||
return std::max(zero_->Max(), one_->Max());
|
||||
}
|
||||
}
|
||||
|
||||
virtual void SetMax(int64 m) {
|
||||
if (index_->Max() == 0) {
|
||||
zero_->SetMax(m);
|
||||
} else if (index_->Min() == 1) {
|
||||
one_->SetMax(m);
|
||||
} else {
|
||||
if (m < zero_->Min()) {
|
||||
index_->SetValue(1);
|
||||
one_->SetMax(m);
|
||||
} else if (m < one_->Min()) {
|
||||
index_->SetValue(0);
|
||||
zero_->SetMax(m);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
virtual void Range(int64* l, int64* u) {
|
||||
if (index_->Max() == 0) {
|
||||
zero_->Range(l, u);
|
||||
} else if (index_->Min() == 1) {
|
||||
one_->Range(l, u);
|
||||
} else {
|
||||
int64 zl = 0;
|
||||
int64 zu = 0;
|
||||
int64 ol = 0;
|
||||
int64 ou = 0;
|
||||
zero_->Range(&zl, &zu);
|
||||
one_->Range(&ol, &ou);
|
||||
*l = std::min(zl, ol);
|
||||
*u = std::max(zu, ou);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void SetRange(int64 mi, int64 ma) {
|
||||
if (index_->Max() == 0) {
|
||||
zero_->SetRange(mi, ma);
|
||||
} else if (index_->Min() == 1) {
|
||||
one_->SetRange(mi, ma);
|
||||
} else {
|
||||
if (ma < zero_->Min() || mi > zero_->Max()) {
|
||||
index_->SetValue(1);
|
||||
one_->SetRange(mi, ma);
|
||||
} else if (ma < one_->Min() || mi > one_->Max()) {
|
||||
index_->SetValue(0);
|
||||
zero_->SetRange(mi, ma);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
virtual bool Bound() const {
|
||||
if (index_->Max() == 0) {
|
||||
return zero_->Bound();
|
||||
} else if (index_->Min() == 1) {
|
||||
return one_->Bound();
|
||||
} else {
|
||||
return zero_->Bound() && one_->Bound() && zero_->Min() == one_->Min();
|
||||
}
|
||||
}
|
||||
|
||||
virtual void WhenRange(Demon* d) {
|
||||
if (index_->Min() == 0) {
|
||||
return zero_->WhenRange(d);
|
||||
}
|
||||
if (index_->Max() == 1) {
|
||||
one_->WhenRange(d);
|
||||
}
|
||||
if (!index_->Bound()) {
|
||||
index_->WhenRange(d);
|
||||
}
|
||||
}
|
||||
|
||||
virtual std::string DebugString() const {
|
||||
return StringPrintf("IfThenElseExpr(%s, [%s, %s])",
|
||||
index_->DebugString().c_str(),
|
||||
zero_->DebugString().c_str(),
|
||||
one_->DebugString().c_str());
|
||||
}
|
||||
|
||||
private:
|
||||
IntVar* const index_;
|
||||
IntExpr* const zero_;
|
||||
IntExpr* const one_;
|
||||
};
|
||||
|
||||
// ----- IntExprArrayElementCt -----
|
||||
|
||||
// This constraint implements vars[index] == var. It is delayed such
|
||||
@@ -1298,6 +1428,17 @@ Constraint* MakeElementEqualityFunc(Solver* const solver,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
IntExpr* Solver::MakeIfThenElse(IntVar* const condition, int64 then_value,
|
||||
int64 else_value) {
|
||||
return MakeSum(MakeProd(condition, then_value - else_value), else_value);
|
||||
}
|
||||
|
||||
IntExpr* Solver::MakeIfThenElse(IntVar* const condition,
|
||||
IntExpr* const then_expr,
|
||||
IntExpr* const else_expr) {
|
||||
return RevAlloc(new IfThenElseExpr(this, condition, else_expr, then_expr));
|
||||
}
|
||||
|
||||
IntExpr* Solver::MakeElement(const std::vector<IntVar*>& vars, IntVar* const index) {
|
||||
if (index->Bound()) {
|
||||
return vars[index->Min()];
|
||||
@@ -1310,6 +1451,14 @@ IntExpr* Solver::MakeElement(const std::vector<IntVar*>& vars, IntVar* const ind
|
||||
}
|
||||
return MakeElement(values, index);
|
||||
}
|
||||
if (index->Size() == 2 && index->Min() + 1 == index->Max() &&
|
||||
index->Min() >= 0 && index->Max() < vars.size()) {
|
||||
// Let's get the index between 0 and 1.
|
||||
IntVar* const scaled_index = MakeSum(index, -index->Min())->Var();
|
||||
IntVar* const zero = vars[index->Min()];
|
||||
IntVar* const one = vars[index->Max()];
|
||||
return RevAlloc(new IfThenElseExpr(this, scaled_index, zero, one));
|
||||
}
|
||||
int64 emin = kint64max;
|
||||
int64 emax = kint64min;
|
||||
std::unique_ptr<IntVarIterator> iterator(index->MakeDomainIterator(false));
|
||||
|
||||
@@ -257,38 +257,67 @@ void ExtractArrayIntElement(FzSolver* fzsolver, FzConstraint* ct) {
|
||||
void ExtractArrayVarIntElement(FzSolver* fzsolver, FzConstraint* ct) {
|
||||
Solver* const solver = fzsolver->solver();
|
||||
IntExpr* const index = fzsolver->GetExpression(ct->Arg(0));
|
||||
const std::vector<IntVar*> vars = fzsolver->GetVariableArray(ct->Arg(1));
|
||||
const int64 array_size = ct->Arg(1).variables.size();
|
||||
const int64 imin = std::max(index->Min(), 1LL);
|
||||
const int64 imax = std::min(index->Max(), static_cast<int64>(vars.size()));
|
||||
const int64 imax = std::min(index->Max(), array_size);
|
||||
IntVar* const shifted_index = solver->MakeSum(index, -imin)->Var();
|
||||
const int64 size = imax - imin + 1;
|
||||
std::vector<IntVar*> var_array(size);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
var_array[i] = vars[i + imin - 1];
|
||||
}
|
||||
if (ct->target_variable != nullptr) {
|
||||
DCHECK_EQ(ct->Arg(2).Var(), ct->target_variable);
|
||||
IntExpr* const target = solver->MakeElement(var_array, shifted_index);
|
||||
FZVLOG << " - creating " << ct->target_variable->DebugString()
|
||||
<< " := " << target->DebugString() << FZENDL;
|
||||
fzsolver->SetExtracted(ct->target_variable, target);
|
||||
} else {
|
||||
Constraint* constraint = nullptr;
|
||||
if (ct->Arg(2).HasOneValue()) {
|
||||
const int64 target = ct->Arg(2).Value();
|
||||
if (fzsolver->IsAllDifferent(ct->Arg(1).variables)) {
|
||||
constraint =
|
||||
solver->MakeIndexOfConstraint(var_array, shifted_index, target);
|
||||
if (array_size == 2 && imin == 1 && imax == 2) {
|
||||
IntExpr* const zero = fzsolver->Extract(ct->Arg(1).variables[0]);
|
||||
IntExpr* const one = fzsolver->Extract(ct->Arg(1).variables[1]);
|
||||
if (ct->target_variable != nullptr) {
|
||||
DCHECK_EQ(ct->Arg(2).Var(), ct->target_variable);
|
||||
IntExpr* const zero = fzsolver->Extract(ct->Arg(1).variables[0]);
|
||||
IntExpr* const one = fzsolver->Extract(ct->Arg(1).variables[1]);
|
||||
IntExpr* const target =
|
||||
solver->MakeIfThenElse(shifted_index->Var(), one, zero);
|
||||
FZVLOG << " - creating " << ct->target_variable->DebugString()
|
||||
<< " := " << target->DebugString() << FZENDL;
|
||||
fzsolver->SetExtracted(ct->target_variable, target);
|
||||
} else {
|
||||
Constraint* constraint = nullptr;
|
||||
if (ct->Arg(2).HasOneValue()) {
|
||||
const int64 target = ct->Arg(2).Value();
|
||||
constraint = solver->MakeEquality(
|
||||
solver->MakeIfThenElse(shifted_index->Var(), one, zero), target);
|
||||
} else {
|
||||
IntVar* const target = fzsolver->GetExpression(ct->Arg(2))->Var();
|
||||
constraint = solver->MakeEquality(
|
||||
solver->MakeIfThenElse(shifted_index->Var(), one, zero), target);
|
||||
}
|
||||
AddConstraint(solver, ct, constraint);
|
||||
}
|
||||
} else {
|
||||
const std::vector<IntVar*> vars = fzsolver->GetVariableArray(ct->Arg(1));
|
||||
const int64 size = imax - imin + 1;
|
||||
std::vector<IntVar*> var_array(size);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
var_array[i] = vars[i + imin - 1];
|
||||
}
|
||||
|
||||
if (ct->target_variable != nullptr) {
|
||||
DCHECK_EQ(ct->Arg(2).Var(), ct->target_variable);
|
||||
IntExpr* const target = solver->MakeElement(var_array, shifted_index);
|
||||
FZVLOG << " - creating " << ct->target_variable->DebugString()
|
||||
<< " := " << target->DebugString() << FZENDL;
|
||||
fzsolver->SetExtracted(ct->target_variable, target);
|
||||
} else {
|
||||
Constraint* constraint = nullptr;
|
||||
if (ct->Arg(2).HasOneValue()) {
|
||||
const int64 target = ct->Arg(2).Value();
|
||||
if (fzsolver->IsAllDifferent(ct->Arg(1).variables)) {
|
||||
constraint =
|
||||
solver->MakeIndexOfConstraint(var_array, shifted_index, target);
|
||||
} else {
|
||||
constraint =
|
||||
solver->MakeElementEquality(var_array, shifted_index, target);
|
||||
}
|
||||
} else {
|
||||
IntVar* const target = fzsolver->GetExpression(ct->Arg(2))->Var();
|
||||
constraint =
|
||||
solver->MakeElementEquality(var_array, shifted_index, target);
|
||||
}
|
||||
} else {
|
||||
IntVar* const target = fzsolver->GetExpression(ct->Arg(2))->Var();
|
||||
constraint =
|
||||
solver->MakeElementEquality(var_array, shifted_index, target);
|
||||
AddConstraint(solver, ct, constraint);
|
||||
}
|
||||
AddConstraint(solver, ct, constraint);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -929,8 +958,8 @@ void ExtractIntEqReif(FzSolver* fzsolver, FzConstraint* ct) {
|
||||
if (ct->target_variable != nullptr) {
|
||||
CHECK_EQ(ct->target_variable, ct->Arg(2).Var());
|
||||
if (ct->Arg(1).HasOneValue()) {
|
||||
IntVar* const boolvar =
|
||||
solver->MakeIsEqualCstVar(left, ct->Arg(1).Value());
|
||||
const int64 value = ct->Arg(1).Value();
|
||||
IntVar* const boolvar = solver->MakeIsEqualCstVar(left, value);
|
||||
FZVLOG << " - creating " << ct->target_variable->DebugString()
|
||||
<< " := " << boolvar->DebugString() << FZENDL;
|
||||
fzsolver->SetExtracted(ct->target_variable, boolvar);
|
||||
@@ -1355,10 +1384,10 @@ void ExtractIntLinEqReif(FzSolver* fzsolver, FzConstraint* ct) {
|
||||
if (ct->target_variable != nullptr) {
|
||||
if (AreAllBooleans(vars) && AreAllOnes(coeffs)) {
|
||||
IntVar* const boolvar = solver->MakeBoolVar();
|
||||
PostIsBooleanSumInRange(fzsolver->Sat(), solver, vars, rhs, rhs,
|
||||
boolvar);
|
||||
FZVLOG << " - creating " << ct->target_variable->DebugString()
|
||||
<< " := " << boolvar->DebugString() << FZENDL;
|
||||
PostIsBooleanSumInRange(fzsolver->Sat(), solver, vars, rhs, rhs,
|
||||
boolvar);
|
||||
fzsolver->SetExtracted(ct->target_variable, boolvar);
|
||||
} else {
|
||||
IntVar* const boolvar =
|
||||
|
||||
Reference in New Issue
Block a user