MakePower, automatic building from square and product, much better at solving cube_sum

This commit is contained in:
lperron@google.com
2012-06-19 15:24:32 +00:00
parent af0b70f23b
commit c35e4416fc
5 changed files with 339 additions and 12 deletions

View File

@@ -567,4 +567,5 @@ constraint int_times(x[7], x[7], INT____00214) :: defines_var(INT____00214);
constraint int_times(x[8], x[8], INT____00216) :: defines_var(INT____00216);
constraint int_times(x[9], x[9], INT____00218) :: defines_var(INT____00218);
constraint int_times(x[10], x[10], INT____00220) :: defines_var(INT____00220);
solve :: int_search([x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9], x[10], INT____00021], first_fail, indomain, complete) minimize INT____00021;

View File

@@ -1,8 +1,13 @@
predicate all_different_int(array [int] of var int: x);
predicate count(array [int] of var int: x, var int: y, var int: c);
predicate fixed_cumulative(array [int] of var int: s, array [int] of int: d, array [int] of int: r, int: b);
predicate global_cardinality(array [int] of var int: x, array [int] of int: cover, array [int] of var int: counts);
predicate maximum_int(var int: m, array [int] of var int: x);
predicate minimum_int(var int: m, array [int] of var int: x);
predicate sort(array [int] of var int: x, array [int] of var int: y);
predicate table_bool(array [int] of var bool: x, array [int, int] of bool: t);
predicate table_int(array [int] of var int: x, array [int, int] of int: t);
predicate var_cumulative(array [int] of var int: s, array [int] of int: d, array [int] of int: r, var int: b);
var 1..160000: INT____00001 :: is_defined_var :: var_is_introduced;
var 1..64000000: INT____00002 :: is_defined_var :: var_is_introduced;
var 1..25600000000: INT____00003 :: is_defined_var :: var_is_introduced;
@@ -33,7 +38,7 @@ constraint int_lin_eq([-1, 1, 1, 1, 1], [INT____00016, INT____00003, INT____0000
constraint int_times(INT____00001, x1, INT____00002) :: defines_var(INT____00002);
constraint int_times(INT____00002, x1, INT____00003) :: defines_var(INT____00003);
constraint int_times(INT____00004, x2, INT____00005) :: defines_var(INT____00005);
constraint int_times(INT____00005, x3, INT____00006) :: defines_var(INT____00006);
constraint int_times(INT____00005, x2, INT____00006) :: defines_var(INT____00006);
constraint int_times(INT____00007, x3, INT____00008) :: defines_var(INT____00008);
constraint int_times(INT____00008, x3, INT____00009) :: defines_var(INT____00009);
constraint int_times(INT____00010, x4, INT____00011) :: defines_var(INT____00011);

View File

@@ -2468,6 +2468,17 @@ void Solver::Fail() {
searches_.back()->JumpBack();
}
// ----- Cast Expression -----
IntExpr* Solver::CastExpression(IntVar* const var) const {
const IntegerCastInfo* const cast_info =
FindOrNull(cast_information_, var);
if (cast_info != NULL) {
return cast_info->expression;
}
return NULL;
}
// --- Propagation object names ---
string Solver::GetName(const PropagationBaseObject* object) {
@@ -2619,6 +2630,7 @@ const char ModelVisitor::kOpposite[] = "Opposite";
const char ModelVisitor::kPack[] = "Pack";
const char ModelVisitor::kPathCumul[] = "PathCumul";
const char ModelVisitor::kPerformedExpr[] = "PerformedExpression";
const char ModelVisitor::kPower[] = "Power";
const char ModelVisitor::kProduct[] = "Product";
const char ModelVisitor::kScalProd[] = "ScalarProduct";
const char ModelVisitor::kScalProdEqual[] = "ScalarProductEqual";

View File

@@ -1258,6 +1258,8 @@ class Solver {
IntExpr* MakeAbs(IntExpr* const expr);
// expr * expr
IntExpr* MakeSquare(IntExpr* const expr);
// expr ^ n (n > 0)
IntExpr* MakePower(IntExpr* const expr, int64 n);
// vals[expr]
IntExpr* MakeElement(const std::vector<int64>& vals, IntVar* const index);
@@ -2848,6 +2850,9 @@ class Solver {
string GetName(const PropagationBaseObject* object);
void SetName(const PropagationBaseObject* object, const string& name);
// Internal.
IntExpr* CastExpression(IntVar* const var) const;
const string name_;
const SolverParameters parameters_;
hash_map<const PropagationBaseObject*, string> propagation_object_names_;
@@ -3171,6 +3176,7 @@ class ModelVisitor : public BaseObject {
static const char kPack[];
static const char kPathCumul[];
static const char kPerformedExpr[];
static const char kPower[];
static const char kProduct[];
static const char kScalProd[];
static const char kScalProdEqual[];

View File

@@ -4030,7 +4030,6 @@ int64 IntAbs::Max() const {
// ----- Square -----
// TODO(user): shouldn't we compare to kint32max^2 instead of kint64max?
class IntSquare : public BaseIntExpr {
public:
@@ -4102,13 +4101,17 @@ class IntSquare : public BaseIntExpr {
visitor->EndVisitIntegerExpression(ModelVisitor::kSquare, this);
}
private:
IntExpr* expr() const {
return expr_;
}
protected:
IntExpr* const expr_;
};
class PosIntSquare : public BaseIntExpr {
class PosIntSquare : public IntSquare {
public:
PosIntSquare(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
PosIntSquare(Solver* const s, IntExpr* const e) : IntSquare(s, e) {}
virtual ~PosIntSquare() {}
virtual int64 Min() const {
@@ -4136,28 +4139,234 @@ class PosIntSquare : public BaseIntExpr {
const int64 root = static_cast<int64>(floor(sqrt(static_cast<double>(m))));
expr_->SetMax(root);
}
};
// ----- EvenPower -----
class BasePower : public BaseIntExpr {
public:
BasePower(Solver* const s, IntExpr* const e, int64 n)
: BaseIntExpr(s),
expr_(e),
pow_(n),
limit_(static_cast<int64>(floor(exp(log(kint64max) / pow_)))) {}
virtual ~BasePower() {}
virtual bool Bound() const {
return expr_->Bound();
}
IntExpr* expr() const {
return expr_;
}
int64 exponant() const {
return pow_;
}
virtual void WhenRange(Demon* d) {
expr_->WhenRange(d);
}
virtual string name() const {
return StringPrintf("PosIntSquare(%s)", expr_->name().c_str());
return StringPrintf("IntPower(%s, %" GG_LL_FORMAT "d)",
expr_->name().c_str(),
pow_);
}
virtual string DebugString() const {
return StringPrintf("PosIntSquare(%s)", expr_->DebugString().c_str());
return StringPrintf("IntPower(%s, %" GG_LL_FORMAT "d)",
expr_->DebugString().c_str(),
pow_);
}
virtual void Accept(ModelVisitor* const visitor) const {
visitor->BeginVisitIntegerExpression(ModelVisitor::kSquare, this);
visitor->BeginVisitIntegerExpression(ModelVisitor::kPower, this);
visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
expr_);
visitor->EndVisitIntegerExpression(ModelVisitor::kSquare, this);
visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, pow_);
visitor->EndVisitIntegerExpression(ModelVisitor::kPower, this);
}
protected:
int64 Pown(int64 value) const {
if (value >= limit_) {
return kint64max;
}
if (value <= -limit_) {
if (pow_ % 2 == 0) {
return kint64max;
} else {
return kint64min;
}
}
int64 result = value;
for (int i = 1; i < pow_; ++i) {
result *= value;
}
return result;
}
int64 SqrnDown(int64 value) const {
if (value == kint64min) {
return kint64min;
}
if (value == kint64max) {
return kint64max;
}
int64 res = 0;
if (value >= 0) {
const double sq = exp(log(value) / pow_);
res = static_cast<int64>(floor(sq));
} else {
CHECK_EQ(1, pow_ % 2);
const double sq = exp(log(-value) / pow_);
res = -static_cast<int64>(ceil(sq));
}
const int64 pow_res = Pown(res + 1);
if (pow_res <= value) {
return res + 1;
} else {
return res;
}
}
int64 SqrnUp(int64 value) const {
if (value == kint64min) {
return kint64min;
}
if (value == kint64max) {
return kint64max;
}
int64 res = 0;
if (value >= 0) {
const double sq = exp(log(value) / pow_);
res = static_cast<int64>(ceil(sq));
} else {
CHECK_EQ(1, pow_ % 2);
const double sq = exp(log(-value) / pow_);
res = -static_cast<int64>(floor(sq));
}
const int64 pow_res = Pown(res - 1);
if (pow_res >= value) {
return res - 1;
} else {
return res;
}
}
private:
IntExpr* const expr_;
const int64 pow_;
const int64 limit_;
};
class IntEvenPower : public BasePower {
public:
IntEvenPower(Solver* const s, IntExpr* const e, int64 n)
: BasePower(s, e, n) {}
virtual ~IntEvenPower() {}
virtual int64 Min() const {
const int64 emin = expr_->Min();
if (emin >= 0) {
return Pown(emin);
}
const int64 emax = expr_->Max();
if (emax < 0) {
return Pown(emax);
}
return 0LL;
}
virtual void SetMin(int64 m) {
if (m <= 0) {
return;
}
const int64 emin = expr_->Min();
const int64 emax = expr_->Max();
const int64 root = SqrnUp(m);
if (emin >= 0) {
expr_->SetMin(root);
} else if (emax <= 0) {
expr_->SetMax(-root);
} else if (expr_->IsVar()) {
reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
}
}
virtual int64 Max() const {
return std::max(Pown(expr_->Min()), Pown(expr_->Max()));
}
virtual void SetMax(int64 m) {
if (m < 0) {
solver()->Fail();
}
if (m == kint64max) {
return;
}
const int64 root = SqrnDown(m);
expr_->SetRange(-root, root);
}
};
class PosIntEvenPower : public BasePower {
public:
PosIntEvenPower(Solver* const s,
IntExpr* const e,
int64 pow) : BasePower(s, e, pow) {}
virtual ~PosIntEvenPower() {}
virtual int64 Min() const {
return Pown(expr_->Min());
}
virtual void SetMin(int64 m) {
if (m <= 0) {
return;
}
expr_->SetMin(SqrnUp(m));
}
virtual int64 Max() const {
return Pown(expr_->Max());
}
virtual void SetMax(int64 m) {
if (m < 0) {
solver()->Fail();
}
if (m == kint64max) {
return;
}
expr_->SetMax(SqrnDown(m));
}
};
class IntOddPower : public BasePower {
public:
IntOddPower(Solver* const s, IntExpr* const e, int64 n)
: BasePower(s, e, n) {}
virtual ~IntOddPower() {}
virtual int64 Min() const {
return Pown(expr_->Min());
}
virtual void SetMin(int64 m) {
expr_->SetMin(SqrnUp(m));
}
virtual int64 Max() const {
return Pown(expr_->Max());
}
virtual void SetMax(int64 m) {
expr_->SetMax(SqrnDown(m));
}
};
// ----- Min(expr, expr) -----
@@ -5406,12 +5615,68 @@ IntExpr* Solver::MakeProd(IntExpr* const l, IntExpr* const r) {
if (l->Bound()) {
return MakeProd(r, l->Min());
}
if (r->Bound()) {
return MakeProd(l, r->Min());
}
if (l == r) {
return MakeSquare(l);
IntExpr* left = l;
IntExpr* right = r;
int64 left_exponant = 1;
int64 right_exponant = 1;
if (dynamic_cast<BasePower*>(l) != NULL) {
BasePower* const left_power = dynamic_cast<BasePower*>(l);
left = left_power->expr();
left_exponant = left_power->exponant();
}
if (dynamic_cast<IntSquare*>(l) != NULL) {
IntSquare* const left_power = dynamic_cast<IntSquare*>(l);
left = left_power->expr();
left_exponant = 2;
}
if (left->IsVar()) {
IntVar* const left_var = left->Var();
IntExpr* const left_sub = CastExpression(left_var);
if (left_sub != NULL && dynamic_cast<BasePower*>(left_sub) != NULL) {
BasePower* const left_power = dynamic_cast<BasePower*>(left_sub);
left = left_power->expr();
left_exponant = left_power->exponant();
}
if (left_sub != NULL && dynamic_cast<IntSquare*>(left_sub) != NULL) {
IntSquare* const left_power = dynamic_cast<IntSquare*>(left_sub);
left = left_power->expr();
left_exponant = 2;
}
}
if (dynamic_cast<BasePower*>(r) != NULL) {
BasePower* const right_power = dynamic_cast<BasePower*>(l);
right = right_power->expr();
right_exponant = right_power->exponant();
}
if (dynamic_cast<IntSquare*>(r) != NULL) {
IntSquare* const right_power = dynamic_cast<IntSquare*>(l);
right = right_power->expr();
right_exponant = 2;
}
if (right->IsVar()) {
IntVar* const right_var = right->Var();
IntExpr* const right_sub = CastExpression(right_var);
if (right_sub != NULL && dynamic_cast<BasePower*>(right_sub) != NULL) {
BasePower* const right_power = dynamic_cast<BasePower*>(right_sub);
right = right_power->expr();
right_exponant = right_power->exponant();
}
if (right_sub != NULL && dynamic_cast<IntSquare*>(right_sub) != NULL) {
IntSquare* const right_power = dynamic_cast<IntSquare*>(right_sub);
right = right_power->expr();
right_exponant = 2;
}
}
if (left == right) {
return MakePower(left, left_exponant + right_exponant);
}
CHECK_EQ(this, l->solver());
CHECK_EQ(this, r->solver());
if (l->IsVar() && l->Var()->VarType() == BOOLEAN_VAR) {
@@ -5517,6 +5782,44 @@ IntExpr* Solver::MakeSquare(IntExpr* const e) {
return result;
}
IntExpr* Solver::MakePower(IntExpr* const e, int64 n) {
CHECK_EQ(this, e->solver());
CHECK_GE(n, 0);
if (e->Bound()) {
const int64 v = e->Min();
int64 result = 1;
for (int i = 0; i < n; ++i) {
result = CapProd(result, v);
}
return MakeIntConst(result);
}
switch (n) {
case 0:
return MakeIntConst(1);
case 1:
return e;
case 2:
return MakeSquare(e);
default: {
IntExpr* result = NULL;
//Cache()->FindExprExpression( e, ModelCache::EXPR_SQUARE);
// if (result == NULL) {
if (n % 2 == 0) { // even.
if (e->Min() >= 0 ) {
result = RegisterIntExpr(RevAlloc(new PosIntEvenPower(this, e, n)));
} else {
result = RegisterIntExpr(RevAlloc(new IntEvenPower(this, e, n)));
}
} else {
result = RegisterIntExpr(RevAlloc(new IntOddPower(this, e, n)));
}
// Cache()->InsertExprExpression(result, e, ModelCache::EXPR_SQUARE);
// }
return result;
}
}
}
IntExpr* Solver::MakeMin(IntExpr* const l, IntExpr* const r) {
CHECK_EQ(this, l->solver());
CHECK_EQ(this, r->solver());