OR-Tools  9.1
element.cc
Go to the documentation of this file.
1// Copyright 2010-2021 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14#include <algorithm>
15#include <cstdint>
16#include <limits>
17#include <memory>
18#include <numeric>
19#include <string>
20#include <utility>
21#include <vector>
22
23#include "absl/strings/str_format.h"
24#include "absl/strings/str_join.h"
31
32ABSL_FLAG(bool, cp_disable_element_cache, true,
33 "If true, caching for IntElement is disabled.");
34
35namespace operations_research {
36
37// ----- IntExprElement -----
38void LinkVarExpr(Solver* const s, IntExpr* const expr, IntVar* const var);
39
40namespace {
41
42template <class T>
43class VectorLess {
44 public:
45 explicit VectorLess(const std::vector<T>* values) : values_(values) {}
46 bool operator()(const T& x, const T& y) const {
47 return (*values_)[x] < (*values_)[y];
48 }
49
50 private:
51 const std::vector<T>* values_;
52};
53
54template <class T>
55class VectorGreater {
56 public:
57 explicit VectorGreater(const std::vector<T>* values) : values_(values) {}
58 bool operator()(const T& x, const T& y) const {
59 return (*values_)[x] > (*values_)[y];
60 }
61
62 private:
63 const std::vector<T>* values_;
64};
65
66// ----- BaseIntExprElement -----
67
68class BaseIntExprElement : public BaseIntExpr {
69 public:
70 BaseIntExprElement(Solver* const s, IntVar* const e);
71 ~BaseIntExprElement() override {}
72 int64_t Min() const override;
73 int64_t Max() const override;
74 void Range(int64_t* mi, int64_t* ma) override;
75 void SetMin(int64_t m) override;
76 void SetMax(int64_t m) override;
77 void SetRange(int64_t mi, int64_t ma) override;
78 bool Bound() const override { return (expr_->Bound()); }
79 // TODO(user) : improve me, the previous test is not always true
80 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
81
82 protected:
83 virtual int64_t ElementValue(int index) const = 0;
84 virtual int64_t ExprMin() const = 0;
85 virtual int64_t ExprMax() const = 0;
86
87 IntVar* const expr_;
88
89 private:
90 void UpdateSupports() const;
91
92 mutable int64_t min_;
93 mutable int min_support_;
94 mutable int64_t max_;
95 mutable int max_support_;
96 mutable bool initial_update_;
97 IntVarIterator* const expr_iterator_;
98};
99
100BaseIntExprElement::BaseIntExprElement(Solver* const s, IntVar* const e)
101 : BaseIntExpr(s),
102 expr_(e),
103 min_(0),
104 min_support_(-1),
105 max_(0),
106 max_support_(-1),
107 initial_update_(true),
108 expr_iterator_(expr_->MakeDomainIterator(true)) {
109 CHECK(s != nullptr);
110 CHECK(e != nullptr);
111}
112
113int64_t BaseIntExprElement::Min() const {
114 UpdateSupports();
115 return min_;
116}
117
118int64_t BaseIntExprElement::Max() const {
119 UpdateSupports();
120 return max_;
121}
122
123void BaseIntExprElement::Range(int64_t* mi, int64_t* ma) {
124 UpdateSupports();
125 *mi = min_;
126 *ma = max_;
127}
128
129#define UPDATE_BASE_ELEMENT_INDEX_BOUNDS(test) \
130 const int64_t emin = ExprMin(); \
131 const int64_t emax = ExprMax(); \
132 int64_t nmin = emin; \
133 int64_t value = ElementValue(nmin); \
134 while (nmin < emax && test) { \
135 nmin++; \
136 value = ElementValue(nmin); \
137 } \
138 if (nmin == emax && test) { \
139 solver()->Fail(); \
140 } \
141 int64_t nmax = emax; \
142 value = ElementValue(nmax); \
143 while (nmax >= nmin && test) { \
144 nmax--; \
145 value = ElementValue(nmax); \
146 } \
147 expr_->SetRange(nmin, nmax);
148
149void BaseIntExprElement::SetMin(int64_t m) {
151}
152
153void BaseIntExprElement::SetMax(int64_t m) {
155}
156
157void BaseIntExprElement::SetRange(int64_t mi, int64_t ma) {
158 if (mi > ma) {
159 solver()->Fail();
160 }
161 UPDATE_BASE_ELEMENT_INDEX_BOUNDS((value < mi || value > ma));
162}
163
164#undef UPDATE_BASE_ELEMENT_INDEX_BOUNDS
165
166void BaseIntExprElement::UpdateSupports() const {
167 if (initial_update_ || !expr_->Contains(min_support_) ||
168 !expr_->Contains(max_support_)) {
169 const int64_t emin = ExprMin();
170 const int64_t emax = ExprMax();
171 int64_t min_value = ElementValue(emax);
172 int64_t max_value = min_value;
173 int min_support = emax;
174 int max_support = emax;
175 const uint64_t expr_size = expr_->Size();
176 if (expr_size > 1) {
177 if (expr_size == emax - emin + 1) {
178 // Value(emax) already stored in min_value, max_value.
179 for (int64_t index = emin; index < emax; ++index) {
180 const int64_t value = ElementValue(index);
181 if (value > max_value) {
182 max_value = value;
183 max_support = index;
184 } else if (value < min_value) {
185 min_value = value;
186 min_support = index;
187 }
188 }
189 } else {
190 for (const int64_t index : InitAndGetValues(expr_iterator_)) {
191 if (index >= emin && index <= emax) {
192 const int64_t value = ElementValue(index);
193 if (value > max_value) {
194 max_value = value;
195 max_support = index;
196 } else if (value < min_value) {
197 min_value = value;
198 min_support = index;
199 }
200 }
201 }
202 }
203 }
204 Solver* s = solver();
205 s->SaveAndSetValue(&min_, min_value);
206 s->SaveAndSetValue(&min_support_, min_support);
207 s->SaveAndSetValue(&max_, max_value);
208 s->SaveAndSetValue(&max_support_, max_support);
209 s->SaveAndSetValue(&initial_update_, false);
210 }
211}
212
213// ----- IntElementConstraint -----
214
215// This constraint implements 'elem' == 'values'['index'].
216// It scans the bounds of 'elem' to propagate on the domain of 'index'.
217// It scans the domain of 'index' to compute the new bounds of 'elem'.
218class IntElementConstraint : public CastConstraint {
219 public:
220 IntElementConstraint(Solver* const s, const std::vector<int64_t>& values,
221 IntVar* const index, IntVar* const elem)
222 : CastConstraint(s, elem),
223 values_(values),
224 index_(index),
225 index_iterator_(index_->MakeDomainIterator(true)) {
226 CHECK(index != nullptr);
227 }
228
229 void Post() override {
230 Demon* const d =
231 solver()->MakeDelayedConstraintInitialPropagateCallback(this);
232 index_->WhenDomain(d);
233 target_var_->WhenRange(d);
234 }
235
236 void InitialPropagate() override {
237 index_->SetRange(0, values_.size() - 1);
238 const int64_t target_var_min = target_var_->Min();
239 const int64_t target_var_max = target_var_->Max();
240 int64_t new_min = target_var_max;
241 int64_t new_max = target_var_min;
242 to_remove_.clear();
243 for (const int64_t index : InitAndGetValues(index_iterator_)) {
244 const int64_t value = values_[index];
245 if (value < target_var_min || value > target_var_max) {
246 to_remove_.push_back(index);
247 } else {
248 if (value < new_min) {
249 new_min = value;
250 }
251 if (value > new_max) {
252 new_max = value;
253 }
254 }
255 }
256 target_var_->SetRange(new_min, new_max);
257 if (!to_remove_.empty()) {
258 index_->RemoveValues(to_remove_);
259 }
260 }
261
262 std::string DebugString() const override {
263 return absl::StrFormat("IntElementConstraint(%s, %s, %s)",
264 absl::StrJoin(values_, ", "), index_->DebugString(),
265 target_var_->DebugString());
266 }
267
268 void Accept(ModelVisitor* const visitor) const override {
269 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
270 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
271 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
272 index_);
273 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
275 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
276 }
277
278 private:
279 const std::vector<int64_t> values_;
280 IntVar* const index_;
281 IntVarIterator* const index_iterator_;
282 std::vector<int64_t> to_remove_;
283};
284
285// ----- IntExprElement
286
287IntVar* BuildDomainIntVar(Solver* const solver, std::vector<int64_t>* values);
288
289class IntExprElement : public BaseIntExprElement {
290 public:
291 IntExprElement(Solver* const s, const std::vector<int64_t>& vals,
292 IntVar* const expr)
293 : BaseIntExprElement(s, expr), values_(vals) {}
294
295 ~IntExprElement() override {}
296
297 std::string name() const override {
298 const int size = values_.size();
299 if (size > 10) {
300 return absl::StrFormat("IntElement(array of size %d, %s)", size,
301 expr_->name());
302 } else {
303 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
304 expr_->name());
305 }
306 }
307
308 std::string DebugString() const override {
309 const int size = values_.size();
310 if (size > 10) {
311 return absl::StrFormat("IntElement(array of size %d, %s)", size,
312 expr_->DebugString());
313 } else {
314 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
315 expr_->DebugString());
316 }
317 }
318
319 IntVar* CastToVar() override {
320 Solver* const s = solver();
321 IntVar* const var = s->MakeIntVar(values_);
322 s->AddCastConstraint(
323 s->RevAlloc(new IntElementConstraint(s, values_, expr_, var)), var,
324 this);
325 return var;
326 }
327
328 void Accept(ModelVisitor* const visitor) const override {
329 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
330 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
331 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
332 expr_);
333 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
334 }
335
336 protected:
337 int64_t ElementValue(int index) const override {
338 DCHECK_LT(index, values_.size());
339 return values_[index];
340 }
341 int64_t ExprMin() const override {
342 return std::max<int64_t>(0, expr_->Min());
343 }
344 int64_t ExprMax() const override {
345 return values_.empty()
346 ? 0
347 : std::min<int64_t>(values_.size() - 1, expr_->Max());
348 }
349
350 private:
351 const std::vector<int64_t> values_;
352};
353
354// ----- Range Minimum Query-based Element -----
355
356class RangeMinimumQueryExprElement : public BaseIntExpr {
357 public:
358 RangeMinimumQueryExprElement(Solver* solver,
359 const std::vector<int64_t>& values,
360 IntVar* index);
361 ~RangeMinimumQueryExprElement() override {}
362 int64_t Min() const override;
363 int64_t Max() const override;
364 void Range(int64_t* mi, int64_t* ma) override;
365 void SetMin(int64_t m) override;
366 void SetMax(int64_t m) override;
367 void SetRange(int64_t mi, int64_t ma) override;
368 bool Bound() const override { return (index_->Bound()); }
369 // TODO(user) : improve me, the previous test is not always true
370 void WhenRange(Demon* d) override { index_->WhenRange(d); }
371 IntVar* CastToVar() override {
372 // TODO(user): Should we try to make holes in the domain of index_, as we
373 // do here, or should we only propagate bounds as we do in
374 // IncreasingIntExprElement ?
375 IntVar* const var = solver()->MakeIntVar(min_rmq_.array());
376 solver()->AddCastConstraint(solver()->RevAlloc(new IntElementConstraint(
377 solver(), min_rmq_.array(), index_, var)),
378 var, this);
379 return var;
380 }
381 void Accept(ModelVisitor* const visitor) const override {
382 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
383 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
384 min_rmq_.array());
385 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
386 index_);
387 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
388 }
389
390 private:
391 int64_t IndexMin() const { return std::max<int64_t>(0, index_->Min()); }
392 int64_t IndexMax() const {
393 return std::min<int64_t>(min_rmq_.array().size() - 1, index_->Max());
394 }
395
396 IntVar* const index_;
397 const RangeMinimumQuery<int64_t, std::less<int64_t>> min_rmq_;
398 const RangeMinimumQuery<int64_t, std::greater<int64_t>> max_rmq_;
399};
400
401RangeMinimumQueryExprElement::RangeMinimumQueryExprElement(
402 Solver* solver, const std::vector<int64_t>& values, IntVar* index)
403 : BaseIntExpr(solver), index_(index), min_rmq_(values), max_rmq_(values) {
404 CHECK(solver != nullptr);
405 CHECK(index != nullptr);
406}
407
408int64_t RangeMinimumQueryExprElement::Min() const {
409 return min_rmq_.GetMinimumFromRange(IndexMin(), IndexMax() + 1);
410}
411
412int64_t RangeMinimumQueryExprElement::Max() const {
413 return max_rmq_.GetMinimumFromRange(IndexMin(), IndexMax() + 1);
414}
415
416void RangeMinimumQueryExprElement::Range(int64_t* mi, int64_t* ma) {
417 const int64_t range_min = IndexMin();
418 const int64_t range_max = IndexMax() + 1;
419 *mi = min_rmq_.GetMinimumFromRange(range_min, range_max);
420 *ma = max_rmq_.GetMinimumFromRange(range_min, range_max);
421}
422
423#define UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(test) \
424 const std::vector<int64_t>& values = min_rmq_.array(); \
425 int64_t index_min = IndexMin(); \
426 int64_t index_max = IndexMax(); \
427 int64_t value = values[index_min]; \
428 while (index_min < index_max && (test)) { \
429 index_min++; \
430 value = values[index_min]; \
431 } \
432 if (index_min == index_max && (test)) { \
433 solver()->Fail(); \
434 } \
435 value = values[index_max]; \
436 while (index_max >= index_min && (test)) { \
437 index_max--; \
438 value = values[index_max]; \
439 } \
440 index_->SetRange(index_min, index_max);
441
442void RangeMinimumQueryExprElement::SetMin(int64_t m) {
444}
445
446void RangeMinimumQueryExprElement::SetMax(int64_t m) {
448}
449
450void RangeMinimumQueryExprElement::SetRange(int64_t mi, int64_t ma) {
451 if (mi > ma) {
452 solver()->Fail();
453 }
454 UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(value < mi || value > ma);
455}
456
457#undef UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS
458
459// ----- Increasing Element -----
460
461class IncreasingIntExprElement : public BaseIntExpr {
462 public:
463 IncreasingIntExprElement(Solver* const s, const std::vector<int64_t>& values,
464 IntVar* const index);
465 ~IncreasingIntExprElement() override {}
466
467 int64_t Min() const override;
468 void SetMin(int64_t m) override;
469 int64_t Max() const override;
470 void SetMax(int64_t m) override;
471 void SetRange(int64_t mi, int64_t ma) override;
472 bool Bound() const override { return (index_->Bound()); }
473 // TODO(user) : improve me, the previous test is not always true
474 std::string name() const override {
475 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
476 index_->name());
477 }
478 std::string DebugString() const override {
479 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),
480 index_->DebugString());
481 }
482
483 void Accept(ModelVisitor* const visitor) const override {
484 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
485 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
486 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
487 index_);
488 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
489 }
490
491 void WhenRange(Demon* d) override { index_->WhenRange(d); }
492
493 IntVar* CastToVar() override {
494 Solver* const s = solver();
495 IntVar* const var = s->MakeIntVar(values_);
496 LinkVarExpr(s, this, var);
497 return var;
498 }
499
500 private:
501 const std::vector<int64_t> values_;
502 IntVar* const index_;
503};
504
505IncreasingIntExprElement::IncreasingIntExprElement(
506 Solver* const s, const std::vector<int64_t>& values, IntVar* const index)
507 : BaseIntExpr(s), values_(values), index_(index) {
508 DCHECK(index);
509 DCHECK(s);
510}
511
512int64_t IncreasingIntExprElement::Min() const {
513 const int64_t expression_min = std::max<int64_t>(0, index_->Min());
514 return (expression_min < values_.size()
515 ? values_[expression_min]
517}
518
519void IncreasingIntExprElement::SetMin(int64_t m) {
520 const int64_t index_min = std::max<int64_t>(0, index_->Min());
521 const int64_t index_max =
522 std::min<int64_t>(values_.size() - 1, index_->Max());
523
524 if (index_min > index_max || m > values_[index_max]) {
525 solver()->Fail();
526 }
527
528 const std::vector<int64_t>::const_iterator first =
529 std::lower_bound(values_.begin(), values_.end(), m);
530 const int64_t new_index_min = first - values_.begin();
531 index_->SetMin(new_index_min);
532}
533
534int64_t IncreasingIntExprElement::Max() const {
535 const int64_t expression_max =
536 std::min<int64_t>(values_.size() - 1, index_->Max());
537 return (expression_max >= 0 ? values_[expression_max]
539}
540
541void IncreasingIntExprElement::SetMax(int64_t m) {
542 int64_t index_min = std::max<int64_t>(0, index_->Min());
543 if (m < values_[index_min]) {
544 solver()->Fail();
545 }
546
547 const std::vector<int64_t>::const_iterator last_after =
548 std::upper_bound(values_.begin(), values_.end(), m);
549 const int64_t new_index_max = (last_after - values_.begin()) - 1;
550 index_->SetRange(0, new_index_max);
551}
552
553void IncreasingIntExprElement::SetRange(int64_t mi, int64_t ma) {
554 if (mi > ma) {
555 solver()->Fail();
556 }
557 const int64_t index_min = std::max<int64_t>(0, index_->Min());
558 const int64_t index_max =
559 std::min<int64_t>(values_.size() - 1, index_->Max());
560
561 if (mi > ma || ma < values_[index_min] || mi > values_[index_max]) {
562 solver()->Fail();
563 }
564
565 const std::vector<int64_t>::const_iterator first =
566 std::lower_bound(values_.begin(), values_.end(), mi);
567 const int64_t new_index_min = first - values_.begin();
568
569 const std::vector<int64_t>::const_iterator last_after =
570 std::upper_bound(first, values_.end(), ma);
571 const int64_t new_index_max = (last_after - values_.begin()) - 1;
572
573 // Assign.
574 index_->SetRange(new_index_min, new_index_max);
575}
576
577// ----- Solver::MakeElement(int array, int var) -----
578IntExpr* BuildElement(Solver* const solver, const std::vector<int64_t>& values,
579 IntVar* const index) {
580 // Various checks.
581 // Is array constant?
582 if (IsArrayConstant(values, values[0])) {
583 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
584 return solver->MakeIntConst(values[0]);
585 }
586 // Is array built with booleans only?
587 // TODO(user): We could maintain the index of the first one.
588 if (IsArrayBoolean(values)) {
589 std::vector<int64_t> ones;
590 int first_zero = -1;
591 for (int i = 0; i < values.size(); ++i) {
592 if (values[i] == 1) {
593 ones.push_back(i);
594 } else {
595 first_zero = i;
596 }
597 }
598 if (ones.size() == 1) {
599 DCHECK_EQ(int64_t{1}, values[ones.back()]);
600 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
601 return solver->MakeIsEqualCstVar(index, ones.back());
602 } else if (ones.size() == values.size() - 1) {
603 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
604 return solver->MakeIsDifferentCstVar(index, first_zero);
605 } else if (ones.size() == ones.back() - ones.front() + 1) { // contiguous.
606 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
607 IntVar* const b = solver->MakeBoolVar("ContiguousBooleanElementVar");
608 solver->AddConstraint(
609 solver->MakeIsBetweenCt(index, ones.front(), ones.back(), b));
610 return b;
611 } else {
612 IntVar* const b = solver->MakeBoolVar("NonContiguousBooleanElementVar");
613 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));
614 solver->AddConstraint(solver->MakeIsMemberCt(index, ones, b));
615 return b;
616 }
617 }
618 IntExpr* cache = nullptr;
619 if (!absl::GetFlag(FLAGS_cp_disable_element_cache)) {
620 cache = solver->Cache()->FindVarConstantArrayExpression(
621 index, values, ModelCache::VAR_CONSTANT_ARRAY_ELEMENT);
622 }
623 if (cache != nullptr) {
624 return cache;
625 } else {
626 IntExpr* result = nullptr;
627 if (values.size() >= 2 && index->Min() == 0 && index->Max() == 1) {
628 result = solver->MakeSum(solver->MakeProd(index, values[1] - values[0]),
629 values[0]);
630 } else if (values.size() == 2 && index->Contains(0) && index->Contains(1)) {
631 solver->AddConstraint(solver->MakeBetweenCt(index, 0, 1));
632 result = solver->MakeSum(solver->MakeProd(index, values[1] - values[0]),
633 values[0]);
634 } else if (IsIncreasingContiguous(values)) {
635 result = solver->MakeSum(index, values[0]);
636 } else if (IsIncreasing(values)) {
637 result = solver->RegisterIntExpr(solver->RevAlloc(
638 new IncreasingIntExprElement(solver, values, index)));
639 } else {
640 if (solver->parameters().use_element_rmq()) {
641 result = solver->RegisterIntExpr(solver->RevAlloc(
642 new RangeMinimumQueryExprElement(solver, values, index)));
643 } else {
644 result = solver->RegisterIntExpr(
645 solver->RevAlloc(new IntExprElement(solver, values, index)));
646 }
647 }
648 if (!absl::GetFlag(FLAGS_cp_disable_element_cache)) {
649 solver->Cache()->InsertVarConstantArrayExpression(
650 result, index, values, ModelCache::VAR_CONSTANT_ARRAY_ELEMENT);
651 }
652 return result;
653 }
654}
655} // namespace
656
657IntExpr* Solver::MakeElement(const std::vector<int64_t>& values,
658 IntVar* const index) {
659 DCHECK(index);
660 DCHECK_EQ(this, index->solver());
661 if (index->Bound()) {
662 return MakeIntConst(values[index->Min()]);
663 }
664 return BuildElement(this, values, index);
665}
666
667IntExpr* Solver::MakeElement(const std::vector<int>& values,
668 IntVar* const index) {
669 DCHECK(index);
670 DCHECK_EQ(this, index->solver());
671 if (index->Bound()) {
672 return MakeIntConst(values[index->Min()]);
673 }
674 return BuildElement(this, ToInt64Vector(values), index);
675}
676
677// ----- IntExprFunctionElement -----
678
679namespace {
680class IntExprFunctionElement : public BaseIntExprElement {
681 public:
682 IntExprFunctionElement(Solver* const s, Solver::IndexEvaluator1 values,
683 IntVar* const e);
684 ~IntExprFunctionElement() override;
685
686 std::string name() const override {
687 return absl::StrFormat("IntFunctionElement(%s)", expr_->name());
688 }
689
690 std::string DebugString() const override {
691 return absl::StrFormat("IntFunctionElement(%s)", expr_->DebugString());
692 }
693
694 void Accept(ModelVisitor* const visitor) const override {
695 // Warning: This will expand all values into a vector.
696 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
697 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
698 expr_);
699 visitor->VisitInt64ToInt64Extension(values_, expr_->Min(), expr_->Max());
700 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
701 }
702
703 protected:
704 int64_t ElementValue(int index) const override { return values_(index); }
705 int64_t ExprMin() const override { return expr_->Min(); }
706 int64_t ExprMax() const override { return expr_->Max(); }
707
708 private:
709 Solver::IndexEvaluator1 values_;
710};
711
712IntExprFunctionElement::IntExprFunctionElement(Solver* const s,
713 Solver::IndexEvaluator1 values,
714 IntVar* const e)
715 : BaseIntExprElement(s, e), values_(std::move(values)) {
716 CHECK(values_ != nullptr);
717}
718
719IntExprFunctionElement::~IntExprFunctionElement() {}
720
721// ----- Increasing Element -----
722
723class IncreasingIntExprFunctionElement : public BaseIntExpr {
724 public:
725 IncreasingIntExprFunctionElement(Solver* const s,
727 IntVar* const index)
728 : BaseIntExpr(s), values_(std::move(values)), index_(index) {
729 DCHECK(values_ != nullptr);
730 DCHECK(index);
731 DCHECK(s);
732 }
733
734 ~IncreasingIntExprFunctionElement() override {}
735
736 int64_t Min() const override { return values_(index_->Min()); }
737
738 void SetMin(int64_t m) override {
739 const int64_t index_min = index_->Min();
740 const int64_t index_max = index_->Max();
741 if (m > values_(index_max)) {
742 solver()->Fail();
743 }
744 const int64_t new_index_min = FindNewIndexMin(index_min, index_max, m);
745 index_->SetMin(new_index_min);
746 }
747
748 int64_t Max() const override { return values_(index_->Max()); }
749
750 void SetMax(int64_t m) override {
751 int64_t index_min = index_->Min();
752 int64_t index_max = index_->Max();
753 if (m < values_(index_min)) {
754 solver()->Fail();
755 }
756 const int64_t new_index_max = FindNewIndexMax(index_min, index_max, m);
757 index_->SetMax(new_index_max);
758 }
759
760 void SetRange(int64_t mi, int64_t ma) override {
761 const int64_t index_min = index_->Min();
762 const int64_t index_max = index_->Max();
763 const int64_t value_min = values_(index_min);
764 const int64_t value_max = values_(index_max);
765 if (mi > ma || ma < value_min || mi > value_max) {
766 solver()->Fail();
767 }
768 if (mi <= value_min && ma >= value_max) {
769 // Nothing to do.
770 return;
771 }
772
773 const int64_t new_index_min = FindNewIndexMin(index_min, index_max, mi);
774 const int64_t new_index_max = FindNewIndexMax(new_index_min, index_max, ma);
775 // Assign.
776 index_->SetRange(new_index_min, new_index_max);
777 }
778
779 std::string name() const override {
780 return absl::StrFormat("IncreasingIntExprFunctionElement(values, %s)",
781 index_->name());
782 }
783
784 std::string DebugString() const override {
785 return absl::StrFormat("IncreasingIntExprFunctionElement(values, %s)",
786 index_->DebugString());
787 }
788
789 void WhenRange(Demon* d) override { index_->WhenRange(d); }
790
791 void Accept(ModelVisitor* const visitor) const override {
792 // Warning: This will expand all values into a vector.
793 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
794 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
795 index_);
796 if (index_->Min() == 0) {
797 visitor->VisitInt64ToInt64AsArray(values_, ModelVisitor::kValuesArgument,
798 index_->Max());
799 } else {
800 visitor->VisitInt64ToInt64Extension(values_, index_->Min(),
801 index_->Max());
802 }
803 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
804 }
805
806 private:
807 int64_t FindNewIndexMin(int64_t index_min, int64_t index_max, int64_t m) {
808 if (m <= values_(index_min)) {
809 return index_min;
810 }
811
812 DCHECK_LT(values_(index_min), m);
813 DCHECK_GE(values_(index_max), m);
814
815 int64_t index_lower_bound = index_min;
816 int64_t index_upper_bound = index_max;
817 while (index_upper_bound - index_lower_bound > 1) {
818 DCHECK_LT(values_(index_lower_bound), m);
819 DCHECK_GE(values_(index_upper_bound), m);
820 const int64_t pivot = (index_lower_bound + index_upper_bound) / 2;
821 const int64_t pivot_value = values_(pivot);
822 if (pivot_value < m) {
823 index_lower_bound = pivot;
824 } else {
825 index_upper_bound = pivot;
826 }
827 }
828 DCHECK(values_(index_upper_bound) >= m);
829 return index_upper_bound;
830 }
831
832 int64_t FindNewIndexMax(int64_t index_min, int64_t index_max, int64_t m) {
833 if (m >= values_(index_max)) {
834 return index_max;
835 }
836
837 DCHECK_LE(values_(index_min), m);
838 DCHECK_GT(values_(index_max), m);
839
840 int64_t index_lower_bound = index_min;
841 int64_t index_upper_bound = index_max;
842 while (index_upper_bound - index_lower_bound > 1) {
843 DCHECK_LE(values_(index_lower_bound), m);
844 DCHECK_GT(values_(index_upper_bound), m);
845 const int64_t pivot = (index_lower_bound + index_upper_bound) / 2;
846 const int64_t pivot_value = values_(pivot);
847 if (pivot_value > m) {
848 index_upper_bound = pivot;
849 } else {
850 index_lower_bound = pivot;
851 }
852 }
853 DCHECK(values_(index_lower_bound) <= m);
854 return index_lower_bound;
855 }
856
858 IntVar* const index_;
859};
860} // namespace
861
863 IntVar* const index) {
864 CHECK_EQ(this, index->solver());
865 return RegisterIntExpr(
866 RevAlloc(new IntExprFunctionElement(this, std::move(values), index)));
867}
868
870 bool increasing, IntVar* const index) {
871 CHECK_EQ(this, index->solver());
872 if (increasing) {
873 return RegisterIntExpr(
874 RevAlloc(new IncreasingIntExprFunctionElement(this, values, index)));
875 } else {
876 // You need to pass by copy such that opposite_value does not include a
877 // dandling reference when leaving this scope.
878 Solver::IndexEvaluator1 opposite_values = [values](int64_t i) {
879 return -values(i);
880 };
882 new IncreasingIntExprFunctionElement(this, opposite_values, index))));
883 }
884}
885
886// ----- IntIntExprFunctionElement -----
887
888namespace {
889class IntIntExprFunctionElement : public BaseIntExpr {
890 public:
891 IntIntExprFunctionElement(Solver* const s, Solver::IndexEvaluator2 values,
892 IntVar* const expr1, IntVar* const expr2);
893 ~IntIntExprFunctionElement() override;
894 std::string DebugString() const override {
895 return absl::StrFormat("IntIntFunctionElement(%s,%s)",
896 expr1_->DebugString(), expr2_->DebugString());
897 }
898 int64_t Min() const override;
899 int64_t Max() const override;
900 void Range(int64_t* lower_bound, int64_t* upper_bound) override;
901 void SetMin(int64_t lower_bound) override;
902 void SetMax(int64_t upper_bound) override;
903 void SetRange(int64_t lower_bound, int64_t upper_bound) override;
904 bool Bound() const override { return expr1_->Bound() && expr2_->Bound(); }
905 // TODO(user) : improve me, the previous test is not always true
906 void WhenRange(Demon* d) override {
907 expr1_->WhenRange(d);
908 expr2_->WhenRange(d);
909 }
910
911 void Accept(ModelVisitor* const visitor) const override {
912 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);
913 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
914 expr1_);
915 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndex2Argument,
916 expr2_);
917 // Warning: This will expand all values into a vector.
918 const int64_t expr1_min = expr1_->Min();
919 const int64_t expr1_max = expr1_->Max();
920 visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, expr1_min);
921 visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, expr1_max);
922 for (int i = expr1_min; i <= expr1_max; ++i) {
923 visitor->VisitInt64ToInt64Extension(
924 [this, i](int64_t j) { return values_(i, j); }, expr2_->Min(),
925 expr2_->Max());
926 }
927 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);
928 }
929
930 private:
931 int64_t ElementValue(int index1, int index2) const {
932 return values_(index1, index2);
933 }
934 void UpdateSupports() const;
935
936 IntVar* const expr1_;
937 IntVar* const expr2_;
938 mutable int64_t min_;
939 mutable int min_support1_;
940 mutable int min_support2_;
941 mutable int64_t max_;
942 mutable int max_support1_;
943 mutable int max_support2_;
944 mutable bool initial_update_;
946 IntVarIterator* const expr1_iterator_;
947 IntVarIterator* const expr2_iterator_;
948};
949
950IntIntExprFunctionElement::IntIntExprFunctionElement(
951 Solver* const s, Solver::IndexEvaluator2 values, IntVar* const expr1,
952 IntVar* const expr2)
953 : BaseIntExpr(s),
954 expr1_(expr1),
955 expr2_(expr2),
956 min_(0),
957 min_support1_(-1),
958 min_support2_(-1),
959 max_(0),
960 max_support1_(-1),
961 max_support2_(-1),
962 initial_update_(true),
963 values_(std::move(values)),
964 expr1_iterator_(expr1_->MakeDomainIterator(true)),
965 expr2_iterator_(expr2_->MakeDomainIterator(true)) {
966 CHECK(values_ != nullptr);
967}
968
969IntIntExprFunctionElement::~IntIntExprFunctionElement() {}
970
971int64_t IntIntExprFunctionElement::Min() const {
972 UpdateSupports();
973 return min_;
974}
975
976int64_t IntIntExprFunctionElement::Max() const {
977 UpdateSupports();
978 return max_;
979}
980
981void IntIntExprFunctionElement::Range(int64_t* lower_bound,
982 int64_t* upper_bound) {
983 UpdateSupports();
984 *lower_bound = min_;
985 *upper_bound = max_;
986}
987
988#define UPDATE_ELEMENT_INDEX_BOUNDS(test) \
989 const int64_t emin1 = expr1_->Min(); \
990 const int64_t emax1 = expr1_->Max(); \
991 const int64_t emin2 = expr2_->Min(); \
992 const int64_t emax2 = expr2_->Max(); \
993 int64_t nmin1 = emin1; \
994 bool found = false; \
995 while (nmin1 <= emax1 && !found) { \
996 for (int i = emin2; i <= emax2; ++i) { \
997 int64_t value = ElementValue(nmin1, i); \
998 if (test) { \
999 found = true; \
1000 break; \
1001 } \
1002 } \
1003 if (!found) { \
1004 nmin1++; \
1005 } \
1006 } \
1007 if (nmin1 > emax1) { \
1008 solver()->Fail(); \
1009 } \
1010 int64_t nmin2 = emin2; \
1011 found = false; \
1012 while (nmin2 <= emax2 && !found) { \
1013 for (int i = emin1; i <= emax1; ++i) { \
1014 int64_t value = ElementValue(i, nmin2); \
1015 if (test) { \
1016 found = true; \
1017 break; \
1018 } \
1019 } \
1020 if (!found) { \
1021 nmin2++; \
1022 } \
1023 } \
1024 if (nmin2 > emax2) { \
1025 solver()->Fail(); \
1026 } \
1027 int64_t nmax1 = emax1; \
1028 found = false; \
1029 while (nmax1 >= nmin1 && !found) { \
1030 for (int i = emin2; i <= emax2; ++i) { \
1031 int64_t value = ElementValue(nmax1, i); \
1032 if (test) { \
1033 found = true; \
1034 break; \
1035 } \
1036 } \
1037 if (!found) { \
1038 nmax1--; \
1039 } \
1040 } \
1041 int64_t nmax2 = emax2; \
1042 found = false; \
1043 while (nmax2 >= nmin2 && !found) { \
1044 for (int i = emin1; i <= emax1; ++i) { \
1045 int64_t value = ElementValue(i, nmax2); \
1046 if (test) { \
1047 found = true; \
1048 break; \
1049 } \
1050 } \
1051 if (!found) { \
1052 nmax2--; \
1053 } \
1054 } \
1055 expr1_->SetRange(nmin1, nmax1); \
1056 expr2_->SetRange(nmin2, nmax2);
1057
1058void IntIntExprFunctionElement::SetMin(int64_t lower_bound) {
1060}
1061
1062void IntIntExprFunctionElement::SetMax(int64_t upper_bound) {
1064}
1065
1066void IntIntExprFunctionElement::SetRange(int64_t lower_bound,
1067 int64_t upper_bound) {
1068 if (lower_bound > upper_bound) {
1069 solver()->Fail();
1070 }
1072}
1073
1074#undef UPDATE_ELEMENT_INDEX_BOUNDS
1075
1076void IntIntExprFunctionElement::UpdateSupports() const {
1077 if (initial_update_ || !expr1_->Contains(min_support1_) ||
1078 !expr1_->Contains(max_support1_) || !expr2_->Contains(min_support2_) ||
1079 !expr2_->Contains(max_support2_)) {
1080 const int64_t emax1 = expr1_->Max();
1081 const int64_t emax2 = expr2_->Max();
1082 int64_t min_value = ElementValue(emax1, emax2);
1083 int64_t max_value = min_value;
1084 int min_support1 = emax1;
1085 int max_support1 = emax1;
1086 int min_support2 = emax2;
1087 int max_support2 = emax2;
1088 for (const int64_t index1 : InitAndGetValues(expr1_iterator_)) {
1089 for (const int64_t index2 : InitAndGetValues(expr2_iterator_)) {
1090 const int64_t value = ElementValue(index1, index2);
1091 if (value > max_value) {
1092 max_value = value;
1093 max_support1 = index1;
1094 max_support2 = index2;
1095 } else if (value < min_value) {
1096 min_value = value;
1097 min_support1 = index1;
1098 min_support2 = index2;
1099 }
1100 }
1101 }
1102 Solver* s = solver();
1103 s->SaveAndSetValue(&min_, min_value);
1104 s->SaveAndSetValue(&min_support1_, min_support1);
1105 s->SaveAndSetValue(&min_support2_, min_support2);
1106 s->SaveAndSetValue(&max_, max_value);
1107 s->SaveAndSetValue(&max_support1_, max_support1);
1108 s->SaveAndSetValue(&max_support2_, max_support2);
1109 s->SaveAndSetValue(&initial_update_, false);
1110 }
1111}
1112} // namespace
1113
1115 IntVar* const index1, IntVar* const index2) {
1116 CHECK_EQ(this, index1->solver());
1117 CHECK_EQ(this, index2->solver());
1119 new IntIntExprFunctionElement(this, std::move(values), index1, index2)));
1120}
1121
1122// ---------- Generalized element ----------
1123
1124// ----- IfThenElseCt -----
1125
1127 public:
1128 IfThenElseCt(Solver* const solver, IntVar* const condition,
1129 IntExpr* const one, IntExpr* const zero, IntVar* const target)
1130 : CastConstraint(solver, target),
1131 condition_(condition),
1132 zero_(zero),
1133 one_(one) {}
1134
1135 ~IfThenElseCt() override {}
1136
1137 void Post() override {
1138 Demon* const demon = solver()->MakeConstraintInitialPropagateCallback(this);
1139 condition_->WhenBound(demon);
1140 one_->WhenRange(demon);
1141 zero_->WhenRange(demon);
1142 target_var_->WhenRange(demon);
1143 }
1144
1145 void InitialPropagate() override {
1146 condition_->SetRange(0, 1);
1147 const int64_t target_var_min = target_var_->Min();
1148 const int64_t target_var_max = target_var_->Max();
1149 int64_t new_min = std::numeric_limits<int64_t>::min();
1150 int64_t new_max = std::numeric_limits<int64_t>::max();
1151 if (condition_->Max() == 0) {
1152 zero_->SetRange(target_var_min, target_var_max);
1153 zero_->Range(&new_min, &new_max);
1154 } else if (condition_->Min() == 1) {
1155 one_->SetRange(target_var_min, target_var_max);
1156 one_->Range(&new_min, &new_max);
1157 } else {
1158 if (target_var_max < zero_->Min() || target_var_min > zero_->Max()) {
1159 condition_->SetValue(1);
1160 one_->SetRange(target_var_min, target_var_max);
1161 one_->Range(&new_min, &new_max);
1162 } else if (target_var_max < one_->Min() || target_var_min > one_->Max()) {
1163 condition_->SetValue(0);
1164 zero_->SetRange(target_var_min, target_var_max);
1165 zero_->Range(&new_min, &new_max);
1166 } else {
1167 int64_t zl = 0;
1168 int64_t zu = 0;
1169 int64_t ol = 0;
1170 int64_t ou = 0;
1171 zero_->Range(&zl, &zu);
1172 one_->Range(&ol, &ou);
1173 new_min = std::min(zl, ol);
1174 new_max = std::max(zu, ou);
1175 }
1176 }
1177 target_var_->SetRange(new_min, new_max);
1178 }
1179
1180 std::string DebugString() const override {
1181 return absl::StrFormat("(%s ? %s : %s) == %s", condition_->DebugString(),
1182 one_->DebugString(), zero_->DebugString(),
1184 }
1185
1186 void Accept(ModelVisitor* const visitor) const override {}
1187
1188 private:
1189 IntVar* const condition_;
1190 IntExpr* const zero_;
1191 IntExpr* const one_;
1192};
1193
1194// ----- IntExprEvaluatorElementCt -----
1195
1196// This constraint implements evaluator(index) == var. It is delayed such
1197// that propagation only occurs when all variables have been touched.
1198// The range of the evaluator is [range_start, range_end).
1199
1200namespace {
1201class IntExprEvaluatorElementCt : public CastConstraint {
1202 public:
1203 IntExprEvaluatorElementCt(Solver* const s, Solver::Int64ToIntVar evaluator,
1204 int64_t range_start, int64_t range_end,
1205 IntVar* const index, IntVar* const target_var);
1206 ~IntExprEvaluatorElementCt() override {}
1207
1208 void Post() override;
1209 void InitialPropagate() override;
1210
1211 void Propagate();
1212 void Update(int index);
1213 void UpdateExpr();
1214
1215 std::string DebugString() const override;
1216 void Accept(ModelVisitor* const visitor) const override;
1217
1218 protected:
1219 IntVar* const index_;
1220
1221 private:
1223 const int64_t range_start_;
1224 const int64_t range_end_;
1225 int min_support_;
1226 int max_support_;
1227};
1228
1229IntExprEvaluatorElementCt::IntExprEvaluatorElementCt(
1230 Solver* const s, Solver::Int64ToIntVar evaluator, int64_t range_start,
1231 int64_t range_end, IntVar* const index, IntVar* const target_var)
1232 : CastConstraint(s, target_var),
1233 index_(index),
1234 evaluator_(std::move(evaluator)),
1235 range_start_(range_start),
1236 range_end_(range_end),
1237 min_support_(-1),
1238 max_support_(-1) {}
1239
1240void IntExprEvaluatorElementCt::Post() {
1241 Demon* const delayed_propagate_demon = MakeDelayedConstraintDemon0(
1242 solver(), this, &IntExprEvaluatorElementCt::Propagate, "Propagate");
1243 for (int i = range_start_; i < range_end_; ++i) {
1244 IntVar* const current_var = evaluator_(i);
1245 current_var->WhenRange(delayed_propagate_demon);
1246 Demon* const update_demon = MakeConstraintDemon1(
1247 solver(), this, &IntExprEvaluatorElementCt::Update, "Update", i);
1248 current_var->WhenRange(update_demon);
1249 }
1250 index_->WhenRange(delayed_propagate_demon);
1251 Demon* const update_expr_demon = MakeConstraintDemon0(
1252 solver(), this, &IntExprEvaluatorElementCt::UpdateExpr, "UpdateExpr");
1253 index_->WhenRange(update_expr_demon);
1254 Demon* const update_var_demon = MakeConstraintDemon0(
1255 solver(), this, &IntExprEvaluatorElementCt::Propagate, "UpdateVar");
1256
1257 target_var_->WhenRange(update_var_demon);
1258}
1259
1260void IntExprEvaluatorElementCt::InitialPropagate() { Propagate(); }
1261
1262void IntExprEvaluatorElementCt::Propagate() {
1263 const int64_t emin = std::max(range_start_, index_->Min());
1264 const int64_t emax = std::min<int64_t>(range_end_ - 1, index_->Max());
1265 const int64_t vmin = target_var_->Min();
1266 const int64_t vmax = target_var_->Max();
1267 if (emin == emax) {
1268 index_->SetValue(emin); // in case it was reduced by the above min/max.
1269 evaluator_(emin)->SetRange(vmin, vmax);
1270 } else {
1271 int64_t nmin = emin;
1272 for (; nmin <= emax; nmin++) {
1273 // break if the intersection of
1274 // [evaluator_(nmin)->Min(), evaluator_(nmin)->Max()] and [vmin, vmax]
1275 // is non-empty.
1276 IntVar* const nmin_var = evaluator_(nmin);
1277 if (nmin_var->Min() <= vmax && nmin_var->Max() >= vmin) break;
1278 }
1279 int64_t nmax = emax;
1280 for (; nmin <= nmax; nmax--) {
1281 // break if the intersection of
1282 // [evaluator_(nmin)->Min(), evaluator_(nmin)->Max()] and [vmin, vmax]
1283 // is non-empty.
1284 IntExpr* const nmax_var = evaluator_(nmax);
1285 if (nmax_var->Min() <= vmax && nmax_var->Max() >= vmin) break;
1286 }
1287 index_->SetRange(nmin, nmax);
1288 if (nmin == nmax) {
1289 evaluator_(nmin)->SetRange(vmin, vmax);
1290 }
1291 }
1292 if (min_support_ == -1 || max_support_ == -1) {
1293 int min_support = -1;
1294 int max_support = -1;
1295 int64_t gmin = std::numeric_limits<int64_t>::max();
1296 int64_t gmax = std::numeric_limits<int64_t>::min();
1297 for (int i = index_->Min(); i <= index_->Max(); ++i) {
1298 IntExpr* const var_i = evaluator_(i);
1299 const int64_t vmin = var_i->Min();
1300 if (vmin < gmin) {
1301 gmin = vmin;
1302 }
1303 const int64_t vmax = var_i->Max();
1304 if (vmax > gmax) {
1305 gmax = vmax;
1306 }
1307 }
1308 solver()->SaveAndSetValue(&min_support_, min_support);
1309 solver()->SaveAndSetValue(&max_support_, max_support);
1310 target_var_->SetRange(gmin, gmax);
1311 }
1312}
1313
1314void IntExprEvaluatorElementCt::Update(int index) {
1315 if (index == min_support_ || index == max_support_) {
1316 solver()->SaveAndSetValue(&min_support_, -1);
1317 solver()->SaveAndSetValue(&max_support_, -1);
1318 }
1319}
1320
1321void IntExprEvaluatorElementCt::UpdateExpr() {
1322 if (!index_->Contains(min_support_) || !index_->Contains(max_support_)) {
1323 solver()->SaveAndSetValue(&min_support_, -1);
1324 solver()->SaveAndSetValue(&max_support_, -1);
1325 }
1326}
1327
1328namespace {
1329std::string StringifyEvaluatorBare(const Solver::Int64ToIntVar& evaluator,
1330 int64_t range_start, int64_t range_end) {
1331 std::string out;
1332 for (int64_t i = range_start; i < range_end; ++i) {
1333 if (i != range_start) {
1334 out += ", ";
1335 }
1336 out += absl::StrFormat("%d -> %s", i, evaluator(i)->DebugString());
1337 }
1338 return out;
1339}
1340
1341std::string StringifyInt64ToIntVar(const Solver::Int64ToIntVar& evaluator,
1342 int64_t range_begin, int64_t range_end) {
1343 std::string out;
1344 if (range_end - range_begin > 10) {
1345 out = absl::StrFormat(
1346 "IntToIntVar(%s, ...%s)",
1347 StringifyEvaluatorBare(evaluator, range_begin, range_begin + 5),
1348 StringifyEvaluatorBare(evaluator, range_end - 5, range_end));
1349 } else {
1350 out = absl::StrFormat(
1351 "IntToIntVar(%s)",
1352 StringifyEvaluatorBare(evaluator, range_begin, range_end));
1353 }
1354 return out;
1355}
1356} // namespace
1357
1358std::string IntExprEvaluatorElementCt::DebugString() const {
1359 return StringifyInt64ToIntVar(evaluator_, range_start_, range_end_);
1360}
1361
1362void IntExprEvaluatorElementCt::Accept(ModelVisitor* const visitor) const {
1363 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
1364 visitor->VisitIntegerVariableEvaluatorArgument(
1366 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument, index_);
1367 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1368 target_var_);
1369 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
1370}
1371
1372// ----- IntExprArrayElementCt -----
1373
1374// This constraint implements vars[index] == var. It is delayed such
1375// that propagation only occurs when all variables have been touched.
1376
1377class IntExprArrayElementCt : public IntExprEvaluatorElementCt {
1378 public:
1379 IntExprArrayElementCt(Solver* const s, std::vector<IntVar*> vars,
1380 IntVar* const index, IntVar* const target_var);
1381
1382 std::string DebugString() const override;
1383 void Accept(ModelVisitor* const visitor) const override;
1384
1385 private:
1386 const std::vector<IntVar*> vars_;
1387};
1388
1389IntExprArrayElementCt::IntExprArrayElementCt(Solver* const s,
1390 std::vector<IntVar*> vars,
1391 IntVar* const index,
1392 IntVar* const target_var)
1393 : IntExprEvaluatorElementCt(
1394 s, [this](int64_t idx) { return vars_[idx]; }, 0, vars.size(), index,
1395 target_var),
1396 vars_(std::move(vars)) {}
1397
1398std::string IntExprArrayElementCt::DebugString() const {
1399 int64_t size = vars_.size();
1400 if (size > 10) {
1401 return absl::StrFormat(
1402 "IntExprArrayElement(var array of size %d, %s) == %s", size,
1403 index_->DebugString(), target_var_->DebugString());
1404 } else {
1405 return absl::StrFormat("IntExprArrayElement([%s], %s) == %s",
1406 JoinDebugStringPtr(vars_, ", "),
1407 index_->DebugString(), target_var_->DebugString());
1408 }
1409}
1410
1411void IntExprArrayElementCt::Accept(ModelVisitor* const visitor) const {
1412 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
1413 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1414 vars_);
1415 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument, index_);
1416 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1417 target_var_);
1418 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
1419}
1420
1421// ----- IntExprArrayElementCstCt -----
1422
1423// This constraint implements vars[index] == constant.
1424
1425class IntExprArrayElementCstCt : public Constraint {
1426 public:
1427 IntExprArrayElementCstCt(Solver* const s, const std::vector<IntVar*>& vars,
1428 IntVar* const index, int64_t target)
1429 : Constraint(s),
1430 vars_(vars),
1431 index_(index),
1432 target_(target),
1433 demons_(vars.size()) {}
1434
1435 ~IntExprArrayElementCstCt() override {}
1436
1437 void Post() override {
1438 for (int i = 0; i < vars_.size(); ++i) {
1439 demons_[i] = MakeConstraintDemon1(
1440 solver(), this, &IntExprArrayElementCstCt::Propagate, "Propagate", i);
1441 vars_[i]->WhenDomain(demons_[i]);
1442 }
1443 Demon* const index_demon = MakeConstraintDemon0(
1444 solver(), this, &IntExprArrayElementCstCt::PropagateIndex,
1445 "PropagateIndex");
1446 index_->WhenBound(index_demon);
1447 }
1448
1449 void InitialPropagate() override {
1450 for (int i = 0; i < vars_.size(); ++i) {
1451 Propagate(i);
1452 }
1453 PropagateIndex();
1454 }
1455
1456 void Propagate(int index) {
1457 if (!vars_[index]->Contains(target_)) {
1458 index_->RemoveValue(index);
1459 demons_[index]->inhibit(solver());
1460 }
1461 }
1462
1463 void PropagateIndex() {
1464 if (index_->Bound()) {
1465 vars_[index_->Min()]->SetValue(target_);
1466 }
1467 }
1468
1469 std::string DebugString() const override {
1470 return absl::StrFormat("IntExprArrayElement([%s], %s) == %d",
1471 JoinDebugStringPtr(vars_, ", "),
1472 index_->DebugString(), target_);
1473 }
1474
1475 void Accept(ModelVisitor* const visitor) const override {
1476 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
1477 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1478 vars_);
1479 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
1480 index_);
1481 visitor->VisitIntegerArgument(ModelVisitor::kTargetArgument, target_);
1482 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
1483 }
1484
1485 private:
1486 const std::vector<IntVar*> vars_;
1487 IntVar* const index_;
1488 const int64_t target_;
1489 std::vector<Demon*> demons_;
1490};
1491
1492// This constraint implements index == position(constant in vars).
1493
1494class IntExprIndexOfCt : public Constraint {
1495 public:
1496 IntExprIndexOfCt(Solver* const s, const std::vector<IntVar*>& vars,
1497 IntVar* const index, int64_t target)
1498 : Constraint(s),
1499 vars_(vars),
1500 index_(index),
1501 target_(target),
1502 demons_(vars_.size()),
1503 index_iterator_(index->MakeHoleIterator(true)) {}
1504
1505 ~IntExprIndexOfCt() override {}
1506
1507 void Post() override {
1508 for (int i = 0; i < vars_.size(); ++i) {
1509 demons_[i] = MakeConstraintDemon1(
1510 solver(), this, &IntExprIndexOfCt::Propagate, "Propagate", i);
1511 vars_[i]->WhenDomain(demons_[i]);
1512 }
1513 Demon* const index_demon = MakeConstraintDemon0(
1514 solver(), this, &IntExprIndexOfCt::PropagateIndex, "PropagateIndex");
1515 index_->WhenDomain(index_demon);
1516 }
1517
1518 void InitialPropagate() override {
1519 for (int i = 0; i < vars_.size(); ++i) {
1520 if (!index_->Contains(i)) {
1521 vars_[i]->RemoveValue(target_);
1522 } else if (!vars_[i]->Contains(target_)) {
1523 index_->RemoveValue(i);
1524 demons_[i]->inhibit(solver());
1525 } else if (vars_[i]->Bound()) {
1526 index_->SetValue(i);
1527 demons_[i]->inhibit(solver());
1528 }
1529 }
1530 }
1531
1532 void Propagate(int index) {
1533 if (!vars_[index]->Contains(target_)) {
1534 index_->RemoveValue(index);
1535 demons_[index]->inhibit(solver());
1536 } else if (vars_[index]->Bound()) {
1537 index_->SetValue(index);
1538 }
1539 }
1540
1541 void PropagateIndex() {
1542 const int64_t oldmax = index_->OldMax();
1543 const int64_t vmin = index_->Min();
1544 const int64_t vmax = index_->Max();
1545 for (int64_t value = index_->OldMin(); value < vmin; ++value) {
1546 vars_[value]->RemoveValue(target_);
1547 demons_[value]->inhibit(solver());
1548 }
1549 for (const int64_t value : InitAndGetValues(index_iterator_)) {
1550 vars_[value]->RemoveValue(target_);
1551 demons_[value]->inhibit(solver());
1552 }
1553 for (int64_t value = vmax + 1; value <= oldmax; ++value) {
1554 vars_[value]->RemoveValue(target_);
1555 demons_[value]->inhibit(solver());
1556 }
1557 if (index_->Bound()) {
1558 vars_[index_->Min()]->SetValue(target_);
1559 }
1560 }
1561
1562 std::string DebugString() const override {
1563 return absl::StrFormat("IntExprIndexOf([%s], %s) == %d",
1564 JoinDebugStringPtr(vars_, ", "),
1565 index_->DebugString(), target_);
1566 }
1567
1568 void Accept(ModelVisitor* const visitor) const override {
1569 visitor->BeginVisitConstraint(ModelVisitor::kIndexOf, this);
1570 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1571 vars_);
1572 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,
1573 index_);
1574 visitor->VisitIntegerArgument(ModelVisitor::kTargetArgument, target_);
1575 visitor->EndVisitConstraint(ModelVisitor::kIndexOf, this);
1576 }
1577
1578 private:
1579 const std::vector<IntVar*> vars_;
1580 IntVar* const index_;
1581 const int64_t target_;
1582 std::vector<Demon*> demons_;
1583 IntVarIterator* const index_iterator_;
1584};
1585
1586// Factory helper.
1587
1588Constraint* MakeElementEqualityFunc(Solver* const solver,
1589 const std::vector<int64_t>& vals,
1590 IntVar* const index, IntVar* const target) {
1591 if (index->Bound()) {
1592 const int64_t val = index->Min();
1593 if (val < 0 || val >= vals.size()) {
1594 return solver->MakeFalseConstraint();
1595 } else {
1596 return solver->MakeEquality(target, vals[val]);
1597 }
1598 } else {
1599 if (IsIncreasingContiguous(vals)) {
1600 return solver->MakeEquality(target, solver->MakeSum(index, vals[0]));
1601 } else {
1602 return solver->RevAlloc(
1603 new IntElementConstraint(solver, vals, index, target));
1604 }
1605 }
1606}
1607} // namespace
1608
1610 IntExpr* const then_expr,
1611 IntExpr* const else_expr,
1612 IntVar* const target_var) {
1613 return RevAlloc(
1614 new IfThenElseCt(this, condition, then_expr, else_expr, target_var));
1615}
1616
1617IntExpr* Solver::MakeElement(const std::vector<IntVar*>& vars,
1618 IntVar* const index) {
1619 if (index->Bound()) {
1620 return vars[index->Min()];
1621 }
1622 const int size = vars.size();
1623 if (AreAllBound(vars)) {
1624 std::vector<int64_t> values(size);
1625 for (int i = 0; i < size; ++i) {
1626 values[i] = vars[i]->Value();
1627 }
1628 return MakeElement(values, index);
1629 }
1630 if (index->Size() == 2 && index->Min() + 1 == index->Max() &&
1631 index->Min() >= 0 && index->Max() < vars.size()) {
1632 // Let's get the index between 0 and 1.
1633 IntVar* const scaled_index = MakeSum(index, -index->Min())->Var();
1634 IntVar* const zero = vars[index->Min()];
1635 IntVar* const one = vars[index->Max()];
1636 const std::string name = absl::StrFormat(
1637 "ElementVar([%s], %s)", JoinNamePtr(vars, ", "), index->name());
1638 IntVar* const target = MakeIntVar(std::min(zero->Min(), one->Min()),
1639 std::max(zero->Max(), one->Max()), name);
1641 RevAlloc(new IfThenElseCt(this, scaled_index, one, zero, target)));
1642 return target;
1643 }
1644 int64_t emin = std::numeric_limits<int64_t>::max();
1645 int64_t emax = std::numeric_limits<int64_t>::min();
1646 std::unique_ptr<IntVarIterator> iterator(index->MakeDomainIterator(false));
1647 for (const int64_t index_value : InitAndGetValues(iterator.get())) {
1648 if (index_value >= 0 && index_value < size) {
1649 emin = std::min(emin, vars[index_value]->Min());
1650 emax = std::max(emax, vars[index_value]->Max());
1651 }
1652 }
1653 const std::string vname =
1654 size > 10 ? absl::StrFormat("ElementVar(var array of size %d, %s)", size,
1655 index->DebugString())
1656 : absl::StrFormat("ElementVar([%s], %s)",
1657 JoinNamePtr(vars, ", "), index->name());
1658 IntVar* const element_var = MakeIntVar(emin, emax, vname);
1660 RevAlloc(new IntExprArrayElementCt(this, vars, index, element_var)));
1661 return element_var;
1662}
1663
1664IntExpr* Solver::MakeElement(Int64ToIntVar vars, int64_t range_start,
1665 int64_t range_end, IntVar* argument) {
1666 const std::string index_name =
1667 !argument->name().empty() ? argument->name() : argument->DebugString();
1668 const std::string vname = absl::StrFormat(
1669 "ElementVar(%s, %s)",
1670 StringifyInt64ToIntVar(vars, range_start, range_end), index_name);
1671 IntVar* const element_var =
1674 IntExprEvaluatorElementCt* evaluation_ct = new IntExprEvaluatorElementCt(
1675 this, std::move(vars), range_start, range_end, argument, element_var);
1676 AddConstraint(RevAlloc(evaluation_ct));
1677 evaluation_ct->Propagate();
1678 return element_var;
1679}
1680
1681Constraint* Solver::MakeElementEquality(const std::vector<int64_t>& vals,
1682 IntVar* const index,
1683 IntVar* const target) {
1684 return MakeElementEqualityFunc(this, vals, index, target);
1685}
1686
1687Constraint* Solver::MakeElementEquality(const std::vector<int>& vals,
1688 IntVar* const index,
1689 IntVar* const target) {
1690 return MakeElementEqualityFunc(this, ToInt64Vector(vals), index, target);
1691}
1692
1693Constraint* Solver::MakeElementEquality(const std::vector<IntVar*>& vars,
1694 IntVar* const index,
1695 IntVar* const target) {
1696 if (AreAllBound(vars)) {
1697 std::vector<int64_t> values(vars.size());
1698 for (int i = 0; i < vars.size(); ++i) {
1699 values[i] = vars[i]->Value();
1700 }
1701 return MakeElementEquality(values, index, target);
1702 }
1703 if (index->Bound()) {
1704 const int64_t val = index->Min();
1705 if (val < 0 || val >= vars.size()) {
1706 return MakeFalseConstraint();
1707 } else {
1708 return MakeEquality(target, vars[val]);
1709 }
1710 } else {
1711 if (target->Bound()) {
1712 return RevAlloc(
1713 new IntExprArrayElementCstCt(this, vars, index, target->Min()));
1714 } else {
1715 return RevAlloc(new IntExprArrayElementCt(this, vars, index, target));
1716 }
1717 }
1718}
1719
1720Constraint* Solver::MakeElementEquality(const std::vector<IntVar*>& vars,
1721 IntVar* const index, int64_t target) {
1722 if (AreAllBound(vars)) {
1723 std::vector<int> valid_indices;
1724 for (int i = 0; i < vars.size(); ++i) {
1725 if (vars[i]->Value() == target) {
1726 valid_indices.push_back(i);
1727 }
1728 }
1729 return MakeMemberCt(index, valid_indices);
1730 }
1731 if (index->Bound()) {
1732 const int64_t pos = index->Min();
1733 if (pos >= 0 && pos < vars.size()) {
1734 IntVar* const var = vars[pos];
1735 return MakeEquality(var, target);
1736 } else {
1737 return MakeFalseConstraint();
1738 }
1739 } else {
1740 return RevAlloc(new IntExprArrayElementCstCt(this, vars, index, target));
1741 }
1742}
1743
1744Constraint* Solver::MakeIndexOfConstraint(const std::vector<IntVar*>& vars,
1745 IntVar* const index, int64_t target) {
1746 if (index->Bound()) {
1747 const int64_t pos = index->Min();
1748 if (pos >= 0 && pos < vars.size()) {
1749 IntVar* const var = vars[pos];
1750 return MakeEquality(var, target);
1751 } else {
1752 return MakeFalseConstraint();
1753 }
1754 } else {
1755 return RevAlloc(new IntExprIndexOfCt(this, vars, index, target));
1756 }
1757}
1758
1759IntExpr* Solver::MakeIndexExpression(const std::vector<IntVar*>& vars,
1760 int64_t value) {
1761 IntExpr* const cache = model_cache_->FindVarArrayConstantExpression(
1763 if (cache != nullptr) {
1764 return cache->Var();
1765 } else {
1766 const std::string name =
1767 absl::StrFormat("Index(%s, %d)", JoinNamePtr(vars, ", "), value);
1768 IntVar* const index = MakeIntVar(0, vars.size() - 1, name);
1770 model_cache_->InsertVarArrayConstantExpression(
1772 return index;
1773 }
1774}
1775} // namespace operations_research
const std::vector< IntVar * > vars_
Definition: alldiff_cst.cc:44
int64_t max
Definition: alldiff_cst.cc:140
int64_t min
Definition: alldiff_cst.cc:139
#define CHECK(condition)
Definition: base/logging.h:491
#define DCHECK_LE(val1, val2)
Definition: base/logging.h:888
#define CHECK_EQ(val1, val2)
Definition: base/logging.h:698
#define DCHECK_GE(val1, val2)
Definition: base/logging.h:890
#define DCHECK_GT(val1, val2)
Definition: base/logging.h:891
#define DCHECK_LT(val1, val2)
Definition: base/logging.h:889
#define DCHECK(condition)
Definition: base/logging.h:885
#define DCHECK_EQ(val1, val2)
Definition: base/logging.h:886
Cast constraints are special channeling constraints designed to keep a variable in sync with an expre...
A constraint is the main modeling object.
A Demon is the base element of a propagation queue.
void Post() override
This method is called when the constraint is processed by the solver.
Definition: element.cc:1137
void InitialPropagate() override
This method performs the initial propagation of the constraint.
Definition: element.cc:1145
IfThenElseCt(Solver *const solver, IntVar *const condition, IntExpr *const one, IntExpr *const zero, IntVar *const target)
Definition: element.cc:1128
void Accept(ModelVisitor *const visitor) const override
Accepts the given visitor.
Definition: element.cc:1186
std::string DebugString() const override
Definition: element.cc:1180
Utility class to encapsulate an IntVarIterator and use it in a range-based loop.
The class IntExpr is the base of all integer expressions in constraint programming.
virtual void SetRange(int64_t l, int64_t u)
This method sets both the min and the max of the expression.
virtual bool Bound() const
Returns true if the min and the max of the expression are equal.
virtual void SetValue(int64_t v)
This method sets the value of the expression.
virtual int64_t Min() const =0
virtual IntVar * Var()=0
Creates a variable from the expression.
virtual int64_t Max() const =0
virtual void Range(int64_t *l, int64_t *u)
By default calls Min() and Max(), but can be redefined when Min and Max code can be factorized.
virtual void WhenRange(Demon *d)=0
Attach a demon that will watch the min or the max of the expression.
The class IntVar is a subset of IntExpr.
virtual bool Contains(int64_t v) const =0
This method returns whether the value 'v' is in the domain of the variable.
virtual void WhenBound(Demon *d)=0
This method attaches a demon that will be awakened when the variable is bound.
virtual std::string name() const
Object naming.
IntExpr * MakeIndexExpression(const std::vector< IntVar * > &vars, int64_t value)
Returns the expression expr such that vars[expr] == value.
Definition: element.cc:1759
IntExpr * RegisterIntExpr(IntExpr *const expr)
Registers a new IntExpr and wraps it inside a TraceIntExpr if necessary.
Definition: trace.cc:849
Constraint * MakeFalseConstraint()
This constraint always fails.
Definition: constraints.cc:523
Constraint * MakeEquality(IntExpr *const left, IntExpr *const right)
left == right
Definition: range_cst.cc:512
Constraint * MakeElementEquality(const std::vector< int64_t > &vals, IntVar *const index, IntVar *const target)
Definition: element.cc:1681
IntVar * MakeIntVar(int64_t min, int64_t max, const std::string &name)
MakeIntVar will create the best range based int var for the bounds given.
Constraint * MakeMemberCt(IntExpr *const expr, const std::vector< int64_t > &values)
expr in set.
Definition: expr_cst.cc:1163
std::function< int64_t(int64_t, int64_t)> IndexEvaluator2
void AddConstraint(Constraint *const c)
Adds the constraint 'c' to the model.
IntExpr * MakeOpposite(IntExpr *const expr)
-expr
Constraint * MakeIfThenElseCt(IntVar *const condition, IntExpr *const then_expr, IntExpr *const else_expr, IntVar *const target_var)
Special cases with arrays of size two.
Definition: element.cc:1609
Demon * MakeConstraintInitialPropagateCallback(Constraint *const ct)
This method is a specialized case of the MakeConstraintDemon method to call the InitiatePropagate of ...
Definition: constraints.cc:35
Constraint * MakeIndexOfConstraint(const std::vector< IntVar * > &vars, IntVar *const index, int64_t target)
This constraint is a special case of the element constraint with an array of integer variables,...
Definition: element.cc:1744
IntExpr * MakeElement(const std::vector< int64_t > &values, IntVar *const index)
values[index]
Definition: element.cc:657
IntExpr * MakeSum(IntExpr *const left, IntExpr *const right)
left + right.
std::function< int64_t(int64_t)> IndexEvaluator1
Callback typedefs.
std::function< IntVar *(int64_t)> Int64ToIntVar
T * RevAlloc(T *object)
Registers the given object as being reversible.
IntExpr * MakeMonotonicElement(IndexEvaluator1 values, bool increasing, IntVar *const index)
Function based element.
Definition: element.cc:869
int64_t b
std::vector< int64_t > to_remove_
const std::string name
int64_t value
#define UPDATE_ELEMENT_INDEX_BOUNDS(test)
Definition: element.cc:988
IntVar *const expr_
Definition: element.cc:87
ABSL_FLAG(bool, cp_disable_element_cache, true, "If true, caching for IntElement is disabled.")
#define UPDATE_BASE_ELEMENT_INDEX_BOUNDS(test)
Definition: element.cc:129
#define UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(test)
Definition: element.cc:423
IntVar * var
Definition: expr_array.cc:1874
double upper_bound
double lower_bound
std::function< int64_t(const Model &)> Value(IntegerVariable v)
Definition: integer.h:1544
Collection of objects used to extend the Constraint Solver library.
bool IsArrayConstant(const std::vector< T > &values, const T &value)
bool IsIncreasing(const std::vector< T > &values)
bool IsArrayBoolean(const std::vector< T > &values)
Demon * MakeDelayedConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
std::string JoinDebugStringPtr(const std::vector< T > &v, const std::string &separator)
Definition: string_array.h:45
Demon * MakeConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
bool IsIncreasingContiguous(const std::vector< T > &values)
std::vector< int64_t > ToInt64Vector(const std::vector< int > &input)
Definition: utilities.cc:828
void LinkVarExpr(Solver *const s, IntExpr *const expr, IntVar *const var)
Demon * MakeConstraintDemon1(Solver *const s, T *const ct, void(T::*method)(P), const std::string &name, P param1)
bool AreAllBound(const std::vector< IntVar * > &vars)
std::string JoinNamePtr(const std::vector< T > &v, const std::string &separator)
Definition: string_array.h:52
STL namespace.
int index
Definition: pack.cc:509
IntervalVar *const target_var_
std::function< int64_t(int64_t, int64_t)> evaluator_
Definition: search.cc:1368