23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_join.h"
33 "If true, caching for IntElement is disabled.");
38 void LinkVarExpr(Solver*
const s, IntExpr*
const expr, IntVar*
const var);
45 explicit VectorLess(
const std::vector<T>* values) : values_(values) {}
46 bool operator()(
const T& x,
const T& y)
const {
47 return (*values_)[x] < (*values_)[y];
51 const std::vector<T>* values_;
57 explicit VectorGreater(
const std::vector<T>* values) : values_(values) {}
58 bool operator()(
const T& x,
const T& y)
const {
59 return (*values_)[x] > (*values_)[y];
63 const std::vector<T>* values_;
68 class BaseIntExprElement :
public BaseIntExpr {
70 BaseIntExprElement(Solver*
const s, IntVar*
const e);
71 ~BaseIntExprElement()
override {}
72 int64_t Min()
const override;
73 int64_t Max()
const override;
74 void Range(int64_t* mi, int64_t* ma)
override;
75 void SetMin(int64_t m)
override;
76 void SetMax(int64_t m)
override;
77 void SetRange(int64_t mi, int64_t ma)
override;
78 bool Bound()
const override {
return (
expr_->Bound()); }
80 void WhenRange(Demon* d)
override {
expr_->WhenRange(d); }
83 virtual int64_t ElementValue(
int index)
const = 0;
84 virtual int64_t ExprMin()
const = 0;
85 virtual int64_t ExprMax()
const = 0;
90 void UpdateSupports()
const;
93 mutable int min_support_;
95 mutable int max_support_;
96 mutable bool initial_update_;
97 IntVarIterator*
const expr_iterator_;
100 BaseIntExprElement::BaseIntExprElement(Solver*
const s, IntVar*
const e)
107 initial_update_(true),
108 expr_iterator_(
expr_->MakeDomainIterator(true)) {
113 int64_t BaseIntExprElement::Min()
const {
118 int64_t BaseIntExprElement::Max()
const {
123 void BaseIntExprElement::Range(int64_t* mi, int64_t* ma) {
129 #define UPDATE_BASE_ELEMENT_INDEX_BOUNDS(test) \
130 const int64_t emin = ExprMin(); \
131 const int64_t emax = ExprMax(); \
132 int64_t nmin = emin; \
133 int64_t value = ElementValue(nmin); \
134 while (nmin < emax && test) { \
136 value = ElementValue(nmin); \
138 if (nmin == emax && test) { \
141 int64_t nmax = emax; \
142 value = ElementValue(nmax); \
143 while (nmax >= nmin && test) { \
145 value = ElementValue(nmax); \
147 expr_->SetRange(nmin, nmax);
149 void BaseIntExprElement::SetMin(int64_t m) {
153 void BaseIntExprElement::SetMax(int64_t m) {
157 void BaseIntExprElement::SetRange(int64_t mi, int64_t ma) {
164 #undef UPDATE_BASE_ELEMENT_INDEX_BOUNDS
166 void BaseIntExprElement::UpdateSupports()
const {
167 if (initial_update_ || !
expr_->Contains(min_support_) ||
168 !
expr_->Contains(max_support_)) {
169 const int64_t emin = ExprMin();
170 const int64_t emax = ExprMax();
171 int64_t min_value = ElementValue(emax);
172 int64_t max_value = min_value;
173 int min_support = emax;
174 int max_support = emax;
175 const uint64_t expr_size =
expr_->Size();
177 if (expr_size == emax - emin + 1) {
181 if (
value > max_value) {
184 }
else if (
value < min_value) {
190 for (
const int64_t
index : InitAndGetValues(expr_iterator_)) {
193 if (
value > max_value) {
196 }
else if (
value < min_value) {
204 Solver* s = solver();
205 s->SaveAndSetValue(&min_, min_value);
206 s->SaveAndSetValue(&min_support_, min_support);
207 s->SaveAndSetValue(&max_, max_value);
208 s->SaveAndSetValue(&max_support_, max_support);
209 s->SaveAndSetValue(&initial_update_,
false);
218 class IntElementConstraint :
public CastConstraint {
220 IntElementConstraint(Solver*
const s,
const std::vector<int64_t>& values,
221 IntVar*
const index, IntVar*
const elem)
222 : CastConstraint(s, elem),
225 index_iterator_(index_->MakeDomainIterator(true)) {
229 void Post()
override {
231 solver()->MakeDelayedConstraintInitialPropagateCallback(
this);
232 index_->WhenDomain(d);
236 void InitialPropagate()
override {
237 index_->SetRange(0, values_.size() - 1);
240 int64_t new_min = target_var_max;
241 int64_t new_max = target_var_min;
243 for (
const int64_t
index : InitAndGetValues(index_iterator_)) {
245 if (value < target_var_min || value > target_var_max) {
248 if (
value < new_min) {
251 if (
value > new_max) {
262 std::string DebugString()
const override {
263 return absl::StrFormat(
"IntElementConstraint(%s, %s, %s)",
264 absl::StrJoin(values_,
", "), index_->DebugString(),
268 void Accept(ModelVisitor*
const visitor)
const override {
269 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual,
this);
270 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
271 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
273 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
275 visitor->EndVisitConstraint(ModelVisitor::kElementEqual,
this);
279 const std::vector<int64_t> values_;
280 IntVar*
const index_;
281 IntVarIterator*
const index_iterator_;
287 IntVar* BuildDomainIntVar(Solver*
const solver, std::vector<int64_t>* values);
289 class IntExprElement :
public BaseIntExprElement {
291 IntExprElement(Solver*
const s,
const std::vector<int64_t>& vals,
293 : BaseIntExprElement(s, expr), values_(vals) {}
295 ~IntExprElement()
override {}
297 std::string
name()
const override {
298 const int size = values_.size();
300 return absl::StrFormat(
"IntElement(array of size %d, %s)", size,
303 return absl::StrFormat(
"IntElement(%s, %s)", absl::StrJoin(values_,
", "),
308 std::string DebugString()
const override {
309 const int size = values_.size();
311 return absl::StrFormat(
"IntElement(array of size %d, %s)", size,
312 expr_->DebugString());
314 return absl::StrFormat(
"IntElement(%s, %s)", absl::StrJoin(values_,
", "),
315 expr_->DebugString());
319 IntVar* CastToVar()
override {
320 Solver*
const s = solver();
321 IntVar*
const var = s->MakeIntVar(values_);
322 s->AddCastConstraint(
323 s->RevAlloc(
new IntElementConstraint(s, values_,
expr_,
var)),
var,
328 void Accept(ModelVisitor*
const visitor)
const override {
329 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement,
this);
330 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
331 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
333 visitor->EndVisitIntegerExpression(ModelVisitor::kElement,
this);
337 int64_t ElementValue(
int index)
const override {
339 return values_[
index];
341 int64_t ExprMin()
const override {
342 return std::max<int64_t>(0,
expr_->Min());
344 int64_t ExprMax()
const override {
345 return values_.empty()
347 : std::min<int64_t>(values_.size() - 1,
expr_->Max());
351 const std::vector<int64_t> values_;
356 class RangeMinimumQueryExprElement :
public BaseIntExpr {
358 RangeMinimumQueryExprElement(Solver* solver,
359 const std::vector<int64_t>& values,
361 ~RangeMinimumQueryExprElement()
override {}
362 int64_t Min()
const override;
363 int64_t Max()
const override;
364 void Range(int64_t* mi, int64_t* ma)
override;
365 void SetMin(int64_t m)
override;
366 void SetMax(int64_t m)
override;
367 void SetRange(int64_t mi, int64_t ma)
override;
368 bool Bound()
const override {
return (index_->Bound()); }
370 void WhenRange(Demon* d)
override { index_->WhenRange(d); }
371 IntVar* CastToVar()
override {
375 IntVar*
const var = solver()->MakeIntVar(min_rmq_.array());
376 solver()->AddCastConstraint(solver()->RevAlloc(
new IntElementConstraint(
377 solver(), min_rmq_.array(), index_,
var)),
381 void Accept(ModelVisitor*
const visitor)
const override {
382 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement,
this);
383 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
385 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
387 visitor->EndVisitIntegerExpression(ModelVisitor::kElement,
this);
391 int64_t IndexMin()
const {
return std::max<int64_t>(0, index_->Min()); }
392 int64_t IndexMax()
const {
393 return std::min<int64_t>(min_rmq_.array().size() - 1, index_->Max());
396 IntVar*
const index_;
397 const RangeMinimumQuery<int64_t, std::less<int64_t>> min_rmq_;
398 const RangeMinimumQuery<int64_t, std::greater<int64_t>> max_rmq_;
401 RangeMinimumQueryExprElement::RangeMinimumQueryExprElement(
402 Solver* solver,
const std::vector<int64_t>& values, IntVar*
index)
403 : BaseIntExpr(solver), index_(
index), min_rmq_(values), max_rmq_(values) {
404 CHECK(solver !=
nullptr);
408 int64_t RangeMinimumQueryExprElement::Min()
const {
409 return min_rmq_.GetMinimumFromRange(IndexMin(), IndexMax() + 1);
412 int64_t RangeMinimumQueryExprElement::Max()
const {
413 return max_rmq_.GetMinimumFromRange(IndexMin(), IndexMax() + 1);
416 void RangeMinimumQueryExprElement::Range(int64_t* mi, int64_t* ma) {
417 const int64_t range_min = IndexMin();
418 const int64_t range_max = IndexMax() + 1;
419 *mi = min_rmq_.GetMinimumFromRange(range_min, range_max);
420 *ma = max_rmq_.GetMinimumFromRange(range_min, range_max);
423 #define UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(test) \
424 const std::vector<int64_t>& values = min_rmq_.array(); \
425 int64_t index_min = IndexMin(); \
426 int64_t index_max = IndexMax(); \
427 int64_t value = values[index_min]; \
428 while (index_min < index_max && (test)) { \
430 value = values[index_min]; \
432 if (index_min == index_max && (test)) { \
435 value = values[index_max]; \
436 while (index_max >= index_min && (test)) { \
438 value = values[index_max]; \
440 index_->SetRange(index_min, index_max);
442 void RangeMinimumQueryExprElement::SetMin(int64_t m) {
446 void RangeMinimumQueryExprElement::SetMax(int64_t m) {
450 void RangeMinimumQueryExprElement::SetRange(int64_t mi, int64_t ma) {
457 #undef UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS
461 class IncreasingIntExprElement :
public BaseIntExpr {
463 IncreasingIntExprElement(Solver*
const s,
const std::vector<int64_t>& values,
464 IntVar*
const index);
465 ~IncreasingIntExprElement()
override {}
467 int64_t Min()
const override;
468 void SetMin(int64_t m)
override;
469 int64_t Max()
const override;
470 void SetMax(int64_t m)
override;
471 void SetRange(int64_t mi, int64_t ma)
override;
472 bool Bound()
const override {
return (index_->Bound()); }
474 std::string
name()
const override {
475 return absl::StrFormat(
"IntElement(%s, %s)", absl::StrJoin(values_,
", "),
478 std::string DebugString()
const override {
479 return absl::StrFormat(
"IntElement(%s, %s)", absl::StrJoin(values_,
", "),
480 index_->DebugString());
483 void Accept(ModelVisitor*
const visitor)
const override {
484 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement,
this);
485 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
486 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
488 visitor->EndVisitIntegerExpression(ModelVisitor::kElement,
this);
491 void WhenRange(Demon* d)
override { index_->WhenRange(d); }
493 IntVar* CastToVar()
override {
494 Solver*
const s = solver();
495 IntVar*
const var = s->MakeIntVar(values_);
501 const std::vector<int64_t> values_;
502 IntVar*
const index_;
505 IncreasingIntExprElement::IncreasingIntExprElement(
506 Solver*
const s,
const std::vector<int64_t>& values, IntVar*
const index)
507 : BaseIntExpr(s), values_(values), index_(
index) {
512 int64_t IncreasingIntExprElement::Min()
const {
513 const int64_t expression_min = std::max<int64_t>(0, index_->Min());
514 return (expression_min < values_.size()
515 ? values_[expression_min]
519 void IncreasingIntExprElement::SetMin(int64_t m) {
520 const int64_t index_min = std::max<int64_t>(0, index_->Min());
521 const int64_t index_max =
522 std::min<int64_t>(values_.size() - 1, index_->Max());
524 if (index_min > index_max || m > values_[index_max]) {
528 const std::vector<int64_t>::const_iterator first =
530 const int64_t new_index_min = first - values_.begin();
531 index_->SetMin(new_index_min);
534 int64_t IncreasingIntExprElement::Max()
const {
535 const int64_t expression_max =
536 std::min<int64_t>(values_.size() - 1, index_->Max());
537 return (expression_max >= 0 ? values_[expression_max]
541 void IncreasingIntExprElement::SetMax(int64_t m) {
542 int64_t index_min = std::max<int64_t>(0, index_->Min());
543 if (m < values_[index_min]) {
547 const std::vector<int64_t>::const_iterator last_after =
549 const int64_t new_index_max = (last_after - values_.begin()) - 1;
550 index_->SetRange(0, new_index_max);
553 void IncreasingIntExprElement::SetRange(int64_t mi, int64_t ma) {
557 const int64_t index_min = std::max<int64_t>(0, index_->Min());
558 const int64_t index_max =
559 std::min<int64_t>(values_.size() - 1, index_->Max());
561 if (mi > ma || ma < values_[index_min] || mi > values_[index_max]) {
565 const std::vector<int64_t>::const_iterator first =
567 const int64_t new_index_min = first - values_.begin();
569 const std::vector<int64_t>::const_iterator last_after =
571 const int64_t new_index_max = (last_after - values_.begin()) - 1;
574 index_->SetRange(new_index_min, new_index_max);
578 IntExpr* BuildElement(Solver*
const solver,
const std::vector<int64_t>& values,
579 IntVar*
const index) {
583 solver->AddConstraint(solver->MakeBetweenCt(
index, 0, values.size() - 1));
584 return solver->MakeIntConst(values[0]);
589 std::vector<int64_t> ones;
591 for (
int i = 0; i < values.size(); ++i) {
592 if (values[i] == 1) {
598 if (ones.size() == 1) {
599 DCHECK_EQ(int64_t{1}, values[ones.back()]);
600 solver->AddConstraint(solver->MakeBetweenCt(
index, 0, values.size() - 1));
601 return solver->MakeIsEqualCstVar(
index, ones.back());
602 }
else if (ones.size() == values.size() - 1) {
603 solver->AddConstraint(solver->MakeBetweenCt(
index, 0, values.size() - 1));
604 return solver->MakeIsDifferentCstVar(
index, first_zero);
605 }
else if (ones.size() == ones.back() - ones.front() + 1) {
606 solver->AddConstraint(solver->MakeBetweenCt(
index, 0, values.size() - 1));
607 IntVar*
const b = solver->MakeBoolVar(
"ContiguousBooleanElementVar");
608 solver->AddConstraint(
609 solver->MakeIsBetweenCt(
index, ones.front(), ones.back(),
b));
612 IntVar*
const b = solver->MakeBoolVar(
"NonContiguousBooleanElementVar");
613 solver->AddConstraint(solver->MakeBetweenCt(
index, 0, values.size() - 1));
614 solver->AddConstraint(solver->MakeIsMemberCt(
index, ones,
b));
618 IntExpr* cache =
nullptr;
619 if (!absl::GetFlag(FLAGS_cp_disable_element_cache)) {
620 cache = solver->Cache()->FindVarConstantArrayExpression(
621 index, values, ModelCache::VAR_CONSTANT_ARRAY_ELEMENT);
623 if (cache !=
nullptr) {
626 IntExpr* result =
nullptr;
627 if (values.size() >= 2 &&
index->Min() == 0 &&
index->Max() == 1) {
628 result = solver->MakeSum(solver->MakeProd(
index, values[1] - values[0]),
630 }
else if (values.size() == 2 &&
index->Contains(0) &&
index->Contains(1)) {
631 solver->AddConstraint(solver->MakeBetweenCt(
index, 0, 1));
632 result = solver->MakeSum(solver->MakeProd(
index, values[1] - values[0]),
635 result = solver->MakeSum(
index, values[0]);
637 result = solver->RegisterIntExpr(solver->RevAlloc(
638 new IncreasingIntExprElement(solver, values,
index)));
640 if (solver->parameters().use_element_rmq()) {
641 result = solver->RegisterIntExpr(solver->RevAlloc(
642 new RangeMinimumQueryExprElement(solver, values,
index)));
644 result = solver->RegisterIntExpr(
645 solver->RevAlloc(
new IntExprElement(solver, values,
index)));
648 if (!absl::GetFlag(FLAGS_cp_disable_element_cache)) {
649 solver->Cache()->InsertVarConstantArrayExpression(
650 result,
index, values, ModelCache::VAR_CONSTANT_ARRAY_ELEMENT);
657 IntExpr* Solver::MakeElement(
const std::vector<int64_t>& values,
661 if (
index->Bound()) {
662 return MakeIntConst(values[
index->Min()]);
664 return BuildElement(
this, values,
index);
667 IntExpr* Solver::MakeElement(
const std::vector<int>& values,
671 if (
index->Bound()) {
672 return MakeIntConst(values[
index->Min()]);
680 class IntExprFunctionElement :
public BaseIntExprElement {
684 ~IntExprFunctionElement()
override;
686 std::string
name()
const override {
687 return absl::StrFormat(
"IntFunctionElement(%s)",
expr_->name());
690 std::string DebugString()
const override {
691 return absl::StrFormat(
"IntFunctionElement(%s)",
expr_->DebugString());
694 void Accept(ModelVisitor*
const visitor)
const override {
696 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement,
this);
697 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
699 visitor->VisitInt64ToInt64Extension(values_,
expr_->Min(),
expr_->Max());
700 visitor->EndVisitIntegerExpression(ModelVisitor::kElement,
this);
704 int64_t ElementValue(
int index)
const override {
return values_(
index); }
705 int64_t ExprMin()
const override {
return expr_->Min(); }
706 int64_t ExprMax()
const override {
return expr_->Max(); }
709 Solver::IndexEvaluator1 values_;
712 IntExprFunctionElement::IntExprFunctionElement(Solver*
const s,
713 Solver::IndexEvaluator1 values,
715 : BaseIntExprElement(s, e), values_(std::move(values)) {
716 CHECK(values_ !=
nullptr);
719 IntExprFunctionElement::~IntExprFunctionElement() {}
723 class IncreasingIntExprFunctionElement :
public BaseIntExpr {
725 IncreasingIntExprFunctionElement(Solver*
const s,
728 : BaseIntExpr(s), values_(std::move(values)), index_(
index) {
729 DCHECK(values_ !=
nullptr);
734 ~IncreasingIntExprFunctionElement()
override {}
736 int64_t Min()
const override {
return values_(index_->Min()); }
738 void SetMin(int64_t m)
override {
739 const int64_t index_min = index_->Min();
740 const int64_t index_max = index_->Max();
741 if (m > values_(index_max)) {
744 const int64_t new_index_min = FindNewIndexMin(index_min, index_max, m);
745 index_->SetMin(new_index_min);
748 int64_t Max()
const override {
return values_(index_->Max()); }
750 void SetMax(int64_t m)
override {
751 int64_t index_min = index_->Min();
752 int64_t index_max = index_->Max();
753 if (m < values_(index_min)) {
756 const int64_t new_index_max = FindNewIndexMax(index_min, index_max, m);
757 index_->SetMax(new_index_max);
760 void SetRange(int64_t mi, int64_t ma)
override {
761 const int64_t index_min = index_->Min();
762 const int64_t index_max = index_->Max();
763 const int64_t value_min = values_(index_min);
764 const int64_t value_max = values_(index_max);
765 if (mi > ma || ma < value_min || mi > value_max) {
768 if (mi <= value_min && ma >= value_max) {
773 const int64_t new_index_min = FindNewIndexMin(index_min, index_max, mi);
774 const int64_t new_index_max = FindNewIndexMax(new_index_min, index_max, ma);
776 index_->SetRange(new_index_min, new_index_max);
779 std::string
name()
const override {
780 return absl::StrFormat(
"IncreasingIntExprFunctionElement(values, %s)",
784 std::string DebugString()
const override {
785 return absl::StrFormat(
"IncreasingIntExprFunctionElement(values, %s)",
786 index_->DebugString());
789 void WhenRange(Demon* d)
override { index_->WhenRange(d); }
791 void Accept(ModelVisitor*
const visitor)
const override {
796 if (index_->Min() == 0) {
800 visitor->VisitInt64ToInt64Extension(values_, index_->Min(),
807 int64_t FindNewIndexMin(int64_t index_min, int64_t index_max, int64_t m) {
808 if (m <= values_(index_min)) {
815 int64_t index_lower_bound = index_min;
816 int64_t index_upper_bound = index_max;
817 while (index_upper_bound - index_lower_bound > 1) {
818 DCHECK_LT(values_(index_lower_bound), m);
819 DCHECK_GE(values_(index_upper_bound), m);
820 const int64_t pivot = (index_lower_bound + index_upper_bound) / 2;
821 const int64_t pivot_value = values_(pivot);
822 if (pivot_value < m) {
823 index_lower_bound = pivot;
825 index_upper_bound = pivot;
828 DCHECK(values_(index_upper_bound) >= m);
829 return index_upper_bound;
832 int64_t FindNewIndexMax(int64_t index_min, int64_t index_max, int64_t m) {
833 if (m >= values_(index_max)) {
840 int64_t index_lower_bound = index_min;
841 int64_t index_upper_bound = index_max;
842 while (index_upper_bound - index_lower_bound > 1) {
843 DCHECK_LE(values_(index_lower_bound), m);
844 DCHECK_GT(values_(index_upper_bound), m);
845 const int64_t pivot = (index_lower_bound + index_upper_bound) / 2;
846 const int64_t pivot_value = values_(pivot);
847 if (pivot_value > m) {
848 index_upper_bound = pivot;
850 index_lower_bound = pivot;
853 DCHECK(values_(index_lower_bound) <= m);
854 return index_lower_bound;
858 IntVar*
const index_;
866 RevAlloc(
new IntExprFunctionElement(
this, std::move(values),
index)));
874 RevAlloc(
new IncreasingIntExprFunctionElement(
this, values,
index)));
882 new IncreasingIntExprFunctionElement(
this, opposite_values,
index))));
889 class IntIntExprFunctionElement :
public BaseIntExpr {
893 ~IntIntExprFunctionElement()
override;
894 std::string DebugString()
const override {
895 return absl::StrFormat(
"IntIntFunctionElement(%s,%s)",
896 expr1_->DebugString(), expr2_->DebugString());
898 int64_t Min()
const override;
899 int64_t Max()
const override;
904 bool Bound()
const override {
return expr1_->Bound() && expr2_->Bound(); }
906 void WhenRange(Demon* d)
override {
907 expr1_->WhenRange(d);
908 expr2_->WhenRange(d);
911 void Accept(ModelVisitor*
const visitor)
const override {
918 const int64_t expr1_min = expr1_->Min();
919 const int64_t expr1_max = expr1_->Max();
922 for (
int i = expr1_min; i <= expr1_max; ++i) {
923 visitor->VisitInt64ToInt64Extension(
924 [
this, i](int64_t j) {
return values_(i, j); }, expr2_->Min(),
931 int64_t ElementValue(
int index1,
int index2)
const {
932 return values_(index1, index2);
934 void UpdateSupports()
const;
936 IntVar*
const expr1_;
937 IntVar*
const expr2_;
938 mutable int64_t min_;
939 mutable int min_support1_;
940 mutable int min_support2_;
941 mutable int64_t max_;
942 mutable int max_support1_;
943 mutable int max_support2_;
944 mutable bool initial_update_;
946 IntVarIterator*
const expr1_iterator_;
947 IntVarIterator*
const expr2_iterator_;
950 IntIntExprFunctionElement::IntIntExprFunctionElement(
962 initial_update_(true),
963 values_(std::move(values)),
964 expr1_iterator_(expr1_->MakeDomainIterator(true)),
965 expr2_iterator_(expr2_->MakeDomainIterator(true)) {
966 CHECK(values_ !=
nullptr);
969 IntIntExprFunctionElement::~IntIntExprFunctionElement() {}
971 int64_t IntIntExprFunctionElement::Min()
const {
976 int64_t IntIntExprFunctionElement::Max()
const {
981 void IntIntExprFunctionElement::Range(int64_t*
lower_bound,
988 #define UPDATE_ELEMENT_INDEX_BOUNDS(test) \
989 const int64_t emin1 = expr1_->Min(); \
990 const int64_t emax1 = expr1_->Max(); \
991 const int64_t emin2 = expr2_->Min(); \
992 const int64_t emax2 = expr2_->Max(); \
993 int64_t nmin1 = emin1; \
994 bool found = false; \
995 while (nmin1 <= emax1 && !found) { \
996 for (int i = emin2; i <= emax2; ++i) { \
997 int64_t value = ElementValue(nmin1, i); \
1007 if (nmin1 > emax1) { \
1010 int64_t nmin2 = emin2; \
1012 while (nmin2 <= emax2 && !found) { \
1013 for (int i = emin1; i <= emax1; ++i) { \
1014 int64_t value = ElementValue(i, nmin2); \
1024 if (nmin2 > emax2) { \
1027 int64_t nmax1 = emax1; \
1029 while (nmax1 >= nmin1 && !found) { \
1030 for (int i = emin2; i <= emax2; ++i) { \
1031 int64_t value = ElementValue(nmax1, i); \
1041 int64_t nmax2 = emax2; \
1043 while (nmax2 >= nmin2 && !found) { \
1044 for (int i = emin1; i <= emax1; ++i) { \
1045 int64_t value = ElementValue(i, nmax2); \
1055 expr1_->SetRange(nmin1, nmax1); \
1056 expr2_->SetRange(nmin2, nmax2);
1058 void IntIntExprFunctionElement::SetMin(int64_t
lower_bound) {
1062 void IntIntExprFunctionElement::SetMax(int64_t
upper_bound) {
1066 void IntIntExprFunctionElement::SetRange(int64_t
lower_bound,
1074 #undef UPDATE_ELEMENT_INDEX_BOUNDS
1076 void IntIntExprFunctionElement::UpdateSupports()
const {
1077 if (initial_update_ || !expr1_->
Contains(min_support1_) ||
1079 !expr2_->
Contains(max_support2_)) {
1080 const int64_t emax1 = expr1_->
Max();
1081 const int64_t emax2 = expr2_->
Max();
1082 int64_t min_value = ElementValue(emax1, emax2);
1083 int64_t max_value = min_value;
1084 int min_support1 = emax1;
1085 int max_support1 = emax1;
1086 int min_support2 = emax2;
1087 int max_support2 = emax2;
1088 for (
const int64_t index1 : InitAndGetValues(expr1_iterator_)) {
1089 for (
const int64_t index2 : InitAndGetValues(expr2_iterator_)) {
1090 const int64_t
value = ElementValue(index1, index2);
1091 if (
value > max_value) {
1093 max_support1 = index1;
1094 max_support2 = index2;
1095 }
else if (
value < min_value) {
1097 min_support1 = index1;
1098 min_support2 = index2;
1102 Solver* s = solver();
1103 s->SaveAndSetValue(&min_, min_value);
1104 s->SaveAndSetValue(&min_support1_, min_support1);
1105 s->SaveAndSetValue(&min_support2_, min_support2);
1106 s->SaveAndSetValue(&max_, max_value);
1107 s->SaveAndSetValue(&max_support1_, max_support1);
1108 s->SaveAndSetValue(&max_support2_, max_support2);
1109 s->SaveAndSetValue(&initial_update_,
false);
1119 new IntIntExprFunctionElement(
this, std::move(values), index1, index2)));
1131 condition_(condition),
1151 if (condition_->
Max() == 0) {
1152 zero_->
SetRange(target_var_min, target_var_max);
1153 zero_->
Range(&new_min, &new_max);
1154 }
else if (condition_->
Min() == 1) {
1155 one_->
SetRange(target_var_min, target_var_max);
1156 one_->
Range(&new_min, &new_max);
1158 if (target_var_max < zero_->Min() || target_var_min > zero_->
Max()) {
1160 one_->
SetRange(target_var_min, target_var_max);
1161 one_->
Range(&new_min, &new_max);
1162 }
else if (target_var_max < one_->Min() || target_var_min > one_->
Max()) {
1164 zero_->
SetRange(target_var_min, target_var_max);
1165 zero_->
Range(&new_min, &new_max);
1171 zero_->
Range(&zl, &zu);
1172 one_->
Range(&ol, &ou);
1181 return absl::StrFormat(
"(%s ? %s : %s) == %s", condition_->
DebugString(),
1189 IntVar*
const condition_;
1201 class IntExprEvaluatorElementCt :
public CastConstraint {
1204 int64_t range_start, int64_t range_end,
1205 IntVar*
const index, IntVar*
const target_var);
1206 ~IntExprEvaluatorElementCt()
override {}
1208 void Post()
override;
1209 void InitialPropagate()
override;
1212 void Update(
int index);
1215 std::string DebugString()
const override;
1216 void Accept(ModelVisitor*
const visitor)
const override;
1219 IntVar*
const index_;
1223 const int64_t range_start_;
1224 const int64_t range_end_;
1229 IntExprEvaluatorElementCt::IntExprEvaluatorElementCt(
1231 int64_t range_end, IntVar*
const index, IntVar*
const target_var)
1232 : CastConstraint(s, target_var),
1235 range_start_(range_start),
1236 range_end_(range_end),
1240 void IntExprEvaluatorElementCt::Post() {
1242 solver(),
this, &IntExprEvaluatorElementCt::Propagate,
"Propagate");
1243 for (
int i = range_start_; i < range_end_; ++i) {
1245 current_var->WhenRange(delayed_propagate_demon);
1247 solver(),
this, &IntExprEvaluatorElementCt::Update,
"Update", i);
1248 current_var->WhenRange(update_demon);
1250 index_->
WhenRange(delayed_propagate_demon);
1252 solver(),
this, &IntExprEvaluatorElementCt::UpdateExpr,
"UpdateExpr");
1255 solver(),
this, &IntExprEvaluatorElementCt::Propagate,
"UpdateVar");
1260 void IntExprEvaluatorElementCt::InitialPropagate() { Propagate(); }
1262 void IntExprEvaluatorElementCt::Propagate() {
1263 const int64_t emin =
std::max(range_start_, index_->
Min());
1264 const int64_t emax = std::min<int64_t>(range_end_ - 1, index_->
Max());
1271 int64_t nmin = emin;
1272 for (; nmin <= emax; nmin++) {
1277 if (nmin_var->Min() <= vmax && nmin_var->Max() >= vmin)
break;
1279 int64_t nmax = emax;
1280 for (; nmin <= nmax; nmax--) {
1285 if (nmax_var->Min() <= vmax && nmax_var->Max() >= vmin)
break;
1292 if (min_support_ == -1 || max_support_ == -1) {
1293 int min_support = -1;
1294 int max_support = -1;
1297 for (
int i = index_->
Min(); i <= index_->Max(); ++i) {
1299 const int64_t vmin = var_i->Min();
1303 const int64_t vmax = var_i->Max();
1308 solver()->SaveAndSetValue(&min_support_, min_support);
1309 solver()->SaveAndSetValue(&max_support_, max_support);
1314 void IntExprEvaluatorElementCt::Update(
int index) {
1315 if (
index == min_support_ ||
index == max_support_) {
1316 solver()->SaveAndSetValue(&min_support_, -1);
1317 solver()->SaveAndSetValue(&max_support_, -1);
1321 void IntExprEvaluatorElementCt::UpdateExpr() {
1323 solver()->SaveAndSetValue(&min_support_, -1);
1324 solver()->SaveAndSetValue(&max_support_, -1);
1330 int64_t range_start, int64_t range_end) {
1332 for (int64_t i = range_start; i < range_end; ++i) {
1333 if (i != range_start) {
1336 out += absl::StrFormat(
"%d -> %s", i, evaluator(i)->DebugString());
1342 int64_t range_begin, int64_t range_end) {
1344 if (range_end - range_begin > 10) {
1345 out = absl::StrFormat(
1346 "IntToIntVar(%s, ...%s)",
1347 StringifyEvaluatorBare(evaluator, range_begin, range_begin + 5),
1348 StringifyEvaluatorBare(evaluator, range_end - 5, range_end));
1350 out = absl::StrFormat(
1352 StringifyEvaluatorBare(evaluator, range_begin, range_end));
1358 std::string IntExprEvaluatorElementCt::DebugString()
const {
1359 return StringifyInt64ToIntVar(
evaluator_, range_start_, range_end_);
1362 void IntExprEvaluatorElementCt::Accept(ModelVisitor*
const visitor)
const {
1364 visitor->VisitIntegerVariableEvaluatorArgument(
1377 class IntExprArrayElementCt :
public IntExprEvaluatorElementCt {
1379 IntExprArrayElementCt(Solver*
const s, std::vector<IntVar*> vars,
1380 IntVar*
const index, IntVar*
const target_var);
1382 std::string DebugString()
const override;
1383 void Accept(ModelVisitor*
const visitor)
const override;
1386 const std::vector<IntVar*>
vars_;
1389 IntExprArrayElementCt::IntExprArrayElementCt(Solver*
const s,
1390 std::vector<IntVar*> vars,
1391 IntVar*
const index,
1392 IntVar*
const target_var)
1393 : IntExprEvaluatorElementCt(
1394 s, [this](int64_t idx) {
return vars_[idx]; }, 0, vars.size(),
index,
1396 vars_(std::move(vars)) {}
1398 std::string IntExprArrayElementCt::DebugString()
const {
1399 int64_t size =
vars_.size();
1401 return absl::StrFormat(
1402 "IntExprArrayElement(var array of size %d, %s) == %s", size,
1405 return absl::StrFormat(
"IntExprArrayElement([%s], %s) == %s",
1411 void IntExprArrayElementCt::Accept(ModelVisitor*
const visitor)
const {
1425 class IntExprArrayElementCstCt :
public Constraint {
1427 IntExprArrayElementCstCt(Solver*
const s,
const std::vector<IntVar*>& vars,
1428 IntVar*
const index, int64_t target)
1433 demons_(vars.size()) {}
1435 ~IntExprArrayElementCstCt()
override {}
1437 void Post()
override {
1438 for (
int i = 0; i <
vars_.size(); ++i) {
1440 solver(),
this, &IntExprArrayElementCstCt::Propagate,
"Propagate", i);
1441 vars_[i]->WhenDomain(demons_[i]);
1444 solver(),
this, &IntExprArrayElementCstCt::PropagateIndex,
1446 index_->WhenBound(index_demon);
1449 void InitialPropagate()
override {
1450 for (
int i = 0; i <
vars_.size(); ++i) {
1456 void Propagate(
int index) {
1457 if (!vars_[
index]->Contains(target_)) {
1458 index_->RemoveValue(
index);
1459 demons_[
index]->inhibit(solver());
1463 void PropagateIndex() {
1464 if (index_->Bound()) {
1465 vars_[index_->Min()]->SetValue(target_);
1469 std::string DebugString()
const override {
1470 return absl::StrFormat(
"IntExprArrayElement([%s], %s) == %d",
1472 index_->DebugString(), target_);
1475 void Accept(ModelVisitor*
const visitor)
const override {
1486 const std::vector<IntVar*>
vars_;
1487 IntVar*
const index_;
1488 const int64_t target_;
1489 std::vector<Demon*> demons_;
1494 class IntExprIndexOfCt :
public Constraint {
1496 IntExprIndexOfCt(Solver*
const s,
const std::vector<IntVar*>& vars,
1497 IntVar*
const index, int64_t target)
1502 demons_(
vars_.size()),
1503 index_iterator_(
index->MakeHoleIterator(true)) {}
1505 ~IntExprIndexOfCt()
override {}
1507 void Post()
override {
1508 for (
int i = 0; i <
vars_.size(); ++i) {
1510 solver(),
this, &IntExprIndexOfCt::Propagate,
"Propagate", i);
1511 vars_[i]->WhenDomain(demons_[i]);
1514 solver(),
this, &IntExprIndexOfCt::PropagateIndex,
"PropagateIndex");
1515 index_->WhenDomain(index_demon);
1518 void InitialPropagate()
override {
1519 for (
int i = 0; i <
vars_.size(); ++i) {
1520 if (!index_->Contains(i)) {
1521 vars_[i]->RemoveValue(target_);
1522 }
else if (!vars_[i]->Contains(target_)) {
1523 index_->RemoveValue(i);
1524 demons_[i]->inhibit(solver());
1525 }
else if (vars_[i]->Bound()) {
1526 index_->SetValue(i);
1527 demons_[i]->inhibit(solver());
1532 void Propagate(
int index) {
1533 if (!vars_[
index]->Contains(target_)) {
1534 index_->RemoveValue(
index);
1535 demons_[
index]->inhibit(solver());
1536 }
else if (vars_[
index]->Bound()) {
1537 index_->SetValue(
index);
1541 void PropagateIndex() {
1542 const int64_t oldmax = index_->OldMax();
1543 const int64_t vmin = index_->Min();
1544 const int64_t vmax = index_->Max();
1547 demons_[
value]->inhibit(solver());
1549 for (
const int64_t
value : InitAndGetValues(index_iterator_)) {
1551 demons_[
value]->inhibit(solver());
1555 demons_[
value]->inhibit(solver());
1557 if (index_->Bound()) {
1558 vars_[index_->Min()]->SetValue(target_);
1562 std::string DebugString()
const override {
1563 return absl::StrFormat(
"IntExprIndexOf([%s], %s) == %d",
1565 index_->DebugString(), target_);
1568 void Accept(ModelVisitor*
const visitor)
const override {
1579 const std::vector<IntVar*>
vars_;
1580 IntVar*
const index_;
1581 const int64_t target_;
1582 std::vector<Demon*> demons_;
1583 IntVarIterator*
const index_iterator_;
1588 Constraint* MakeElementEqualityFunc(Solver*
const solver,
1589 const std::vector<int64_t>& vals,
1590 IntVar*
const index, IntVar*
const target) {
1591 if (
index->Bound()) {
1592 const int64_t val =
index->Min();
1593 if (val < 0 || val >= vals.size()) {
1594 return solver->MakeFalseConstraint();
1596 return solver->MakeEquality(target, vals[val]);
1600 return solver->MakeEquality(target, solver->MakeSum(
index, vals[0]));
1602 return solver->RevAlloc(
1603 new IntElementConstraint(solver, vals,
index, target));
1612 IntVar*
const target_var) {
1614 new IfThenElseCt(
this, condition, then_expr, else_expr, target_var));
1619 if (
index->Bound()) {
1620 return vars[
index->Min()];
1622 const int size = vars.size();
1624 std::vector<int64_t> values(size);
1625 for (
int i = 0; i < size; ++i) {
1626 values[i] = vars[i]->Value();
1631 index->Min() >= 0 &&
index->Max() < vars.size()) {
1636 const std::string
name = absl::StrFormat(
1646 std::unique_ptr<IntVarIterator> iterator(
index->MakeDomainIterator(
false));
1648 if (index_value >= 0 && index_value < size) {
1649 emin =
std::min(emin, vars[index_value]->Min());
1650 emax =
std::max(emax, vars[index_value]->Max());
1653 const std::string vname =
1654 size > 10 ? absl::StrFormat(
"ElementVar(var array of size %d, %s)", size,
1655 index->DebugString())
1656 : absl::StrFormat(
"ElementVar([%s], %s)",
1660 RevAlloc(
new IntExprArrayElementCt(
this, vars,
index, element_var)));
1665 int64_t range_end,
IntVar* argument) {
1666 const std::string index_name =
1668 const std::string vname = absl::StrFormat(
1669 "ElementVar(%s, %s)",
1670 StringifyInt64ToIntVar(vars, range_start, range_end), index_name);
1671 IntVar*
const element_var =
1674 IntExprEvaluatorElementCt* evaluation_ct =
new IntExprEvaluatorElementCt(
1675 this, std::move(vars), range_start, range_end, argument, element_var);
1677 evaluation_ct->Propagate();
1684 return MakeElementEqualityFunc(
this, vals,
index, target);
1697 std::vector<int64_t> values(vars.size());
1698 for (
int i = 0; i < vars.size(); ++i) {
1699 values[i] = vars[i]->Value();
1703 if (
index->Bound()) {
1704 const int64_t val =
index->Min();
1705 if (val < 0 || val >= vars.size()) {
1711 if (target->
Bound()) {
1713 new IntExprArrayElementCstCt(
this, vars,
index, target->
Min()));
1715 return RevAlloc(
new IntExprArrayElementCt(
this, vars,
index, target));
1723 std::vector<int> valid_indices;
1724 for (
int i = 0; i < vars.size(); ++i) {
1725 if (vars[i]->
Value() == target) {
1726 valid_indices.push_back(i);
1731 if (
index->Bound()) {
1732 const int64_t pos =
index->Min();
1733 if (pos >= 0 && pos < vars.size()) {
1740 return RevAlloc(
new IntExprArrayElementCstCt(
this, vars,
index, target));
1746 if (
index->Bound()) {
1747 const int64_t pos =
index->Min();
1748 if (pos >= 0 && pos < vars.size()) {
1755 return RevAlloc(
new IntExprIndexOfCt(
this, vars,
index, target));
1761 IntExpr*
const cache = model_cache_->FindVarArrayConstantExpression(
1763 if (cache !=
nullptr) {
1764 return cache->
Var();
1766 const std::string
name =
1770 model_cache_->InsertVarArrayConstantExpression(
const std::vector< IntVar * > vars_
#define DCHECK_LE(val1, val2)
#define CHECK_EQ(val1, val2)
#define DCHECK_GE(val1, val2)
#define DCHECK_GT(val1, val2)
#define DCHECK_LT(val1, val2)
#define DCHECK(condition)
#define DCHECK_EQ(val1, val2)
Cast constraints are special channeling constraints designed to keep a variable in sync with an expre...
IntVar *const target_var_
A constraint is the main modeling object.
A Demon is the base element of a propagation queue.
void Post() override
This method is called when the constraint is processed by the solver.
void InitialPropagate() override
This method performs the initial propagation of the constraint.
IfThenElseCt(Solver *const solver, IntVar *const condition, IntExpr *const one, IntExpr *const zero, IntVar *const target)
void Accept(ModelVisitor *const visitor) const override
Accepts the given visitor.
std::string DebugString() const override
Utility class to encapsulate an IntVarIterator and use it in a range-based loop.
The class IntExpr is the base of all integer expressions in constraint programming.
virtual IntVar * Var()=0
Creates a variable from the expression.
virtual void SetRange(int64_t l, int64_t u)
This method sets both the min and the max of the expression.
virtual bool Bound() const
Returns true if the min and the max of the expression are equal.
virtual void SetValue(int64_t v)
This method sets the value of the expression.
virtual int64_t Min() const =0
virtual int64_t Max() const =0
virtual void Range(int64_t *l, int64_t *u)
By default calls Min() and Max(), but can be redefined when Min and Max code can be factorized.
virtual void WhenRange(Demon *d)=0
Attach a demon that will watch the min or the max of the expression.
The class IntVar is a subset of IntExpr.
virtual bool Contains(int64_t v) const =0
This method returns whether the value 'v' is in the domain of the variable.
virtual void WhenBound(Demon *d)=0
This method attaches a demon that will be awakened when the variable is bound.
@ VAR_ARRAY_CONSTANT_INDEX
static const char kIndex2Argument[]
static const char kMinArgument[]
static const char kElementEqual[]
static const char kTargetArgument[]
static const char kMaxArgument[]
static const char kEvaluatorArgument[]
static const char kVarsArgument[]
static const char kIndexOf[]
static const char kElement[]
static const char kValuesArgument[]
static const char kIndexArgument[]
virtual std::string name() const
Object naming.
std::string DebugString() const override
IntExpr * MakeIndexExpression(const std::vector< IntVar * > &vars, int64_t value)
Returns the expression expr such that vars[expr] == value.
IntExpr * RegisterIntExpr(IntExpr *const expr)
Registers a new IntExpr and wraps it inside a TraceIntExpr if necessary.
Constraint * MakeFalseConstraint()
This constraint always fails.
Constraint * MakeEquality(IntExpr *const left, IntExpr *const right)
left == right
Constraint * MakeElementEquality(const std::vector< int64_t > &vals, IntVar *const index, IntVar *const target)
IntVar * MakeIntVar(int64_t min, int64_t max, const std::string &name)
MakeIntVar will create the best range based int var for the bounds given.
Constraint * MakeMemberCt(IntExpr *const expr, const std::vector< int64_t > &values)
expr in set.
std::function< int64_t(int64_t, int64_t)> IndexEvaluator2
void AddConstraint(Constraint *const c)
Adds the constraint 'c' to the model.
IntExpr * MakeOpposite(IntExpr *const expr)
-expr
Constraint * MakeIfThenElseCt(IntVar *const condition, IntExpr *const then_expr, IntExpr *const else_expr, IntVar *const target_var)
Special cases with arrays of size two.
Demon * MakeConstraintInitialPropagateCallback(Constraint *const ct)
This method is a specialized case of the MakeConstraintDemon method to call the InitiatePropagate of ...
Constraint * MakeIndexOfConstraint(const std::vector< IntVar * > &vars, IntVar *const index, int64_t target)
This constraint is a special case of the element constraint with an array of integer variables,...
IntExpr * MakeElement(const std::vector< int64_t > &values, IntVar *const index)
values[index]
T * RevAlloc(T *object)
Registers the given object as being reversible.
IntExpr * MakeSum(IntExpr *const left, IntExpr *const right)
left + right.
std::function< int64_t(int64_t)> IndexEvaluator1
Callback typedefs.
std::function< IntVar *(int64_t)> Int64ToIntVar
IntExpr * MakeMonotonicElement(IndexEvaluator1 values, bool increasing, IntVar *const index)
Function based element.
std::vector< int64_t > to_remove_
#define UPDATE_ELEMENT_INDEX_BOUNDS(test)
ABSL_FLAG(bool, cp_disable_element_cache, true, "If true, caching for IntElement is disabled.")
#define UPDATE_BASE_ELEMENT_INDEX_BOUNDS(test)
#define UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(test)
std::function< int64_t(const Model &)> Value(IntegerVariable v)
Collection of objects used to extend the Constraint Solver library.
bool IsArrayConstant(const std::vector< T > &values, const T &value)
bool IsIncreasing(const std::vector< T > &values)
Demon * MakeConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
bool IsArrayBoolean(const std::vector< T > &values)
Demon * MakeConstraintDemon1(Solver *const s, T *const ct, void(T::*method)(P), const std::string &name, P param1)
Demon * MakeDelayedConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
std::string JoinDebugStringPtr(const std::vector< T > &v, const std::string &separator)
bool IsIncreasingContiguous(const std::vector< T > &values)
std::vector< int64_t > ToInt64Vector(const std::vector< int > &input)
void LinkVarExpr(Solver *const s, IntExpr *const expr, IntVar *const var)
bool AreAllBound(const std::vector< IntVar * > &vars)
std::string JoinNamePtr(const std::vector< T > &v, const std::string &separator)
IntervalVar *const target_var_
std::function< int64_t(int64_t, int64_t)> evaluator_