22 #include "absl/strings/str_format.h"
23 #include "absl/strings/str_join.h"
34 "Initial size of the array of the hash "
35 "table of caches for objects of type Var(x == 3)");
43 class EqualityExprCst :
public Constraint {
45 EqualityExprCst(Solver*
const s, IntExpr*
const e,
int64 v);
46 ~EqualityExprCst()
override {}
48 void InitialPropagate()
override;
49 IntVar* Var()
override {
50 return solver()->MakeIsEqualCstVar(
expr_->Var(), value_);
52 std::string DebugString()
const override;
54 void Accept(ModelVisitor*
const visitor)
const override {
55 visitor->BeginVisitConstraint(ModelVisitor::kEquality,
this);
56 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
58 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
59 visitor->EndVisitConstraint(ModelVisitor::kEquality,
this);
67 EqualityExprCst::EqualityExprCst(Solver*
const s, IntExpr*
const e,
int64 v)
68 : Constraint(s),
expr_(e), value_(v) {}
70 void EqualityExprCst::Post() {
71 if (!
expr_->IsVar()) {
72 Demon* d = solver()->MakeConstraintInitialPropagateCallback(
this);
77 void EqualityExprCst::InitialPropagate() {
expr_->SetValue(value_); }
79 std::string EqualityExprCst::DebugString()
const {
80 return absl::StrFormat(
"(%s == %d)",
expr_->DebugString(), value_);
84 Constraint* Solver::MakeEquality(IntExpr*
const e,
int64 v) {
85 CHECK_EQ(
this, e->solver());
86 IntExpr* left =
nullptr;
87 IntExpr* right =
nullptr;
88 if (IsADifference(e, &left, &right)) {
89 return MakeEquality(left, MakeSum(right, v));
90 }
else if (e->IsVar() && !e->Var()->Contains(v)) {
91 return MakeFalseConstraint();
92 }
else if (e->Min() == e->Max() && e->Min() == v) {
93 return MakeTrueConstraint();
95 return RevAlloc(
new EqualityExprCst(
this, e, v));
99 Constraint* Solver::MakeEquality(IntExpr*
const e,
int v) {
100 CHECK_EQ(
this, e->solver());
101 IntExpr* left =
nullptr;
102 IntExpr* right =
nullptr;
103 if (IsADifference(e, &left, &right)) {
104 return MakeEquality(left, MakeSum(right, v));
105 }
else if (e->IsVar() && !e->Var()->Contains(v)) {
106 return MakeFalseConstraint();
107 }
else if (e->Min() == e->Max() && e->Min() == v) {
108 return MakeTrueConstraint();
110 return RevAlloc(
new EqualityExprCst(
this, e, v));
118 class GreaterEqExprCst :
public Constraint {
120 GreaterEqExprCst(Solver*
const s, IntExpr*
const e,
int64 v);
121 ~GreaterEqExprCst()
override {}
122 void Post()
override;
123 void InitialPropagate()
override;
124 std::string DebugString()
const override;
125 IntVar* Var()
override {
126 return solver()->MakeIsGreaterOrEqualCstVar(
expr_->Var(), value_);
129 void Accept(ModelVisitor*
const visitor)
const override {
130 visitor->BeginVisitConstraint(ModelVisitor::kGreaterOrEqual,
this);
131 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
133 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
134 visitor->EndVisitConstraint(ModelVisitor::kGreaterOrEqual,
this);
138 IntExpr*
const expr_;
143 GreaterEqExprCst::GreaterEqExprCst(Solver*
const s, IntExpr*
const e,
int64 v)
144 : Constraint(s),
expr_(e), value_(v), demon_(nullptr) {}
146 void GreaterEqExprCst::Post() {
147 if (!
expr_->IsVar() &&
expr_->Min() < value_) {
148 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
149 expr_->WhenRange(demon_);
156 void GreaterEqExprCst::InitialPropagate() {
157 expr_->SetMin(value_);
158 if (demon_ !=
nullptr &&
expr_->Min() >= value_) {
159 demon_->inhibit(solver());
163 std::string GreaterEqExprCst::DebugString()
const {
164 return absl::StrFormat(
"(%s >= %d)",
expr_->DebugString(), value_);
168 Constraint* Solver::MakeGreaterOrEqual(IntExpr*
const e,
int64 v) {
169 CHECK_EQ(
this, e->solver());
171 return MakeTrueConstraint();
172 }
else if (e->Max() < v) {
173 return MakeFalseConstraint();
175 return RevAlloc(
new GreaterEqExprCst(
this, e, v));
179 Constraint* Solver::MakeGreaterOrEqual(IntExpr*
const e,
int v) {
180 CHECK_EQ(
this, e->solver());
182 return MakeTrueConstraint();
183 }
else if (e->Max() < v) {
184 return MakeFalseConstraint();
186 return RevAlloc(
new GreaterEqExprCst(
this, e, v));
190 Constraint* Solver::MakeGreater(IntExpr*
const e,
int64 v) {
191 CHECK_EQ(
this, e->solver());
193 return MakeTrueConstraint();
194 }
else if (e->Max() <= v) {
195 return MakeFalseConstraint();
197 return RevAlloc(
new GreaterEqExprCst(
this, e, v + 1));
201 Constraint* Solver::MakeGreater(IntExpr*
const e,
int v) {
202 CHECK_EQ(
this, e->solver());
204 return MakeTrueConstraint();
205 }
else if (e->Max() <= v) {
206 return MakeFalseConstraint();
208 return RevAlloc(
new GreaterEqExprCst(
this, e, v + 1));
216 class LessEqExprCst :
public Constraint {
218 LessEqExprCst(Solver*
const s, IntExpr*
const e,
int64 v);
219 ~LessEqExprCst()
override {}
220 void Post()
override;
221 void InitialPropagate()
override;
222 std::string DebugString()
const override;
223 IntVar* Var()
override {
224 return solver()->MakeIsLessOrEqualCstVar(
expr_->Var(), value_);
226 void Accept(ModelVisitor*
const visitor)
const override {
227 visitor->BeginVisitConstraint(ModelVisitor::kLessOrEqual,
this);
228 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
230 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
231 visitor->EndVisitConstraint(ModelVisitor::kLessOrEqual,
this);
235 IntExpr*
const expr_;
240 LessEqExprCst::LessEqExprCst(Solver*
const s, IntExpr*
const e,
int64 v)
241 : Constraint(s),
expr_(e), value_(v), demon_(nullptr) {}
243 void LessEqExprCst::Post() {
244 if (!
expr_->IsVar() &&
expr_->Max() > value_) {
245 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
246 expr_->WhenRange(demon_);
253 void LessEqExprCst::InitialPropagate() {
254 expr_->SetMax(value_);
255 if (demon_ !=
nullptr &&
expr_->Max() <= value_) {
256 demon_->inhibit(solver());
260 std::string LessEqExprCst::DebugString()
const {
261 return absl::StrFormat(
"(%s <= %d)",
expr_->DebugString(), value_);
265 Constraint* Solver::MakeLessOrEqual(IntExpr*
const e,
int64 v) {
266 CHECK_EQ(
this, e->solver());
268 return MakeTrueConstraint();
269 }
else if (e->Min() > v) {
270 return MakeFalseConstraint();
272 return RevAlloc(
new LessEqExprCst(
this, e, v));
276 Constraint* Solver::MakeLessOrEqual(IntExpr*
const e,
int v) {
277 CHECK_EQ(
this, e->solver());
279 return MakeTrueConstraint();
280 }
else if (e->Min() > v) {
281 return MakeFalseConstraint();
283 return RevAlloc(
new LessEqExprCst(
this, e, v));
287 Constraint* Solver::MakeLess(IntExpr*
const e,
int64 v) {
288 CHECK_EQ(
this, e->solver());
290 return MakeTrueConstraint();
291 }
else if (e->Min() >= v) {
292 return MakeFalseConstraint();
294 return RevAlloc(
new LessEqExprCst(
this, e, v - 1));
298 Constraint* Solver::MakeLess(IntExpr*
const e,
int v) {
299 CHECK_EQ(
this, e->solver());
301 return MakeTrueConstraint();
302 }
else if (e->Min() >= v) {
303 return MakeFalseConstraint();
305 return RevAlloc(
new LessEqExprCst(
this, e, v - 1));
313 class DiffCst :
public Constraint {
316 ~DiffCst()
override {}
317 void Post()
override {}
318 void InitialPropagate()
override;
319 void BoundPropagate();
320 std::string DebugString()
const override;
321 IntVar* Var()
override {
322 return solver()->MakeIsDifferentCstVar(var_, value_);
324 void Accept(ModelVisitor*
const visitor)
const override {
325 visitor->BeginVisitConstraint(ModelVisitor::kNonEqual,
this);
326 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
328 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
329 visitor->EndVisitConstraint(ModelVisitor::kNonEqual,
this);
333 bool HasLargeDomain(IntVar*
var);
340 DiffCst::DiffCst(Solver*
const s, IntVar*
const var,
int64 value)
341 : Constraint(s), var_(
var), value_(
value), demon_(nullptr) {}
343 void DiffCst::InitialPropagate() {
344 if (HasLargeDomain(var_)) {
347 var_->WhenRange(demon_);
349 var_->RemoveValue(value_);
353 void DiffCst::BoundPropagate() {
354 const int64 var_min = var_->Min();
355 const int64 var_max = var_->Max();
356 if (var_min > value_ || var_max < value_) {
357 demon_->inhibit(solver());
358 }
else if (var_min == value_) {
359 var_->SetMin(value_ + 1);
360 }
else if (var_max == value_) {
361 var_->SetMax(value_ - 1);
362 }
else if (!HasLargeDomain(var_)) {
363 demon_->inhibit(solver());
364 var_->RemoveValue(value_);
368 std::string DiffCst::DebugString()
const {
369 return absl::StrFormat(
"(%s != %d)", var_->DebugString(), value_);
372 bool DiffCst::HasLargeDomain(IntVar*
var) {
377 Constraint* Solver::MakeNonEquality(IntExpr*
const e,
int64 v) {
378 CHECK_EQ(
this, e->solver());
379 IntExpr* left =
nullptr;
380 IntExpr* right =
nullptr;
381 if (IsADifference(e, &left, &right)) {
382 return MakeNonEquality(left, MakeSum(right, v));
383 }
else if (e->IsVar() && !e->Var()->Contains(v)) {
384 return MakeTrueConstraint();
385 }
else if (e->Bound() && e->Min() == v) {
386 return MakeFalseConstraint();
388 return RevAlloc(
new DiffCst(
this, e->Var(), v));
392 Constraint* Solver::MakeNonEquality(IntExpr*
const e,
int v) {
393 CHECK_EQ(
this, e->solver());
394 IntExpr* left =
nullptr;
395 IntExpr* right =
nullptr;
396 if (IsADifference(e, &left, &right)) {
397 return MakeNonEquality(left, MakeSum(right, v));
398 }
else if (e->IsVar() && !e->Var()->Contains(v)) {
399 return MakeTrueConstraint();
400 }
else if (e->Bound() && e->Min() == v) {
401 return MakeFalseConstraint();
403 return RevAlloc(
new DiffCst(
this, e->Var(), v));
409 class IsEqualCstCt :
public CastConstraint {
411 IsEqualCstCt(Solver*
const s, IntVar*
const v,
int64 c, IntVar*
const b)
412 : CastConstraint(s,
b), var_(v),
cst_(c), demon_(nullptr) {}
413 void Post()
override {
414 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
415 var_->WhenDomain(demon_);
418 void InitialPropagate()
override {
419 bool inhibit = var_->Bound();
421 int64 l = inhibit ? u : 0;
425 if (var_->Size() <= 0xFFFFFF) {
426 var_->RemoveValue(
cst_);
430 var_->SetValue(
cst_);
435 demon_->inhibit(solver());
438 std::string DebugString()
const override {
439 return absl::StrFormat(
"IsEqualCstCt(%s, %d, %s)", var_->DebugString(),
443 void Accept(ModelVisitor*
const visitor)
const override {
444 visitor->BeginVisitConstraint(ModelVisitor::kIsEqual,
this);
445 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
447 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument,
cst_);
448 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
450 visitor->EndVisitConstraint(ModelVisitor::kIsEqual,
this);
460 IntVar* Solver::MakeIsEqualCstVar(IntExpr*
const var,
int64 value) {
461 IntExpr* left =
nullptr;
462 IntExpr* right =
nullptr;
463 if (IsADifference(
var, &left, &right)) {
464 return MakeIsEqualVar(left, MakeSum(right,
value));
468 return MakeDifference(
value + 1,
var)->Var();
470 return MakeSum(
var, -
value + 1)->Var();
472 return MakeIntConst(0);
478 IntVar*
const boolvar =
479 MakeBoolVar(absl::StrFormat(
"Is(%s == %d)",
var->DebugString(),
value));
480 AddConstraint(MakeIsEqualCstCt(
var,
value, boolvar));
485 Constraint* Solver::MakeIsEqualCstCt(IntExpr*
const var,
int64 value,
486 IntVar*
const boolvar) {
487 CHECK_EQ(
this,
var->solver());
488 CHECK_EQ(
this, boolvar->solver());
491 return MakeEquality(MakeDifference(
value + 1,
var), boolvar);
493 return MakeIsLessOrEqualCstCt(
var,
value, boolvar);
497 return MakeEquality(MakeSum(
var, -
value + 1), boolvar);
499 return MakeIsGreaterOrEqualCstCt(
var,
value, boolvar);
501 if (boolvar->Bound()) {
502 if (boolvar->Min() == 0) {
510 model_cache_->InsertExprConstantExpression(
511 boolvar,
var,
value, ModelCache::EXPR_CONSTANT_IS_EQUAL);
512 IntExpr* left =
nullptr;
513 IntExpr* right =
nullptr;
514 if (IsADifference(
var, &left, &right)) {
515 return MakeIsEqualCt(left, MakeSum(right,
value), boolvar);
517 return RevAlloc(
new IsEqualCstCt(
this,
var->Var(),
value, boolvar));
524 class IsDiffCstCt :
public CastConstraint {
526 IsDiffCstCt(Solver*
const s, IntVar*
const v,
int64 c, IntVar*
const b)
527 : CastConstraint(s,
b), var_(v),
cst_(c), demon_(nullptr) {}
529 void Post()
override {
530 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
531 var_->WhenDomain(demon_);
535 void InitialPropagate()
override {
536 bool inhibit = var_->Bound();
538 int64 u = inhibit ? l : 1;
542 if (var_->Size() <= 0xFFFFFF) {
543 var_->RemoveValue(
cst_);
547 var_->SetValue(
cst_);
552 demon_->inhibit(solver());
556 std::string DebugString()
const override {
557 return absl::StrFormat(
"IsDiffCstCt(%s, %d, %s)", var_->DebugString(),
cst_,
561 void Accept(ModelVisitor*
const visitor)
const override {
562 visitor->BeginVisitConstraint(ModelVisitor::kIsDifferent,
this);
563 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
565 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument,
cst_);
566 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
568 visitor->EndVisitConstraint(ModelVisitor::kIsDifferent,
this);
578 IntVar* Solver::MakeIsDifferentCstVar(IntExpr*
const var,
int64 value) {
579 IntExpr* left =
nullptr;
580 IntExpr* right =
nullptr;
581 if (IsADifference(
var, &left, &right)) {
582 return MakeIsDifferentVar(left, MakeSum(right,
value));
584 return var->Var()->IsDifferent(
value);
587 Constraint* Solver::MakeIsDifferentCstCt(IntExpr*
const var,
int64 value,
588 IntVar*
const boolvar) {
589 CHECK_EQ(
this,
var->solver());
590 CHECK_EQ(
this, boolvar->solver());
592 return MakeIsGreaterOrEqualCstCt(
var,
value + 1, boolvar);
595 return MakeIsLessOrEqualCstCt(
var,
value - 1, boolvar);
597 if (
var->IsVar() && !
var->Var()->Contains(
value)) {
598 return MakeEquality(boolvar,
int64{1});
601 return MakeEquality(boolvar, Zero());
603 if (boolvar->Bound()) {
604 if (boolvar->Min() == 0) {
610 model_cache_->InsertExprConstantExpression(
611 boolvar,
var,
value, ModelCache::EXPR_CONSTANT_IS_NOT_EQUAL);
612 IntExpr* left =
nullptr;
613 IntExpr* right =
nullptr;
614 if (IsADifference(
var, &left, &right)) {
615 return MakeIsDifferentCt(left, MakeSum(right,
value), boolvar);
617 return RevAlloc(
new IsDiffCstCt(
this,
var->Var(),
value, boolvar));
624 class IsGreaterEqualCstCt :
public CastConstraint {
626 IsGreaterEqualCstCt(Solver*
const s, IntExpr*
const v,
int64 c,
628 : CastConstraint(s,
b),
expr_(v),
cst_(c), demon_(nullptr) {}
629 void Post()
override {
630 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
631 expr_->WhenRange(demon_);
634 void InitialPropagate()
override {
635 bool inhibit =
false;
651 demon_->inhibit(solver());
654 std::string DebugString()
const override {
655 return absl::StrFormat(
"IsGreaterEqualCstCt(%s, %d, %s)",
660 void Accept(ModelVisitor*
const visitor)
const override {
661 visitor->BeginVisitConstraint(ModelVisitor::kIsGreaterOrEqual,
this);
662 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
664 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument,
cst_);
665 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
667 visitor->EndVisitConstraint(ModelVisitor::kIsGreaterOrEqual,
this);
671 IntExpr*
const expr_;
677 IntVar* Solver::MakeIsGreaterOrEqualCstVar(IntExpr*
const var,
int64 value) {
679 return MakeIntConst(
int64{1});
682 return MakeIntConst(
int64{0});
685 return var->Var()->IsGreaterOrEqual(
value);
687 IntVar*
const boolvar =
688 MakeBoolVar(absl::StrFormat(
"Is(%s >= %d)",
var->DebugString(),
value));
689 AddConstraint(MakeIsGreaterOrEqualCstCt(
var,
value, boolvar));
694 IntVar* Solver::MakeIsGreaterCstVar(IntExpr*
const var,
int64 value) {
695 return MakeIsGreaterOrEqualCstVar(
var,
value + 1);
698 Constraint* Solver::MakeIsGreaterOrEqualCstCt(IntExpr*
const var,
int64 value,
699 IntVar*
const boolvar) {
700 if (boolvar->Bound()) {
701 if (boolvar->Min() == 0) {
704 return MakeGreaterOrEqual(
var,
value);
707 CHECK_EQ(
this,
var->solver());
708 CHECK_EQ(
this, boolvar->solver());
709 model_cache_->InsertExprConstantExpression(
710 boolvar,
var,
value, ModelCache::EXPR_CONSTANT_IS_GREATER_OR_EQUAL);
711 return RevAlloc(
new IsGreaterEqualCstCt(
this,
var,
value, boolvar));
714 Constraint* Solver::MakeIsGreaterCstCt(IntExpr*
const v,
int64 c,
716 return MakeIsGreaterOrEqualCstCt(v, c + 1,
b);
722 class IsLessEqualCstCt :
public CastConstraint {
724 IsLessEqualCstCt(Solver*
const s, IntExpr*
const v,
int64 c, IntVar*
const b)
725 : CastConstraint(s,
b),
expr_(v),
cst_(c), demon_(nullptr) {}
727 void Post()
override {
728 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
729 expr_->WhenRange(demon_);
733 void InitialPropagate()
override {
734 bool inhibit =
false;
750 demon_->inhibit(solver());
754 std::string DebugString()
const override {
755 return absl::StrFormat(
"IsLessEqualCstCt(%s, %d, %s)",
expr_->DebugString(),
759 void Accept(ModelVisitor*
const visitor)
const override {
760 visitor->BeginVisitConstraint(ModelVisitor::kIsLessOrEqual,
this);
761 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
763 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument,
cst_);
764 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
766 visitor->EndVisitConstraint(ModelVisitor::kIsLessOrEqual,
this);
770 IntExpr*
const expr_;
776 IntVar* Solver::MakeIsLessOrEqualCstVar(IntExpr*
const var,
int64 value) {
778 return MakeIntConst(
int64{1});
781 return MakeIntConst(
int64{0});
784 return var->Var()->IsLessOrEqual(
value);
786 IntVar*
const boolvar =
787 MakeBoolVar(absl::StrFormat(
"Is(%s <= %d)",
var->DebugString(),
value));
788 AddConstraint(MakeIsLessOrEqualCstCt(
var,
value, boolvar));
793 IntVar* Solver::MakeIsLessCstVar(IntExpr*
const var,
int64 value) {
794 return MakeIsLessOrEqualCstVar(
var,
value - 1);
797 Constraint* Solver::MakeIsLessOrEqualCstCt(IntExpr*
const var,
int64 value,
798 IntVar*
const boolvar) {
799 if (boolvar->Bound()) {
800 if (boolvar->Min() == 0) {
806 CHECK_EQ(
this,
var->solver());
807 CHECK_EQ(
this, boolvar->solver());
808 model_cache_->InsertExprConstantExpression(
809 boolvar,
var,
value, ModelCache::EXPR_CONSTANT_IS_LESS_OR_EQUAL);
810 return RevAlloc(
new IsLessEqualCstCt(
this,
var,
value, boolvar));
813 Constraint* Solver::MakeIsLessCstCt(IntExpr*
const v,
int64 c,
815 return MakeIsLessOrEqualCstCt(v, c - 1,
b);
821 class BetweenCt :
public Constraint {
823 BetweenCt(Solver*
const s, IntExpr*
const v,
int64 l,
int64 u)
824 : Constraint(s),
expr_(v), min_(l), max_(u), demon_(nullptr) {}
826 void Post()
override {
827 if (!
expr_->IsVar()) {
828 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
829 expr_->WhenRange(demon_);
833 void InitialPropagate()
override {
834 expr_->SetRange(min_, max_);
837 expr_->Range(&emin, &emax);
838 if (demon_ !=
nullptr && emin >= min_ && emax <= max_) {
839 demon_->inhibit(solver());
843 std::string DebugString()
const override {
844 return absl::StrFormat(
"BetweenCt(%s, %d, %d)",
expr_->DebugString(), min_,
848 void Accept(ModelVisitor*
const visitor)
const override {
849 visitor->BeginVisitConstraint(ModelVisitor::kBetween,
this);
850 visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, min_);
851 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
853 visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, max_);
854 visitor->EndVisitConstraint(ModelVisitor::kBetween,
this);
858 IntExpr*
const expr_;
866 class NotBetweenCt :
public Constraint {
868 NotBetweenCt(Solver*
const s, IntExpr*
const v,
int64 l,
int64 u)
869 : Constraint(s),
expr_(v), min_(l), max_(u), demon_(nullptr) {}
871 void Post()
override {
872 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
873 expr_->WhenRange(demon_);
876 void InitialPropagate()
override {
879 expr_->Range(&emin, &emax);
881 expr_->SetMin(max_ + 1);
882 }
else if (emax <= max_) {
883 expr_->SetMax(min_ - 1);
886 if (!
expr_->IsVar() && (emax < min_ || emin > max_)) {
887 demon_->inhibit(solver());
891 std::string DebugString()
const override {
892 return absl::StrFormat(
"NotBetweenCt(%s, %d, %d)",
expr_->DebugString(),
896 void Accept(ModelVisitor*
const visitor)
const override {
897 visitor->BeginVisitConstraint(ModelVisitor::kNotBetween,
this);
898 visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, min_);
899 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
901 visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, max_);
902 visitor->EndVisitConstraint(ModelVisitor::kBetween,
this);
906 IntExpr*
const expr_;
912 int64 ExtractExprProductCoeff(IntExpr** expr) {
915 while ((*expr)->solver()->IsProduct(*expr, expr, &coeff)) prod *= coeff;
920 Constraint* Solver::MakeBetweenCt(IntExpr* expr,
int64 l,
int64 u) {
921 DCHECK_EQ(
this, expr->solver());
924 if (l > u)
return MakeFalseConstraint();
925 return MakeEquality(expr, l);
929 expr->Range(&emin, &emax);
931 if (emax < l || emin > u)
return MakeFalseConstraint();
932 if (emin >= l && emax <= u)
return MakeTrueConstraint();
934 if (emax <= u)
return MakeGreaterOrEqual(expr, l);
935 if (emin >= l)
return MakeLessOrEqual(expr, u);
937 int64 coeff = ExtractExprProductCoeff(&expr);
949 return RevAlloc(
new BetweenCt(
this, expr, l, u));
953 Constraint* Solver::MakeNotBetweenCt(IntExpr* expr,
int64 l,
int64 u) {
954 DCHECK_EQ(
this, expr->solver());
957 return MakeTrueConstraint();
962 expr->Range(&emin, &emax);
964 if (emax < l || emin > u)
return MakeTrueConstraint();
965 if (emin >= l && emax <= u)
return MakeFalseConstraint();
967 if (emin >= l)
return MakeGreater(expr, u);
968 if (emax <= u)
return MakeLess(expr, l);
971 return RevAlloc(
new NotBetweenCt(
this, expr, l, u));
977 class IsBetweenCt :
public Constraint {
979 IsBetweenCt(Solver*
const s, IntExpr*
const e,
int64 l,
int64 u,
988 void Post()
override {
989 demon_ = solver()->MakeConstraintInitialPropagateCallback(
this);
990 expr_->WhenRange(demon_);
991 boolvar_->WhenBound(demon_);
994 void InitialPropagate()
override {
995 bool inhibit =
false;
998 expr_->Range(&emin, &emax);
999 int64 u = 1 - (emin > max_ || emax < min_);
1000 int64 l = emax <= max_ && emin >= min_;
1001 boolvar_->SetRange(l, u);
1002 if (boolvar_->Bound()) {
1004 if (boolvar_->Min() == 0) {
1005 if (
expr_->IsVar()) {
1006 expr_->Var()->RemoveInterval(min_, max_);
1008 }
else if (emin > min_) {
1009 expr_->SetMin(max_ + 1);
1010 }
else if (emax < max_) {
1011 expr_->SetMax(min_ - 1);
1014 expr_->SetRange(min_, max_);
1017 if (inhibit &&
expr_->IsVar()) {
1018 demon_->inhibit(solver());
1023 std::string DebugString()
const override {
1024 return absl::StrFormat(
"IsBetweenCt(%s, %d, %d, %s)",
expr_->DebugString(),
1025 min_, max_, boolvar_->DebugString());
1028 void Accept(ModelVisitor*
const visitor)
const override {
1029 visitor->BeginVisitConstraint(ModelVisitor::kIsBetween,
this);
1030 visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, min_);
1031 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1033 visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, max_);
1034 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1036 visitor->EndVisitConstraint(ModelVisitor::kIsBetween,
this);
1040 IntExpr*
const expr_;
1043 IntVar*
const boolvar_;
1048 Constraint* Solver::MakeIsBetweenCt(IntExpr* expr,
int64 l,
int64 u,
1050 CHECK_EQ(
this, expr->solver());
1051 CHECK_EQ(
this,
b->solver());
1054 if (l > u)
return MakeEquality(
b, Zero());
1055 return MakeIsEqualCstCt(expr, l,
b);
1059 expr->Range(&emin, &emax);
1061 if (emax < l || emin > u)
return MakeEquality(
b, Zero());
1062 if (emin >= l && emax <= u)
return MakeEquality(
b, 1);
1064 if (emax <= u)
return MakeIsGreaterOrEqualCstCt(expr, l,
b);
1065 if (emin >= l)
return MakeIsLessOrEqualCstCt(expr, u,
b);
1067 int64 coeff = ExtractExprProductCoeff(&expr);
1080 return RevAlloc(
new IsBetweenCt(
this, expr, l, u,
b));
1084 IntVar* Solver::MakeIsBetweenVar(IntExpr*
const v,
int64 l,
int64 u) {
1085 CHECK_EQ(
this, v->solver());
1086 IntVar*
const b = MakeBoolVar();
1087 AddConstraint(MakeIsBetweenCt(v, l, u,
b));
1097 class MemberCt :
public Constraint {
1099 MemberCt(Solver*
const s, IntVar*
const v,
1100 const std::vector<int64>& sorted_values)
1101 : Constraint(s), var_(v), values_(sorted_values) {
1102 DCHECK(v !=
nullptr);
1103 DCHECK(s !=
nullptr);
1106 void Post()
override {}
1108 void InitialPropagate()
override { var_->SetValues(values_); }
1110 std::string DebugString()
const override {
1111 return absl::StrFormat(
"Member(%s, %s)", var_->DebugString(),
1112 absl::StrJoin(values_,
", "));
1115 void Accept(ModelVisitor*
const visitor)
const override {
1116 visitor->BeginVisitConstraint(ModelVisitor::kMember,
this);
1117 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1119 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
1120 visitor->EndVisitConstraint(ModelVisitor::kMember,
this);
1125 const std::vector<int64> values_;
1128 class NotMemberCt :
public Constraint {
1130 NotMemberCt(Solver*
const s, IntVar*
const v,
1131 const std::vector<int64>& sorted_values)
1132 : Constraint(s), var_(v), values_(sorted_values) {
1133 DCHECK(v !=
nullptr);
1134 DCHECK(s !=
nullptr);
1137 void Post()
override {}
1139 void InitialPropagate()
override { var_->RemoveValues(values_); }
1141 std::string DebugString()
const override {
1142 return absl::StrFormat(
"NotMember(%s, %s)", var_->DebugString(),
1143 absl::StrJoin(values_,
", "));
1146 void Accept(ModelVisitor*
const visitor)
const override {
1147 visitor->BeginVisitConstraint(ModelVisitor::kMember,
this);
1148 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1150 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
1151 visitor->EndVisitConstraint(ModelVisitor::kMember,
this);
1156 const std::vector<int64> values_;
1160 Constraint* Solver::MakeMemberCt(IntExpr* expr,
1161 const std::vector<int64>& values) {
1162 const int64 coeff = ExtractExprProductCoeff(&expr);
1164 return std::find(values.begin(), values.end(), 0) == values.end()
1165 ? MakeFalseConstraint()
1166 : MakeTrueConstraint();
1168 std::vector<int64> copied_values = values;
1173 for (
const int64 v : copied_values) {
1174 if (v % coeff == 0) copied_values[num_kept++] = v / coeff;
1176 copied_values.resize(num_kept);
1182 expr->Range(&emin, &emax);
1183 for (
const int64 v : copied_values) {
1184 if (v >= emin && v <= emax) copied_values[num_kept++] = v;
1186 copied_values.resize(num_kept);
1188 if (copied_values.empty())
return MakeFalseConstraint();
1192 if (copied_values.size() == 1)
return MakeEquality(expr, copied_values[0]);
1194 if (copied_values.size() ==
1195 copied_values.back() - copied_values.front() + 1) {
1197 return MakeBetweenCt(expr, copied_values.front(), copied_values.back());
1202 if (emax - emin < 2 * copied_values.size()) {
1204 std::vector<bool> is_among_input_values(emax - emin + 1,
false);
1205 for (
const int64 v : copied_values) is_among_input_values[v - emin] =
true;
1208 copied_values.clear();
1209 for (
int64 v_off = 0; v_off < is_among_input_values.size(); ++v_off) {
1210 if (!is_among_input_values[v_off]) copied_values.push_back(v_off + emin);
1214 DCHECK_GE(copied_values.size(), 1);
1215 if (copied_values.size() == 1) {
1216 return MakeNonEquality(expr, copied_values[0]);
1218 return RevAlloc(
new NotMemberCt(
this, expr->Var(), copied_values));
1221 return RevAlloc(
new MemberCt(
this, expr->Var(), copied_values));
1224 Constraint* Solver::MakeMemberCt(IntExpr*
const expr,
1225 const std::vector<int>& values) {
1229 Constraint* Solver::MakeNotMemberCt(IntExpr* expr,
1230 const std::vector<int64>& values) {
1231 const int64 coeff = ExtractExprProductCoeff(&expr);
1233 return std::find(values.begin(), values.end(), 0) == values.end()
1234 ? MakeTrueConstraint()
1235 : MakeFalseConstraint();
1237 std::vector<int64> copied_values = values;
1242 for (
const int64 v : copied_values) {
1243 if (v % coeff == 0) copied_values[num_kept++] = v / coeff;
1245 copied_values.resize(num_kept);
1251 expr->Range(&emin, &emax);
1252 for (
const int64 v : copied_values) {
1253 if (v >= emin && v <= emax) copied_values[num_kept++] = v;
1255 copied_values.resize(num_kept);
1257 if (copied_values.empty())
return MakeTrueConstraint();
1261 if (copied_values.size() == 1)
return MakeNonEquality(expr, copied_values[0]);
1263 if (copied_values.size() ==
1264 copied_values.back() - copied_values.front() + 1) {
1265 return MakeNotBetweenCt(expr, copied_values.front(), copied_values.back());
1270 if (emax - emin < 2 * copied_values.size()) {
1272 std::vector<bool> is_among_input_values(emax - emin + 1,
false);
1273 for (
const int64 v : copied_values) is_among_input_values[v - emin] =
true;
1276 copied_values.clear();
1277 for (
int64 v_off = 0; v_off < is_among_input_values.size(); ++v_off) {
1278 if (!is_among_input_values[v_off]) copied_values.push_back(v_off + emin);
1282 DCHECK_GE(copied_values.size(), 1);
1283 if (copied_values.size() == 1) {
1284 return MakeEquality(expr, copied_values[0]);
1286 return RevAlloc(
new MemberCt(
this, expr->Var(), copied_values));
1289 return RevAlloc(
new NotMemberCt(
this, expr->Var(), copied_values));
1292 Constraint* Solver::MakeNotMemberCt(IntExpr*
const expr,
1293 const std::vector<int>& values) {
1300 class IsMemberCt :
public Constraint {
1302 IsMemberCt(Solver*
const s, IntVar*
const v,
1303 const std::vector<int64>& sorted_values, IntVar*
const b)
1306 values_as_set_(sorted_values.begin(), sorted_values.end()),
1307 values_(sorted_values),
1311 domain_(var_->MakeDomainIterator(true)),
1313 DCHECK(v !=
nullptr);
1314 DCHECK(s !=
nullptr);
1315 DCHECK(
b !=
nullptr);
1321 void Post()
override {
1324 if (!var_->Bound()) {
1325 var_->WhenDomain(demon_);
1327 if (!boolvar_->Bound()) {
1329 solver(),
this, &IsMemberCt::TargetBound,
"TargetBound");
1330 boolvar_->WhenBound(bdemon);
1334 void InitialPropagate()
override {
1335 boolvar_->SetRange(0, 1);
1336 if (boolvar_->Bound()) {
1343 std::string DebugString()
const override {
1344 return absl::StrFormat(
"IsMemberCt(%s, %s, %s)", var_->DebugString(),
1345 absl::StrJoin(values_,
", "),
1346 boolvar_->DebugString());
1349 void Accept(ModelVisitor*
const visitor)
const override {
1350 visitor->BeginVisitConstraint(ModelVisitor::kIsMember,
this);
1351 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1353 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
1354 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1356 visitor->EndVisitConstraint(ModelVisitor::kIsMember,
this);
1361 if (boolvar_->Bound()) {
1364 for (
int offset = 0; offset < values_.size(); ++offset) {
1365 const int candidate = (support_ + offset) % values_.size();
1366 if (var_->Contains(values_[candidate])) {
1367 support_ = candidate;
1368 if (var_->Bound()) {
1369 demon_->inhibit(solver());
1370 boolvar_->SetValue(1);
1375 if (var_->Contains(neg_support_)) {
1379 for (
const int64 value : InitAndGetValues(domain_)) {
1381 neg_support_ =
value;
1387 demon_->inhibit(solver());
1388 boolvar_->SetValue(1);
1393 demon_->inhibit(solver());
1394 boolvar_->SetValue(0);
1398 void TargetBound() {
1399 DCHECK(boolvar_->Bound());
1400 if (boolvar_->Min() == 1LL) {
1401 demon_->inhibit(solver());
1402 var_->SetValues(values_);
1404 demon_->inhibit(solver());
1405 var_->RemoveValues(values_);
1410 absl::flat_hash_set<int64> values_as_set_;
1411 std::vector<int64> values_;
1412 IntVar*
const boolvar_;
1415 IntVarIterator*
const domain_;
1420 Constraint* BuildIsMemberCt(Solver*
const solver, IntExpr*
const expr,
1421 const std::vector<T>& values,
1422 IntVar*
const boolvar) {
1425 IntExpr* sub =
nullptr;
1427 if (solver->IsProduct(expr, &sub, &
coef) &&
coef != 0 &&
coef != 1) {
1428 std::vector<int64> new_values;
1429 new_values.reserve(values.size());
1435 return BuildIsMemberCt(solver, sub, new_values, boolvar);
1438 std::set<T> set_of_values(values.begin(), values.end());
1439 std::vector<int64> filtered_values;
1440 bool all_values =
false;
1441 if (expr->IsVar()) {
1442 IntVar*
const var = expr->Var();
1443 for (
const T
value : set_of_values) {
1445 filtered_values.push_back(
value);
1448 all_values = (filtered_values.size() ==
var->Size());
1452 expr->Range(&emin, &emax);
1453 for (
const T
value : set_of_values) {
1455 filtered_values.push_back(
value);
1458 all_values = (filtered_values.size() == emax - emin + 1);
1460 if (filtered_values.empty()) {
1461 return solver->MakeEquality(boolvar, Zero());
1462 }
else if (all_values) {
1463 return solver->MakeEquality(boolvar, 1);
1464 }
else if (filtered_values.size() == 1) {
1465 return solver->MakeIsEqualCstCt(expr, filtered_values.back(), boolvar);
1466 }
else if (filtered_values.back() ==
1467 filtered_values.front() + filtered_values.size() - 1) {
1469 return solver->MakeIsBetweenCt(expr, filtered_values.front(),
1470 filtered_values.back(), boolvar);
1472 return solver->RevAlloc(
1473 new IsMemberCt(solver, expr->Var(), filtered_values, boolvar));
1478 Constraint* Solver::MakeIsMemberCt(IntExpr*
const expr,
1479 const std::vector<int64>& values,
1480 IntVar*
const boolvar) {
1481 return BuildIsMemberCt(
this, expr, values, boolvar);
1484 Constraint* Solver::MakeIsMemberCt(IntExpr*
const expr,
1485 const std::vector<int>& values,
1486 IntVar*
const boolvar) {
1487 return BuildIsMemberCt(
this, expr, values, boolvar);
1490 IntVar* Solver::MakeIsMemberVar(IntExpr*
const expr,
1491 const std::vector<int64>& values) {
1492 IntVar*
const b = MakeBoolVar();
1493 AddConstraint(MakeIsMemberCt(expr, values,
b));
1497 IntVar* Solver::MakeIsMemberVar(IntExpr*
const expr,
1498 const std::vector<int>& values) {
1499 IntVar*
const b = MakeBoolVar();
1500 AddConstraint(MakeIsMemberCt(expr, values,
b));
1505 class SortedDisjointForbiddenIntervalsConstraint :
public Constraint {
1507 SortedDisjointForbiddenIntervalsConstraint(
1508 Solver*
const solver, IntVar*
const var,
1509 SortedDisjointIntervalList intervals)
1510 : Constraint(solver), var_(
var), intervals_(std::move(intervals)) {}
1512 ~SortedDisjointForbiddenIntervalsConstraint()
override {}
1514 void Post()
override {
1515 Demon*
const demon = solver()->MakeConstraintInitialPropagateCallback(
this);
1516 var_->WhenRange(demon);
1519 void InitialPropagate()
override {
1520 const int64 vmin = var_->Min();
1521 const int64 vmax = var_->Max();
1522 const auto first_interval_it = intervals_.FirstIntervalGreaterOrEqual(vmin);
1523 if (first_interval_it == intervals_.end()) {
1527 const auto last_interval_it = intervals_.LastIntervalLessOrEqual(vmax);
1528 if (last_interval_it == intervals_.end()) {
1534 if (vmin >= first_interval_it->start) {
1537 var_->SetMin(
CapAdd(first_interval_it->end, 1));
1539 if (vmax <= last_interval_it->end) {
1541 var_->SetMax(
CapSub(last_interval_it->start, 1));
1545 std::string DebugString()
const override {
1546 return absl::StrFormat(
"ForbiddenIntervalCt(%s, %s)", var_->DebugString(),
1547 intervals_.DebugString());
1550 void Accept(ModelVisitor*
const visitor)
const override {
1551 visitor->BeginVisitConstraint(ModelVisitor::kNotMember,
this);
1552 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1554 std::vector<int64> starts;
1555 std::vector<int64> ends;
1556 for (
auto&
interval : intervals_) {
1560 visitor->VisitIntegerArrayArgument(ModelVisitor::kStartsArgument, starts);
1561 visitor->VisitIntegerArrayArgument(ModelVisitor::kEndsArgument, ends);
1562 visitor->EndVisitConstraint(ModelVisitor::kNotMember,
this);
1567 const SortedDisjointIntervalList intervals_;
1571 Constraint* Solver::MakeNotMemberCt(IntExpr*
const expr,
1572 std::vector<int64> starts,
1573 std::vector<int64> ends) {
1574 return RevAlloc(
new SortedDisjointForbiddenIntervalsConstraint(
1575 this, expr->Var(), {starts, ends}));
1578 Constraint* Solver::MakeNotMemberCt(IntExpr*
const expr,
1579 std::vector<int> starts,
1580 std::vector<int> ends) {
1581 return RevAlloc(
new SortedDisjointForbiddenIntervalsConstraint(
1582 this, expr->Var(), {starts, ends}));
1585 Constraint* Solver::MakeNotMemberCt(IntExpr* expr,
1586 SortedDisjointIntervalList intervals) {
1587 return RevAlloc(
new SortedDisjointForbiddenIntervalsConstraint(
1588 this, expr->Var(), std::move(intervals)));