diff --git a/examples/flatzinc/cube_sum.fzn b/examples/flatzinc/cube_sum.fzn index 68848aa230..ac7b77834d 100644 --- a/examples/flatzinc/cube_sum.fzn +++ b/examples/flatzinc/cube_sum.fzn @@ -567,4 +567,5 @@ constraint int_times(x[7], x[7], INT____00214) :: defines_var(INT____00214); constraint int_times(x[8], x[8], INT____00216) :: defines_var(INT____00216); constraint int_times(x[9], x[9], INT____00218) :: defines_var(INT____00218); constraint int_times(x[10], x[10], INT____00220) :: defines_var(INT____00220); + solve :: int_search([x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9], x[10], INT____00021], first_fail, indomain, complete) minimize INT____00021; diff --git a/examples/flatzinc/four_power.fzn b/examples/flatzinc/four_power.fzn index eda4f34e09..7fc4268a65 100644 --- a/examples/flatzinc/four_power.fzn +++ b/examples/flatzinc/four_power.fzn @@ -1,8 +1,13 @@ predicate all_different_int(array [int] of var int: x); predicate count(array [int] of var int: x, var int: y, var int: c); +predicate fixed_cumulative(array [int] of var int: s, array [int] of int: d, array [int] of int: r, int: b); predicate global_cardinality(array [int] of var int: x, array [int] of int: cover, array [int] of var int: counts); +predicate maximum_int(var int: m, array [int] of var int: x); +predicate minimum_int(var int: m, array [int] of var int: x); +predicate sort(array [int] of var int: x, array [int] of var int: y); predicate table_bool(array [int] of var bool: x, array [int, int] of bool: t); predicate table_int(array [int] of var int: x, array [int, int] of int: t); +predicate var_cumulative(array [int] of var int: s, array [int] of int: d, array [int] of int: r, var int: b); var 1..160000: INT____00001 :: is_defined_var :: var_is_introduced; var 1..64000000: INT____00002 :: is_defined_var :: var_is_introduced; var 1..25600000000: INT____00003 :: is_defined_var :: var_is_introduced; @@ -33,7 +38,7 @@ constraint int_lin_eq([-1, 1, 1, 1, 1], [INT____00016, INT____00003, INT____0000 constraint int_times(INT____00001, x1, INT____00002) :: defines_var(INT____00002); constraint int_times(INT____00002, x1, INT____00003) :: defines_var(INT____00003); constraint int_times(INT____00004, x2, INT____00005) :: defines_var(INT____00005); -constraint int_times(INT____00005, x3, INT____00006) :: defines_var(INT____00006); +constraint int_times(INT____00005, x2, INT____00006) :: defines_var(INT____00006); constraint int_times(INT____00007, x3, INT____00008) :: defines_var(INT____00008); constraint int_times(INT____00008, x3, INT____00009) :: defines_var(INT____00009); constraint int_times(INT____00010, x4, INT____00011) :: defines_var(INT____00011); diff --git a/src/constraint_solver/constraint_solver.cc b/src/constraint_solver/constraint_solver.cc index 93bcffaecd..9b6016f1a0 100644 --- a/src/constraint_solver/constraint_solver.cc +++ b/src/constraint_solver/constraint_solver.cc @@ -2468,6 +2468,17 @@ void Solver::Fail() { searches_.back()->JumpBack(); } +// ----- Cast Expression ----- + +IntExpr* Solver::CastExpression(IntVar* const var) const { + const IntegerCastInfo* const cast_info = + FindOrNull(cast_information_, var); + if (cast_info != NULL) { + return cast_info->expression; + } + return NULL; +} + // --- Propagation object names --- string Solver::GetName(const PropagationBaseObject* object) { @@ -2619,6 +2630,7 @@ const char ModelVisitor::kOpposite[] = "Opposite"; const char ModelVisitor::kPack[] = "Pack"; const char ModelVisitor::kPathCumul[] = "PathCumul"; const char ModelVisitor::kPerformedExpr[] = "PerformedExpression"; +const char ModelVisitor::kPower[] = "Power"; const char ModelVisitor::kProduct[] = "Product"; const char ModelVisitor::kScalProd[] = "ScalarProduct"; const char ModelVisitor::kScalProdEqual[] = "ScalarProductEqual"; diff --git a/src/constraint_solver/constraint_solver.h b/src/constraint_solver/constraint_solver.h index 0141ff6cb3..93839c1e7b 100644 --- a/src/constraint_solver/constraint_solver.h +++ b/src/constraint_solver/constraint_solver.h @@ -1258,6 +1258,8 @@ class Solver { IntExpr* MakeAbs(IntExpr* const expr); // expr * expr IntExpr* MakeSquare(IntExpr* const expr); + // expr ^ n (n > 0) + IntExpr* MakePower(IntExpr* const expr, int64 n); // vals[expr] IntExpr* MakeElement(const std::vector& vals, IntVar* const index); @@ -2848,6 +2850,9 @@ class Solver { string GetName(const PropagationBaseObject* object); void SetName(const PropagationBaseObject* object, const string& name); + // Internal. + IntExpr* CastExpression(IntVar* const var) const; + const string name_; const SolverParameters parameters_; hash_map propagation_object_names_; @@ -3171,6 +3176,7 @@ class ModelVisitor : public BaseObject { static const char kPack[]; static const char kPathCumul[]; static const char kPerformedExpr[]; + static const char kPower[]; static const char kProduct[]; static const char kScalProd[]; static const char kScalProdEqual[]; diff --git a/src/constraint_solver/expressions.cc b/src/constraint_solver/expressions.cc index 5275050d6c..6335cba7b7 100644 --- a/src/constraint_solver/expressions.cc +++ b/src/constraint_solver/expressions.cc @@ -4030,7 +4030,6 @@ int64 IntAbs::Max() const { // ----- Square ----- - // TODO(user): shouldn't we compare to kint32max^2 instead of kint64max? class IntSquare : public BaseIntExpr { public: @@ -4102,13 +4101,17 @@ class IntSquare : public BaseIntExpr { visitor->EndVisitIntegerExpression(ModelVisitor::kSquare, this); } - private: + IntExpr* expr() const { + return expr_; + } + + protected: IntExpr* const expr_; }; -class PosIntSquare : public BaseIntExpr { +class PosIntSquare : public IntSquare { public: - PosIntSquare(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {} + PosIntSquare(Solver* const s, IntExpr* const e) : IntSquare(s, e) {} virtual ~PosIntSquare() {} virtual int64 Min() const { @@ -4136,28 +4139,234 @@ class PosIntSquare : public BaseIntExpr { const int64 root = static_cast(floor(sqrt(static_cast(m)))); expr_->SetMax(root); } +}; + +// ----- EvenPower ----- + +class BasePower : public BaseIntExpr { + public: + BasePower(Solver* const s, IntExpr* const e, int64 n) + : BaseIntExpr(s), + expr_(e), + pow_(n), + limit_(static_cast(floor(exp(log(kint64max) / pow_)))) {} + + virtual ~BasePower() {} + virtual bool Bound() const { return expr_->Bound(); } + + IntExpr* expr() const { + return expr_; + } + + int64 exponant() const { + return pow_; + } + virtual void WhenRange(Demon* d) { expr_->WhenRange(d); } + virtual string name() const { - return StringPrintf("PosIntSquare(%s)", expr_->name().c_str()); + return StringPrintf("IntPower(%s, %" GG_LL_FORMAT "d)", + expr_->name().c_str(), + pow_); } + virtual string DebugString() const { - return StringPrintf("PosIntSquare(%s)", expr_->DebugString().c_str()); + return StringPrintf("IntPower(%s, %" GG_LL_FORMAT "d)", + expr_->DebugString().c_str(), + pow_); } virtual void Accept(ModelVisitor* const visitor) const { - visitor->BeginVisitIntegerExpression(ModelVisitor::kSquare, this); + visitor->BeginVisitIntegerExpression(ModelVisitor::kPower, this); visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument, expr_); - visitor->EndVisitIntegerExpression(ModelVisitor::kSquare, this); + visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, pow_); + visitor->EndVisitIntegerExpression(ModelVisitor::kPower, this); + } + + protected: + int64 Pown(int64 value) const { + if (value >= limit_) { + return kint64max; + } + if (value <= -limit_) { + if (pow_ % 2 == 0) { + return kint64max; + } else { + return kint64min; + } + } + + int64 result = value; + for (int i = 1; i < pow_; ++i) { + result *= value; + } + return result; + } + + int64 SqrnDown(int64 value) const { + if (value == kint64min) { + return kint64min; + } + if (value == kint64max) { + return kint64max; + } + int64 res = 0; + if (value >= 0) { + const double sq = exp(log(value) / pow_); + res = static_cast(floor(sq)); + } else { + CHECK_EQ(1, pow_ % 2); + const double sq = exp(log(-value) / pow_); + res = -static_cast(ceil(sq)); + } + const int64 pow_res = Pown(res + 1); + if (pow_res <= value) { + return res + 1; + } else { + return res; + } + } + + int64 SqrnUp(int64 value) const { + if (value == kint64min) { + return kint64min; + } + if (value == kint64max) { + return kint64max; + } + int64 res = 0; + if (value >= 0) { + const double sq = exp(log(value) / pow_); + res = static_cast(ceil(sq)); + } else { + CHECK_EQ(1, pow_ % 2); + const double sq = exp(log(-value) / pow_); + res = -static_cast(floor(sq)); + } + const int64 pow_res = Pown(res - 1); + if (pow_res >= value) { + return res - 1; + } else { + return res; + } } - private: IntExpr* const expr_; + const int64 pow_; + const int64 limit_; +}; + +class IntEvenPower : public BasePower { + public: + IntEvenPower(Solver* const s, IntExpr* const e, int64 n) + : BasePower(s, e, n) {} + + virtual ~IntEvenPower() {} + + virtual int64 Min() const { + const int64 emin = expr_->Min(); + if (emin >= 0) { + return Pown(emin); + } + const int64 emax = expr_->Max(); + if (emax < 0) { + return Pown(emax); + } + return 0LL; + } + virtual void SetMin(int64 m) { + if (m <= 0) { + return; + } + const int64 emin = expr_->Min(); + const int64 emax = expr_->Max(); + const int64 root = SqrnUp(m); + if (emin >= 0) { + expr_->SetMin(root); + } else if (emax <= 0) { + expr_->SetMax(-root); + } else if (expr_->IsVar()) { + reinterpret_cast(expr_)->RemoveInterval(-root + 1, root - 1); + } + } + + virtual int64 Max() const { + return std::max(Pown(expr_->Min()), Pown(expr_->Max())); + } + + virtual void SetMax(int64 m) { + if (m < 0) { + solver()->Fail(); + } + if (m == kint64max) { + return; + } + const int64 root = SqrnDown(m); + expr_->SetRange(-root, root); + } +}; + +class PosIntEvenPower : public BasePower { + public: + PosIntEvenPower(Solver* const s, + IntExpr* const e, + int64 pow) : BasePower(s, e, pow) {} + + virtual ~PosIntEvenPower() {} + + virtual int64 Min() const { + return Pown(expr_->Min()); + } + + virtual void SetMin(int64 m) { + if (m <= 0) { + return; + } + expr_->SetMin(SqrnUp(m)); + } + virtual int64 Max() const { + return Pown(expr_->Max()); + } + + virtual void SetMax(int64 m) { + if (m < 0) { + solver()->Fail(); + } + if (m == kint64max) { + return; + } + expr_->SetMax(SqrnDown(m)); + } +}; + +class IntOddPower : public BasePower { + public: + IntOddPower(Solver* const s, IntExpr* const e, int64 n) + : BasePower(s, e, n) {} + + virtual ~IntOddPower() {} + + virtual int64 Min() const { + return Pown(expr_->Min()); + } + + virtual void SetMin(int64 m) { + expr_->SetMin(SqrnUp(m)); + } + + virtual int64 Max() const { + return Pown(expr_->Max()); + } + + virtual void SetMax(int64 m) { + expr_->SetMax(SqrnDown(m)); + } }; // ----- Min(expr, expr) ----- @@ -5406,12 +5615,68 @@ IntExpr* Solver::MakeProd(IntExpr* const l, IntExpr* const r) { if (l->Bound()) { return MakeProd(r, l->Min()); } + if (r->Bound()) { return MakeProd(l, r->Min()); } - if (l == r) { - return MakeSquare(l); + + IntExpr* left = l; + IntExpr* right = r; + int64 left_exponant = 1; + int64 right_exponant = 1; + if (dynamic_cast(l) != NULL) { + BasePower* const left_power = dynamic_cast(l); + left = left_power->expr(); + left_exponant = left_power->exponant(); } + if (dynamic_cast(l) != NULL) { + IntSquare* const left_power = dynamic_cast(l); + left = left_power->expr(); + left_exponant = 2; + } + if (left->IsVar()) { + IntVar* const left_var = left->Var(); + IntExpr* const left_sub = CastExpression(left_var); + if (left_sub != NULL && dynamic_cast(left_sub) != NULL) { + BasePower* const left_power = dynamic_cast(left_sub); + left = left_power->expr(); + left_exponant = left_power->exponant(); + } + if (left_sub != NULL && dynamic_cast(left_sub) != NULL) { + IntSquare* const left_power = dynamic_cast(left_sub); + left = left_power->expr(); + left_exponant = 2; + } + } + if (dynamic_cast(r) != NULL) { + BasePower* const right_power = dynamic_cast(l); + right = right_power->expr(); + right_exponant = right_power->exponant(); + } + if (dynamic_cast(r) != NULL) { + IntSquare* const right_power = dynamic_cast(l); + right = right_power->expr(); + right_exponant = 2; + } + if (right->IsVar()) { + IntVar* const right_var = right->Var(); + IntExpr* const right_sub = CastExpression(right_var); + if (right_sub != NULL && dynamic_cast(right_sub) != NULL) { + BasePower* const right_power = dynamic_cast(right_sub); + right = right_power->expr(); + right_exponant = right_power->exponant(); + } + if (right_sub != NULL && dynamic_cast(right_sub) != NULL) { + IntSquare* const right_power = dynamic_cast(right_sub); + right = right_power->expr(); + right_exponant = 2; + } + } + + if (left == right) { + return MakePower(left, left_exponant + right_exponant); + } + CHECK_EQ(this, l->solver()); CHECK_EQ(this, r->solver()); if (l->IsVar() && l->Var()->VarType() == BOOLEAN_VAR) { @@ -5517,6 +5782,44 @@ IntExpr* Solver::MakeSquare(IntExpr* const e) { return result; } +IntExpr* Solver::MakePower(IntExpr* const e, int64 n) { + CHECK_EQ(this, e->solver()); + CHECK_GE(n, 0); + if (e->Bound()) { + const int64 v = e->Min(); + int64 result = 1; + for (int i = 0; i < n; ++i) { + result = CapProd(result, v); + } + return MakeIntConst(result); + } + switch (n) { + case 0: + return MakeIntConst(1); + case 1: + return e; + case 2: + return MakeSquare(e); + default: { + IntExpr* result = NULL; + //Cache()->FindExprExpression( e, ModelCache::EXPR_SQUARE); + // if (result == NULL) { + if (n % 2 == 0) { // even. + if (e->Min() >= 0 ) { + result = RegisterIntExpr(RevAlloc(new PosIntEvenPower(this, e, n))); + } else { + result = RegisterIntExpr(RevAlloc(new IntEvenPower(this, e, n))); + } + } else { + result = RegisterIntExpr(RevAlloc(new IntOddPower(this, e, n))); + } + // Cache()->InsertExprExpression(result, e, ModelCache::EXPR_SQUARE); + // } + return result; + } + } +} + IntExpr* Solver::MakeMin(IntExpr* const l, IntExpr* const r) { CHECK_EQ(this, l->solver()); CHECK_EQ(this, r->solver());