diff --git a/src/constraint_solver/constraint_solver.h b/src/constraint_solver/constraint_solver.h index 36ada815e4..3c51028267 100644 --- a/src/constraint_solver/constraint_solver.h +++ b/src/constraint_solver/constraint_solver.h @@ -3738,6 +3738,10 @@ class IntVar : public IntExpr { // Accepts the given visitor. virtual void Accept(ModelVisitor* const visitor) const; + // IsEqual + virtual IntVar* IsEqual(int64 constant) = 0; + virtual IntVar* IsDifferent(int64 constant) = 0; + private: DISALLOW_COPY_AND_ASSIGN(IntVar); }; diff --git a/src/constraint_solver/expressions.cc b/src/constraint_solver/expressions.cc index 0888b4f962..4fd39142f3 100644 --- a/src/constraint_solver/expressions.cc +++ b/src/constraint_solver/expressions.cc @@ -301,6 +301,14 @@ class DomainIntVar : public IntVar { } } + virtual IntVar* IsEqual(int64 constant) { + return NULL; // IMPLEMENT ME. + } + + virtual IntVar* IsDifferent(int64 constant) { + return NULL; // IMPLEMENT ME + } + void Process(); void Push(); void ClearInProcess(); @@ -1342,6 +1350,28 @@ class BooleanVar : public IntVar { virtual string DebugString() const; virtual int VarType() const { return BOOLEAN_VAR; } + virtual IntVar* IsEqual(int64 constant) { + if (constant > 1 || constant < 0) { + return solver()->MakeIntConst(0); + } + if (constant == 1) { + return this; + } else { // constant == 0. + return solver()->MakeDifference(1, this)->Var(); + } + } + + virtual IntVar* IsDifferent(int64 constant) { + if (constant > 1 || constant < 0) { + return solver()->MakeIntConst(1); + } + if (constant == 1) { + return solver()->MakeDifference(1, this)->Var(); + } else { // constant == 0. + return this; + } + } + void RestoreValue() { value_ = kUnboundBooleanVarValue; } virtual string BaseName() const { return "BooleanVar"; } @@ -1524,6 +1554,22 @@ class IntConst : public IntVar { virtual string DebugString() const; virtual int VarType() const { return CONST_VAR; } + virtual IntVar* IsEqual(int64 constant) { + if (constant == value_) { + return solver()->MakeIntConst(1); + } else { + return solver()->MakeIntConst(0); + } + } + + virtual IntVar* IsDifferent(int64 constant) { + if (constant == value_) { + return solver()->MakeIntConst(0); + } else { + return solver()->MakeIntConst(1); + } + } + private: int64 value_; }; @@ -1602,6 +1648,14 @@ class PlusCstIntVar : public IntVar { var_); } + virtual IntVar* IsEqual(int64 constant) { + return var_->IsEqual(constant - cst_); + } + + virtual IntVar* IsDifferent(int64 constant) { + return var_->IsDifferent(constant - cst_); + } + private: IntVar* const var_; const int64 cst_; @@ -1739,6 +1793,14 @@ class PlusCstDomainIntVar : public IntVar { var_); } + virtual IntVar* IsEqual(int64 constant) { + return var_->IsEqual(constant - cst_); + } + + virtual IntVar* IsDifferent(int64 constant) { + return var_->IsDifferent(constant - cst_); + } + private: DomainIntVar* const var_; const int64 cst_; @@ -1878,6 +1940,14 @@ class SubCstIntVar : public IntVar { var_); } + virtual IntVar* IsEqual(int64 constant) { + return var_->IsEqual(cst_ - constant); + } + + virtual IntVar* IsDifferent(int64 constant) { + return var_->IsDifferent(cst_ - constant); + } + private: IntVar* const var_; const int64 cst_; @@ -2009,6 +2079,14 @@ class OppIntVar : public IntVar { var_); } + virtual IntVar* IsEqual(int64 constant) { + return var_->IsEqual(-constant); + } + + virtual IntVar* IsDifferent(int64 constant) { + return var_->IsDifferent(-constant); + } + private: IntVar* const var_; }; @@ -2162,6 +2240,22 @@ class TimesPosCstIntVar : public IntVar { var_); } + virtual IntVar* IsEqual(int64 constant) { + if (constant % cst_ == 0) { + return var_->IsEqual(constant / cst_); + } else { + return solver()->MakeIntConst(0); + } + } + + virtual IntVar* IsDifferent(int64 constant) { + if (constant % cst_ == 0) { + return var_->IsDifferent(constant / cst_); + } else { + return solver()->MakeIntConst(1); + } + } + private: IntVar* const var_; const int64 cst_; @@ -2315,6 +2409,22 @@ class TimesPosCstBoolVar : public IntVar { var_); } + virtual IntVar* IsEqual(int64 constant) { + if (constant % cst_ == 0) { + return var_->IsEqual(constant / cst_); + } else { + return solver()->MakeIntConst(0); + } + } + + virtual IntVar* IsDifferent(int64 constant) { + if (constant % cst_ == 0) { + return var_->IsDifferent(constant / cst_); + } else { + return solver()->MakeIntConst(1); + } + } + private: BooleanVar* const var_; const int64 cst_;