OR-Tools  9.2
expressions.cc
Go to the documentation of this file.
1// Copyright 2010-2021 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14#include <algorithm>
15#include <cmath>
16#include <cstdint>
17#include <limits>
18#include <memory>
19#include <string>
20#include <utility>
21#include <vector>
22
23#include "absl/container/flat_hash_map.h"
24#include "absl/strings/str_cat.h"
25#include "absl/strings/str_format.h"
34#include "ortools/util/bitset.h"
37
38ABSL_FLAG(bool, cp_disable_expression_optimization, false,
39 "Disable special optimization when creating expressions.");
40ABSL_FLAG(bool, cp_share_int_consts, true,
41 "Share IntConst's with the same value.");
42
43#if defined(_MSC_VER)
44#pragma warning(disable : 4351 4355)
45#endif
46
47namespace operations_research {
48
49// ---------- IntExpr ----------
50
51IntVar* IntExpr::VarWithName(const std::string& name) {
52 IntVar* const var = Var();
53 var->set_name(name);
54 return var;
55}
56
57// ---------- IntVar ----------
58
59IntVar::IntVar(Solver* const s) : IntExpr(s), index_(s->GetNewIntVarIndex()) {}
60
61IntVar::IntVar(Solver* const s, const std::string& name)
62 : IntExpr(s), index_(s->GetNewIntVarIndex()) {
64}
65
66// ----- Boolean variable -----
67
69
70void BooleanVar::SetMin(int64_t m) {
71 if (m <= 0) return;
72 if (m > 1) solver()->Fail();
73 SetValue(1);
74}
75
76void BooleanVar::SetMax(int64_t m) {
77 if (m >= 1) return;
78 if (m < 0) solver()->Fail();
79 SetValue(0);
80}
81
82void BooleanVar::SetRange(int64_t mi, int64_t ma) {
83 if (mi > 1 || ma < 0 || mi > ma) {
84 solver()->Fail();
85 }
86 if (mi == 1) {
87 SetValue(1);
88 } else if (ma == 0) {
89 SetValue(0);
90 }
91}
92
93void BooleanVar::RemoveValue(int64_t v) {
95 if (v == 0) {
96 SetValue(1);
97 } else if (v == 1) {
98 SetValue(0);
99 }
100 } else if (v == value_) {
101 solver()->Fail();
102 }
103}
104
105void BooleanVar::RemoveInterval(int64_t l, int64_t u) {
106 if (u < l) return;
107 if (l <= 0 && u >= 1) {
108 solver()->Fail();
109 } else if (l == 1) {
110 SetValue(0);
111 } else if (u == 0) {
112 SetValue(1);
113 }
114}
115
119 delayed_bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
120 } else {
121 bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
122 }
123 }
124}
125
126uint64_t BooleanVar::Size() const {
127 return (1 + (value_ == kUnboundBooleanVarValue));
128}
129
130bool BooleanVar::Contains(int64_t v) const {
131 return ((v == 0 && value_ != 1) || (v == 1 && value_ != 0));
132}
133
135 if (constant > 1 || constant < 0) {
136 return solver()->MakeIntConst(0);
137 }
138 if (constant == 1) {
139 return this;
140 } else { // constant == 0.
141 return solver()->MakeDifference(1, this)->Var();
142 }
143}
144
146 if (constant > 1 || constant < 0) {
147 return solver()->MakeIntConst(1);
148 }
149 if (constant == 1) {
150 return solver()->MakeDifference(1, this)->Var();
151 } else { // constant == 0.
152 return this;
153 }
154}
155
157 if (constant > 1) {
158 return solver()->MakeIntConst(0);
159 } else if (constant <= 0) {
160 return solver()->MakeIntConst(1);
161 } else {
162 return this;
163 }
164}
165
167 if (constant < 0) {
168 return solver()->MakeIntConst(0);
169 } else if (constant >= 1) {
170 return solver()->MakeIntConst(1);
171 } else {
172 return IsEqual(0);
173 }
174}
175
176std::string BooleanVar::DebugString() const {
177 std::string out;
178 const std::string& var_name = name();
179 if (!var_name.empty()) {
180 out = var_name + "(";
181 } else {
182 out = "BooleanVar(";
183 }
184 switch (value_) {
185 case 0:
186 out += "0";
187 break;
188 case 1:
189 out += "1";
190 break;
192 out += "0 .. 1";
193 break;
194 }
195 out += ")";
196 return out;
197}
198
199namespace {
200// ---------- Subclasses of IntVar ----------
201
202// ----- Domain Int Var: base class for variables -----
203// It Contains bounds and a bitset representation of possible values.
204class DomainIntVar : public IntVar {
205 public:
206 // Utility classes
207 class BitSetIterator : public BaseObject {
208 public:
209 BitSetIterator(uint64_t* const bitset, int64_t omin)
210 : bitset_(bitset),
211 omin_(omin),
212 max_(std::numeric_limits<int64_t>::min()),
213 current_(std::numeric_limits<int64_t>::max()) {}
214
215 ~BitSetIterator() override {}
216
217 void Init(int64_t min, int64_t max) {
218 max_ = max;
219 current_ = min;
220 }
221
222 bool Ok() const { return current_ <= max_; }
223
224 int64_t Value() const { return current_; }
225
226 void Next() {
227 if (++current_ <= max_) {
229 bitset_, current_ - omin_, max_ - omin_) +
230 omin_;
231 }
232 }
233
234 std::string DebugString() const override { return "BitSetIterator"; }
235
236 private:
237 uint64_t* const bitset_;
238 const int64_t omin_;
239 int64_t max_;
240 int64_t current_;
241 };
242
243 class BitSet : public BaseObject {
244 public:
245 explicit BitSet(Solver* const s) : solver_(s), holes_stamp_(0) {}
246 ~BitSet() override {}
247
248 virtual int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) = 0;
249 virtual int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) = 0;
250 virtual bool Contains(int64_t val) const = 0;
251 virtual bool SetValue(int64_t val) = 0;
252 virtual bool RemoveValue(int64_t val) = 0;
253 virtual uint64_t Size() const = 0;
254 virtual void DelayRemoveValue(int64_t val) = 0;
255 virtual void ApplyRemovedValues(DomainIntVar* var) = 0;
256 virtual void ClearRemovedValues() = 0;
257 virtual std::string pretty_DebugString(int64_t min, int64_t max) const = 0;
258 virtual BitSetIterator* MakeIterator() = 0;
259
260 void InitHoles() {
261 const uint64_t current_stamp = solver_->stamp();
262 if (holes_stamp_ < current_stamp) {
263 holes_.clear();
264 holes_stamp_ = current_stamp;
265 }
266 }
267
268 virtual void ClearHoles() { holes_.clear(); }
269
270 const std::vector<int64_t>& Holes() { return holes_; }
271
272 void AddHole(int64_t value) { holes_.push_back(value); }
273
274 int NumHoles() const {
275 return holes_stamp_ < solver_->stamp() ? 0 : holes_.size();
276 }
277
278 protected:
279 Solver* const solver_;
280
281 private:
282 std::vector<int64_t> holes_;
283 uint64_t holes_stamp_;
284 };
285
286 class QueueHandler : public Demon {
287 public:
288 explicit QueueHandler(DomainIntVar* const var) : var_(var) {}
289 ~QueueHandler() override {}
290 void Run(Solver* const s) override {
291 s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
292 var_->Process();
293 s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
294 }
295 Solver::DemonPriority priority() const override {
297 }
298 std::string DebugString() const override {
299 return absl::StrFormat("Handler(%s)", var_->DebugString());
300 }
301
302 private:
303 DomainIntVar* const var_;
304 };
305
306 // Bounds and Value watchers
307
308 // This class stores the watchers variables attached to values. It is
309 // reversible and it helps maintaining the set of 'active' watchers
310 // (variables not bound to a single value).
311 template <class T>
312 class RevIntPtrMap {
313 public:
314 RevIntPtrMap(Solver* const solver, int64_t rmin, int64_t rmax)
315 : solver_(solver), range_min_(rmin), start_(0) {}
316
317 ~RevIntPtrMap() {}
318
319 bool Empty() const { return start_.Value() == elements_.size(); }
320
321 void SortActive() { std::sort(elements_.begin(), elements_.end()); }
322
323 // Access with value API.
324
325 // Add the pointer to the map attached to the given value.
326 void UnsafeRevInsert(int64_t value, T* elem) {
327 elements_.push_back(std::make_pair(value, elem));
328 if (solver_->state() != Solver::OUTSIDE_SEARCH) {
329 solver_->AddBacktrackAction(
330 [this, value](Solver* s) { Uninsert(value); }, false);
331 }
332 }
333
334 T* FindPtrOrNull(int64_t value, int* position) {
335 for (int pos = start_.Value(); pos < elements_.size(); ++pos) {
336 if (elements_[pos].first == value) {
337 if (position != nullptr) *position = pos;
338 return At(pos).second;
339 }
340 }
341 return nullptr;
342 }
343
344 // Access map through the underlying vector.
345 void RemoveAt(int position) {
346 const int start = start_.Value();
347 DCHECK_GE(position, start);
348 DCHECK_LT(position, elements_.size());
349 if (position > start) {
350 // Swap the current element with the one at the start position, and
351 // increase start.
352 const std::pair<int64_t, T*> copy = elements_[start];
353 elements_[start] = elements_[position];
354 elements_[position] = copy;
355 }
356 start_.Incr(solver_);
357 }
358
359 const std::pair<int64_t, T*>& At(int position) const {
360 DCHECK_GE(position, start_.Value());
361 DCHECK_LT(position, elements_.size());
362 return elements_[position];
363 }
364
365 void RemoveAll() { start_.SetValue(solver_, elements_.size()); }
366
367 int start() const { return start_.Value(); }
368 int end() const { return elements_.size(); }
369 // Number of active elements.
370 int Size() const { return elements_.size() - start_.Value(); }
371
372 // Removes the object permanently from the map.
373 void Uninsert(int64_t value) {
374 for (int pos = 0; pos < elements_.size(); ++pos) {
375 if (elements_[pos].first == value) {
376 DCHECK_GE(pos, start_.Value());
377 const int last = elements_.size() - 1;
378 if (pos != last) { // Swap the current with the last.
379 elements_[pos] = elements_.back();
380 }
381 elements_.pop_back();
382 return;
383 }
384 }
385 LOG(FATAL) << "The element should have been removed";
386 }
387
388 private:
389 Solver* const solver_;
390 const int64_t range_min_;
391 NumericalRev<int> start_;
392 std::vector<std::pair<int64_t, T*>> elements_;
393 };
394
395 // Base class for value watchers
396 class BaseValueWatcher : public Constraint {
397 public:
398 explicit BaseValueWatcher(Solver* const solver) : Constraint(solver) {}
399
400 ~BaseValueWatcher() override {}
401
402 virtual IntVar* GetOrMakeValueWatcher(int64_t value) = 0;
403
404 virtual void SetValueWatcher(IntVar* const boolvar, int64_t value) = 0;
405 };
406
407 // This class monitors the domain of the variable and updates the
408 // IsEqual/IsDifferent boolean variables accordingly.
409 class ValueWatcher : public BaseValueWatcher {
410 public:
411 class WatchDemon : public Demon {
412 public:
413 WatchDemon(ValueWatcher* const watcher, int64_t value, IntVar* var)
414 : value_watcher_(watcher), value_(value), var_(var) {}
415 ~WatchDemon() override {}
416
417 void Run(Solver* const solver) override {
418 value_watcher_->ProcessValueWatcher(value_, var_);
419 }
420
421 private:
422 ValueWatcher* const value_watcher_;
423 const int64_t value_;
424 IntVar* const var_;
425 };
426
427 class VarDemon : public Demon {
428 public:
429 explicit VarDemon(ValueWatcher* const watcher)
430 : value_watcher_(watcher) {}
431
432 ~VarDemon() override {}
433
434 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
435
436 private:
437 ValueWatcher* const value_watcher_;
438 };
439
440 ValueWatcher(Solver* const solver, DomainIntVar* const variable)
441 : BaseValueWatcher(solver),
442 variable_(variable),
443 hole_iterator_(variable_->MakeHoleIterator(true)),
444 var_demon_(nullptr),
445 watchers_(solver, variable->Min(), variable->Max()) {}
446
447 ~ValueWatcher() override {}
448
449 IntVar* GetOrMakeValueWatcher(int64_t value) override {
450 IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
451 if (watcher != nullptr) return watcher;
452 if (variable_->Contains(value)) {
453 if (variable_->Bound()) {
454 return solver()->MakeIntConst(1);
455 } else {
456 const std::string vname = variable_->HasName()
457 ? variable_->name()
458 : variable_->DebugString();
459 const std::string bname =
460 absl::StrFormat("Watch<%s == %d>", vname, value);
461 IntVar* const boolvar = solver()->MakeBoolVar(bname);
462 watchers_.UnsafeRevInsert(value, boolvar);
463 if (posted_.Switched()) {
464 boolvar->WhenBound(
465 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
466 var_demon_->desinhibit(solver());
467 }
468 return boolvar;
469 }
470 } else {
471 return variable_->solver()->MakeIntConst(0);
472 }
473 }
474
475 void SetValueWatcher(IntVar* const boolvar, int64_t value) override {
476 CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
477 if (!boolvar->Bound()) {
478 watchers_.UnsafeRevInsert(value, boolvar);
479 if (posted_.Switched() && !boolvar->Bound()) {
480 boolvar->WhenBound(
481 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
482 var_demon_->desinhibit(solver());
483 }
484 }
485 }
486
487 void Post() override {
488 var_demon_ = solver()->RevAlloc(new VarDemon(this));
489 variable_->WhenDomain(var_demon_);
490 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
491 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
492 const int64_t value = w.first;
493 IntVar* const boolvar = w.second;
494 if (!boolvar->Bound() && variable_->Contains(value)) {
495 boolvar->WhenBound(
496 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
497 }
498 }
499 posted_.Switch(solver());
500 }
501
502 void InitialPropagate() override {
503 if (variable_->Bound()) {
504 VariableBound();
505 } else {
506 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
507 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
508 const int64_t value = w.first;
509 IntVar* const boolvar = w.second;
510 if (!variable_->Contains(value)) {
511 boolvar->SetValue(0);
512 watchers_.RemoveAt(pos);
513 } else {
514 if (boolvar->Bound()) {
515 ProcessValueWatcher(value, boolvar);
516 watchers_.RemoveAt(pos);
517 }
518 }
519 }
520 CheckInhibit();
521 }
522 }
523
524 void ProcessValueWatcher(int64_t value, IntVar* boolvar) {
525 if (boolvar->Min() == 0) {
526 if (variable_->Size() < 0xFFFFFF) {
527 variable_->RemoveValue(value);
528 } else {
529 // Delay removal.
530 solver()->AddConstraint(solver()->MakeNonEquality(variable_, value));
531 }
532 } else {
533 variable_->SetValue(value);
534 }
535 }
536
537 void ProcessVar() {
538 const int kSmallList = 16;
539 if (variable_->Bound()) {
540 VariableBound();
541 } else if (watchers_.Size() <= kSmallList ||
542 variable_->Min() != variable_->OldMin() ||
543 variable_->Max() != variable_->OldMax()) {
544 // Brute force loop for small numbers of watchers, or if the bounds have
545 // changed, which would have required a sort (n log(n)) anyway to take
546 // advantage of.
547 ScanWatchers();
548 CheckInhibit();
549 } else {
550 // If there is no bitset, then there are no holes.
551 // In that case, the two loops above should have performed all
552 // propagation. Otherwise, scan the remaining watchers.
553 BitSet* const bitset = variable_->bitset();
554 if (bitset != nullptr && !watchers_.Empty()) {
555 if (bitset->NumHoles() * 2 < watchers_.Size()) {
556 for (const int64_t hole : InitAndGetValues(hole_iterator_)) {
557 int pos = 0;
558 IntVar* const boolvar = watchers_.FindPtrOrNull(hole, &pos);
559 if (boolvar != nullptr) {
560 boolvar->SetValue(0);
561 watchers_.RemoveAt(pos);
562 }
563 }
564 } else {
565 ScanWatchers();
566 }
567 }
568 CheckInhibit();
569 }
570 }
571
572 // Optimized case if the variable is bound.
573 void VariableBound() {
574 DCHECK(variable_->Bound());
575 const int64_t value = variable_->Min();
576 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
577 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
578 w.second->SetValue(w.first == value);
579 }
580 watchers_.RemoveAll();
581 var_demon_->inhibit(solver());
582 }
583
584 // Scans all the watchers to check and assign them.
585 void ScanWatchers() {
586 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
587 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
588 if (!variable_->Contains(w.first)) {
589 IntVar* const boolvar = w.second;
590 boolvar->SetValue(0);
591 watchers_.RemoveAt(pos);
592 }
593 }
594 }
595
596 // If the set of active watchers is empty, we can inhibit the demon on the
597 // main variable.
598 void CheckInhibit() {
599 if (watchers_.Empty()) {
600 var_demon_->inhibit(solver());
601 }
602 }
603
604 void Accept(ModelVisitor* const visitor) const override {
605 visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
606 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
607 variable_);
608 std::vector<int64_t> all_coefficients;
609 std::vector<IntVar*> all_bool_vars;
610 for (int position = watchers_.start(); position < watchers_.end();
611 ++position) {
612 const std::pair<int64_t, IntVar*>& w = watchers_.At(position);
613 all_coefficients.push_back(w.first);
614 all_bool_vars.push_back(w.second);
615 }
616 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
617 all_bool_vars);
618 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
619 all_coefficients);
620 visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
621 }
622
623 std::string DebugString() const override {
624 return absl::StrFormat("ValueWatcher(%s)", variable_->DebugString());
625 }
626
627 private:
628 DomainIntVar* const variable_;
629 IntVarIterator* const hole_iterator_;
630 RevSwitch posted_;
631 Demon* var_demon_;
632 RevIntPtrMap<IntVar> watchers_;
633 };
634
635 // Optimized case for small maps.
636 class DenseValueWatcher : public BaseValueWatcher {
637 public:
638 class WatchDemon : public Demon {
639 public:
640 WatchDemon(DenseValueWatcher* const watcher, int64_t value, IntVar* var)
641 : value_watcher_(watcher), value_(value), var_(var) {}
642 ~WatchDemon() override {}
643
644 void Run(Solver* const solver) override {
645 value_watcher_->ProcessValueWatcher(value_, var_);
646 }
647
648 private:
649 DenseValueWatcher* const value_watcher_;
650 const int64_t value_;
651 IntVar* const var_;
652 };
653
654 class VarDemon : public Demon {
655 public:
656 explicit VarDemon(DenseValueWatcher* const watcher)
657 : value_watcher_(watcher) {}
658
659 ~VarDemon() override {}
660
661 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
662
663 private:
664 DenseValueWatcher* const value_watcher_;
665 };
666
667 DenseValueWatcher(Solver* const solver, DomainIntVar* const variable)
668 : BaseValueWatcher(solver),
669 variable_(variable),
670 hole_iterator_(variable_->MakeHoleIterator(true)),
671 var_demon_(nullptr),
672 offset_(variable->Min()),
673 watchers_(variable->Max() - variable->Min() + 1, nullptr),
674 active_watchers_(0) {}
675
676 ~DenseValueWatcher() override {}
677
678 IntVar* GetOrMakeValueWatcher(int64_t value) override {
679 const int64_t var_max = offset_ + watchers_.size() - 1; // Bad cast.
680 if (value < offset_ || value > var_max) {
681 return solver()->MakeIntConst(0);
682 }
683 const int index = value - offset_;
684 IntVar* const watcher = watchers_[index];
685 if (watcher != nullptr) return watcher;
686 if (variable_->Contains(value)) {
687 if (variable_->Bound()) {
688 return solver()->MakeIntConst(1);
689 } else {
690 const std::string vname = variable_->HasName()
691 ? variable_->name()
692 : variable_->DebugString();
693 const std::string bname =
694 absl::StrFormat("Watch<%s == %d>", vname, value);
695 IntVar* const boolvar = solver()->MakeBoolVar(bname);
696 RevInsert(index, boolvar);
697 if (posted_.Switched()) {
698 boolvar->WhenBound(
699 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
700 var_demon_->desinhibit(solver());
701 }
702 return boolvar;
703 }
704 } else {
705 return variable_->solver()->MakeIntConst(0);
706 }
707 }
708
709 void SetValueWatcher(IntVar* const boolvar, int64_t value) override {
710 const int index = value - offset_;
711 CHECK(watchers_[index] == nullptr);
712 if (!boolvar->Bound()) {
713 RevInsert(index, boolvar);
714 if (posted_.Switched() && !boolvar->Bound()) {
715 boolvar->WhenBound(
716 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
717 var_demon_->desinhibit(solver());
718 }
719 }
720 }
721
722 void Post() override {
723 var_demon_ = solver()->RevAlloc(new VarDemon(this));
724 variable_->WhenDomain(var_demon_);
725 for (int pos = 0; pos < watchers_.size(); ++pos) {
726 const int64_t value = pos + offset_;
727 IntVar* const boolvar = watchers_[pos];
728 if (boolvar != nullptr && !boolvar->Bound() &&
729 variable_->Contains(value)) {
730 boolvar->WhenBound(
731 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
732 }
733 }
734 posted_.Switch(solver());
735 }
736
737 void InitialPropagate() override {
738 if (variable_->Bound()) {
739 VariableBound();
740 } else {
741 for (int pos = 0; pos < watchers_.size(); ++pos) {
742 IntVar* const boolvar = watchers_[pos];
743 if (boolvar == nullptr) continue;
744 const int64_t value = pos + offset_;
745 if (!variable_->Contains(value)) {
746 boolvar->SetValue(0);
747 RevRemove(pos);
748 } else if (boolvar->Bound()) {
749 ProcessValueWatcher(value, boolvar);
750 RevRemove(pos);
751 }
752 }
753 if (active_watchers_.Value() == 0) {
754 var_demon_->inhibit(solver());
755 }
756 }
757 }
758
759 void ProcessValueWatcher(int64_t value, IntVar* boolvar) {
760 if (boolvar->Min() == 0) {
761 variable_->RemoveValue(value);
762 } else {
763 variable_->SetValue(value);
764 }
765 }
766
767 void ProcessVar() {
768 if (variable_->Bound()) {
769 VariableBound();
770 } else {
771 // Brute force loop for small numbers of watchers.
772 ScanWatchers();
773 if (active_watchers_.Value() == 0) {
774 var_demon_->inhibit(solver());
775 }
776 }
777 }
778
779 // Optimized case if the variable is bound.
780 void VariableBound() {
781 DCHECK(variable_->Bound());
782 const int64_t value = variable_->Min();
783 for (int pos = 0; pos < watchers_.size(); ++pos) {
784 IntVar* const boolvar = watchers_[pos];
785 if (boolvar != nullptr) {
786 boolvar->SetValue(pos + offset_ == value);
787 RevRemove(pos);
788 }
789 }
790 var_demon_->inhibit(solver());
791 }
792
793 // Scans all the watchers to check and assign them.
794 void ScanWatchers() {
795 const int64_t old_min_index = variable_->OldMin() - offset_;
796 const int64_t old_max_index = variable_->OldMax() - offset_;
797 const int64_t min_index = variable_->Min() - offset_;
798 const int64_t max_index = variable_->Max() - offset_;
799 for (int pos = old_min_index; pos < min_index; ++pos) {
800 IntVar* const boolvar = watchers_[pos];
801 if (boolvar != nullptr) {
802 boolvar->SetValue(0);
803 RevRemove(pos);
804 }
805 }
806 for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
807 IntVar* const boolvar = watchers_[pos];
808 if (boolvar != nullptr) {
809 boolvar->SetValue(0);
810 RevRemove(pos);
811 }
812 }
813 BitSet* const bitset = variable_->bitset();
814 if (bitset != nullptr) {
815 if (bitset->NumHoles() * 2 < active_watchers_.Value()) {
816 for (const int64_t hole : InitAndGetValues(hole_iterator_)) {
817 IntVar* const boolvar = watchers_[hole - offset_];
818 if (boolvar != nullptr) {
819 boolvar->SetValue(0);
820 RevRemove(hole - offset_);
821 }
822 }
823 } else {
824 for (int pos = min_index + 1; pos < max_index; ++pos) {
825 IntVar* const boolvar = watchers_[pos];
826 if (boolvar != nullptr && !variable_->Contains(offset_ + pos)) {
827 boolvar->SetValue(0);
828 RevRemove(pos);
829 }
830 }
831 }
832 }
833 }
834
835 void RevRemove(int pos) {
836 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
837 watchers_[pos] = nullptr;
838 active_watchers_.Decr(solver());
839 }
840
841 void RevInsert(int pos, IntVar* boolvar) {
842 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
843 watchers_[pos] = boolvar;
844 active_watchers_.Incr(solver());
845 }
846
847 void Accept(ModelVisitor* const visitor) const override {
848 visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
849 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
850 variable_);
851 std::vector<int64_t> all_coefficients;
852 std::vector<IntVar*> all_bool_vars;
853 for (int position = 0; position < watchers_.size(); ++position) {
854 if (watchers_[position] != nullptr) {
855 all_coefficients.push_back(position + offset_);
856 all_bool_vars.push_back(watchers_[position]);
857 }
858 }
859 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
860 all_bool_vars);
861 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
862 all_coefficients);
863 visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
864 }
865
866 std::string DebugString() const override {
867 return absl::StrFormat("DenseValueWatcher(%s)", variable_->DebugString());
868 }
869
870 private:
871 DomainIntVar* const variable_;
872 IntVarIterator* const hole_iterator_;
873 RevSwitch posted_;
874 Demon* var_demon_;
875 const int64_t offset_;
876 std::vector<IntVar*> watchers_;
877 NumericalRev<int> active_watchers_;
878 };
879
880 class BaseUpperBoundWatcher : public Constraint {
881 public:
882 explicit BaseUpperBoundWatcher(Solver* const solver) : Constraint(solver) {}
883
884 ~BaseUpperBoundWatcher() override {}
885
886 virtual IntVar* GetOrMakeUpperBoundWatcher(int64_t value) = 0;
887
888 virtual void SetUpperBoundWatcher(IntVar* const boolvar, int64_t value) = 0;
889 };
890
891 // This class watches the bounds of the variable and updates the
892 // IsGreater/IsGreaterOrEqual/IsLess/IsLessOrEqual demons
893 // accordingly.
894 class UpperBoundWatcher : public BaseUpperBoundWatcher {
895 public:
896 class WatchDemon : public Demon {
897 public:
898 WatchDemon(UpperBoundWatcher* const watcher, int64_t index,
899 IntVar* const var)
900 : value_watcher_(watcher), index_(index), var_(var) {}
901 ~WatchDemon() override {}
902
903 void Run(Solver* const solver) override {
904 value_watcher_->ProcessUpperBoundWatcher(index_, var_);
905 }
906
907 private:
908 UpperBoundWatcher* const value_watcher_;
909 const int64_t index_;
910 IntVar* const var_;
911 };
912
913 class VarDemon : public Demon {
914 public:
915 explicit VarDemon(UpperBoundWatcher* const watcher)
916 : value_watcher_(watcher) {}
917 ~VarDemon() override {}
918
919 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
920
921 private:
922 UpperBoundWatcher* const value_watcher_;
923 };
924
925 UpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
926 : BaseUpperBoundWatcher(solver),
927 variable_(variable),
928 var_demon_(nullptr),
929 watchers_(solver, variable->Min(), variable->Max()),
930 start_(0),
931 end_(0),
932 sorted_(false) {}
933
934 ~UpperBoundWatcher() override {}
935
936 IntVar* GetOrMakeUpperBoundWatcher(int64_t value) override {
937 IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
938 if (watcher != nullptr) {
939 return watcher;
940 }
941 if (variable_->Max() >= value) {
942 if (variable_->Min() >= value) {
943 return solver()->MakeIntConst(1);
944 } else {
945 const std::string vname = variable_->HasName()
946 ? variable_->name()
947 : variable_->DebugString();
948 const std::string bname =
949 absl::StrFormat("Watch<%s >= %d>", vname, value);
950 IntVar* const boolvar = solver()->MakeBoolVar(bname);
951 watchers_.UnsafeRevInsert(value, boolvar);
952 if (posted_.Switched()) {
953 boolvar->WhenBound(
954 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
955 var_demon_->desinhibit(solver());
956 sorted_ = false;
957 }
958 return boolvar;
959 }
960 } else {
961 return variable_->solver()->MakeIntConst(0);
962 }
963 }
964
965 void SetUpperBoundWatcher(IntVar* const boolvar, int64_t value) override {
966 CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
967 watchers_.UnsafeRevInsert(value, boolvar);
968 if (posted_.Switched() && !boolvar->Bound()) {
969 boolvar->WhenBound(
970 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
971 var_demon_->desinhibit(solver());
972 sorted_ = false;
973 }
974 }
975
976 void Post() override {
977 const int kTooSmallToSort = 8;
978 var_demon_ = solver()->RevAlloc(new VarDemon(this));
979 variable_->WhenRange(var_demon_);
980
981 if (watchers_.Size() > kTooSmallToSort) {
982 watchers_.SortActive();
983 sorted_ = true;
984 start_.SetValue(solver(), watchers_.start());
985 end_.SetValue(solver(), watchers_.end() - 1);
986 }
987
988 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
989 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
990 IntVar* const boolvar = w.second;
991 const int64_t value = w.first;
992 if (!boolvar->Bound() && value > variable_->Min() &&
993 value <= variable_->Max()) {
994 boolvar->WhenBound(
995 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
996 }
997 }
998 posted_.Switch(solver());
999 }
1000
1001 void InitialPropagate() override {
1002 const int64_t var_min = variable_->Min();
1003 const int64_t var_max = variable_->Max();
1004 if (sorted_) {
1005 while (start_.Value() <= end_.Value()) {
1006 const std::pair<int64_t, IntVar*>& w = watchers_.At(start_.Value());
1007 if (w.first <= var_min) {
1008 w.second->SetValue(1);
1009 start_.Incr(solver());
1010 } else {
1011 break;
1012 }
1013 }
1014 while (end_.Value() >= start_.Value()) {
1015 const std::pair<int64_t, IntVar*>& w = watchers_.At(end_.Value());
1016 if (w.first > var_max) {
1017 w.second->SetValue(0);
1018 end_.Decr(solver());
1019 } else {
1020 break;
1021 }
1022 }
1023 for (int i = start_.Value(); i <= end_.Value(); ++i) {
1024 const std::pair<int64_t, IntVar*>& w = watchers_.At(i);
1025 if (w.second->Bound()) {
1026 ProcessUpperBoundWatcher(w.first, w.second);
1027 }
1028 }
1029 if (start_.Value() > end_.Value()) {
1030 var_demon_->inhibit(solver());
1031 }
1032 } else {
1033 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1034 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1035 const int64_t value = w.first;
1036 IntVar* const boolvar = w.second;
1037
1038 if (value <= var_min) {
1039 boolvar->SetValue(1);
1040 watchers_.RemoveAt(pos);
1041 } else if (value > var_max) {
1042 boolvar->SetValue(0);
1043 watchers_.RemoveAt(pos);
1044 } else if (boolvar->Bound()) {
1045 ProcessUpperBoundWatcher(value, boolvar);
1046 watchers_.RemoveAt(pos);
1047 }
1048 }
1049 }
1050 }
1051
1052 void Accept(ModelVisitor* const visitor) const override {
1053 visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1054 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1055 variable_);
1056 std::vector<int64_t> all_coefficients;
1057 std::vector<IntVar*> all_bool_vars;
1058 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1059 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1060 all_coefficients.push_back(w.first);
1061 all_bool_vars.push_back(w.second);
1062 }
1063 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1064 all_bool_vars);
1065 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1066 all_coefficients);
1067 visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1068 }
1069
1070 std::string DebugString() const override {
1071 return absl::StrFormat("UpperBoundWatcher(%s)", variable_->DebugString());
1072 }
1073
1074 private:
1075 void ProcessUpperBoundWatcher(int64_t value, IntVar* const boolvar) {
1076 if (boolvar->Min() == 0) {
1077 variable_->SetMax(value - 1);
1078 } else {
1079 variable_->SetMin(value);
1080 }
1081 }
1082
1083 void ProcessVar() {
1084 const int64_t var_min = variable_->Min();
1085 const int64_t var_max = variable_->Max();
1086 if (sorted_) {
1087 while (start_.Value() <= end_.Value()) {
1088 const std::pair<int64_t, IntVar*>& w = watchers_.At(start_.Value());
1089 if (w.first <= var_min) {
1090 w.second->SetValue(1);
1091 start_.Incr(solver());
1092 } else {
1093 break;
1094 }
1095 }
1096 while (end_.Value() >= start_.Value()) {
1097 const std::pair<int64_t, IntVar*>& w = watchers_.At(end_.Value());
1098 if (w.first > var_max) {
1099 w.second->SetValue(0);
1100 end_.Decr(solver());
1101 } else {
1102 break;
1103 }
1104 }
1105 if (start_.Value() > end_.Value()) {
1106 var_demon_->inhibit(solver());
1107 }
1108 } else {
1109 for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1110 const std::pair<int64_t, IntVar*>& w = watchers_.At(pos);
1111 const int64_t value = w.first;
1112 IntVar* const boolvar = w.second;
1113
1114 if (value <= var_min) {
1115 boolvar->SetValue(1);
1116 watchers_.RemoveAt(pos);
1117 } else if (value > var_max) {
1118 boolvar->SetValue(0);
1119 watchers_.RemoveAt(pos);
1120 }
1121 }
1122 if (watchers_.Empty()) {
1123 var_demon_->inhibit(solver());
1124 }
1125 }
1126 }
1127
1128 DomainIntVar* const variable_;
1129 RevSwitch posted_;
1130 Demon* var_demon_;
1131 RevIntPtrMap<IntVar> watchers_;
1132 NumericalRev<int> start_;
1133 NumericalRev<int> end_;
1134 bool sorted_;
1135 };
1136
1137 // Optimized case for small maps.
1138 class DenseUpperBoundWatcher : public BaseUpperBoundWatcher {
1139 public:
1140 class WatchDemon : public Demon {
1141 public:
1142 WatchDemon(DenseUpperBoundWatcher* const watcher, int64_t value,
1143 IntVar* var)
1144 : value_watcher_(watcher), value_(value), var_(var) {}
1145 ~WatchDemon() override {}
1146
1147 void Run(Solver* const solver) override {
1148 value_watcher_->ProcessUpperBoundWatcher(value_, var_);
1149 }
1150
1151 private:
1152 DenseUpperBoundWatcher* const value_watcher_;
1153 const int64_t value_;
1154 IntVar* const var_;
1155 };
1156
1157 class VarDemon : public Demon {
1158 public:
1159 explicit VarDemon(DenseUpperBoundWatcher* const watcher)
1160 : value_watcher_(watcher) {}
1161
1162 ~VarDemon() override {}
1163
1164 void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
1165
1166 private:
1167 DenseUpperBoundWatcher* const value_watcher_;
1168 };
1169
1170 DenseUpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
1171 : BaseUpperBoundWatcher(solver),
1172 variable_(variable),
1173 var_demon_(nullptr),
1174 offset_(variable->Min()),
1175 watchers_(variable->Max() - variable->Min() + 1, nullptr),
1176 active_watchers_(0) {}
1177
1178 ~DenseUpperBoundWatcher() override {}
1179
1180 IntVar* GetOrMakeUpperBoundWatcher(int64_t value) override {
1181 if (variable_->Max() >= value) {
1182 if (variable_->Min() >= value) {
1183 return solver()->MakeIntConst(1);
1184 } else {
1185 const std::string vname = variable_->HasName()
1186 ? variable_->name()
1187 : variable_->DebugString();
1188 const std::string bname =
1189 absl::StrFormat("Watch<%s >= %d>", vname, value);
1190 IntVar* const boolvar = solver()->MakeBoolVar(bname);
1191 RevInsert(value - offset_, boolvar);
1192 if (posted_.Switched()) {
1193 boolvar->WhenBound(
1194 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1195 var_demon_->desinhibit(solver());
1196 }
1197 return boolvar;
1198 }
1199 } else {
1200 return variable_->solver()->MakeIntConst(0);
1201 }
1202 }
1203
1204 void SetUpperBoundWatcher(IntVar* const boolvar, int64_t value) override {
1205 const int index = value - offset_;
1206 CHECK(watchers_[index] == nullptr);
1207 if (!boolvar->Bound()) {
1208 RevInsert(index, boolvar);
1209 if (posted_.Switched() && !boolvar->Bound()) {
1210 boolvar->WhenBound(
1211 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1212 var_demon_->desinhibit(solver());
1213 }
1214 }
1215 }
1216
1217 void Post() override {
1218 var_demon_ = solver()->RevAlloc(new VarDemon(this));
1219 variable_->WhenRange(var_demon_);
1220 for (int pos = 0; pos < watchers_.size(); ++pos) {
1221 const int64_t value = pos + offset_;
1222 IntVar* const boolvar = watchers_[pos];
1223 if (boolvar != nullptr && !boolvar->Bound() &&
1224 value > variable_->Min() && value <= variable_->Max()) {
1225 boolvar->WhenBound(
1226 solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1227 }
1228 }
1229 posted_.Switch(solver());
1230 }
1231
1232 void InitialPropagate() override {
1233 for (int pos = 0; pos < watchers_.size(); ++pos) {
1234 IntVar* const boolvar = watchers_[pos];
1235 if (boolvar == nullptr) continue;
1236 const int64_t value = pos + offset_;
1237 if (value <= variable_->Min()) {
1238 boolvar->SetValue(1);
1239 RevRemove(pos);
1240 } else if (value > variable_->Max()) {
1241 boolvar->SetValue(0);
1242 RevRemove(pos);
1243 } else if (boolvar->Bound()) {
1244 ProcessUpperBoundWatcher(value, boolvar);
1245 RevRemove(pos);
1246 }
1247 }
1248 if (active_watchers_.Value() == 0) {
1249 var_demon_->inhibit(solver());
1250 }
1251 }
1252
1253 void ProcessUpperBoundWatcher(int64_t value, IntVar* boolvar) {
1254 if (boolvar->Min() == 0) {
1255 variable_->SetMax(value - 1);
1256 } else {
1257 variable_->SetMin(value);
1258 }
1259 }
1260
1261 void ProcessVar() {
1262 const int64_t old_min_index = variable_->OldMin() - offset_;
1263 const int64_t old_max_index = variable_->OldMax() - offset_;
1264 const int64_t min_index = variable_->Min() - offset_;
1265 const int64_t max_index = variable_->Max() - offset_;
1266 for (int pos = old_min_index; pos <= min_index; ++pos) {
1267 IntVar* const boolvar = watchers_[pos];
1268 if (boolvar != nullptr) {
1269 boolvar->SetValue(1);
1270 RevRemove(pos);
1271 }
1272 }
1273
1274 for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
1275 IntVar* const boolvar = watchers_[pos];
1276 if (boolvar != nullptr) {
1277 boolvar->SetValue(0);
1278 RevRemove(pos);
1279 }
1280 }
1281 if (active_watchers_.Value() == 0) {
1282 var_demon_->inhibit(solver());
1283 }
1284 }
1285
1286 void RevRemove(int pos) {
1287 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1288 watchers_[pos] = nullptr;
1289 active_watchers_.Decr(solver());
1290 }
1291
1292 void RevInsert(int pos, IntVar* boolvar) {
1293 solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1294 watchers_[pos] = boolvar;
1295 active_watchers_.Incr(solver());
1296 }
1297
1298 void Accept(ModelVisitor* const visitor) const override {
1299 visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1300 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1301 variable_);
1302 std::vector<int64_t> all_coefficients;
1303 std::vector<IntVar*> all_bool_vars;
1304 for (int position = 0; position < watchers_.size(); ++position) {
1305 if (watchers_[position] != nullptr) {
1306 all_coefficients.push_back(position + offset_);
1307 all_bool_vars.push_back(watchers_[position]);
1308 }
1309 }
1310 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1311 all_bool_vars);
1312 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1313 all_coefficients);
1314 visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1315 }
1316
1317 std::string DebugString() const override {
1318 return absl::StrFormat("DenseUpperBoundWatcher(%s)",
1319 variable_->DebugString());
1320 }
1321
1322 private:
1323 DomainIntVar* const variable_;
1324 RevSwitch posted_;
1325 Demon* var_demon_;
1326 const int64_t offset_;
1327 std::vector<IntVar*> watchers_;
1328 NumericalRev<int> active_watchers_;
1329 };
1330
1331 // ----- Main Class -----
1332 DomainIntVar(Solver* const s, int64_t vmin, int64_t vmax,
1333 const std::string& name);
1334 DomainIntVar(Solver* const s, const std::vector<int64_t>& sorted_values,
1335 const std::string& name);
1336 ~DomainIntVar() override;
1337
1338 int64_t Min() const override { return min_.Value(); }
1339 void SetMin(int64_t m) override;
1340 int64_t Max() const override { return max_.Value(); }
1341 void SetMax(int64_t m) override;
1342 void SetRange(int64_t mi, int64_t ma) override;
1343 void SetValue(int64_t v) override;
1344 bool Bound() const override { return (min_.Value() == max_.Value()); }
1345 int64_t Value() const override {
1346 CHECK_EQ(min_.Value(), max_.Value())
1347 << " variable " << DebugString() << " is not bound.";
1348 return min_.Value();
1349 }
1350 void RemoveValue(int64_t v) override;
1351 void RemoveInterval(int64_t l, int64_t u) override;
1352 void CreateBits();
1353 void WhenBound(Demon* d) override {
1354 if (min_.Value() != max_.Value()) {
1355 if (d->priority() == Solver::DELAYED_PRIORITY) {
1356 delayed_bound_demons_.PushIfNotTop(solver(),
1357 solver()->RegisterDemon(d));
1358 } else {
1359 bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1360 }
1361 }
1362 }
1363 void WhenRange(Demon* d) override {
1364 if (min_.Value() != max_.Value()) {
1365 if (d->priority() == Solver::DELAYED_PRIORITY) {
1366 delayed_range_demons_.PushIfNotTop(solver(),
1367 solver()->RegisterDemon(d));
1368 } else {
1369 range_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1370 }
1371 }
1372 }
1373 void WhenDomain(Demon* d) override {
1374 if (min_.Value() != max_.Value()) {
1375 if (d->priority() == Solver::DELAYED_PRIORITY) {
1376 delayed_domain_demons_.PushIfNotTop(solver(),
1377 solver()->RegisterDemon(d));
1378 } else {
1379 domain_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1380 }
1381 }
1382 }
1383
1384 IntVar* IsEqual(int64_t constant) override {
1385 Solver* const s = solver();
1386 if (constant == min_.Value() && value_watcher_ == nullptr) {
1387 return s->MakeIsLessOrEqualCstVar(this, constant);
1388 }
1389 if (constant == max_.Value() && value_watcher_ == nullptr) {
1390 return s->MakeIsGreaterOrEqualCstVar(this, constant);
1391 }
1392 if (!Contains(constant)) {
1393 return s->MakeIntConst(int64_t{0});
1394 }
1395 if (Bound() && min_.Value() == constant) {
1396 return s->MakeIntConst(int64_t{1});
1397 }
1398 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1400 if (cache != nullptr) {
1401 return cache->Var();
1402 } else {
1403 if (value_watcher_ == nullptr) {
1404 if (CapSub(Max(), Min()) <= 256) {
1405 solver()->SaveAndSetValue(
1406 reinterpret_cast<void**>(&value_watcher_),
1407 reinterpret_cast<void*>(
1408 solver()->RevAlloc(new DenseValueWatcher(solver(), this))));
1409
1410 } else {
1411 solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1412 reinterpret_cast<void*>(solver()->RevAlloc(
1413 new ValueWatcher(solver(), this))));
1414 }
1415 solver()->AddConstraint(value_watcher_);
1416 }
1417 IntVar* const boolvar = value_watcher_->GetOrMakeValueWatcher(constant);
1418 s->Cache()->InsertExprConstantExpression(
1420 return boolvar;
1421 }
1422 }
1423
1424 Constraint* SetIsEqual(const std::vector<int64_t>& values,
1425 const std::vector<IntVar*>& vars) {
1426 if (value_watcher_ == nullptr) {
1427 solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1428 reinterpret_cast<void*>(solver()->RevAlloc(
1429 new ValueWatcher(solver(), this))));
1430 for (int i = 0; i < vars.size(); ++i) {
1431 value_watcher_->SetValueWatcher(vars[i], values[i]);
1432 }
1433 }
1434 return value_watcher_;
1435 }
1436
1437 IntVar* IsDifferent(int64_t constant) override {
1438 Solver* const s = solver();
1439 if (constant == min_.Value() && value_watcher_ == nullptr) {
1440 return s->MakeIsGreaterOrEqualCstVar(this, constant + 1);
1441 }
1442 if (constant == max_.Value() && value_watcher_ == nullptr) {
1443 return s->MakeIsLessOrEqualCstVar(this, constant - 1);
1444 }
1445 if (!Contains(constant)) {
1446 return s->MakeIntConst(int64_t{1});
1447 }
1448 if (Bound() && min_.Value() == constant) {
1449 return s->MakeIntConst(int64_t{0});
1450 }
1451 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1453 if (cache != nullptr) {
1454 return cache->Var();
1455 } else {
1456 IntVar* const boolvar = s->MakeDifference(1, IsEqual(constant))->Var();
1457 s->Cache()->InsertExprConstantExpression(
1459 return boolvar;
1460 }
1461 }
1462
1463 IntVar* IsGreaterOrEqual(int64_t constant) override {
1464 Solver* const s = solver();
1465 if (max_.Value() < constant) {
1466 return s->MakeIntConst(int64_t{0});
1467 }
1468 if (min_.Value() >= constant) {
1469 return s->MakeIntConst(int64_t{1});
1470 }
1471 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1473 if (cache != nullptr) {
1474 return cache->Var();
1475 } else {
1476 if (bound_watcher_ == nullptr) {
1477 if (CapSub(Max(), Min()) <= 256) {
1478 solver()->SaveAndSetValue(
1479 reinterpret_cast<void**>(&bound_watcher_),
1480 reinterpret_cast<void*>(solver()->RevAlloc(
1481 new DenseUpperBoundWatcher(solver(), this))));
1482 solver()->AddConstraint(bound_watcher_);
1483 } else {
1484 solver()->SaveAndSetValue(
1485 reinterpret_cast<void**>(&bound_watcher_),
1486 reinterpret_cast<void*>(
1487 solver()->RevAlloc(new UpperBoundWatcher(solver(), this))));
1488 solver()->AddConstraint(bound_watcher_);
1489 }
1490 }
1491 IntVar* const boolvar =
1492 bound_watcher_->GetOrMakeUpperBoundWatcher(constant);
1493 s->Cache()->InsertExprConstantExpression(
1494 boolvar, this, constant,
1496 return boolvar;
1497 }
1498 }
1499
1500 Constraint* SetIsGreaterOrEqual(const std::vector<int64_t>& values,
1501 const std::vector<IntVar*>& vars) {
1502 if (bound_watcher_ == nullptr) {
1503 if (CapSub(Max(), Min()) <= 256) {
1504 solver()->SaveAndSetValue(
1505 reinterpret_cast<void**>(&bound_watcher_),
1506 reinterpret_cast<void*>(solver()->RevAlloc(
1507 new DenseUpperBoundWatcher(solver(), this))));
1508 solver()->AddConstraint(bound_watcher_);
1509 } else {
1510 solver()->SaveAndSetValue(reinterpret_cast<void**>(&bound_watcher_),
1511 reinterpret_cast<void*>(solver()->RevAlloc(
1512 new UpperBoundWatcher(solver(), this))));
1513 solver()->AddConstraint(bound_watcher_);
1514 }
1515 for (int i = 0; i < values.size(); ++i) {
1516 bound_watcher_->SetUpperBoundWatcher(vars[i], values[i]);
1517 }
1518 }
1519 return bound_watcher_;
1520 }
1521
1522 IntVar* IsLessOrEqual(int64_t constant) override {
1523 Solver* const s = solver();
1524 IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1526 if (cache != nullptr) {
1527 return cache->Var();
1528 } else {
1529 IntVar* const boolvar =
1530 s->MakeDifference(1, IsGreaterOrEqual(constant + 1))->Var();
1531 s->Cache()->InsertExprConstantExpression(
1533 return boolvar;
1534 }
1535 }
1536
1537 void Process();
1538 void Push();
1539 void CleanInProcess();
1540 uint64_t Size() const override {
1541 if (bits_ != nullptr) return bits_->Size();
1542 return (static_cast<uint64_t>(max_.Value()) -
1543 static_cast<uint64_t>(min_.Value()) + 1);
1544 }
1545 bool Contains(int64_t v) const override {
1546 if (v < min_.Value() || v > max_.Value()) return false;
1547 return (bits_ == nullptr ? true : bits_->Contains(v));
1548 }
1549 IntVarIterator* MakeHoleIterator(bool reversible) const override;
1550 IntVarIterator* MakeDomainIterator(bool reversible) const override;
1551 int64_t OldMin() const override { return std::min(old_min_, min_.Value()); }
1552 int64_t OldMax() const override { return std::max(old_max_, max_.Value()); }
1553
1554 std::string DebugString() const override;
1555 BitSet* bitset() const { return bits_; }
1556 int VarType() const override { return DOMAIN_INT_VAR; }
1557 std::string BaseName() const override { return "IntegerVar"; }
1558
1559 friend class PlusCstDomainIntVar;
1560 friend class LinkExprAndDomainIntVar;
1561
1562 private:
1563 void CheckOldMin() {
1564 if (old_min_ > min_.Value()) {
1565 old_min_ = min_.Value();
1566 }
1567 }
1568 void CheckOldMax() {
1569 if (old_max_ < max_.Value()) {
1570 old_max_ = max_.Value();
1571 }
1572 }
1573 Rev<int64_t> min_;
1574 Rev<int64_t> max_;
1575 int64_t old_min_;
1576 int64_t old_max_;
1577 int64_t new_min_;
1578 int64_t new_max_;
1579 SimpleRevFIFO<Demon*> bound_demons_;
1580 SimpleRevFIFO<Demon*> range_demons_;
1581 SimpleRevFIFO<Demon*> domain_demons_;
1582 SimpleRevFIFO<Demon*> delayed_bound_demons_;
1583 SimpleRevFIFO<Demon*> delayed_range_demons_;
1584 SimpleRevFIFO<Demon*> delayed_domain_demons_;
1585 QueueHandler handler_;
1586 bool in_process_;
1587 BitSet* bits_;
1588 BaseValueWatcher* value_watcher_;
1589 BaseUpperBoundWatcher* bound_watcher_;
1590};
1591
1592// ----- BitSet -----
1593
1594// Return whether an integer interval [a..b] (inclusive) contains at most
1595// K values, i.e. b - a < K, in a way that's robust to overflows.
1596// For performance reasons, in opt mode it doesn't check that [a, b] is a
1597// valid interval, nor that K is nonnegative.
1598inline bool ClosedIntervalNoLargerThan(int64_t a, int64_t b, int64_t K) {
1599 DCHECK_LE(a, b);
1600 DCHECK_GE(K, 0);
1601 if (a > 0) {
1602 return a > b - K;
1603 } else {
1604 return a + K > b;
1605 }
1606}
1607
1608class SimpleBitSet : public DomainIntVar::BitSet {
1609 public:
1610 SimpleBitSet(Solver* const s, int64_t vmin, int64_t vmax)
1611 : BitSet(s),
1612 bits_(nullptr),
1613 stamps_(nullptr),
1614 omin_(vmin),
1615 omax_(vmax),
1616 size_(vmax - vmin + 1),
1617 bsize_(BitLength64(size_.Value())) {
1618 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1619 << "Bitset too large: [" << vmin << ", " << vmax << "]";
1620 bits_ = new uint64_t[bsize_];
1621 stamps_ = new uint64_t[bsize_];
1622 for (int i = 0; i < bsize_; ++i) {
1623 const int bs =
1624 (i == size_.Value() - 1) ? 63 - BitPos64(size_.Value()) : 0;
1625 bits_[i] = kAllBits64 >> bs;
1626 stamps_[i] = s->stamp() - 1;
1627 }
1628 }
1629
1630 SimpleBitSet(Solver* const s, const std::vector<int64_t>& sorted_values,
1631 int64_t vmin, int64_t vmax)
1632 : BitSet(s),
1633 bits_(nullptr),
1634 stamps_(nullptr),
1635 omin_(vmin),
1636 omax_(vmax),
1637 size_(sorted_values.size()),
1638 bsize_(BitLength64(vmax - vmin + 1)) {
1639 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1640 << "Bitset too large: [" << vmin << ", " << vmax << "]";
1641 bits_ = new uint64_t[bsize_];
1642 stamps_ = new uint64_t[bsize_];
1643 for (int i = 0; i < bsize_; ++i) {
1644 bits_[i] = uint64_t{0};
1645 stamps_[i] = s->stamp() - 1;
1646 }
1647 for (int i = 0; i < sorted_values.size(); ++i) {
1648 const int64_t val = sorted_values[i];
1649 DCHECK(!bit(val));
1650 const int offset = BitOffset64(val - omin_);
1651 const int pos = BitPos64(val - omin_);
1652 bits_[offset] |= OneBit64(pos);
1653 }
1654 }
1655
1656 ~SimpleBitSet() override {
1657 delete[] bits_;
1658 delete[] stamps_;
1659 }
1660
1661 bool bit(int64_t val) const { return IsBitSet64(bits_, val - omin_); }
1662
1663 int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) override {
1664 DCHECK_GE(nmin, cmin);
1665 DCHECK_LE(nmin, cmax);
1666 DCHECK_LE(cmin, cmax);
1667 DCHECK_GE(cmin, omin_);
1668 DCHECK_LE(cmax, omax_);
1669 const int64_t new_min =
1670 UnsafeLeastSignificantBitPosition64(bits_, nmin - omin_, cmax - omin_) +
1671 omin_;
1672 const uint64_t removed_bits =
1673 BitCountRange64(bits_, cmin - omin_, new_min - omin_ - 1);
1674 size_.Add(solver_, -removed_bits);
1675 return new_min;
1676 }
1677
1678 int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) override {
1679 DCHECK_GE(nmax, cmin);
1680 DCHECK_LE(nmax, cmax);
1681 DCHECK_LE(cmin, cmax);
1682 DCHECK_GE(cmin, omin_);
1683 DCHECK_LE(cmax, omax_);
1684 const int64_t new_max =
1685 UnsafeMostSignificantBitPosition64(bits_, cmin - omin_, nmax - omin_) +
1686 omin_;
1687 const uint64_t removed_bits =
1688 BitCountRange64(bits_, new_max - omin_ + 1, cmax - omin_);
1689 size_.Add(solver_, -removed_bits);
1690 return new_max;
1691 }
1692
1693 bool SetValue(int64_t val) override {
1694 DCHECK_GE(val, omin_);
1695 DCHECK_LE(val, omax_);
1696 if (bit(val)) {
1697 size_.SetValue(solver_, 1);
1698 return true;
1699 }
1700 return false;
1701 }
1702
1703 bool Contains(int64_t val) const override {
1704 DCHECK_GE(val, omin_);
1705 DCHECK_LE(val, omax_);
1706 return bit(val);
1707 }
1708
1709 bool RemoveValue(int64_t val) override {
1710 if (val < omin_ || val > omax_ || !bit(val)) {
1711 return false;
1712 }
1713 // Bitset.
1714 const int64_t val_offset = val - omin_;
1715 const int offset = BitOffset64(val_offset);
1716 const uint64_t current_stamp = solver_->stamp();
1717 if (stamps_[offset] < current_stamp) {
1718 stamps_[offset] = current_stamp;
1719 solver_->SaveValue(&bits_[offset]);
1720 }
1721 const int pos = BitPos64(val_offset);
1722 bits_[offset] &= ~OneBit64(pos);
1723 // Size.
1724 size_.Decr(solver_);
1725 // Holes.
1726 InitHoles();
1727 AddHole(val);
1728 return true;
1729 }
1730 uint64_t Size() const override { return size_.Value(); }
1731
1732 std::string DebugString() const override {
1733 std::string out;
1734 absl::StrAppendFormat(&out, "SimpleBitSet(%d..%d : ", omin_, omax_);
1735 for (int i = 0; i < bsize_; ++i) {
1736 absl::StrAppendFormat(&out, "%x", bits_[i]);
1737 }
1738 out += ")";
1739 return out;
1740 }
1741
1742 void DelayRemoveValue(int64_t val) override { removed_.push_back(val); }
1743
1744 void ApplyRemovedValues(DomainIntVar* var) override {
1745 std::sort(removed_.begin(), removed_.end());
1746 for (std::vector<int64_t>::iterator it = removed_.begin();
1747 it != removed_.end(); ++it) {
1748 var->RemoveValue(*it);
1749 }
1750 }
1751
1752 void ClearRemovedValues() override { removed_.clear(); }
1753
1754 std::string pretty_DebugString(int64_t min, int64_t max) const override {
1755 std::string out;
1756 DCHECK(bit(min));
1757 DCHECK(bit(max));
1758 if (max != min) {
1759 int cumul = true;
1760 int64_t start_cumul = min;
1761 for (int64_t v = min + 1; v < max; ++v) {
1762 if (bit(v)) {
1763 if (!cumul) {
1764 cumul = true;
1765 start_cumul = v;
1766 }
1767 } else {
1768 if (cumul) {
1769 if (v == start_cumul + 1) {
1770 absl::StrAppendFormat(&out, "%d ", start_cumul);
1771 } else if (v == start_cumul + 2) {
1772 absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1773 } else {
1774 absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1775 }
1776 cumul = false;
1777 }
1778 }
1779 }
1780 if (cumul) {
1781 if (max == start_cumul + 1) {
1782 absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1783 } else {
1784 absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1785 }
1786 } else {
1787 absl::StrAppendFormat(&out, "%d", max);
1788 }
1789 } else {
1790 absl::StrAppendFormat(&out, "%d", min);
1791 }
1792 return out;
1793 }
1794
1795 DomainIntVar::BitSetIterator* MakeIterator() override {
1796 return new DomainIntVar::BitSetIterator(bits_, omin_);
1797 }
1798
1799 private:
1800 uint64_t* bits_;
1801 uint64_t* stamps_;
1802 const int64_t omin_;
1803 const int64_t omax_;
1804 NumericalRev<int64_t> size_;
1805 const int bsize_;
1806 std::vector<int64_t> removed_;
1807};
1808
1809// This is a special case where the bitset fits into one 64 bit integer.
1810// In that case, there are no offset to compute.
1811// Overflows are caught by the robust ClosedIntervalNoLargerThan() method.
1812class SmallBitSet : public DomainIntVar::BitSet {
1813 public:
1814 SmallBitSet(Solver* const s, int64_t vmin, int64_t vmax)
1815 : BitSet(s),
1816 bits_(uint64_t{0}),
1817 stamp_(s->stamp() - 1),
1818 omin_(vmin),
1819 omax_(vmax),
1820 size_(vmax - vmin + 1) {
1821 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1822 bits_ = OneRange64(0, size_.Value() - 1);
1823 }
1824
1825 SmallBitSet(Solver* const s, const std::vector<int64_t>& sorted_values,
1826 int64_t vmin, int64_t vmax)
1827 : BitSet(s),
1828 bits_(uint64_t{0}),
1829 stamp_(s->stamp() - 1),
1830 omin_(vmin),
1831 omax_(vmax),
1832 size_(sorted_values.size()) {
1833 CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1834 // We know the array is sorted and does not contains duplicate values.
1835 for (int i = 0; i < sorted_values.size(); ++i) {
1836 const int64_t val = sorted_values[i];
1837 DCHECK_GE(val, vmin);
1838 DCHECK_LE(val, vmax);
1839 DCHECK(!IsBitSet64(&bits_, val - omin_));
1840 bits_ |= OneBit64(val - omin_);
1841 }
1842 }
1843
1844 ~SmallBitSet() override {}
1845
1846 bool bit(int64_t val) const {
1847 DCHECK_GE(val, omin_);
1848 DCHECK_LE(val, omax_);
1849 return (bits_ & OneBit64(val - omin_)) != 0;
1850 }
1851
1852 int64_t ComputeNewMin(int64_t nmin, int64_t cmin, int64_t cmax) override {
1853 DCHECK_GE(nmin, cmin);
1854 DCHECK_LE(nmin, cmax);
1855 DCHECK_LE(cmin, cmax);
1856 DCHECK_GE(cmin, omin_);
1857 DCHECK_LE(cmax, omax_);
1858 // We do not clean the bits between cmin and nmin.
1859 // But we use mask to look only at 'active' bits.
1860
1861 // Create the mask and compute new bits
1862 const uint64_t new_bits = bits_ & OneRange64(nmin - omin_, cmax - omin_);
1863 if (new_bits != uint64_t{0}) {
1864 // Compute new size and new min
1865 size_.SetValue(solver_, BitCount64(new_bits));
1866 if (bit(nmin)) { // Common case, the new min is inside the bitset
1867 return nmin;
1868 }
1869 return LeastSignificantBitPosition64(new_bits) + omin_;
1870 } else { // == 0 -> Fail()
1871 solver_->Fail();
1873 }
1874 }
1875
1876 int64_t ComputeNewMax(int64_t nmax, int64_t cmin, int64_t cmax) override {
1877 DCHECK_GE(nmax, cmin);
1878 DCHECK_LE(nmax, cmax);
1879 DCHECK_LE(cmin, cmax);
1880 DCHECK_GE(cmin, omin_);
1881 DCHECK_LE(cmax, omax_);
1882 // We do not clean the bits between nmax and cmax.
1883 // But we use mask to look only at 'active' bits.
1884
1885 // Create the mask and compute new_bits
1886 const uint64_t new_bits = bits_ & OneRange64(cmin - omin_, nmax - omin_);
1887 if (new_bits != uint64_t{0}) {
1888 // Compute new size and new min
1889 size_.SetValue(solver_, BitCount64(new_bits));
1890 if (bit(nmax)) { // Common case, the new max is inside the bitset
1891 return nmax;
1892 }
1893 return MostSignificantBitPosition64(new_bits) + omin_;
1894 } else { // == 0 -> Fail()
1895 solver_->Fail();
1897 }
1898 }
1899
1900 bool SetValue(int64_t val) override {
1901 DCHECK_GE(val, omin_);
1902 DCHECK_LE(val, omax_);
1903 // We do not clean the bits. We will use masks to ignore the bits
1904 // that should have been cleaned.
1905 if (bit(val)) {
1906 size_.SetValue(solver_, 1);
1907 return true;
1908 }
1909 return false;
1910 }
1911
1912 bool Contains(int64_t val) const override {
1913 DCHECK_GE(val, omin_);
1914 DCHECK_LE(val, omax_);
1915 return bit(val);
1916 }
1917
1918 bool RemoveValue(int64_t val) override {
1919 DCHECK_GE(val, omin_);
1920 DCHECK_LE(val, omax_);
1921 if (bit(val)) {
1922 // Bitset.
1923 const uint64_t current_stamp = solver_->stamp();
1924 if (stamp_ < current_stamp) {
1925 stamp_ = current_stamp;
1926 solver_->SaveValue(&bits_);
1927 }
1928 bits_ &= ~OneBit64(val - omin_);
1929 DCHECK(!bit(val));
1930 // Size.
1931 size_.Decr(solver_);
1932 // Holes.
1933 InitHoles();
1934 AddHole(val);
1935 return true;
1936 } else {
1937 return false;
1938 }
1939 }
1940
1941 uint64_t Size() const override { return size_.Value(); }
1942
1943 std::string DebugString() const override {
1944 return absl::StrFormat("SmallBitSet(%d..%d : %llx)", omin_, omax_, bits_);
1945 }
1946
1947 void DelayRemoveValue(int64_t val) override {
1948 DCHECK_GE(val, omin_);
1949 DCHECK_LE(val, omax_);
1950 removed_.push_back(val);
1951 }
1952
1953 void ApplyRemovedValues(DomainIntVar* var) override {
1954 std::sort(removed_.begin(), removed_.end());
1955 for (std::vector<int64_t>::iterator it = removed_.begin();
1956 it != removed_.end(); ++it) {
1957 var->RemoveValue(*it);
1958 }
1959 }
1960
1961 void ClearRemovedValues() override { removed_.clear(); }
1962
1963 std::string pretty_DebugString(int64_t min, int64_t max) const override {
1964 std::string out;
1965 DCHECK(bit(min));
1966 DCHECK(bit(max));
1967 if (max != min) {
1968 int cumul = true;
1969 int64_t start_cumul = min;
1970 for (int64_t v = min + 1; v < max; ++v) {
1971 if (bit(v)) {
1972 if (!cumul) {
1973 cumul = true;
1974 start_cumul = v;
1975 }
1976 } else {
1977 if (cumul) {
1978 if (v == start_cumul + 1) {
1979 absl::StrAppendFormat(&out, "%d ", start_cumul);
1980 } else if (v == start_cumul + 2) {
1981 absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1982 } else {
1983 absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1984 }
1985 cumul = false;
1986 }
1987 }
1988 }
1989 if (cumul) {
1990 if (max == start_cumul + 1) {
1991 absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1992 } else {
1993 absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1994 }
1995 } else {
1996 absl::StrAppendFormat(&out, "%d", max);
1997 }
1998 } else {
1999 absl::StrAppendFormat(&out, "%d", min);
2000 }
2001 return out;
2002 }
2003
2004 DomainIntVar::BitSetIterator* MakeIterator() override {
2005 return new DomainIntVar::BitSetIterator(&bits_, omin_);
2006 }
2007
2008 private:
2009 uint64_t bits_;
2010 uint64_t stamp_;
2011 const int64_t omin_;
2012 const int64_t omax_;
2013 NumericalRev<int64_t> size_;
2014 std::vector<int64_t> removed_;
2015};
2016
2017class EmptyIterator : public IntVarIterator {
2018 public:
2019 ~EmptyIterator() override {}
2020 void Init() override {}
2021 bool Ok() const override { return false; }
2022 int64_t Value() const override {
2023 LOG(FATAL) << "Should not be called";
2024 return 0LL;
2025 }
2026 void Next() override {}
2027};
2028
2029class RangeIterator : public IntVarIterator {
2030 public:
2031 explicit RangeIterator(const IntVar* const var)
2032 : var_(var),
2033 min_(std::numeric_limits<int64_t>::max()),
2034 max_(std::numeric_limits<int64_t>::min()),
2035 current_(-1) {}
2036
2037 ~RangeIterator() override {}
2038
2039 void Init() override {
2040 min_ = var_->Min();
2041 max_ = var_->Max();
2042 current_ = min_;
2043 }
2044
2045 bool Ok() const override { return current_ <= max_; }
2046
2047 int64_t Value() const override { return current_; }
2048
2049 void Next() override { current_++; }
2050
2051 private:
2052 const IntVar* const var_;
2053 int64_t min_;
2054 int64_t max_;
2055 int64_t current_;
2056};
2057
2058class DomainIntVarHoleIterator : public IntVarIterator {
2059 public:
2060 explicit DomainIntVarHoleIterator(const DomainIntVar* const v)
2061 : var_(v), bits_(nullptr), values_(nullptr), size_(0), index_(0) {}
2062
2063 ~DomainIntVarHoleIterator() override {}
2064
2065 void Init() override {
2066 bits_ = var_->bitset();
2067 if (bits_ != nullptr) {
2068 bits_->InitHoles();
2069 values_ = bits_->Holes().data();
2070 size_ = bits_->Holes().size();
2071 } else {
2072 values_ = nullptr;
2073 size_ = 0;
2074 }
2075 index_ = 0;
2076 }
2077
2078 bool Ok() const override { return index_ < size_; }
2079
2080 int64_t Value() const override {
2081 DCHECK(bits_ != nullptr);
2082 DCHECK(index_ < size_);
2083 return values_[index_];
2084 }
2085
2086 void Next() override { index_++; }
2087
2088 private:
2089 const DomainIntVar* const var_;
2090 DomainIntVar::BitSet* bits_;
2091 const int64_t* values_;
2092 int size_;
2093 int index_;
2094};
2095
2096class DomainIntVarDomainIterator : public IntVarIterator {
2097 public:
2098 explicit DomainIntVarDomainIterator(const DomainIntVar* const v,
2099 bool reversible)
2100 : var_(v),
2101 bitset_iterator_(nullptr),
2102 min_(std::numeric_limits<int64_t>::max()),
2103 max_(std::numeric_limits<int64_t>::min()),
2104 current_(-1),
2105 reversible_(reversible) {}
2106
2107 ~DomainIntVarDomainIterator() override {
2108 if (!reversible_ && bitset_iterator_) {
2109 delete bitset_iterator_;
2110 }
2111 }
2112
2113 void Init() override {
2114 if (var_->bitset() != nullptr && !var_->Bound()) {
2115 if (reversible_) {
2116 if (!bitset_iterator_) {
2117 Solver* const solver = var_->solver();
2118 solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2119 bitset_iterator_ = solver->RevAlloc(var_->bitset()->MakeIterator());
2120 }
2121 } else {
2122 if (bitset_iterator_) {
2123 delete bitset_iterator_;
2124 }
2125 bitset_iterator_ = var_->bitset()->MakeIterator();
2126 }
2127 bitset_iterator_->Init(var_->Min(), var_->Max());
2128 } else {
2129 if (bitset_iterator_) {
2130 if (reversible_) {
2131 Solver* const solver = var_->solver();
2132 solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2133 } else {
2134 delete bitset_iterator_;
2135 }
2136 bitset_iterator_ = nullptr;
2137 }
2138 min_ = var_->Min();
2139 max_ = var_->Max();
2140 current_ = min_;
2141 }
2142 }
2143
2144 bool Ok() const override {
2145 return bitset_iterator_ ? bitset_iterator_->Ok() : (current_ <= max_);
2146 }
2147
2148 int64_t Value() const override {
2149 return bitset_iterator_ ? bitset_iterator_->Value() : current_;
2150 }
2151
2152 void Next() override {
2153 if (bitset_iterator_) {
2154 bitset_iterator_->Next();
2155 } else {
2156 current_++;
2157 }
2158 }
2159
2160 private:
2161 const DomainIntVar* const var_;
2162 DomainIntVar::BitSetIterator* bitset_iterator_;
2163 int64_t min_;
2164 int64_t max_;
2165 int64_t current_;
2166 const bool reversible_;
2167};
2168
2169class UnaryIterator : public IntVarIterator {
2170 public:
2171 UnaryIterator(const IntVar* const v, bool hole, bool reversible)
2172 : iterator_(hole ? v->MakeHoleIterator(reversible)
2173 : v->MakeDomainIterator(reversible)),
2174 reversible_(reversible) {}
2175
2176 ~UnaryIterator() override {
2177 if (!reversible_) {
2178 delete iterator_;
2179 }
2180 }
2181
2182 void Init() override { iterator_->Init(); }
2183
2184 bool Ok() const override { return iterator_->Ok(); }
2185
2186 void Next() override { iterator_->Next(); }
2187
2188 protected:
2189 IntVarIterator* const iterator_;
2190 const bool reversible_;
2191};
2192
2193DomainIntVar::DomainIntVar(Solver* const s, int64_t vmin, int64_t vmax,
2194 const std::string& name)
2195 : IntVar(s, name),
2196 min_(vmin),
2197 max_(vmax),
2198 old_min_(vmin),
2199 old_max_(vmax),
2200 new_min_(vmin),
2201 new_max_(vmax),
2202 handler_(this),
2203 in_process_(false),
2204 bits_(nullptr),
2205 value_watcher_(nullptr),
2206 bound_watcher_(nullptr) {}
2207
2208DomainIntVar::DomainIntVar(Solver* const s,
2209 const std::vector<int64_t>& sorted_values,
2210 const std::string& name)
2211 : IntVar(s, name),
2212 min_(std::numeric_limits<int64_t>::max()),
2213 max_(std::numeric_limits<int64_t>::min()),
2214 old_min_(std::numeric_limits<int64_t>::max()),
2215 old_max_(std::numeric_limits<int64_t>::min()),
2216 new_min_(std::numeric_limits<int64_t>::max()),
2217 new_max_(std::numeric_limits<int64_t>::min()),
2218 handler_(this),
2219 in_process_(false),
2220 bits_(nullptr),
2221 value_watcher_(nullptr),
2222 bound_watcher_(nullptr) {
2223 CHECK_GE(sorted_values.size(), 1);
2224 // We know that the vector is sorted and does not have duplicate values.
2225 const int64_t vmin = sorted_values.front();
2226 const int64_t vmax = sorted_values.back();
2227 const bool contiguous = vmax - vmin + 1 == sorted_values.size();
2228
2229 min_.SetValue(solver(), vmin);
2230 old_min_ = vmin;
2231 new_min_ = vmin;
2232 max_.SetValue(solver(), vmax);
2233 old_max_ = vmax;
2234 new_max_ = vmax;
2235
2236 if (!contiguous) {
2237 if (vmax - vmin + 1 < 65) {
2238 bits_ = solver()->RevAlloc(
2239 new SmallBitSet(solver(), sorted_values, vmin, vmax));
2240 } else {
2241 bits_ = solver()->RevAlloc(
2242 new SimpleBitSet(solver(), sorted_values, vmin, vmax));
2243 }
2244 }
2245}
2246
2247DomainIntVar::~DomainIntVar() {}
2248
2249void DomainIntVar::SetMin(int64_t m) {
2250 if (m <= min_.Value()) return;
2251 if (m > max_.Value()) solver()->Fail();
2252 if (in_process_) {
2253 if (m > new_min_) {
2254 new_min_ = m;
2255 if (new_min_ > new_max_) {
2256 solver()->Fail();
2257 }
2258 }
2259 } else {
2260 CheckOldMin();
2261 const int64_t new_min =
2262 (bits_ == nullptr
2263 ? m
2264 : bits_->ComputeNewMin(m, min_.Value(), max_.Value()));
2265 min_.SetValue(solver(), new_min);
2266 if (min_.Value() > max_.Value()) {
2267 solver()->Fail();
2268 }
2269 Push();
2270 }
2271}
2272
2273void DomainIntVar::SetMax(int64_t m) {
2274 if (m >= max_.Value()) return;
2275 if (m < min_.Value()) solver()->Fail();
2276 if (in_process_) {
2277 if (m < new_max_) {
2278 new_max_ = m;
2279 if (new_max_ < new_min_) {
2280 solver()->Fail();
2281 }
2282 }
2283 } else {
2284 CheckOldMax();
2285 const int64_t new_max =
2286 (bits_ == nullptr
2287 ? m
2288 : bits_->ComputeNewMax(m, min_.Value(), max_.Value()));
2289 max_.SetValue(solver(), new_max);
2290 if (min_.Value() > max_.Value()) {
2291 solver()->Fail();
2292 }
2293 Push();
2294 }
2295}
2296
2297void DomainIntVar::SetRange(int64_t mi, int64_t ma) {
2298 if (mi == ma) {
2299 SetValue(mi);
2300 } else {
2301 if (mi > ma || mi > max_.Value() || ma < min_.Value()) solver()->Fail();
2302 if (mi <= min_.Value() && ma >= max_.Value()) return;
2303 if (in_process_) {
2304 if (ma < new_max_) {
2305 new_max_ = ma;
2306 }
2307 if (mi > new_min_) {
2308 new_min_ = mi;
2309 }
2310 if (new_min_ > new_max_) {
2311 solver()->Fail();
2312 }
2313 } else {
2314 if (mi > min_.Value()) {
2315 CheckOldMin();
2316 const int64_t new_min =
2317 (bits_ == nullptr
2318 ? mi
2319 : bits_->ComputeNewMin(mi, min_.Value(), max_.Value()));
2320 min_.SetValue(solver(), new_min);
2321 }
2322 if (min_.Value() > ma) {
2323 solver()->Fail();
2324 }
2325 if (ma < max_.Value()) {
2326 CheckOldMax();
2327 const int64_t new_max =
2328 (bits_ == nullptr
2329 ? ma
2330 : bits_->ComputeNewMax(ma, min_.Value(), max_.Value()));
2331 max_.SetValue(solver(), new_max);
2332 }
2333 if (min_.Value() > max_.Value()) {
2334 solver()->Fail();
2335 }
2336 Push();
2337 }
2338 }
2339}
2340
2341void DomainIntVar::SetValue(int64_t v) {
2342 if (v != min_.Value() || v != max_.Value()) {
2343 if (v < min_.Value() || v > max_.Value()) {
2344 solver()->Fail();
2345 }
2346 if (in_process_) {
2347 if (v > new_max_ || v < new_min_) {
2348 solver()->Fail();
2349 }
2350 new_min_ = v;
2351 new_max_ = v;
2352 } else {
2353 if (bits_ && !bits_->SetValue(v)) {
2354 solver()->Fail();
2355 }
2356 CheckOldMin();
2357 CheckOldMax();
2358 min_.SetValue(solver(), v);
2359 max_.SetValue(solver(), v);
2360 Push();
2361 }
2362 }
2363}
2364
2365void DomainIntVar::RemoveValue(int64_t v) {
2366 if (v < min_.Value() || v > max_.Value()) return;
2367 if (v == min_.Value()) {
2368 SetMin(v + 1);
2369 } else if (v == max_.Value()) {
2370 SetMax(v - 1);
2371 } else {
2372 if (bits_ == nullptr) {
2373 CreateBits();
2374 }
2375 if (in_process_) {
2376 if (v >= new_min_ && v <= new_max_ && bits_->Contains(v)) {
2377 bits_->DelayRemoveValue(v);
2378 }
2379 } else {
2380 if (bits_->RemoveValue(v)) {
2381 Push();
2382 }
2383 }
2384 }
2385}
2386
2387void DomainIntVar::RemoveInterval(int64_t l, int64_t u) {
2388 if (l <= min_.Value()) {
2389 SetMin(u + 1);
2390 } else if (u >= max_.Value()) {
2391 SetMax(l - 1);
2392 } else {
2393 for (int64_t v = l; v <= u; ++v) {
2394 RemoveValue(v);
2395 }
2396 }
2397}
2398
2399void DomainIntVar::CreateBits() {
2400 solver()->SaveValue(reinterpret_cast<void**>(&bits_));
2401 if (max_.Value() - min_.Value() < 64) {
2402 bits_ = solver()->RevAlloc(
2403 new SmallBitSet(solver(), min_.Value(), max_.Value()));
2404 } else {
2405 bits_ = solver()->RevAlloc(
2406 new SimpleBitSet(solver(), min_.Value(), max_.Value()));
2407 }
2408}
2409
2410void DomainIntVar::CleanInProcess() {
2411 in_process_ = false;
2412 if (bits_ != nullptr) {
2413 bits_->ClearHoles();
2414 }
2415}
2416
2417void DomainIntVar::Push() {
2418 const bool in_process = in_process_;
2419 EnqueueVar(&handler_);
2420 CHECK_EQ(in_process, in_process_);
2421}
2422
2423void DomainIntVar::Process() {
2425 in_process_ = true;
2426 if (bits_ != nullptr) {
2427 bits_->ClearRemovedValues();
2428 }
2429 set_variable_to_clean_on_fail(this);
2430 new_min_ = min_.Value();
2431 new_max_ = max_.Value();
2432 const bool is_bound = min_.Value() == max_.Value();
2433 const bool range_changed =
2434 min_.Value() != OldMin() || max_.Value() != OldMax();
2435 // Process immediate demons.
2436 if (is_bound) {
2437 ExecuteAll(bound_demons_);
2438 }
2439 if (range_changed) {
2440 ExecuteAll(range_demons_);
2441 }
2442 ExecuteAll(domain_demons_);
2443
2444 // Process delayed demons.
2445 if (is_bound) {
2446 EnqueueAll(delayed_bound_demons_);
2447 }
2448 if (range_changed) {
2449 EnqueueAll(delayed_range_demons_);
2450 }
2451 EnqueueAll(delayed_domain_demons_);
2452
2453 // Everything went well if we arrive here. Let's clean the variable.
2454 set_variable_to_clean_on_fail(nullptr);
2455 CleanInProcess();
2456 old_min_ = min_.Value();
2457 old_max_ = max_.Value();
2458 if (min_.Value() < new_min_) {
2459 SetMin(new_min_);
2460 }
2461 if (max_.Value() > new_max_) {
2462 SetMax(new_max_);
2463 }
2464 if (bits_ != nullptr) {
2465 bits_->ApplyRemovedValues(this);
2466 }
2467}
2468
2469#define COND_REV_ALLOC(rev, alloc) rev ? solver()->RevAlloc(alloc) : alloc;
2470
2471IntVarIterator* DomainIntVar::MakeHoleIterator(bool reversible) const {
2472 return COND_REV_ALLOC(reversible, new DomainIntVarHoleIterator(this));
2473}
2474
2475IntVarIterator* DomainIntVar::MakeDomainIterator(bool reversible) const {
2476 return COND_REV_ALLOC(reversible,
2477 new DomainIntVarDomainIterator(this, reversible));
2478}
2479
2480std::string DomainIntVar::DebugString() const {
2481 std::string out;
2482 const std::string& var_name = name();
2483 if (!var_name.empty()) {
2484 out = var_name + "(";
2485 } else {
2486 out = "DomainIntVar(";
2487 }
2488 if (min_.Value() == max_.Value()) {
2489 absl::StrAppendFormat(&out, "%d", min_.Value());
2490 } else if (bits_ != nullptr) {
2491 out.append(bits_->pretty_DebugString(min_.Value(), max_.Value()));
2492 } else {
2493 absl::StrAppendFormat(&out, "%d..%d", min_.Value(), max_.Value());
2494 }
2495 out += ")";
2496 return out;
2497}
2498
2499// ----- Real Boolean Var -----
2500
2501class ConcreteBooleanVar : public BooleanVar {
2502 public:
2503 // Utility classes
2504 class Handler : public Demon {
2505 public:
2506 explicit Handler(ConcreteBooleanVar* const var) : Demon(), var_(var) {}
2507 ~Handler() override {}
2508 void Run(Solver* const s) override {
2509 s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
2510 var_->Process();
2511 s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
2512 }
2513 Solver::DemonPriority priority() const override {
2514 return Solver::VAR_PRIORITY;
2515 }
2516 std::string DebugString() const override {
2517 return absl::StrFormat("Handler(%s)", var_->DebugString());
2518 }
2519
2520 private:
2521 ConcreteBooleanVar* const var_;
2522 };
2523
2524 ConcreteBooleanVar(Solver* const s, const std::string& name)
2525 : BooleanVar(s, name), handler_(this) {}
2526
2527 ~ConcreteBooleanVar() override {}
2528
2529 void SetValue(int64_t v) override {
2530 if (value_ == kUnboundBooleanVarValue) {
2531 if ((v & 0xfffffffffffffffe) == 0) {
2532 InternalSaveBooleanVarValue(solver(), this);
2533 value_ = static_cast<int>(v);
2534 EnqueueVar(&handler_);
2535 return;
2536 }
2537 } else if (v == value_) {
2538 return;
2539 }
2540 solver()->Fail();
2541 }
2542
2543 void Process() {
2544 DCHECK_NE(value_, kUnboundBooleanVarValue);
2545 ExecuteAll(bound_demons_);
2546 for (SimpleRevFIFO<Demon*>::Iterator it(&delayed_bound_demons_); it.ok();
2547 ++it) {
2548 EnqueueDelayedDemon(*it);
2549 }
2550 }
2551
2552 int64_t OldMin() const override { return 0LL; }
2553 int64_t OldMax() const override { return 1LL; }
2554 void RestoreValue() override { value_ = kUnboundBooleanVarValue; }
2555
2556 private:
2557 Handler handler_;
2558};
2559
2560// ----- IntConst -----
2561
2562class IntConst : public IntVar {
2563 public:
2564 IntConst(Solver* const s, int64_t value, const std::string& name = "")
2565 : IntVar(s, name), value_(value) {}
2566 ~IntConst() override {}
2567
2568 int64_t Min() const override { return value_; }
2569 void SetMin(int64_t m) override {
2570 if (m > value_) {
2571 solver()->Fail();
2572 }
2573 }
2574 int64_t Max() const override { return value_; }
2575 void SetMax(int64_t m) override {
2576 if (m < value_) {
2577 solver()->Fail();
2578 }
2579 }
2580 void SetRange(int64_t l, int64_t u) override {
2581 if (l > value_ || u < value_) {
2582 solver()->Fail();
2583 }
2584 }
2585 void SetValue(int64_t v) override {
2586 if (v != value_) {
2587 solver()->Fail();
2588 }
2589 }
2590 bool Bound() const override { return true; }
2591 int64_t Value() const override { return value_; }
2592 void RemoveValue(int64_t v) override {
2593 if (v == value_) {
2594 solver()->Fail();
2595 }
2596 }
2597 void RemoveInterval(int64_t l, int64_t u) override {
2598 if (l <= value_ && value_ <= u) {
2599 solver()->Fail();
2600 }
2601 }
2602 void WhenBound(Demon* d) override {}
2603 void WhenRange(Demon* d) override {}
2604 void WhenDomain(Demon* d) override {}
2605 uint64_t Size() const override { return 1; }
2606 bool Contains(int64_t v) const override { return (v == value_); }
2607 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2608 return COND_REV_ALLOC(reversible, new EmptyIterator());
2609 }
2610 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2611 return COND_REV_ALLOC(reversible, new RangeIterator(this));
2612 }
2613 int64_t OldMin() const override { return value_; }
2614 int64_t OldMax() const override { return value_; }
2615 std::string DebugString() const override {
2616 std::string out;
2617 if (solver()->HasName(this)) {
2618 const std::string& var_name = name();
2619 absl::StrAppendFormat(&out, "%s(%d)", var_name, value_);
2620 } else {
2621 absl::StrAppendFormat(&out, "IntConst(%d)", value_);
2622 }
2623 return out;
2624 }
2625
2626 int VarType() const override { return CONST_VAR; }
2627
2628 IntVar* IsEqual(int64_t constant) override {
2629 if (constant == value_) {
2630 return solver()->MakeIntConst(1);
2631 } else {
2632 return solver()->MakeIntConst(0);
2633 }
2634 }
2635
2636 IntVar* IsDifferent(int64_t constant) override {
2637 if (constant == value_) {
2638 return solver()->MakeIntConst(0);
2639 } else {
2640 return solver()->MakeIntConst(1);
2641 }
2642 }
2643
2644 IntVar* IsGreaterOrEqual(int64_t constant) override {
2645 return solver()->MakeIntConst(value_ >= constant);
2646 }
2647
2648 IntVar* IsLessOrEqual(int64_t constant) override {
2649 return solver()->MakeIntConst(value_ <= constant);
2650 }
2651
2652 std::string name() const override {
2653 if (solver()->HasName(this)) {
2655 } else {
2656 return absl::StrCat(value_);
2657 }
2658 }
2659
2660 private:
2661 int64_t value_;
2662};
2663
2664// ----- x + c variable, optimized case -----
2665
2666class PlusCstVar : public IntVar {
2667 public:
2668 PlusCstVar(Solver* const s, IntVar* v, int64_t c)
2669 : IntVar(s), var_(v), cst_(c) {}
2670
2671 ~PlusCstVar() override {}
2672
2673 void WhenRange(Demon* d) override { var_->WhenRange(d); }
2674
2675 void WhenBound(Demon* d) override { var_->WhenBound(d); }
2676
2677 void WhenDomain(Demon* d) override { var_->WhenDomain(d); }
2678
2679 int64_t OldMin() const override { return CapAdd(var_->OldMin(), cst_); }
2680
2681 int64_t OldMax() const override { return CapAdd(var_->OldMax(), cst_); }
2682
2683 std::string DebugString() const override {
2684 if (HasName()) {
2685 return absl::StrFormat("%s(%s + %d)", name(), var_->DebugString(), cst_);
2686 } else {
2687 return absl::StrFormat("(%s + %d)", var_->DebugString(), cst_);
2688 }
2689 }
2690
2691 int VarType() const override { return VAR_ADD_CST; }
2692
2693 void Accept(ModelVisitor* const visitor) const override {
2694 visitor->VisitIntegerVariable(this, ModelVisitor::kSumOperation, cst_,
2695 var_);
2696 }
2697
2698 IntVar* IsEqual(int64_t constant) override {
2699 return var_->IsEqual(constant - cst_);
2700 }
2701
2702 IntVar* IsDifferent(int64_t constant) override {
2703 return var_->IsDifferent(constant - cst_);
2704 }
2705
2706 IntVar* IsGreaterOrEqual(int64_t constant) override {
2707 return var_->IsGreaterOrEqual(constant - cst_);
2708 }
2709
2710 IntVar* IsLessOrEqual(int64_t constant) override {
2711 return var_->IsLessOrEqual(constant - cst_);
2712 }
2713
2714 IntVar* SubVar() const { return var_; }
2715
2716 int64_t Constant() const { return cst_; }
2717
2718 protected:
2719 IntVar* const var_;
2720 const int64_t cst_;
2721};
2722
2723class PlusCstIntVar : public PlusCstVar {
2724 public:
2725 class PlusCstIntVarIterator : public UnaryIterator {
2726 public:
2727 PlusCstIntVarIterator(const IntVar* const v, int64_t c, bool hole, bool rev)
2728 : UnaryIterator(v, hole, rev), cst_(c) {}
2729
2730 ~PlusCstIntVarIterator() override {}
2731
2732 int64_t Value() const override { return iterator_->Value() + cst_; }
2733
2734 private:
2735 const int64_t cst_;
2736 };
2737
2738 PlusCstIntVar(Solver* const s, IntVar* v, int64_t c) : PlusCstVar(s, v, c) {}
2739
2740 ~PlusCstIntVar() override {}
2741
2742 int64_t Min() const override { return var_->Min() + cst_; }
2743
2744 void SetMin(int64_t m) override { var_->SetMin(CapSub(m, cst_)); }
2745
2746 int64_t Max() const override { return var_->Max() + cst_; }
2747
2748 void SetMax(int64_t m) override { var_->SetMax(CapSub(m, cst_)); }
2749
2750 void SetRange(int64_t l, int64_t u) override {
2751 var_->SetRange(CapSub(l, cst_), CapSub(u, cst_));
2752 }
2753
2754 void SetValue(int64_t v) override { var_->SetValue(v - cst_); }
2755
2756 int64_t Value() const override { return var_->Value() + cst_; }
2757
2758 bool Bound() const override { return var_->Bound(); }
2759
2760 void RemoveValue(int64_t v) override { var_->RemoveValue(v - cst_); }
2761
2762 void RemoveInterval(int64_t l, int64_t u) override {
2763 var_->RemoveInterval(l - cst_, u - cst_);
2764 }
2765
2766 uint64_t Size() const override { return var_->Size(); }
2767
2768 bool Contains(int64_t v) const override { return var_->Contains(v - cst_); }
2769
2770 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2771 return COND_REV_ALLOC(
2772 reversible, new PlusCstIntVarIterator(var_, cst_, true, reversible));
2773 }
2774 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2775 return COND_REV_ALLOC(
2776 reversible, new PlusCstIntVarIterator(var_, cst_, false, reversible));
2777 }
2778};
2779
2780class PlusCstDomainIntVar : public PlusCstVar {
2781 public:
2782 class PlusCstDomainIntVarIterator : public UnaryIterator {
2783 public:
2784 PlusCstDomainIntVarIterator(const IntVar* const v, int64_t c, bool hole,
2785 bool reversible)
2786 : UnaryIterator(v, hole, reversible), cst_(c) {}
2787
2788 ~PlusCstDomainIntVarIterator() override {}
2789
2790 int64_t Value() const override { return iterator_->Value() + cst_; }
2791
2792 private:
2793 const int64_t cst_;
2794 };
2795
2796 PlusCstDomainIntVar(Solver* const s, DomainIntVar* v, int64_t c)
2797 : PlusCstVar(s, v, c) {}
2798
2799 ~PlusCstDomainIntVar() override {}
2800
2801 int64_t Min() const override;
2802 void SetMin(int64_t m) override;
2803 int64_t Max() const override;
2804 void SetMax(int64_t m) override;
2805 void SetRange(int64_t l, int64_t u) override;
2806 void SetValue(int64_t v) override;
2807 bool Bound() const override;
2808 int64_t Value() const override;
2809 void RemoveValue(int64_t v) override;
2810 void RemoveInterval(int64_t l, int64_t u) override;
2811 uint64_t Size() const override;
2812 bool Contains(int64_t v) const override;
2813
2814 DomainIntVar* domain_int_var() const {
2815 return reinterpret_cast<DomainIntVar*>(var_);
2816 }
2817
2818 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2819 return COND_REV_ALLOC(reversible, new PlusCstDomainIntVarIterator(
2820 var_, cst_, true, reversible));
2821 }
2822 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2823 return COND_REV_ALLOC(reversible, new PlusCstDomainIntVarIterator(
2824 var_, cst_, false, reversible));
2825 }
2826};
2827
2828int64_t PlusCstDomainIntVar::Min() const {
2829 return domain_int_var()->min_.Value() + cst_;
2830}
2831
2832void PlusCstDomainIntVar::SetMin(int64_t m) {
2833 domain_int_var()->DomainIntVar::SetMin(m - cst_);
2834}
2835
2836int64_t PlusCstDomainIntVar::Max() const {
2837 return domain_int_var()->max_.Value() + cst_;
2838}
2839
2840void PlusCstDomainIntVar::SetMax(int64_t m) {
2841 domain_int_var()->DomainIntVar::SetMax(m - cst_);
2842}
2843
2844void PlusCstDomainIntVar::SetRange(int64_t l, int64_t u) {
2845 domain_int_var()->DomainIntVar::SetRange(l - cst_, u - cst_);
2846}
2847
2848void PlusCstDomainIntVar::SetValue(int64_t v) {
2849 domain_int_var()->DomainIntVar::SetValue(v - cst_);
2850}
2851
2852bool PlusCstDomainIntVar::Bound() const {
2853 return domain_int_var()->min_.Value() == domain_int_var()->max_.Value();
2854}
2855
2856int64_t PlusCstDomainIntVar::Value() const {
2857 CHECK_EQ(domain_int_var()->min_.Value(), domain_int_var()->max_.Value())
2858 << " variable is not bound";
2859 return domain_int_var()->min_.Value() + cst_;
2860}
2861
2862void PlusCstDomainIntVar::RemoveValue(int64_t v) {
2863 domain_int_var()->DomainIntVar::RemoveValue(v - cst_);
2864}
2865
2866void PlusCstDomainIntVar::RemoveInterval(int64_t l, int64_t u) {
2867 domain_int_var()->DomainIntVar::RemoveInterval(l - cst_, u - cst_);
2868}
2869
2870uint64_t PlusCstDomainIntVar::Size() const {
2871 return domain_int_var()->DomainIntVar::Size();
2872}
2873
2874bool PlusCstDomainIntVar::Contains(int64_t v) const {
2875 return domain_int_var()->DomainIntVar::Contains(v - cst_);
2876}
2877
2878// c - x variable, optimized case
2879
2880class SubCstIntVar : public IntVar {
2881 public:
2882 class SubCstIntVarIterator : public UnaryIterator {
2883 public:
2884 SubCstIntVarIterator(const IntVar* const v, int64_t c, bool hole, bool rev)
2885 : UnaryIterator(v, hole, rev), cst_(c) {}
2886 ~SubCstIntVarIterator() override {}
2887
2888 int64_t Value() const override { return cst_ - iterator_->Value(); }
2889
2890 private:
2891 const int64_t cst_;
2892 };
2893
2894 SubCstIntVar(Solver* const s, IntVar* v, int64_t c);
2895 ~SubCstIntVar() override;
2896
2897 int64_t Min() const override;
2898 void SetMin(int64_t m) override;
2899 int64_t Max() const override;
2900 void SetMax(int64_t m) override;
2901 void SetRange(int64_t l, int64_t u) override;
2902 void SetValue(int64_t v) override;
2903 bool Bound() const override;
2904 int64_t Value() const override;
2905 void RemoveValue(int64_t v) override;
2906 void RemoveInterval(int64_t l, int64_t u) override;
2907 uint64_t Size() const override;
2908 bool Contains(int64_t v) const override;
2909 void WhenRange(Demon* d) override;
2910 void WhenBound(Demon* d) override;
2911 void WhenDomain(Demon* d) override;
2912 IntVarIterator* MakeHoleIterator(bool reversible) const override {
2913 return COND_REV_ALLOC(
2914 reversible, new SubCstIntVarIterator(var_, cst_, true, reversible));
2915 }
2916 IntVarIterator* MakeDomainIterator(bool reversible) const override {
2917 return COND_REV_ALLOC(
2918 reversible, new SubCstIntVarIterator(var_, cst_, false, reversible));
2919 }
2920 int64_t OldMin() const override { return CapSub(cst_, var_->OldMax()); }
2921 int64_t OldMax() const override { return CapSub(cst_, var_->OldMin()); }
2922 std::string DebugString() const override;
2923 std::string name() const override;
2924 int VarType() const override { return CST_SUB_VAR; }
2925
2926 void Accept(ModelVisitor* const visitor) const override {
2927 visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation,
2928 cst_, var_);
2929 }
2930
2931 IntVar* IsEqual(int64_t constant) override {
2932 return var_->IsEqual(cst_ - constant);
2933 }
2934
2935 IntVar* IsDifferent(int64_t constant) override {
2936 return var_->IsDifferent(cst_ - constant);
2937 }
2938
2939 IntVar* IsGreaterOrEqual(int64_t constant) override {
2940 return var_->IsLessOrEqual(cst_ - constant);
2941 }
2942
2943 IntVar* IsLessOrEqual(int64_t constant) override {
2944 return var_->IsGreaterOrEqual(cst_ - constant);
2945 }
2946
2947 IntVar* SubVar() const { return var_; }
2948 int64_t Constant() const { return cst_; }
2949
2950 private:
2951 IntVar* const var_;
2952 const int64_t cst_;
2953};
2954
2955SubCstIntVar::SubCstIntVar(Solver* const s, IntVar* v, int64_t c)
2956 : IntVar(s), var_(v), cst_(c) {}
2957
2958SubCstIntVar::~SubCstIntVar() {}
2959
2960int64_t SubCstIntVar::Min() const { return cst_ - var_->Max(); }
2961
2962void SubCstIntVar::SetMin(int64_t m) { var_->SetMax(CapSub(cst_, m)); }
2963
2964int64_t SubCstIntVar::Max() const { return cst_ - var_->Min(); }
2965
2966void SubCstIntVar::SetMax(int64_t m) { var_->SetMin(CapSub(cst_, m)); }
2967
2968void SubCstIntVar::SetRange(int64_t l, int64_t u) {
2969 var_->SetRange(CapSub(cst_, u), CapSub(cst_, l));
2970}
2971
2972void SubCstIntVar::SetValue(int64_t v) { var_->SetValue(cst_ - v); }
2973
2974bool SubCstIntVar::Bound() const { return var_->Bound(); }
2975
2976void SubCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
2977
2978int64_t SubCstIntVar::Value() const { return cst_ - var_->Value(); }
2979
2980void SubCstIntVar::RemoveValue(int64_t v) { var_->RemoveValue(cst_ - v); }
2981
2982void SubCstIntVar::RemoveInterval(int64_t l, int64_t u) {
2983 var_->RemoveInterval(cst_ - u, cst_ - l);
2984}
2985
2986void SubCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
2987
2988void SubCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
2989
2990uint64_t SubCstIntVar::Size() const { return var_->Size(); }
2991
2992bool SubCstIntVar::Contains(int64_t v) const {
2993 return var_->Contains(cst_ - v);
2994}
2995
2996std::string SubCstIntVar::DebugString() const {
2997 if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
2998 return absl::StrFormat("Not(%s)", var_->DebugString());
2999 } else {
3000 return absl::StrFormat("(%d - %s)", cst_, var_->DebugString());
3001 }
3002}
3003
3004std::string SubCstIntVar::name() const {
3005 if (solver()->HasName(this)) {
3007 } else if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
3008 return absl::StrFormat("Not(%s)", var_->name());
3009 } else {
3010 return absl::StrFormat("(%d - %s)", cst_, var_->name());
3011 }
3012}
3013
3014// -x variable, optimized case
3015
3016class OppIntVar : public IntVar {
3017 public:
3018 class OppIntVarIterator : public UnaryIterator {
3019 public:
3020 OppIntVarIterator(const IntVar* const v, bool hole, bool reversible)
3021 : UnaryIterator(v, hole, reversible) {}
3022 ~OppIntVarIterator() override {}
3023
3024 int64_t Value() const override { return -iterator_->Value(); }
3025 };
3026
3027 OppIntVar(Solver* const s, IntVar* v);
3028 ~OppIntVar() override;
3029
3030 int64_t Min() const override;
3031 void SetMin(int64_t m) override;
3032 int64_t Max() const override;
3033 void SetMax(int64_t m) override;
3034 void SetRange(int64_t l, int64_t u) override;
3035 void SetValue(int64_t v) override;
3036 bool Bound() const override;
3037 int64_t Value() const override;
3038 void RemoveValue(int64_t v) override;
3039 void RemoveInterval(int64_t l, int64_t u) override;
3040 uint64_t Size() const override;
3041 bool Contains(int64_t v) const override;
3042 void WhenRange(Demon* d) override;
3043 void WhenBound(Demon* d) override;
3044 void WhenDomain(Demon* d) override;
3045 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3046 return COND_REV_ALLOC(reversible,
3047 new OppIntVarIterator(var_, true, reversible));
3048 }
3049 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3050 return COND_REV_ALLOC(reversible,
3051 new OppIntVarIterator(var_, false, reversible));
3052 }
3053 int64_t OldMin() const override { return CapOpp(var_->OldMax()); }
3054 int64_t OldMax() const override { return CapOpp(var_->OldMin()); }
3055 std::string DebugString() const override;
3056 int VarType() const override { return OPP_VAR; }
3057
3058 void Accept(ModelVisitor* const visitor) const override {
3059 visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation, 0,
3060 var_);
3061 }
3062
3063 IntVar* IsEqual(int64_t constant) override {
3064 return var_->IsEqual(-constant);
3065 }
3066
3067 IntVar* IsDifferent(int64_t constant) override {
3068 return var_->IsDifferent(-constant);
3069 }
3070
3071 IntVar* IsGreaterOrEqual(int64_t constant) override {
3072 return var_->IsLessOrEqual(-constant);
3073 }
3074
3075 IntVar* IsLessOrEqual(int64_t constant) override {
3076 return var_->IsGreaterOrEqual(-constant);
3077 }
3078
3079 IntVar* SubVar() const { return var_; }
3080
3081 private:
3082 IntVar* const var_;
3083};
3084
3085OppIntVar::OppIntVar(Solver* const s, IntVar* v) : IntVar(s), var_(v) {}
3086
3087OppIntVar::~OppIntVar() {}
3088
3089int64_t OppIntVar::Min() const { return -var_->Max(); }
3090
3091void OppIntVar::SetMin(int64_t m) { var_->SetMax(CapOpp(m)); }
3092
3093int64_t OppIntVar::Max() const { return -var_->Min(); }
3094
3095void OppIntVar::SetMax(int64_t m) { var_->SetMin(CapOpp(m)); }
3096
3097void OppIntVar::SetRange(int64_t l, int64_t u) {
3098 var_->SetRange(CapOpp(u), CapOpp(l));
3099}
3100
3101void OppIntVar::SetValue(int64_t v) { var_->SetValue(CapOpp(v)); }
3102
3103bool OppIntVar::Bound() const { return var_->Bound(); }
3104
3105void OppIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3106
3107int64_t OppIntVar::Value() const { return -var_->Value(); }
3108
3109void OppIntVar::RemoveValue(int64_t v) { var_->RemoveValue(-v); }
3110
3111void OppIntVar::RemoveInterval(int64_t l, int64_t u) {
3112 var_->RemoveInterval(-u, -l);
3113}
3114
3115void OppIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3116
3117void OppIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3118
3119uint64_t OppIntVar::Size() const { return var_->Size(); }
3120
3121bool OppIntVar::Contains(int64_t v) const { return var_->Contains(-v); }
3122
3123std::string OppIntVar::DebugString() const {
3124 return absl::StrFormat("-(%s)", var_->DebugString());
3125}
3126
3127// ----- Utility functions -----
3128
3129// x * c variable, optimized case
3130
3131class TimesCstIntVar : public IntVar {
3132 public:
3133 TimesCstIntVar(Solver* const s, IntVar* v, int64_t c)
3134 : IntVar(s), var_(v), cst_(c) {}
3135 ~TimesCstIntVar() override {}
3136
3137 IntVar* SubVar() const { return var_; }
3138 int64_t Constant() const { return cst_; }
3139
3140 void Accept(ModelVisitor* const visitor) const override {
3141 visitor->VisitIntegerVariable(this, ModelVisitor::kProductOperation, cst_,
3142 var_);
3143 }
3144
3145 IntVar* IsEqual(int64_t constant) override {
3146 if (constant % cst_ == 0) {
3147 return var_->IsEqual(constant / cst_);
3148 } else {
3149 return solver()->MakeIntConst(0);
3150 }
3151 }
3152
3153 IntVar* IsDifferent(int64_t constant) override {
3154 if (constant % cst_ == 0) {
3155 return var_->IsDifferent(constant / cst_);
3156 } else {
3157 return solver()->MakeIntConst(1);
3158 }
3159 }
3160
3161 IntVar* IsGreaterOrEqual(int64_t constant) override {
3162 if (cst_ > 0) {
3163 return var_->IsGreaterOrEqual(PosIntDivUp(constant, cst_));
3164 } else {
3165 return var_->IsLessOrEqual(PosIntDivDown(-constant, -cst_));
3166 }
3167 }
3168
3169 IntVar* IsLessOrEqual(int64_t constant) override {
3170 if (cst_ > 0) {
3171 return var_->IsLessOrEqual(PosIntDivDown(constant, cst_));
3172 } else {
3173 return var_->IsGreaterOrEqual(PosIntDivUp(-constant, -cst_));
3174 }
3175 }
3176
3177 std::string DebugString() const override {
3178 return absl::StrFormat("(%s * %d)", var_->DebugString(), cst_);
3179 }
3180
3181 int VarType() const override { return VAR_TIMES_CST; }
3182
3183 protected:
3184 IntVar* const var_;
3185 const int64_t cst_;
3186};
3187
3188class TimesPosCstIntVar : public TimesCstIntVar {
3189 public:
3190 class TimesPosCstIntVarIterator : public UnaryIterator {
3191 public:
3192 TimesPosCstIntVarIterator(const IntVar* const v, int64_t c, bool hole,
3193 bool reversible)
3194 : UnaryIterator(v, hole, reversible), cst_(c) {}
3195 ~TimesPosCstIntVarIterator() override {}
3196
3197 int64_t Value() const override { return iterator_->Value() * cst_; }
3198
3199 private:
3200 const int64_t cst_;
3201 };
3202
3203 TimesPosCstIntVar(Solver* const s, IntVar* v, int64_t c);
3204 ~TimesPosCstIntVar() override;
3205
3206 int64_t Min() const override;
3207 void SetMin(int64_t m) override;
3208 int64_t Max() const override;
3209 void SetMax(int64_t m) override;
3210 void SetRange(int64_t l, int64_t u) override;
3211 void SetValue(int64_t v) override;
3212 bool Bound() const override;
3213 int64_t Value() const override;
3214 void RemoveValue(int64_t v) override;
3215 void RemoveInterval(int64_t l, int64_t u) override;
3216 uint64_t Size() const override;
3217 bool Contains(int64_t v) const override;
3218 void WhenRange(Demon* d) override;
3219 void WhenBound(Demon* d) override;
3220 void WhenDomain(Demon* d) override;
3221 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3222 return COND_REV_ALLOC(reversible, new TimesPosCstIntVarIterator(
3223 var_, cst_, true, reversible));
3224 }
3225 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3226 return COND_REV_ALLOC(reversible, new TimesPosCstIntVarIterator(
3227 var_, cst_, false, reversible));
3228 }
3229 int64_t OldMin() const override { return CapProd(var_->OldMin(), cst_); }
3230 int64_t OldMax() const override { return CapProd(var_->OldMax(), cst_); }
3231};
3232
3233// ----- TimesPosCstIntVar -----
3234
3235TimesPosCstIntVar::TimesPosCstIntVar(Solver* const s, IntVar* v, int64_t c)
3236 : TimesCstIntVar(s, v, c) {}
3237
3238TimesPosCstIntVar::~TimesPosCstIntVar() {}
3239
3240int64_t TimesPosCstIntVar::Min() const { return CapProd(var_->Min(), cst_); }
3241
3242void TimesPosCstIntVar::SetMin(int64_t m) {
3244 var_->SetMin(PosIntDivUp(m, cst_));
3245 }
3246}
3247
3248int64_t TimesPosCstIntVar::Max() const { return CapProd(var_->Max(), cst_); }
3249
3250void TimesPosCstIntVar::SetMax(int64_t m) {
3252 var_->SetMax(PosIntDivDown(m, cst_));
3253 }
3254}
3255
3256void TimesPosCstIntVar::SetRange(int64_t l, int64_t u) {
3257 var_->SetRange(PosIntDivUp(l, cst_), PosIntDivDown(u, cst_));
3258}
3259
3260void TimesPosCstIntVar::SetValue(int64_t v) {
3261 if (v % cst_ != 0) {
3262 solver()->Fail();
3263 }
3264 var_->SetValue(v / cst_);
3265}
3266
3267bool TimesPosCstIntVar::Bound() const { return var_->Bound(); }
3268
3269void TimesPosCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3270
3271int64_t TimesPosCstIntVar::Value() const {
3272 return CapProd(var_->Value(), cst_);
3273}
3274
3275void TimesPosCstIntVar::RemoveValue(int64_t v) {
3276 if (v % cst_ == 0) {
3277 var_->RemoveValue(v / cst_);
3278 }
3279}
3280
3281void TimesPosCstIntVar::RemoveInterval(int64_t l, int64_t u) {
3282 for (int64_t v = l; v <= u; ++v) {
3283 RemoveValue(v);
3284 }
3285 // TODO(user) : Improve me
3286}
3287
3288void TimesPosCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3289
3290void TimesPosCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3291
3292uint64_t TimesPosCstIntVar::Size() const { return var_->Size(); }
3293
3294bool TimesPosCstIntVar::Contains(int64_t v) const {
3295 return (v % cst_ == 0 && var_->Contains(v / cst_));
3296}
3297
3298// b * c variable, optimized case
3299
3300class TimesPosCstBoolVar : public TimesCstIntVar {
3301 public:
3302 class TimesPosCstBoolVarIterator : public UnaryIterator {
3303 public:
3304 // TODO(user) : optimize this.
3305 TimesPosCstBoolVarIterator(const IntVar* const v, int64_t c, bool hole,
3306 bool reversible)
3307 : UnaryIterator(v, hole, reversible), cst_(c) {}
3308 ~TimesPosCstBoolVarIterator() override {}
3309
3310 int64_t Value() const override { return iterator_->Value() * cst_; }
3311
3312 private:
3313 const int64_t cst_;
3314 };
3315
3316 TimesPosCstBoolVar(Solver* const s, BooleanVar* v, int64_t c);
3317 ~TimesPosCstBoolVar() override;
3318
3319 int64_t Min() const override;
3320 void SetMin(int64_t m) override;
3321 int64_t Max() const override;
3322 void SetMax(int64_t m) override;
3323 void SetRange(int64_t l, int64_t u) override;
3324 void SetValue(int64_t v) override;
3325 bool Bound() const override;
3326 int64_t Value() const override;
3327 void RemoveValue(int64_t v) override;
3328 void RemoveInterval(int64_t l, int64_t u) override;
3329 uint64_t Size() const override;
3330 bool Contains(int64_t v) const override;
3331 void WhenRange(Demon* d) override;
3332 void WhenBound(Demon* d) override;
3333 void WhenDomain(Demon* d) override;
3334 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3335 return COND_REV_ALLOC(reversible, new EmptyIterator());
3336 }
3337 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3338 return COND_REV_ALLOC(
3339 reversible,
3340 new TimesPosCstBoolVarIterator(boolean_var(), cst_, false, reversible));
3341 }
3342 int64_t OldMin() const override { return 0; }
3343 int64_t OldMax() const override { return cst_; }
3344
3345 BooleanVar* boolean_var() const {
3346 return reinterpret_cast<BooleanVar*>(var_);
3347 }
3348};
3349
3350// ----- TimesPosCstBoolVar -----
3351
3352TimesPosCstBoolVar::TimesPosCstBoolVar(Solver* const s, BooleanVar* v,
3353 int64_t c)
3354 : TimesCstIntVar(s, v, c) {}
3355
3356TimesPosCstBoolVar::~TimesPosCstBoolVar() {}
3357
3358int64_t TimesPosCstBoolVar::Min() const {
3359 return (boolean_var()->RawValue() == 1) * cst_;
3360}
3361
3362void TimesPosCstBoolVar::SetMin(int64_t m) {
3363 if (m > cst_) {
3364 solver()->Fail();
3365 } else if (m > 0) {
3366 boolean_var()->SetMin(1);
3367 }
3368}
3369
3370int64_t TimesPosCstBoolVar::Max() const {
3371 return (boolean_var()->RawValue() != 0) * cst_;
3372}
3373
3374void TimesPosCstBoolVar::SetMax(int64_t m) {
3375 if (m < 0) {
3376 solver()->Fail();
3377 } else if (m < cst_) {
3378 boolean_var()->SetMax(0);
3379 }
3380}
3381
3382void TimesPosCstBoolVar::SetRange(int64_t l, int64_t u) {
3383 if (u < 0 || l > cst_ || l > u) {
3384 solver()->Fail();
3385 }
3386 if (l > 0) {
3387 boolean_var()->SetMin(1);
3388 } else if (u < cst_) {
3389 boolean_var()->SetMax(0);
3390 }
3391}
3392
3393void TimesPosCstBoolVar::SetValue(int64_t v) {
3394 if (v == 0) {
3395 boolean_var()->SetValue(0);
3396 } else if (v == cst_) {
3397 boolean_var()->SetValue(1);
3398 } else {
3399 solver()->Fail();
3400 }
3401}
3402
3403bool TimesPosCstBoolVar::Bound() const {
3404 return boolean_var()->RawValue() != BooleanVar::kUnboundBooleanVarValue;
3405}
3406
3407void TimesPosCstBoolVar::WhenRange(Demon* d) { boolean_var()->WhenRange(d); }
3408
3409int64_t TimesPosCstBoolVar::Value() const {
3410 CHECK_NE(boolean_var()->RawValue(), BooleanVar::kUnboundBooleanVarValue)
3411 << " variable is not bound";
3412 return boolean_var()->RawValue() * cst_;
3413}
3414
3415void TimesPosCstBoolVar::RemoveValue(int64_t v) {
3416 if (v == 0) {
3417 boolean_var()->RemoveValue(0);
3418 } else if (v == cst_) {
3419 boolean_var()->RemoveValue(1);
3420 }
3421}
3422
3423void TimesPosCstBoolVar::RemoveInterval(int64_t l, int64_t u) {
3424 if (l <= 0 && u >= 0) {
3425 boolean_var()->RemoveValue(0);
3426 }
3427 if (l <= cst_ && u >= cst_) {
3428 boolean_var()->RemoveValue(1);
3429 }
3430}
3431
3432void TimesPosCstBoolVar::WhenBound(Demon* d) { boolean_var()->WhenBound(d); }
3433
3434void TimesPosCstBoolVar::WhenDomain(Demon* d) { boolean_var()->WhenDomain(d); }
3435
3436uint64_t TimesPosCstBoolVar::Size() const {
3437 return (1 +
3438 (boolean_var()->RawValue() == BooleanVar::kUnboundBooleanVarValue));
3439}
3440
3441bool TimesPosCstBoolVar::Contains(int64_t v) const {
3442 if (v == 0) {
3443 return boolean_var()->RawValue() != 1;
3444 } else if (v == cst_) {
3445 return boolean_var()->RawValue() != 0;
3446 }
3447 return false;
3448}
3449
3450// TimesNegCstIntVar
3451
3452class TimesNegCstIntVar : public TimesCstIntVar {
3453 public:
3454 class TimesNegCstIntVarIterator : public UnaryIterator {
3455 public:
3456 TimesNegCstIntVarIterator(const IntVar* const v, int64_t c, bool hole,
3457 bool reversible)
3458 : UnaryIterator(v, hole, reversible), cst_(c) {}
3459 ~TimesNegCstIntVarIterator() override {}
3460
3461 int64_t Value() const override { return iterator_->Value() * cst_; }
3462
3463 private:
3464 const int64_t cst_;
3465 };
3466
3467 TimesNegCstIntVar(Solver* const s, IntVar* v, int64_t c);
3468 ~TimesNegCstIntVar() override;
3469
3470 int64_t Min() const override;
3471 void SetMin(int64_t m) override;
3472 int64_t Max() const override;
3473 void SetMax(int64_t m) override;
3474 void SetRange(int64_t l, int64_t u) override;
3475 void SetValue(int64_t v) override;
3476 bool Bound() const override;
3477 int64_t Value() const override;
3478 void RemoveValue(int64_t v) override;
3479 void RemoveInterval(int64_t l, int64_t u) override;
3480 uint64_t Size() const override;
3481 bool Contains(int64_t v) const override;
3482 void WhenRange(Demon* d) override;
3483 void WhenBound(Demon* d) override;
3484 void WhenDomain(Demon* d) override;
3485 IntVarIterator* MakeHoleIterator(bool reversible) const override {
3486 return COND_REV_ALLOC(reversible, new TimesNegCstIntVarIterator(
3487 var_, cst_, true, reversible));
3488 }
3489 IntVarIterator* MakeDomainIterator(bool reversible) const override {
3490 return COND_REV_ALLOC(reversible, new TimesNegCstIntVarIterator(
3491 var_, cst_, false, reversible));
3492 }
3493 int64_t OldMin() const override { return CapProd(var_->OldMax(), cst_); }
3494 int64_t OldMax() const override { return CapProd(var_->OldMin(), cst_); }
3495};
3496
3497// ----- TimesNegCstIntVar -----
3498
3499TimesNegCstIntVar::TimesNegCstIntVar(Solver* const s, IntVar* v, int64_t c)
3500 : TimesCstIntVar(s, v, c) {}
3501
3502TimesNegCstIntVar::~TimesNegCstIntVar() {}
3503
3504int64_t TimesNegCstIntVar::Min() const { return CapProd(var_->Max(), cst_); }
3505
3506void TimesNegCstIntVar::SetMin(int64_t m) {
3508 var_->SetMax(PosIntDivDown(-m, -cst_));
3509 }
3510}
3511
3512int64_t TimesNegCstIntVar::Max() const { return CapProd(var_->Min(), cst_); }
3513
3514void TimesNegCstIntVar::SetMax(int64_t m) {
3516 var_->SetMin(PosIntDivUp(-m, -cst_));
3517 }
3518}
3519
3520void TimesNegCstIntVar::SetRange(int64_t l, int64_t u) {
3521 var_->SetRange(PosIntDivUp(-u, -cst_), PosIntDivDown(-l, -cst_));
3522}
3523
3524void TimesNegCstIntVar::SetValue(int64_t v) {
3525 if (v % cst_ != 0) {
3526 solver()->Fail();
3527 }
3528 var_->SetValue(v / cst_);
3529}
3530
3531bool TimesNegCstIntVar::Bound() const { return var_->Bound(); }
3532
3533void TimesNegCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3534
3535int64_t TimesNegCstIntVar::Value() const {
3536 return CapProd(var_->Value(), cst_);
3537}
3538
3539void TimesNegCstIntVar::RemoveValue(int64_t v) {
3540 if (v % cst_ == 0) {
3541 var_->RemoveValue(v / cst_);
3542 }
3543}
3544
3545void TimesNegCstIntVar::RemoveInterval(int64_t l, int64_t u) {
3546 for (int64_t v = l; v <= u; ++v) {
3547 RemoveValue(v);
3548 }
3549 // TODO(user) : Improve me
3550}
3551
3552void TimesNegCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3553
3554void TimesNegCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3555
3556uint64_t TimesNegCstIntVar::Size() const { return var_->Size(); }
3557
3558bool TimesNegCstIntVar::Contains(int64_t v) const {
3559 return (v % cst_ == 0 && var_->Contains(v / cst_));
3560}
3561
3562// ---------- arithmetic expressions ----------
3563
3564// ----- PlusIntExpr -----
3565
3566class PlusIntExpr : public BaseIntExpr {
3567 public:
3568 PlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3569 : BaseIntExpr(s), left_(l), right_(r) {}
3570
3571 ~PlusIntExpr() override {}
3572
3573 int64_t Min() const override { return left_->Min() + right_->Min(); }
3574
3575 void SetMin(int64_t m) override {
3576 if (m > left_->Min() + right_->Min()) {
3577 left_->SetMin(m - right_->Max());
3578 right_->SetMin(m - left_->Max());
3579 }
3580 }
3581
3582 void SetRange(int64_t l, int64_t u) override {
3583 const int64_t left_min = left_->Min();
3584 const int64_t right_min = right_->Min();
3585 const int64_t left_max = left_->Max();
3586 const int64_t right_max = right_->Max();
3587 if (l > left_min + right_min) {
3588 left_->SetMin(l - right_max);
3589 right_->SetMin(l - left_max);
3590 }
3591 if (u < left_max + right_max) {
3592 left_->SetMax(u - right_min);
3593 right_->SetMax(u - left_min);
3594 }
3595 }
3596
3597 int64_t Max() const override { return left_->Max() + right_->Max(); }
3598
3599 void SetMax(int64_t m) override {
3600 if (m < left_->Max() + right_->Max()) {
3601 left_->SetMax(m - right_->Min());
3602 right_->SetMax(m - left_->Min());
3603 }
3604 }
3605
3606 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3607
3608 void Range(int64_t* const mi, int64_t* const ma) override {
3609 *mi = left_->Min() + right_->Min();
3610 *ma = left_->Max() + right_->Max();
3611 }
3612
3613 std::string name() const override {
3614 return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3615 }
3616
3617 std::string DebugString() const override {
3618 return absl::StrFormat("(%s + %s)", left_->DebugString(),
3619 right_->DebugString());
3620 }
3621
3622 void WhenRange(Demon* d) override {
3623 left_->WhenRange(d);
3624 right_->WhenRange(d);
3625 }
3626
3627 void ExpandPlusIntExpr(IntExpr* const expr, std::vector<IntExpr*>* subs) {
3628 PlusIntExpr* const casted = dynamic_cast<PlusIntExpr*>(expr);
3629 if (casted != nullptr) {
3630 ExpandPlusIntExpr(casted->left_, subs);
3631 ExpandPlusIntExpr(casted->right_, subs);
3632 } else {
3633 subs->push_back(expr);
3634 }
3635 }
3636
3637 IntVar* CastToVar() override {
3638 if (dynamic_cast<PlusIntExpr*>(left_) != nullptr ||
3639 dynamic_cast<PlusIntExpr*>(right_) != nullptr) {
3640 std::vector<IntExpr*> sub_exprs;
3641 ExpandPlusIntExpr(left_, &sub_exprs);
3642 ExpandPlusIntExpr(right_, &sub_exprs);
3643 if (sub_exprs.size() >= 3) {
3644 std::vector<IntVar*> sub_vars(sub_exprs.size());
3645 for (int i = 0; i < sub_exprs.size(); ++i) {
3646 sub_vars[i] = sub_exprs[i]->Var();
3647 }
3648 return solver()->MakeSum(sub_vars)->Var();
3649 }
3650 }
3651 return BaseIntExpr::CastToVar();
3652 }
3653
3654 void Accept(ModelVisitor* const visitor) const override {
3655 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3656 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3657 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3658 right_);
3659 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3660 }
3661
3662 private:
3663 IntExpr* const left_;
3664 IntExpr* const right_;
3665};
3666
3667class SafePlusIntExpr : public BaseIntExpr {
3668 public:
3669 SafePlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3670 : BaseIntExpr(s), left_(l), right_(r) {}
3671
3672 ~SafePlusIntExpr() override {}
3673
3674 int64_t Min() const override { return CapAdd(left_->Min(), right_->Min()); }
3675
3676 void SetMin(int64_t m) override {
3677 left_->SetMin(CapSub(m, right_->Max()));
3678 right_->SetMin(CapSub(m, left_->Max()));
3679 }
3680
3681 void SetRange(int64_t l, int64_t u) override {
3682 const int64_t left_min = left_->Min();
3683 const int64_t right_min = right_->Min();
3684 const int64_t left_max = left_->Max();
3685 const int64_t right_max = right_->Max();
3686 if (l > CapAdd(left_min, right_min)) {
3687 left_->SetMin(CapSub(l, right_max));
3688 right_->SetMin(CapSub(l, left_max));
3689 }
3690 if (u < CapAdd(left_max, right_max)) {
3691 left_->SetMax(CapSub(u, right_min));
3692 right_->SetMax(CapSub(u, left_min));
3693 }
3694 }
3695
3696 int64_t Max() const override { return CapAdd(left_->Max(), right_->Max()); }
3697
3698 void SetMax(int64_t m) override {
3699 left_->SetMax(CapSub(m, right_->Min()));
3700 right_->SetMax(CapSub(m, left_->Min()));
3701 }
3702
3703 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3704
3705 std::string name() const override {
3706 return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3707 }
3708
3709 std::string DebugString() const override {
3710 return absl::StrFormat("(%s + %s)", left_->DebugString(),
3711 right_->DebugString());
3712 }
3713
3714 void WhenRange(Demon* d) override {
3715 left_->WhenRange(d);
3716 right_->WhenRange(d);
3717 }
3718
3719 void Accept(ModelVisitor* const visitor) const override {
3720 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3721 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3722 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3723 right_);
3724 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3725 }
3726
3727 private:
3728 IntExpr* const left_;
3729 IntExpr* const right_;
3730};
3731
3732// ----- PlusIntCstExpr -----
3733
3734class PlusIntCstExpr : public BaseIntExpr {
3735 public:
3736 PlusIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
3737 : BaseIntExpr(s), expr_(e), value_(v) {}
3738 ~PlusIntCstExpr() override {}
3739 int64_t Min() const override { return CapAdd(expr_->Min(), value_); }
3740 void SetMin(int64_t m) override { expr_->SetMin(CapSub(m, value_)); }
3741 int64_t Max() const override { return CapAdd(expr_->Max(), value_); }
3742 void SetMax(int64_t m) override { expr_->SetMax(CapSub(m, value_)); }
3743 bool Bound() const override { return (expr_->Bound()); }
3744 std::string name() const override {
3745 return absl::StrFormat("(%s + %d)", expr_->name(), value_);
3746 }
3747 std::string DebugString() const override {
3748 return absl::StrFormat("(%s + %d)", expr_->DebugString(), value_);
3749 }
3750 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3751 IntVar* CastToVar() override;
3752 void Accept(ModelVisitor* const visitor) const override {
3753 visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3754 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3755 expr_);
3756 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3757 visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3758 }
3759
3760 private:
3761 IntExpr* const expr_;
3762 const int64_t value_;
3763};
3764
3765IntVar* PlusIntCstExpr::CastToVar() {
3766 Solver* const s = solver();
3767 IntVar* const var = expr_->Var();
3768 IntVar* cast = nullptr;
3769 if (AddOverflows(value_, expr_->Max()) ||
3770 AddOverflows(value_, expr_->Min())) {
3771 return BaseIntExpr::CastToVar();
3772 }
3773 switch (var->VarType()) {
3774 case DOMAIN_INT_VAR:
3775 cast = s->RegisterIntVar(s->RevAlloc(new PlusCstDomainIntVar(
3776 s, reinterpret_cast<DomainIntVar*>(var), value_)));
3777 // FIXME: Break was inserted during fallthrough cleanup. Please check.
3778 break;
3779 default:
3780 cast = s->RegisterIntVar(s->RevAlloc(new PlusCstIntVar(s, var, value_)));
3781 break;
3782 }
3783 return cast;
3784}
3785
3786// ----- SubIntExpr -----
3787
3788class SubIntExpr : public BaseIntExpr {
3789 public:
3790 SubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3791 : BaseIntExpr(s), left_(l), right_(r) {}
3792
3793 ~SubIntExpr() override {}
3794
3795 int64_t Min() const override { return left_->Min() - right_->Max(); }
3796
3797 void SetMin(int64_t m) override {
3798 left_->SetMin(CapAdd(m, right_->Min()));
3799 right_->SetMax(CapSub(left_->Max(), m));
3800 }
3801
3802 int64_t Max() const override { return left_->Max() - right_->Min(); }
3803
3804 void SetMax(int64_t m) override {
3805 left_->SetMax(CapAdd(m, right_->Max()));
3806 right_->SetMin(CapSub(left_->Min(), m));
3807 }
3808
3809 void Range(int64_t* mi, int64_t* ma) override {
3810 *mi = left_->Min() - right_->Max();
3811 *ma = left_->Max() - right_->Min();
3812 }
3813
3814 void SetRange(int64_t l, int64_t u) override {
3815 const int64_t left_min = left_->Min();
3816 const int64_t right_min = right_->Min();
3817 const int64_t left_max = left_->Max();
3818 const int64_t right_max = right_->Max();
3819 if (l > left_min - right_max) {
3820 left_->SetMin(CapAdd(l, right_min));
3821 right_->SetMax(CapSub(left_max, l));
3822 }
3823 if (u < left_max - right_min) {
3824 left_->SetMax(CapAdd(u, right_max));
3825 right_->SetMin(CapSub(left_min, u));
3826 }
3827 }
3828
3829 bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3830
3831 std::string name() const override {
3832 return absl::StrFormat("(%s - %s)", left_->name(), right_->name());
3833 }
3834
3835 std::string DebugString() const override {
3836 return absl::StrFormat("(%s - %s)", left_->DebugString(),
3837 right_->DebugString());
3838 }
3839
3840 void WhenRange(Demon* d) override {
3841 left_->WhenRange(d);
3842 right_->WhenRange(d);
3843 }
3844
3845 void Accept(ModelVisitor* const visitor) const override {
3846 visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3847 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3848 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3849 right_);
3850 visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3851 }
3852
3853 IntExpr* left() const { return left_; }
3854 IntExpr* right() const { return right_; }
3855
3856 protected:
3857 IntExpr* const left_;
3858 IntExpr* const right_;
3859};
3860
3861class SafeSubIntExpr : public SubIntExpr {
3862 public:
3863 SafeSubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3864 : SubIntExpr(s, l, r) {}
3865
3866 ~SafeSubIntExpr() override {}
3867
3868 int64_t Min() const override { return CapSub(left_->Min(), right_->Max()); }
3869
3870 void SetMin(int64_t m) override {
3871 left_->SetMin(CapAdd(m, right_->Min()));
3872 right_->SetMax(CapSub(left_->Max(), m));
3873 }
3874
3875 void SetRange(int64_t l, int64_t u) override {
3876 const int64_t left_min = left_->Min();
3877 const int64_t right_min = right_->Min();
3878 const int64_t left_max = left_->Max();
3879 const int64_t right_max = right_->Max();
3880 if (l > CapSub(left_min, right_max)) {
3881 left_->SetMin(CapAdd(l, right_min));
3882 right_->SetMax(CapSub(left_max, l));
3883 }
3884 if (u < CapSub(left_max, right_min)) {
3885 left_->SetMax(CapAdd(u, right_max));
3886 right_->SetMin(CapSub(left_min, u));
3887 }
3888 }
3889
3890 void Range(int64_t* mi, int64_t* ma) override {
3891 *mi = CapSub(left_->Min(), right_->Max());
3892 *ma = CapSub(left_->Max(), right_->Min());
3893 }
3894
3895 int64_t Max() const override { return CapSub(left_->Max(), right_->Min()); }
3896
3897 void SetMax(int64_t m) override {
3898 left_->SetMax(CapAdd(m, right_->Max()));
3899 right_->SetMin(CapSub(left_->Min(), m));
3900 }
3901};
3902
3903// l - r
3904
3905// ----- SubIntCstExpr -----
3906
3907class SubIntCstExpr : public BaseIntExpr {
3908 public:
3909 SubIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
3910 : BaseIntExpr(s), expr_(e), value_(v) {}
3911 ~SubIntCstExpr() override {}
3912 int64_t Min() const override { return CapSub(value_, expr_->Max()); }
3913 void SetMin(int64_t m) override { expr_->SetMax(CapSub(value_, m)); }
3914 int64_t Max() const override { return CapSub(value_, expr_->Min()); }
3915 void SetMax(int64_t m) override { expr_->SetMin(CapSub(value_, m)); }
3916 bool Bound() const override { return (expr_->Bound()); }
3917 std::string name() const override {
3918 return absl::StrFormat("(%d - %s)", value_, expr_->name());
3919 }
3920 std::string DebugString() const override {
3921 return absl::StrFormat("(%d - %s)", value_, expr_->DebugString());
3922 }
3923 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3924 IntVar* CastToVar() override;
3925
3926 void Accept(ModelVisitor* const visitor) const override {
3927 visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3928 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3929 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3930 expr_);
3931 visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3932 }
3933
3934 private:
3935 IntExpr* const expr_;
3936 const int64_t value_;
3937};
3938
3939IntVar* SubIntCstExpr::CastToVar() {
3940 if (SubOverflows(value_, expr_->Min()) ||
3941 SubOverflows(value_, expr_->Max())) {
3942 return BaseIntExpr::CastToVar();
3943 }
3944 Solver* const s = solver();
3945 IntVar* const var =
3946 s->RegisterIntVar(s->RevAlloc(new SubCstIntVar(s, expr_->Var(), value_)));
3947 return var;
3948}
3949
3950// ----- OppIntExpr -----
3951
3952class OppIntExpr : public BaseIntExpr {
3953 public:
3954 OppIntExpr(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
3955 ~OppIntExpr() override {}
3956 int64_t Min() const override { return (CapOpp(expr_->Max())); }
3957 void SetMin(int64_t m) override { expr_->SetMax(CapOpp(m)); }
3958 int64_t Max() const override { return (CapOpp(expr_->Min())); }
3959 void SetMax(int64_t m) override { expr_->SetMin(CapOpp(m)); }
3960 bool Bound() const override { return (expr_->Bound()); }
3961 std::string name() const override {
3962 return absl::StrFormat("(-%s)", expr_->name());
3963 }
3964 std::string DebugString() const override {
3965 return absl::StrFormat("(-%s)", expr_->DebugString());
3966 }
3967 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3968 IntVar* CastToVar() override;
3969
3970 void Accept(ModelVisitor* const visitor) const override {
3971 visitor->BeginVisitIntegerExpression(ModelVisitor::kOpposite, this);
3972 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3973 expr_);
3974 visitor->EndVisitIntegerExpression(ModelVisitor::kOpposite, this);
3975 }
3976
3977 private:
3978 IntExpr* const expr_;
3979};
3980
3981IntVar* OppIntExpr::CastToVar() {
3982 Solver* const s = solver();
3983 IntVar* const var =
3984 s->RegisterIntVar(s->RevAlloc(new OppIntVar(s, expr_->Var())));
3985 return var;
3986}
3987
3988// ----- TimesIntCstExpr -----
3989
3990class TimesIntCstExpr : public BaseIntExpr {
3991 public:
3992 TimesIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
3993 : BaseIntExpr(s), expr_(e), value_(v) {}
3994
3995 ~TimesIntCstExpr() override {}
3996
3997 bool Bound() const override { return (expr_->Bound()); }
3998
3999 std::string name() const override {
4000 return absl::StrFormat("(%s * %d)", expr_->name(), value_);
4001 }
4002
4003 std::string DebugString() const override {
4004 return absl::StrFormat("(%s * %d)", expr_->DebugString(), value_);
4005 }
4006
4007 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
4008
4009 IntExpr* Expr() const { return expr_; }
4010
4011 int64_t Constant() const { return value_; }
4012
4013 void Accept(ModelVisitor* const visitor) const override {
4014 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4015 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
4016 expr_);
4017 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4018 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4019 }
4020
4021 protected:
4022 IntExpr* const expr_;
4023 const int64_t value_;
4024};
4025
4026// ----- TimesPosIntCstExpr -----
4027
4028class TimesPosIntCstExpr : public TimesIntCstExpr {
4029 public:
4030 TimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4031 : TimesIntCstExpr(s, e, v) {
4032 CHECK_GT(v, 0);
4033 }
4034
4035 ~TimesPosIntCstExpr() override {}
4036
4037 int64_t Min() const override { return expr_->Min() * value_; }
4038
4039 void SetMin(int64_t m) override { expr_->SetMin(PosIntDivUp(m, value_)); }
4040
4041 int64_t Max() const override { return expr_->Max() * value_; }
4042
4043 void SetMax(int64_t m) override { expr_->SetMax(PosIntDivDown(m, value_)); }
4044
4045 IntVar* CastToVar() override {
4046 Solver* const s = solver();
4047 IntVar* var = nullptr;
4048 if (expr_->IsVar() &&
4049 reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4050 var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4051 s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4052 } else {
4053 var = s->RegisterIntVar(
4054 s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4055 }
4056 return var;
4057 }
4058};
4059
4060// This expressions adds safe arithmetic (w.r.t. overflows) compared
4061// to the previous one.
4062class SafeTimesPosIntCstExpr : public TimesIntCstExpr {
4063 public:
4064 SafeTimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4065 : TimesIntCstExpr(s, e, v) {
4066 CHECK_GT(v, 0);
4067 }
4068
4069 ~SafeTimesPosIntCstExpr() override {}
4070
4071 int64_t Min() const override { return CapProd(expr_->Min(), value_); }
4072
4073 void SetMin(int64_t m) override {
4075 expr_->SetMin(PosIntDivUp(m, value_));
4076 }
4077 }
4078
4079 int64_t Max() const override { return CapProd(expr_->Max(), value_); }
4080
4081 void SetMax(int64_t m) override {
4083 expr_->SetMax(PosIntDivDown(m, value_));
4084 }
4085 }
4086
4087 IntVar* CastToVar() override {
4088 Solver* const s = solver();
4089 IntVar* var = nullptr;
4090 if (expr_->IsVar() &&
4091 reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4092 var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4093 s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4094 } else {
4095 // TODO(user): Check overflows.
4096 var = s->RegisterIntVar(
4097 s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4098 }
4099 return var;
4100 }
4101};
4102
4103// ----- TimesIntNegCstExpr -----
4104
4105class TimesIntNegCstExpr : public TimesIntCstExpr {
4106 public:
4107 TimesIntNegCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4108 : TimesIntCstExpr(s, e, v) {
4109 CHECK_LT(v, 0);
4110 }
4111
4112 ~TimesIntNegCstExpr() override {}
4113
4114 int64_t Min() const override { return CapProd(expr_->Max(), value_); }
4115
4116 void SetMin(int64_t m) override {
4118 expr_->SetMax(PosIntDivDown(-m, -value_));
4119 }
4120 }
4121
4122 int64_t Max() const override { return CapProd(expr_->Min(), value_); }
4123
4124 void SetMax(int64_t m) override {
4126 expr_->SetMin(PosIntDivUp(-m, -value_));
4127 }
4128 }
4129
4130 IntVar* CastToVar() override {
4131 Solver* const s = solver();
4132 IntVar* var = nullptr;
4133 var = s->RegisterIntVar(
4134 s->RevAlloc(new TimesNegCstIntVar(s, expr_->Var(), value_)));
4135 return var;
4136 }
4137};
4138
4139// ----- Utilities for product expression -----
4140
4141// Propagates set_min on left * right, left and right >= 0.
4142void SetPosPosMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4143 DCHECK_GE(left->Min(), 0);
4144 DCHECK_GE(right->Min(), 0);
4145 const int64_t lmax = left->Max();
4146 const int64_t rmax = right->Max();
4147 if (m > CapProd(lmax, rmax)) {
4148 left->solver()->Fail();
4149 }
4150 if (m > CapProd(left->Min(), right->Min())) {
4151 // Ok for m == 0 due to left and right being positive
4152 if (0 != rmax) {
4153 left->SetMin(PosIntDivUp(m, rmax));
4154 }
4155 if (0 != lmax) {
4156 right->SetMin(PosIntDivUp(m, lmax));
4157 }
4158 }
4159}
4160
4161// Propagates set_max on left * right, left and right >= 0.
4162void SetPosPosMaxExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4163 DCHECK_GE(left->Min(), 0);
4164 DCHECK_GE(right->Min(), 0);
4165 const int64_t lmin = left->Min();
4166 const int64_t rmin = right->Min();
4167 if (m < CapProd(lmin, rmin)) {
4168 left->solver()->Fail();
4169 }
4170 if (m < CapProd(left->Max(), right->Max())) {
4171 if (0 != lmin) {
4172 right->SetMax(PosIntDivDown(m, lmin));
4173 }
4174 if (0 != rmin) {
4175 left->SetMax(PosIntDivDown(m, rmin));
4176 }
4177 // else do nothing: 0 is supporting any value from other expr.
4178 }
4179}
4180
4181// Propagates set_min on left * right, left >= 0, right across 0.
4182void SetPosGenMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4183 DCHECK_GE(left->Min(), 0);
4184 DCHECK_GT(right->Max(), 0);
4185 DCHECK_LT(right->Min(), 0);
4186 const int64_t lmax = left->Max();
4187 const int64_t rmax = right->Max();
4188 if (m > CapProd(lmax, rmax)) {
4189 left->solver()->Fail();
4190 }
4191 if (left->Max() == 0) { // left is bound to 0, product is bound to 0.
4192 DCHECK_EQ(0, left->Min());
4193 DCHECK_LE(m, 0);
4194 } else {
4195 if (m > 0) { // We deduce right > 0.
4196 left->SetMin(PosIntDivUp(m, rmax));
4197 right->SetMin(PosIntDivUp(m, lmax));
4198 } else if (m == 0) {
4199 const int64_t lmin = left->Min();
4200 if (lmin > 0) {
4201 right->SetMin(0);
4202 }
4203 } else { // m < 0
4204 const int64_t lmin = left->Min();
4205 if (0 != lmin) { // We cannot deduce anything if 0 is in the domain.
4206 right->SetMin(-PosIntDivDown(-m, lmin));
4207 }
4208 }
4209 }
4210}
4211
4212// Propagates set_min on left * right, left and right across 0.
4213void SetGenGenMinExpr(IntExpr* const left, IntExpr* const right, int64_t m) {
4214 DCHECK_LT(left->Min(), 0);
4215 DCHECK_GT(left->Max(), 0);
4216 DCHECK_GT(right->Max(), 0);
4217 DCHECK_LT(right->Min(), 0);
4218 const int64_t lmin = left->Min();
4219 const int64_t lmax = left->Max();
4220 const int64_t rmin = right->Min();
4221 const int64_t rmax = right->Max();
4222 if (m > std::max(CapProd(lmin, rmin), CapProd(lmax, rmax))) {
4223 left->solver()->Fail();
4224 }
4225 if (m > lmin * rmin) { // Must be positive section * positive section.
4226 left->SetMin(PosIntDivUp(m, rmax));
4227 right->SetMin(PosIntDivUp(m, lmax));
4228 } else if (m > CapProd(lmax, rmax)) { // Negative section * negative section.
4229 left->SetMax(-PosIntDivUp(m, -rmin));
4230 right->SetMax(-PosIntDivUp(m, -lmin));
4231 }
4232}
4233
4234void TimesSetMin(IntExpr* const left, IntExpr* const right,
4235 IntExpr* const minus_left, IntExpr* const minus_right,
4236 int64_t m) {
4237 if (left->Min() >= 0) {
4238 if (right->Min() >= 0) {
4239 SetPosPosMinExpr(left, right, m);
4240 } else if (right->Max() <= 0) {
4241 SetPosPosMaxExpr(left, minus_right, -m);
4242 } else { // right->Min() < 0 && right->Max() > 0
4243 SetPosGenMinExpr(left, right, m);
4244 }
4245 } else if (left->Max() <= 0) {
4246 if (right->Min() >= 0) {
4247 SetPosPosMaxExpr(right, minus_left, -m);
4248 } else if (right->Max() <= 0) {
4249 SetPosPosMinExpr(minus_left, minus_right, m);
4250 } else { // right->Min() < 0 && right->Max() > 0
4251 SetPosGenMinExpr(minus_left, minus_right, m);
4252 }
4253 } else if (right->Min() >= 0) { // left->Min() < 0 && left->Max() > 0
4254 SetPosGenMinExpr(right, left, m);
4255 } else if (right->Max() <= 0) { // left->Min() < 0 && left->Max() > 0
4256 SetPosGenMinExpr(minus_right, minus_left, m);
4257 } else { // left->Min() < 0 && left->Max() > 0 &&
4258 // right->Min() < 0 && right->Max() > 0
4259 SetGenGenMinExpr(left, right, m);
4260 }
4261}
4262
4263class TimesIntExpr : public BaseIntExpr {
4264 public:
4265 TimesIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4266 : BaseIntExpr(s),
4267 left_(l),
4268 right_(r),
4269 minus_left_(s->MakeOpposite(left_)),
4270 minus_right_(s->MakeOpposite(right_)) {}
4271 ~TimesIntExpr() override {}
4272 int64_t Min() const override {
4273 const int64_t lmin = left_->Min();
4274 const int64_t lmax = left_->Max();
4275 const int64_t rmin = right_->Min();
4276 const int64_t rmax = right_->Max();
4277 return std::min(std::min(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4278 std::min(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4279 }
4280 void SetMin(int64_t m) override;
4281 int64_t Max() const override {
4282 const int64_t lmin = left_->Min();
4283 const int64_t lmax = left_->Max();
4284 const int64_t rmin = right_->Min();
4285 const int64_t rmax = right_->Max();
4286 return std::max(std::max(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4287 std::max(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4288 }
4289 void SetMax(int64_t m) override;
4290 bool Bound() const override;
4291 std::string name() const override {
4292 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4293 }
4294 std::string DebugString() const override {
4295 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4296 right_->DebugString());
4297 }
4298 void WhenRange(Demon* d) override {
4299 left_->WhenRange(d);
4300 right_->WhenRange(d);
4301 }
4302
4303 void Accept(ModelVisitor* const visitor) const override {
4304 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4305 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4306 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4307 right_);
4308 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4309 }
4310
4311 private:
4312 IntExpr* const left_;
4313 IntExpr* const right_;
4314 IntExpr* const minus_left_;
4315 IntExpr* const minus_right_;
4316};
4317
4318void TimesIntExpr::SetMin(int64_t m) {
4320 TimesSetMin(left_, right_, minus_left_, minus_right_, m);
4321 }
4322}
4323
4324void TimesIntExpr::SetMax(int64_t m) {
4326 TimesSetMin(left_, minus_right_, minus_left_, right_, CapOpp(m));
4327 }
4328}
4329
4330bool TimesIntExpr::Bound() const {
4331 const bool left_bound = left_->Bound();
4332 const bool right_bound = right_->Bound();
4333 return ((left_bound && left_->Max() == 0) ||
4334 (right_bound && right_->Max() == 0) || (left_bound && right_bound));
4335}
4336
4337// ----- TimesPosIntExpr -----
4338
4339class TimesPosIntExpr : public BaseIntExpr {
4340 public:
4341 TimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4342 : BaseIntExpr(s), left_(l), right_(r) {}
4343 ~TimesPosIntExpr() override {}
4344 int64_t Min() const override { return (left_->Min() * right_->Min()); }
4345 void SetMin(int64_t m) override;
4346 int64_t Max() const override { return (left_->Max() * right_->Max()); }
4347 void SetMax(int64_t m) override;
4348 bool Bound() const override;
4349 std::string name() const override {
4350 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4351 }
4352 std::string DebugString() const override {
4353 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4354 right_->DebugString());
4355 }
4356 void WhenRange(Demon* d) override {
4357 left_->WhenRange(d);
4358 right_->WhenRange(d);
4359 }
4360
4361 void Accept(ModelVisitor* const visitor) const override {
4362 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4363 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4364 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4365 right_);
4366 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4367 }
4368
4369 private:
4370 IntExpr* const left_;
4371 IntExpr* const right_;
4372};
4373
4374void TimesPosIntExpr::SetMin(int64_t m) { SetPosPosMinExpr(left_, right_, m); }
4375
4376void TimesPosIntExpr::SetMax(int64_t m) { SetPosPosMaxExpr(left_, right_, m); }
4377
4378bool TimesPosIntExpr::Bound() const {
4379 return (left_->Max() == 0 || right_->Max() == 0 ||
4380 (left_->Bound() && right_->Bound()));
4381}
4382
4383// ----- SafeTimesPosIntExpr -----
4384
4385class SafeTimesPosIntExpr : public BaseIntExpr {
4386 public:
4387 SafeTimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4388 : BaseIntExpr(s), left_(l), right_(r) {}
4389 ~SafeTimesPosIntExpr() override {}
4390 int64_t Min() const override { return CapProd(left_->Min(), right_->Min()); }
4391 void SetMin(int64_t m) override {
4393 SetPosPosMinExpr(left_, right_, m);
4394 }
4395 }
4396 int64_t Max() const override { return CapProd(left_->Max(), right_->Max()); }
4397 void SetMax(int64_t m) override {
4399 SetPosPosMaxExpr(left_, right_, m);
4400 }
4401 }
4402 bool Bound() const override {
4403 return (left_->Max() == 0 || right_->Max() == 0 ||
4404 (left_->Bound() && right_->Bound()));
4405 }
4406 std::string name() const override {
4407 return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4408 }
4409 std::string DebugString() const override {
4410 return absl::StrFormat("(%s * %s)", left_->DebugString(),
4411 right_->DebugString());
4412 }
4413 void WhenRange(Demon* d) override {
4414 left_->WhenRange(d);
4415 right_->WhenRange(d);
4416 }
4417
4418 void Accept(ModelVisitor* const visitor) const override {
4419 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4420 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4421 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4422 right_);
4423 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4424 }
4425
4426 private:
4427 IntExpr* const left_;
4428 IntExpr* const right_;
4429};
4430
4431// ----- TimesBooleanPosIntExpr -----
4432
4433class TimesBooleanPosIntExpr : public BaseIntExpr {
4434 public:
4435 TimesBooleanPosIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4436 : BaseIntExpr(s), boolvar_(b), expr_(e) {}
4437 ~TimesBooleanPosIntExpr() override {}
4438 int64_t Min() const override {
4439 return (boolvar_->RawValue() == 1 ? expr_->Min() : 0);
4440 }
4441 void SetMin(int64_t m) override;
4442 int64_t Max() const override {
4443 return (boolvar_->RawValue() == 0 ? 0 : expr_->Max());
4444 }
4445 void SetMax(int64_t m) override;
4446 void Range(int64_t* mi, int64_t* ma) override;
4447 void SetRange(int64_t mi, int64_t ma) override;
4448 bool Bound() const override;
4449 std::string name() const override {
4450 return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4451 }
4452 std::string DebugString() const override {
4453 return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4454 expr_->DebugString());
4455 }
4456 void WhenRange(Demon* d) override {
4457 boolvar_->WhenRange(d);
4458 expr_->WhenRange(d);
4459 }
4460
4461 void Accept(ModelVisitor* const visitor) const override {
4462 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4463 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4464 boolvar_);
4465 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4466 expr_);
4467 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4468 }
4469
4470 private:
4471 BooleanVar* const boolvar_;
4472 IntExpr* const expr_;
4473};
4474
4475void TimesBooleanPosIntExpr::SetMin(int64_t m) {
4476 if (m > 0) {
4477 boolvar_->SetValue(1);
4478 expr_->SetMin(m);
4479 }
4480}
4481
4482void TimesBooleanPosIntExpr::SetMax(int64_t m) {
4483 if (m < 0) {
4484 solver()->Fail();
4485 }
4486 if (m < expr_->Min()) {
4487 boolvar_->SetValue(0);
4488 }
4489 if (boolvar_->RawValue() == 1) {
4490 expr_->SetMax(m);
4491 }
4492}
4493
4494void TimesBooleanPosIntExpr::Range(int64_t* mi, int64_t* ma) {
4495 const int value = boolvar_->RawValue();
4496 if (value == 0) {
4497 *mi = 0;
4498 *ma = 0;
4499 } else if (value == 1) {
4500 expr_->Range(mi, ma);
4501 } else {
4502 *mi = 0;
4503 *ma = expr_->Max();
4504 }
4505}
4506
4507void TimesBooleanPosIntExpr::SetRange(int64_t mi, int64_t ma) {
4508 if (ma < 0 || mi > ma) {
4509 solver()->Fail();
4510 }
4511 if (mi > 0) {
4512 boolvar_->SetValue(1);
4513 expr_->SetMin(mi);
4514 }
4515 if (ma < expr_->Min()) {
4516 boolvar_->SetValue(0);
4517 }
4518 if (boolvar_->RawValue() == 1) {
4519 expr_->SetMax(ma);
4520 }
4521}
4522
4523bool TimesBooleanPosIntExpr::Bound() const {
4524 return (boolvar_->RawValue() == 0 || expr_->Max() == 0 ||
4525 (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue &&
4526 expr_->Bound()));
4527}
4528
4529// ----- TimesBooleanIntExpr -----
4530
4531class TimesBooleanIntExpr : public BaseIntExpr {
4532 public:
4533 TimesBooleanIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4534 : BaseIntExpr(s), boolvar_(b), expr_(e) {}
4535 ~TimesBooleanIntExpr() override {}
4536 int64_t Min() const override {
4537 switch (boolvar_->RawValue()) {
4538 case 0: {
4539 return 0LL;
4540 }
4541 case 1: {
4542 return expr_->Min();
4543 }
4544 default: {
4545 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4546 return std::min(int64_t{0}, expr_->Min());
4547 }
4548 }
4549 }
4550 void SetMin(int64_t m) override;
4551 int64_t Max() const override {
4552 switch (boolvar_->RawValue()) {
4553 case 0: {
4554 return 0LL;
4555 }
4556 case 1: {
4557 return expr_->Max();
4558 }
4559 default: {
4560 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4561 return std::max(int64_t{0}, expr_->Max());
4562 }
4563 }
4564 }
4565 void SetMax(int64_t m) override;
4566 void Range(int64_t* mi, int64_t* ma) override;
4567 void SetRange(int64_t mi, int64_t ma) override;
4568 bool Bound() const override;
4569 std::string name() const override {
4570 return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4571 }
4572 std::string DebugString() const override {
4573 return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4574 expr_->DebugString());
4575 }
4576 void WhenRange(Demon* d) override {
4577 boolvar_->WhenRange(d);
4578 expr_->WhenRange(d);
4579 }
4580
4581 void Accept(ModelVisitor* const visitor) const override {
4582 visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4583 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4584 boolvar_);
4585 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4586 expr_);
4587 visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4588 }
4589
4590 private:
4591 BooleanVar* const boolvar_;
4592 IntExpr* const expr_;
4593};
4594
4595void TimesBooleanIntExpr::SetMin(int64_t m) {
4596 switch (boolvar_->RawValue()) {
4597 case 0: {
4598 if (m > 0) {
4599 solver()->Fail();
4600 }
4601 break;
4602 }
4603 case 1: {
4604 expr_->SetMin(m);
4605 break;
4606 }
4607 default: {
4608 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4609 if (m > 0) { // 0 is no longer possible for boolvar because min > 0.
4610 boolvar_->SetValue(1);
4611 expr_->SetMin(m);
4612 } else if (m <= 0 && expr_->Max() < m) {
4613 boolvar_->SetValue(0);
4614 }
4615 }
4616 }
4617}
4618
4619void TimesBooleanIntExpr::SetMax(int64_t m) {
4620 switch (boolvar_->RawValue()) {
4621 case 0: {
4622 if (m < 0) {
4623 solver()->Fail();
4624 }
4625 break;
4626 }
4627 case 1: {
4628 expr_->SetMax(m);
4629 break;
4630 }
4631 default: {
4632 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4633 if (m < 0) { // 0 is no longer possible for boolvar because max < 0.
4634 boolvar_->SetValue(1);
4635 expr_->SetMax(m);
4636 } else if (m >= 0 && expr_->Min() > m) {
4637 boolvar_->SetValue(0);
4638 }
4639 }
4640 }
4641}
4642
4643void TimesBooleanIntExpr::Range(int64_t* mi, int64_t* ma) {
4644 switch (boolvar_->RawValue()) {
4645 case 0: {
4646 *mi = 0;
4647 *ma = 0;
4648 break;
4649 }
4650 case 1: {
4651 *mi = expr_->Min();
4652 *ma = expr_->Max();
4653 break;
4654 }
4655 default: {
4656 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4657 *mi = std::min(int64_t{0}, expr_->Min());
4658 *ma = std::max(int64_t{0}, expr_->Max());
4659 break;
4660 }
4661 }
4662}
4663
4664void TimesBooleanIntExpr::SetRange(int64_t mi, int64_t ma) {
4665 if (mi > ma) {
4666 solver()->Fail();
4667 }
4668 switch (boolvar_->RawValue()) {
4669 case 0: {
4670 if (mi > 0 || ma < 0) {
4671 solver()->Fail();
4672 }
4673 break;
4674 }
4675 case 1: {
4676 expr_->SetRange(mi, ma);
4677 break;
4678 }
4679 default: {
4680 DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4681 if (mi > 0) {
4682 boolvar_->SetValue(1);
4683 expr_->SetMin(mi);
4684 } else if (mi == 0 && expr_->Max() < 0) {
4685 boolvar_->SetValue(0);
4686 }
4687 if (ma < 0) {
4688 boolvar_->SetValue(1);
4689 expr_->SetMax(ma);
4690 } else if (ma == 0 && expr_->Min() > 0) {
4691 boolvar_->SetValue(0);
4692 }
4693 break;
4694 }
4695 }
4696}
4697
4698bool TimesBooleanIntExpr::Bound() const {
4699 return (boolvar_->RawValue() == 0 ||
4700 (expr_->Bound() &&
4701 (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue ||
4702 expr_->Max() == 0)));
4703}
4704
4705// ----- DivPosIntCstExpr -----
4706
4707class DivPosIntCstExpr : public BaseIntExpr {
4708 public:
4709 DivPosIntCstExpr(Solver* const s, IntExpr* const e, int64_t v)
4710 : BaseIntExpr(s), expr_(e), value_(v) {
4711 CHECK_GE(v, 0);
4712 }
4713 ~DivPosIntCstExpr() override {}
4714
4715 int64_t Min() const override { return expr_->Min() / value_; }
4716
4717 void SetMin(int64_t m) override {
4718 if (m > 0) {
4719 expr_->SetMin(m * value_);
4720 } else {
4721 expr_->SetMin((m - 1) * value_ + 1);
4722 }
4723 }
4724 int64_t Max() const override { return expr_->Max() / value_; }
4725
4726 void SetMax(int64_t m) override {
4727 if (m >= 0) {
4728 expr_->SetMax((m + 1) * value_ - 1);
4729 } else {
4730 expr_->SetMax(m * value_);
4731 }
4732 }
4733
4734 std::string name() const override {
4735 return absl::StrFormat("(%s div %d)", expr_->name(), value_);
4736 }
4737
4738 std::string DebugString() const override {
4739 return absl::StrFormat("(%s div %d)", expr_->DebugString(), value_);
4740 }
4741
4742 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
4743
4744 void Accept(ModelVisitor* const visitor) const override {
4745 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4746 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
4747 expr_);
4748 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4749 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4750 }
4751
4752 private:
4753 IntExpr* const expr_;
4754 const int64_t value_;
4755};
4756
4757// DivPosIntExpr
4758
4759class DivPosIntExpr : public BaseIntExpr {
4760 public:
4761 DivPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4762 : BaseIntExpr(s),
4763 num_(num),
4764 denom_(denom),
4765 opp_num_(s->MakeOpposite(num)) {}
4766
4767 ~DivPosIntExpr() override {}
4768
4769 int64_t Min() const override {
4770 return num_->Min() >= 0
4771 ? num_->Min() / denom_->Max()
4772 : (denom_->Min() == 0 ? num_->Min()
4773 : num_->Min() / denom_->Min());
4774 }
4775
4776 int64_t Max() const override {
4777 return num_->Max() >= 0 ? (denom_->Min() == 0 ? num_->Max()
4778 : num_->Max() / denom_->Min())
4779 : num_->Max() / denom_->Max();
4780 }
4781
4782 static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64_t m) {
4783 num->SetMin(m * denom->Min());
4784 denom->SetMax(num->Max() / m);
4785 }
4786
4787 static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64_t m) {
4788 num->SetMax((m + 1) * denom->Max() - 1);
4789 denom->SetMin(num->Min() / (m + 1) + 1);
4790 }
4791
4792 void SetMin(int64_t m) override {
4793 if (m > 0) {
4794 SetPosMin(num_, denom_, m);
4795 } else {
4796 SetPosMax(opp_num_, denom_, -m);
4797 }
4798 }
4799
4800 void SetMax(int64_t m) override {
4801 if (m >= 0) {
4802 SetPosMax(num_, denom_, m);
4803 } else {
4804 SetPosMin(opp_num_, denom_, -m);
4805 }
4806 }
4807
4808 std::string name() const override {
4809 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4810 }
4811 std::string DebugString() const override {
4812 return absl::StrFormat("(%s div %s)", num_->DebugString(),
4813 denom_->DebugString());
4814 }
4815 void WhenRange(Demon* d) override {
4816 num_->WhenRange(d);
4817 denom_->WhenRange(d);
4818 }
4819
4820 void Accept(ModelVisitor* const visitor) const override {
4821 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4822 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4823 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4824 denom_);
4825 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4826 }
4827
4828 private:
4829 IntExpr* const num_;
4830 IntExpr* const denom_;
4831 IntExpr* const opp_num_;
4832};
4833
4834class DivPosPosIntExpr : public BaseIntExpr {
4835 public:
4836 DivPosPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4837 : BaseIntExpr(s), num_(num), denom_(denom) {}
4838
4839 ~DivPosPosIntExpr() override {}
4840
4841 int64_t Min() const override {
4842 if (denom_->Max() == 0) {
4843 solver()->Fail();
4844 }
4845 return num_->Min() / denom_->Max();
4846 }
4847
4848 int64_t Max() const override {
4849 if (denom_->Min() == 0) {
4850 return num_->Max();
4851 } else {
4852 return num_->Max() / denom_->Min();
4853 }
4854 }
4855
4856 void SetMin(int64_t m) override {
4857 if (m > 0) {
4858 num_->SetMin(m * denom_->Min());
4859 denom_->SetMax(num_->Max() / m);
4860 }
4861 }
4862
4863 void SetMax(int64_t m) override {
4864 if (m >= 0) {
4865 num_->SetMax((m + 1) * denom_->Max() - 1);
4866 denom_->SetMin(num_->Min() / (m + 1) + 1);
4867 } else {
4868 solver()->Fail();
4869 }
4870 }
4871
4872 std::string name() const override {
4873 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4874 }
4875
4876 std::string DebugString() const override {
4877 return absl::StrFormat("(%s div %s)", num_->DebugString(),
4878 denom_->DebugString());
4879 }
4880
4881 void WhenRange(Demon* d) override {
4882 num_->WhenRange(d);
4883 denom_->WhenRange(d);
4884 }
4885
4886 void Accept(ModelVisitor* const visitor) const override {
4887 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4888 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4889 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4890 denom_);
4891 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4892 }
4893
4894 private:
4895 IntExpr* const num_;
4896 IntExpr* const denom_;
4897};
4898
4899// DivIntExpr
4900
4901class DivIntExpr : public BaseIntExpr {
4902 public:
4903 DivIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4904 : BaseIntExpr(s),
4905 num_(num),
4906 denom_(denom),
4907 opp_num_(s->MakeOpposite(num)) {}
4908
4909 ~DivIntExpr() override {}
4910
4911 int64_t Min() const override {
4912 const int64_t num_min = num_->Min();
4913 const int64_t num_max = num_->Max();
4914 const int64_t denom_min = denom_->Min();
4915 const int64_t denom_max = denom_->Max();
4916
4917 if (denom_min == 0 && denom_max == 0) {
4918 return std::numeric_limits<int64_t>::max(); // TODO(user): Check this
4919 // convention.
4920 }
4921
4922 if (denom_min >= 0) { // Denominator strictly positive.
4923 DCHECK_GT(denom_max, 0);
4924 const int64_t adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4925 return num_min >= 0 ? num_min / denom_max : num_min / adjusted_denom_min;
4926 } else if (denom_max <= 0) { // Denominator strictly negative.
4927 DCHECK_LT(denom_min, 0);
4928 const int64_t adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4929 return num_max >= 0 ? num_max / adjusted_denom_max : num_max / denom_min;
4930 } else { // Denominator across 0.
4931 return std::min(num_min, -num_max);
4932 }
4933 }
4934
4935 int64_t Max() const override {
4936 const int64_t num_min = num_->Min();
4937 const int64_t num_max = num_->Max();
4938 const int64_t denom_min = denom_->Min();
4939 const int64_t denom_max = denom_->Max();
4940
4941 if (denom_min == 0 && denom_max == 0) {
4942 return std::numeric_limits<int64_t>::min(); // TODO(user): Check this
4943 // convention.
4944 }
4945
4946 if (denom_min >= 0) { // Denominator strictly positive.
4947 DCHECK_GT(denom_max, 0);
4948 const int64_t adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4949 return num_max >= 0 ? num_max / adjusted_denom_min : num_max / denom_max;
4950 } else if (denom_max <= 0) { // Denominator strictly negative.
4951 DCHECK_LT(denom_min, 0);
4952 const int64_t adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4953 return num_min >= 0 ? num_min / denom_min
4954 : -num_min / -adjusted_denom_max;
4955 } else { // Denominator across 0.
4956 return std::max(num_max, -num_min);
4957 }
4958 }
4959
4960 void AdjustDenominator() {
4961 if (denom_->Min() == 0) {
4962 denom_->SetMin(1);
4963 } else if (denom_->Max() == 0) {
4964 denom_->SetMax(-1);
4965 }
4966 }
4967
4968 // m > 0.
4969 static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64_t m) {
4970 DCHECK_GT(m, 0);
4971 const int64_t num_min = num->Min();
4972 const int64_t num_max = num->Max();
4973 const int64_t denom_min = denom->Min();
4974 const int64_t denom_max = denom->Max();
4975 DCHECK_NE(denom_min, 0);
4976 DCHECK_NE(denom_max, 0);
4977 if (denom_min > 0) { // Denominator strictly positive.
4978 num->SetMin(m * denom_min);
4979 denom->SetMax(num_max / m);
4980 } else if (denom_max < 0) { // Denominator strictly negative.
4981 num->SetMax(m * denom_max);
4982 denom->SetMin(num_min / m);
4983 } else { // Denominator across 0.
4984 if (num_min >= 0) {
4985 num->SetMin(m);
4986 denom->SetRange(1, num_max / m);
4987 } else if (num_max <= 0) {
4988 num->SetMax(-m);
4989 denom->SetRange(num_min / m, -1);
4990 } else {
4991 if (m > -num_min) { // Denominator is forced positive.
4992 num->SetMin(m);
4993 denom->SetRange(1, num_max / m);
4994 } else if (m > num_max) { // Denominator is forced negative.
4995 num->SetMax(-m);
4996 denom->SetRange(num_min / m, -1);
4997 } else {
4998 denom->SetRange(num_min / m, num_max / m);
4999 }
5000 }
5001 }
5002 }
5003
5004 // m >= 0.
5005 static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64_t m) {
5006 DCHECK_GE(m, 0);
5007 const int64_t num_min = num->Min();
5008 const int64_t num_max = num->Max();
5009 const int64_t denom_min = denom->Min();
5010 const int64_t denom_max = denom->Max();
5011 DCHECK_NE(denom_min, 0);
5012 DCHECK_NE(denom_max, 0);
5013 if (denom_min > 0) { // Denominator strictly positive.
5014 num->SetMax((m + 1) * denom_max - 1);
5015 denom->SetMin((num_min / (m + 1)) + 1);
5016 } else if (denom_max < 0) {
5017 num->SetMin((m + 1) * denom_min + 1);
5018 denom->SetMax(num_max / (m + 1) - 1);
5019 } else if (num_min > (m + 1) * denom_max - 1) {
5020 denom->SetMax(-1);
5021 } else if (num_max < (m + 1) * denom_min + 1) {
5022 denom->SetMin(1);
5023 }
5024 }
5025
5026 void SetMin(int64_t m) override {
5027 AdjustDenominator();
5028 if (m > 0) {
5029 SetPosMin(num_, denom_, m);
5030 } else {
5031 SetPosMax(opp_num_, denom_, -m);
5032 }
5033 }
5034
5035 void SetMax(int64_t m) override {
5036 AdjustDenominator();
5037 if (m >= 0) {
5038 SetPosMax(num_, denom_, m);
5039 } else {
5040 SetPosMin(opp_num_, denom_, -m);
5041 }
5042 }
5043
5044 std::string name() const override {
5045 return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
5046 }
5047 std::string DebugString() const override {
5048 return absl::StrFormat("(%s div %s)", num_->DebugString(),
5049 denom_->DebugString());
5050 }
5051 void WhenRange(Demon* d) override {
5052 num_->WhenRange(d);
5053 denom_->WhenRange(d);
5054 }
5055
5056 void Accept(ModelVisitor* const visitor) const override {
5057 visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
5058 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
5059 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5060 denom_);
5061 visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
5062 }
5063
5064 private:
5065 IntExpr* const num_;
5066 IntExpr* const denom_;
5067 IntExpr* const opp_num_;
5068};
5069
5070// ----- IntAbs And IntAbsConstraint ------
5071
5072class IntAbsConstraint : public CastConstraint {
5073 public:
5074 IntAbsConstraint(Solver* const s, IntVar* const sub, IntVar* const target)
5075 : CastConstraint(s, target), sub_(sub) {}
5076
5077 ~IntAbsConstraint() override {}
5078
5079 void Post() override {
5080 Demon* const sub_demon = MakeConstraintDemon0(
5081 solver(), this, &IntAbsConstraint::PropagateSub, "PropagateSub");
5082 sub_->WhenRange(sub_demon);
5083 Demon* const target_demon = MakeConstraintDemon0(
5084 solver(), this, &IntAbsConstraint::PropagateTarget, "PropagateTarget");
5085 target_var_->WhenRange(target_demon);
5086 }
5087
5088 void InitialPropagate() override {
5089 PropagateSub();
5090 PropagateTarget();
5091 }
5092
5093 void PropagateSub() {
5094 const int64_t smin = sub_->Min();
5095 const int64_t smax = sub_->Max();
5096 if (smax <= 0) {
5097 target_var_->SetRange(-smax, -smin);
5098 } else if (smin >= 0) {
5099 target_var_->SetRange(smin, smax);
5100 } else {
5101 target_var_->SetRange(0, std::max(-smin, smax));
5102 }
5103 }
5104
5105 void PropagateTarget() {
5106 const int64_t target_max = target_var_->Max();
5107 sub_->SetRange(-target_max, target_max);
5108 const int64_t target_min = target_var_->Min();
5109 if (target_min > 0) {
5110 if (sub_->Min() > -target_min) {
5111 sub_->SetMin(target_min);
5112 } else if (sub_->Max() < target_min) {
5113 sub_->SetMax(-target_min);
5114 }
5115 }
5116 }
5117
5118 std::string DebugString() const override {
5119 return absl::StrFormat("IntAbsConstraint(%s, %s)", sub_->DebugString(),
5120 target_var_->DebugString());
5121 }
5122
5123 void Accept(ModelVisitor* const visitor) const override {
5124 visitor->BeginVisitConstraint(ModelVisitor::kAbsEqual, this);
5125 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5126 sub_);
5127 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
5128 target_var_);
5129 visitor->EndVisitConstraint(ModelVisitor::kAbsEqual, this);
5130 }
5131
5132 private:
5133 IntVar* const sub_;
5134};
5135
5136class IntAbs : public BaseIntExpr {
5137 public:
5138 IntAbs(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
5139
5140 ~IntAbs() override {}
5141
5142 int64_t Min() const override {
5143 int64_t emin = 0;
5144 int64_t emax = 0;
5145 expr_->Range(&emin, &emax);
5146 if (emin >= 0) {
5147 return emin;
5148 }
5149 if (emax <= 0) {
5150 return -emax;
5151 }
5152 return 0;
5153 }
5154
5155 void SetMin(int64_t m) override {
5156 if (m > 0) {
5157 int64_t emin = 0;
5158 int64_t emax = 0;
5159 expr_->Range(&emin, &emax);
5160 if (emin > -m) {
5161 expr_->SetMin(m);
5162 } else if (emax < m) {
5163 expr_->SetMax(-m);
5164 }
5165 }
5166 }
5167
5168 int64_t Max() const override {
5169 int64_t emin = 0;
5170 int64_t emax = 0;
5171 expr_->Range(&emin, &emax);
5172 return std::max(-emin, emax);
5173 }
5174
5175 void SetMax(int64_t m) override { expr_->SetRange(-m, m); }
5176
5177 void SetRange(int64_t mi, int64_t ma) override {
5178 expr_->SetRange(-ma, ma);
5179 if (mi > 0) {
5180 int64_t emin = 0;
5181 int64_t emax = 0;
5182 expr_->Range(&emin, &emax);
5183 if (emin > -mi) {
5184 expr_->SetMin(mi);
5185 } else if (emax < mi) {
5186 expr_->SetMax(-mi);
5187 }
5188 }
5189 }
5190
5191 void Range(int64_t* mi, int64_t* ma) override {
5192 int64_t emin = 0;
5193 int64_t emax = 0;
5194 expr_->Range(&emin, &emax);
5195 if (emin >= 0) {
5196 *mi = emin;
5197 *ma = emax;
5198 } else if (emax <= 0) {
5199 *mi = -emax;
5200 *ma = -emin;
5201 } else {
5202 *mi = 0;
5203 *ma = std::max(-emin, emax);
5204 }
5205 }
5206
5207 bool Bound() const override { return expr_->Bound(); }
5208
5209 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5210
5211 std::string name() const override {
5212 return absl::StrFormat("IntAbs(%s)", expr_->name());
5213 }
5214
5215 std::string DebugString() const override {
5216 return absl::StrFormat("IntAbs(%s)", expr_->DebugString());
5217 }
5218
5219 void Accept(ModelVisitor* const visitor) const override {
5220 visitor->BeginVisitIntegerExpression(ModelVisitor::kAbs, this);
5221 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5222 expr_);
5223 visitor->EndVisitIntegerExpression(ModelVisitor::kAbs, this);
5224 }
5225
5226 IntVar* CastToVar() override {
5227 int64_t min_value = 0;
5228 int64_t max_value = 0;
5229 Range(&min_value, &max_value);
5230 Solver* const s = solver();
5231 const std::string name = absl::StrFormat("AbsVar(%s)", expr_->name());
5232 IntVar* const target = s->MakeIntVar(min_value, max_value, name);
5233 CastConstraint* const ct =
5234 s->RevAlloc(new IntAbsConstraint(s, expr_->Var(), target));
5235 s->AddCastConstraint(ct, target, this);
5236 return target;
5237 }
5238
5239 private:
5240 IntExpr* const expr_;
5241};
5242
5243// ----- Square -----
5244
5245// TODO(user): shouldn't we compare to kint32max^2 instead of kint64max?
5246class IntSquare : public BaseIntExpr {
5247 public:
5248 IntSquare(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
5249 ~IntSquare() override {}
5250
5251 int64_t Min() const override {
5252 const int64_t emin = expr_->Min();
5253 if (emin >= 0) {
5254 return emin >= std::numeric_limits<int32_t>::max()
5256 : emin * emin;
5257 }
5258 const int64_t emax = expr_->Max();
5259 if (emax < 0) {
5260 return emax <= -std::numeric_limits<int32_t>::max()
5262 : emax * emax;
5263 }
5264 return 0LL;
5265 }
5266 void SetMin(int64_t m) override {
5267 if (m <= 0) {
5268 return;
5269 }
5270 // TODO(user): What happens if m is kint64max?
5271 const int64_t emin = expr_->Min();
5272 const int64_t emax = expr_->Max();
5273 const int64_t root =
5274 static_cast<int64_t>(ceil(sqrt(static_cast<double>(m))));
5275 if (emin >= 0) {
5276 expr_->SetMin(root);
5277 } else if (emax <= 0) {
5278 expr_->SetMax(-root);
5279 } else if (expr_->IsVar()) {
5280 reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5281 }
5282 }
5283 int64_t Max() const override {
5284 const int64_t emax = expr_->Max();
5285 const int64_t emin = expr_->Min();
5286 if (emax >= std::numeric_limits<int32_t>::max() ||
5289 }
5290 return std::max(emin * emin, emax * emax);
5291 }
5292 void SetMax(int64_t m) override {
5293 if (m < 0) {
5294 solver()->Fail();
5295 }
5297 return;
5298 }
5299 const int64_t root =
5300 static_cast<int64_t>(floor(sqrt(static_cast<double>(m))));
5301 expr_->SetRange(-root, root);
5302 }
5303 bool Bound() const override { return expr_->Bound(); }
5304 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5305 std::string name() const override {
5306 return absl::StrFormat("IntSquare(%s)", expr_->name());
5307 }
5308 std::string DebugString() const override {
5309 return absl::StrFormat("IntSquare(%s)", expr_->DebugString());
5310 }
5311
5312 void Accept(ModelVisitor* const visitor) const override {
5313 visitor->BeginVisitIntegerExpression(ModelVisitor::kSquare, this);
5314 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5315 expr_);
5316 visitor->EndVisitIntegerExpression(ModelVisitor::kSquare, this);
5317 }
5318
5319 IntExpr* expr() const { return expr_; }
5320
5321 protected:
5322 IntExpr* const expr_;
5323};
5324
5325class PosIntSquare : public IntSquare {
5326 public:
5327 PosIntSquare(Solver* const s, IntExpr* const e) : IntSquare(s, e) {}
5328 ~PosIntSquare() override {}
5329
5330 int64_t Min() const override {
5331 const int64_t emin = expr_->Min();
5332 return emin >= std::numeric_limits<int32_t>::max()
5334 : emin * emin;
5335 }
5336 void SetMin(int64_t m) override {
5337 if (m <= 0) {
5338 return;
5339 }
5340 const int64_t root =
5341 static_cast<int64_t>(ceil(sqrt(static_cast<double>(m))));
5342 expr_->SetMin(root);
5343 }
5344 int64_t Max() const override {
5345 const int64_t emax = expr_->Max();
5346 return emax >= std::numeric_limits<int32_t>::max()
5348 : emax * emax;
5349 }
5350 void SetMax(int64_t m) override {
5351 if (m < 0) {
5352 solver()->Fail();
5353 }
5355 return;
5356 }
5357 const int64_t root =
5358 static_cast<int64_t>(floor(sqrt(static_cast<double>(m))));
5359 expr_->SetMax(root);
5360 }
5361};
5362
5363// ----- EvenPower -----
5364
5365int64_t IntPower(int64_t value, int64_t power) {
5366 int64_t result = value;
5367 // TODO(user): Speed that up.
5368 for (int i = 1; i < power; ++i) {
5369 result *= value;
5370 }
5371 return result;
5372}
5373
5374int64_t OverflowLimit(int64_t power) {
5375 return static_cast<int64_t>(floor(exp(
5376 log(static_cast<double>(std::numeric_limits<int64_t>::max())) / power)));
5377}
5378
5379class BasePower : public BaseIntExpr {
5380 public:
5381 BasePower(Solver* const s, IntExpr* const e, int64_t n)
5382 : BaseIntExpr(s), expr_(e), pow_(n), limit_(OverflowLimit(n)) {
5383 CHECK_GT(n, 0);
5384 }
5385
5386 ~BasePower() override {}
5387
5388 bool Bound() const override { return expr_->Bound(); }
5389
5390 IntExpr* expr() const { return expr_; }
5391
5392 int64_t exponant() const { return pow_; }
5393
5394 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5395
5396 std::string name() const override {
5397 return absl::StrFormat("IntPower(%s, %d)", expr_->name(), pow_);
5398 }
5399
5400 std::string DebugString() const override {
5401 return absl::StrFormat("IntPower(%s, %d)", expr_->DebugString(), pow_);
5402 }
5403
5404 void Accept(ModelVisitor* const visitor) const override {
5405 visitor->BeginVisitIntegerExpression(ModelVisitor::kPower, this);
5406 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5407 expr_);
5408 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, pow_);
5409 visitor->EndVisitIntegerExpression(ModelVisitor::kPower, this);
5410 }
5411
5412 protected:
5413 int64_t Pown(int64_t value) const {
5414 if (value >= limit_) {
5416 }
5417 if (value <= -limit_) {
5418 if (pow_ % 2 == 0) {
5420 } else {
5422 }
5423 }
5424 return IntPower(value, pow_);
5425 }
5426
5427 int64_t SqrnDown(int64_t value) const {
5430 }
5433 }
5434 int64_t res = 0;
5435 const double d_value = static_cast<double>(value);
5436 if (value >= 0) {
5437 const double sq = exp(log(d_value) / pow_);
5438 res = static_cast<int64_t>(floor(sq));
5439 } else {
5440 CHECK_EQ(1, pow_ % 2);
5441 const double sq = exp(log(-d_value) / pow_);
5442 res = -static_cast<int64_t>(ceil(sq));
5443 }
5444 const int64_t pow_res = Pown(res + 1);
5445 if (pow_res <= value) {
5446 return res + 1;
5447 } else {
5448 return res;
5449 }
5450 }
5451
5452 int64_t SqrnUp(int64_t value) const {
5455 }
5458 }
5459 int64_t res = 0;
5460 const double d_value = static_cast<double>(value);
5461 if (value >= 0) {
5462 const double sq = exp(log(d_value) / pow_);
5463 res = static_cast<int64_t>(ceil(sq));
5464 } else {
5465 CHECK_EQ(1, pow_ % 2);
5466 const double sq = exp(log(-d_value) / pow_);
5467 res = -static_cast<int64_t>(floor(sq));
5468 }
5469 const int64_t pow_res = Pown(res - 1);
5470 if (pow_res >= value) {
5471 return res - 1;
5472 } else {
5473 return res;
5474 }
5475 }
5476
5477 IntExpr* const expr_;
5478 const int64_t pow_;
5479 const int64_t limit_;
5480};
5481
5482class IntEvenPower : public BasePower {
5483 public:
5484 IntEvenPower(Solver* const s, IntExpr* const e, int64_t n)
5485 : BasePower(s, e, n) {
5486 CHECK_EQ(0, n % 2);
5487 }
5488
5489 ~IntEvenPower() override {}
5490
5491 int64_t Min() const override {
5492 int64_t emin = 0;
5493 int64_t emax = 0;
5494 expr_->Range(&emin, &emax);
5495 if (emin >= 0) {
5496 return Pown(emin);
5497 }
5498 if (emax < 0) {
5499 return Pown(emax);
5500 }
5501 return 0LL;
5502 }
5503 void SetMin(int64_t m) override {
5504 if (m <= 0) {
5505 return;
5506 }
5507 int64_t emin = 0;
5508 int64_t emax = 0;
5509 expr_->Range(&emin, &emax);
5510 const int64_t root = SqrnUp(m);
5511 if (emin > -root) {
5512 expr_->SetMin(root);
5513 } else if (emax < root) {
5514 expr_->SetMax(-root);
5515 } else if (expr_->IsVar()) {
5516 reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5517 }
5518 }
5519
5520 int64_t Max() const override {
5521 return std::max(Pown(expr_->Min()), Pown(expr_->Max()));
5522 }
5523
5524 void SetMax(int64_t m) override {
5525 if (m < 0) {
5526 solver()->Fail();
5527 }
5529 return;
5530 }
5531 const int64_t root = SqrnDown(m);
5532 expr_->SetRange(-root, root);
5533 }
5534};
5535
5536class PosIntEvenPower : public BasePower {
5537 public:
5538 PosIntEvenPower(Solver* const s, IntExpr* const e, int64_t pow)
5539 : BasePower(s, e, pow) {
5540 CHECK_EQ(0, pow % 2);
5541 }
5542
5543 ~PosIntEvenPower() override {}
5544
5545 int64_t Min() const override { return Pown(expr_->Min()); }
5546
5547 void SetMin(int64_t m) override {
5548 if (m <= 0) {
5549 return;
5550 }
5551 expr_->SetMin(SqrnUp(m));
5552 }
5553 int64_t Max() const override { return Pown(expr_->Max()); }
5554
5555 void SetMax(int64_t m) override {
5556 if (m < 0) {
5557 solver()->Fail();
5558 }
5560 return;
5561 }
5562 expr_->SetMax(SqrnDown(m));
5563 }
5564};
5565
5566class IntOddPower : public BasePower {
5567 public:
5568 IntOddPower(Solver* const s, IntExpr* const e, int64_t n)
5569 : BasePower(s, e, n) {
5570 CHECK_EQ(1, n % 2);
5571 }
5572
5573 ~IntOddPower() override {}
5574
5575 int64_t Min() const override { return Pown(expr_->Min()); }
5576
5577 void SetMin(int64_t m) override { expr_->SetMin(SqrnUp(m)); }
5578
5579 int64_t Max() const override { return Pown(expr_->Max()); }
5580
5581 void SetMax(int64_t m) override { expr_->SetMax(SqrnDown(m)); }
5582};
5583
5584// ----- Min(expr, expr) -----
5585
5586class MinIntExpr : public BaseIntExpr {
5587 public:
5588 MinIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5589 : BaseIntExpr(s), left_(l), right_(r) {}
5590 ~MinIntExpr() override {}
5591 int64_t Min() const override {
5592 const int64_t lmin = left_->Min();
5593 const int64_t rmin = right_->Min();
5594 return std::min(lmin, rmin);
5595 }
5596 void SetMin(int64_t m) override {
5597 left_->SetMin(m);
5598 right_->SetMin(m);
5599 }
5600 int64_t Max() const override {
5601 const int64_t lmax = left_->Max();
5602 const int64_t rmax = right_->Max();
5603 return std::min(lmax, rmax);
5604 }
5605 void SetMax(int64_t m) override {
5606 if (left_->Min() > m) {
5607 right_->SetMax(m);
5608 }
5609 if (right_->Min() > m) {
5610 left_->SetMax(m);
5611 }
5612 }
5613 std::string name() const override {
5614 return absl::StrFormat("MinIntExpr(%s, %s)", left_->name(), right_->name());
5615 }
5616 std::string DebugString() const override {
5617 return absl::StrFormat("MinIntExpr(%s, %s)", left_->DebugString(),
5618 right_->DebugString());
5619 }
5620 void WhenRange(Demon* d) override {
5621 left_->WhenRange(d);
5622 right_->WhenRange(d);
5623 }
5624
5625 void Accept(ModelVisitor* const visitor) const override {
5626 visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5627 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5628 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5629 right_);
5630 visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5631 }
5632
5633 private:
5634 IntExpr* const left_;
5635 IntExpr* const right_;
5636};
5637
5638// ----- Min(expr, constant) -----
5639
5640class MinCstIntExpr : public BaseIntExpr {
5641 public:
5642 MinCstIntExpr(Solver* const s, IntExpr* const e, int64_t v)
5643 : BaseIntExpr(s), expr_(e), value_(v) {}
5644
5645 ~MinCstIntExpr() override {}
5646
5647 int64_t Min() const override { return std::min(expr_->Min(), value_); }
5648
5649 void SetMin(int64_t m) override {
5650 if (m > value_) {
5651 solver()->Fail();
5652 }
5653 expr_->SetMin(m);
5654 }
5655
5656 int64_t Max() const override { return std::min(expr_->Max(), value_); }
5657
5658 void SetMax(int64_t m) override {
5659 if (value_ > m) {
5660 expr_->SetMax(m);
5661 }
5662 }
5663
5664 bool Bound() const override {
5665 return (expr_->Bound() || expr_->Min() >= value_);
5666 }
5667
5668 std::string name() const override {
5669 return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->name(), value_);
5670 }
5671
5672 std::string DebugString() const override {
5673 return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->DebugString(),
5674 value_);
5675 }
5676
5677 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5678
5679 void Accept(ModelVisitor* const visitor) const override {
5680 visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5681 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5682 expr_);
5683 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5684 visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5685 }
5686
5687 private:
5688 IntExpr* const expr_;
5689 const int64_t value_;
5690};
5691
5692// ----- Max(expr, expr) -----
5693
5694class MaxIntExpr : public BaseIntExpr {
5695 public:
5696 MaxIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5697 : BaseIntExpr(s), left_(l), right_(r) {}
5698
5699 ~MaxIntExpr() override {}
5700
5701 int64_t Min() const override { return std::max(left_->Min(), right_->Min()); }
5702
5703 void SetMin(int64_t m) override {
5704 if (left_->Max() < m) {
5705 right_->SetMin(m);
5706 } else {
5707 if (right_->Max() < m) {
5708 left_->SetMin(m);
5709 }
5710 }
5711 }
5712
5713 int64_t Max() const override { return std::max(left_->Max(), right_->Max()); }
5714
5715 void SetMax(int64_t m) override {
5716 left_->SetMax(m);
5717 right_->SetMax(m);
5718 }
5719
5720 std::string name() const override {
5721 return absl::StrFormat("MaxIntExpr(%s, %s)", left_->name(), right_->name());
5722 }
5723
5724 std::string DebugString() const override {
5725 return absl::StrFormat("MaxIntExpr(%s, %s)", left_->DebugString(),
5726 right_->DebugString());
5727 }
5728
5729 void WhenRange(Demon* d) override {
5730 left_->WhenRange(d);
5731 right_->WhenRange(d);
5732 }
5733
5734 void Accept(ModelVisitor* const visitor) const override {
5735 visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5736 visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5737 visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5738 right_);
5739 visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5740 }
5741
5742 private:
5743 IntExpr* const left_;
5744 IntExpr* const right_;
5745};
5746
5747// ----- Max(expr, constant) -----
5748
5749class MaxCstIntExpr : public BaseIntExpr {
5750 public:
5751 MaxCstIntExpr(Solver* const s, IntExpr* const e, int64_t v)
5752 : BaseIntExpr(s), expr_(e), value_(v) {}
5753
5754 ~MaxCstIntExpr() override {}
5755
5756 int64_t Min() const override { return std::max(expr_->Min(), value_); }
5757
5758 void SetMin(int64_t m) override {
5759 if (value_ < m) {
5760 expr_->SetMin(m);
5761 }
5762 }
5763
5764 int64_t Max() const override { return std::max(expr_->Max(), value_); }
5765
5766 void SetMax(int64_t m) override {
5767 if (m < value_) {
5768 solver()->Fail();
5769 }
5770 expr_->SetMax(m);
5771 }
5772
5773 bool Bound() const override {
5774 return (expr_->Bound() || expr_->Max() <= value_);
5775 }
5776
5777 std::string name() const override {
5778 return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->name(), value_);
5779 }
5780
5781 std::string DebugString() const override {
5782 return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->DebugString(),
5783 value_);
5784 }
5785
5786 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5787
5788 void Accept(ModelVisitor* const visitor) const override {
5789 visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5790 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5791 expr_);
5792 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5793 visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5794 }
5795
5796 private:
5797 IntExpr* const expr_;
5798 const int64_t value_;
5799};
5800
5801// ----- Convex Piecewise -----
5802
5803// This class is a very simple convex piecewise linear function. The
5804// argument of the function is the expression. Between early_date and
5805// late_date, the value of the function is 0. Before early date, it
5806// is affine and the cost is early_cost * (early_date - x). After
5807// late_date, the cost is late_cost * (x - late_date).
5808
5809class SimpleConvexPiecewiseExpr : public BaseIntExpr {
5810 public:
5811 SimpleConvexPiecewiseExpr(Solver* const s, IntExpr* const e, int64_t ec,
5812 int64_t ed, int64_t ld, int64_t lc)
5813 : BaseIntExpr(s),
5814 expr_(e),
5815 early_cost_(ec),
5816 early_date_(ec == 0 ? std::numeric_limits<int64_t>::min() : ed),
5817 late_date_(lc == 0 ? std::numeric_limits<int64_t>::max() : ld),
5818 late_cost_(lc) {
5819 DCHECK_GE(ec, int64_t{0});
5820 DCHECK_GE(lc, int64_t{0});
5821 DCHECK_GE(ld, ed);
5822
5823 // If the penalty is 0, we can push the "confort zone or zone
5824 // of no cost towards infinity.
5825 }
5826
5827 ~SimpleConvexPiecewiseExpr() override {}
5828
5829 int64_t Min() const override {
5830 const int64_t vmin = expr_->Min();
5831 const int64_t vmax = expr_->Max();
5832 if (vmin >= late_date_) {
5833 return (vmin - late_date_) * late_cost_;
5834 } else if (vmax <= early_date_) {
5835 return (early_date_ - vmax) * early_cost_;
5836 } else {
5837 return 0LL;
5838 }
5839 }
5840
5841 void SetMin(int64_t m) override {
5842 if (m <= 0) {
5843 return;
5844 }
5845 int64_t vmin = 0;
5846 int64_t vmax = 0;
5847 expr_->Range(&vmin, &vmax);
5848
5849 const int64_t rb =
5850 (late_cost_ == 0 ? vmax : late_date_ + PosIntDivUp(m, late_cost_) - 1);
5851 const int64_t lb =
5852 (early_cost_ == 0 ? vmin
5853 : early_date_ - PosIntDivUp(m, early_cost_) + 1);
5854
5855 if (expr_->IsVar()) {
5856 expr_->Var()->RemoveInterval(lb, rb);
5857 }
5858 }
5859
5860 int64_t Max() const override {
5861 const int64_t vmin = expr_->Min();
5862 const int64_t vmax = expr_->Max();
5863 const int64_t mr = vmax > late_date_ ? (vmax - late_date_) * late_cost_ : 0;
5864 const int64_t ml =
5865 vmin < early_date_ ? (early_date_ - vmin) * early_cost_ : 0;
5866 return std::max(mr, ml);
5867 }
5868
5869 void SetMax(int64_t m) override {
5870 if (m < 0) {
5871 solver()->Fail();
5872 }
5873 if (late_cost_ != 0LL) {
5874 const int64_t rb = late_date_ + PosIntDivDown(m, late_cost_);
5875 if (early_cost_ != 0LL) {
5876 const int64_t lb = early_date_ - PosIntDivDown(m, early_cost_);
5877 expr_->SetRange(lb, rb);
5878 } else {
5879 expr_->SetMax(rb);
5880 }
5881 } else {
5882 if (early_cost_ != 0LL) {
5883 const int64_t lb = early_date_ - PosIntDivDown(m, early_cost_);
5884 expr_->SetMin(lb);
5885 }
5886 }
5887 }
5888
5889 std::string name() const override {
5890 return absl::StrFormat(
5891 "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5892 expr_->name(), early_cost_, early_date_, late_date_, late_cost_);
5893 }
5894
5895 std::string DebugString() const override {
5896 return absl::StrFormat(
5897 "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5898 expr_->DebugString(), early_cost_, early_date_, late_date_, late_cost_);
5899 }
5900
5901 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5902
5903 void Accept(ModelVisitor* const visitor) const override {
5904 visitor->BeginVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5905 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5906 expr_);
5907 visitor->VisitIntegerArgument(ModelVisitor::kEarlyCostArgument,
5908 early_cost_);
5909 visitor->VisitIntegerArgument(ModelVisitor::kEarlyDateArgument,
5910 early_date_);
5911 visitor->VisitIntegerArgument(ModelVisitor::kLateCostArgument, late_cost_);
5912 visitor->VisitIntegerArgument(ModelVisitor::kLateDateArgument, late_date_);
5913 visitor->EndVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5914 }
5915
5916 private:
5917 IntExpr* const expr_;
5918 const int64_t early_cost_;
5919 const int64_t early_date_;
5920 const int64_t late_date_;
5921 const int64_t late_cost_;
5922};
5923
5924// ----- Semi Continuous -----
5925
5926class SemiContinuousExpr : public BaseIntExpr {
5927 public:
5928 SemiContinuousExpr(Solver* const s, IntExpr* const e, int64_t fixed_charge,
5929 int64_t step)
5930 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge), step_(step) {
5931 DCHECK_GE(fixed_charge, int64_t{0});
5932 DCHECK_GT(step, int64_t{0});
5933 }
5934
5935 ~SemiContinuousExpr() override {}
5936
5937 int64_t Value(int64_t x) const {
5938 if (x <= 0) {
5939 return 0;
5940 } else {
5941 return CapAdd(fixed_charge_, CapProd(x, step_));
5942 }
5943 }
5944
5945 int64_t Min() const override { return Value(expr_->Min()); }
5946
5947 void SetMin(int64_t m) override {
5948 if (m >= CapAdd(fixed_charge_, step_)) {
5949 const int64_t y = PosIntDivUp(CapSub(m, fixed_charge_), step_);
5950 expr_->SetMin(y);
5951 } else if (m > 0) {
5952 expr_->SetMin(1);
5953 }
5954 }
5955
5956 int64_t Max() const override { return Value(expr_->Max()); }
5957
5958 void SetMax(int64_t m) override {
5959 if (m < 0) {
5960 solver()->Fail();
5961 }
5963 return;
5964 }
5965 if (m < CapAdd(fixed_charge_, step_)) {
5966 expr_->SetMax(0);
5967 } else {
5968 const int64_t y = PosIntDivDown(CapSub(m, fixed_charge_), step_);
5969 expr_->SetMax(y);
5970 }
5971 }
5972
5973 std::string name() const override {
5974 return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
5975 expr_->name(), fixed_charge_, step_);
5976 }
5977
5978 std::string DebugString() const override {
5979 return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
5980 expr_->DebugString(), fixed_charge_, step_);
5981 }
5982
5983 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5984
5985 void Accept(ModelVisitor* const visitor) const override {
5986 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
5987 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5988 expr_);
5989 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
5990 fixed_charge_);
5991 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, step_);
5992 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
5993 }
5994
5995 private:
5996 IntExpr* const expr_;
5997 const int64_t fixed_charge_;
5998 const int64_t step_;
5999};
6000
6001class SemiContinuousStepOneExpr : public BaseIntExpr {
6002 public:
6003 SemiContinuousStepOneExpr(Solver* const s, IntExpr* const e,
6004 int64_t fixed_charge)
6005 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
6006 DCHECK_GE(fixed_charge, int64_t{0});
6007 }
6008
6009 ~SemiContinuousStepOneExpr() override {}
6010
6011 int64_t Value(int64_t x) const {
6012 if (x <= 0) {
6013 return 0;
6014 } else {
6015 return fixed_charge_ + x;
6016 }
6017 }
6018
6019 int64_t Min() const override { return Value(expr_->Min()); }
6020
6021 void SetMin(int64_t m) override {
6022 if (m >= fixed_charge_ + 1) {
6023 expr_->SetMin(m - fixed_charge_);
6024 } else if (m > 0) {
6025 expr_->SetMin(1);
6026 }
6027 }
6028
6029 int64_t Max() const override { return Value(expr_->Max()); }
6030
6031 void SetMax(int64_t m) override {
6032 if (m < 0) {
6033 solver()->Fail();
6034 }
6035 if (m < fixed_charge_ + 1) {
6036 expr_->SetMax(0);
6037 } else {
6038 expr_->SetMax(m - fixed_charge_);
6039 }
6040 }
6041
6042 std::string name() const override {
6043 return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6044 expr_->name(), fixed_charge_);
6045 }
6046
6047 std::string DebugString() const override {
6048 return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6049 expr_->DebugString(), fixed_charge_);
6050 }
6051
6052 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6053
6054 void Accept(ModelVisitor* const visitor) const override {
6055 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6056 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6057 expr_);
6058 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6059 fixed_charge_);
6060 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 1);
6061 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6062 }
6063
6064 private:
6065 IntExpr* const expr_;
6066 const int64_t fixed_charge_;
6067};
6068
6069class SemiContinuousStepZeroExpr : public BaseIntExpr {
6070 public:
6071 SemiContinuousStepZeroExpr(Solver* const s, IntExpr* const e,
6072 int64_t fixed_charge)
6073 : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
6074 DCHECK_GT(fixed_charge, int64_t{0});
6075 }
6076
6077 ~SemiContinuousStepZeroExpr() override {}
6078
6079 int64_t Value(int64_t x) const {
6080 if (x <= 0) {
6081 return 0;
6082 } else {
6083 return fixed_charge_;
6084 }
6085 }
6086
6087 int64_t Min() const override { return Value(expr_->Min()); }
6088
6089 void SetMin(int64_t m) override {
6090 if (m >= fixed_charge_) {
6091 solver()->Fail();
6092 } else if (m > 0) {
6093 expr_->SetMin(1);
6094 }
6095 }
6096
6097 int64_t Max() const override { return Value(expr_->Max()); }
6098
6099 void SetMax(int64_t m) override {
6100 if (m < 0) {
6101 solver()->Fail();
6102 }
6103 if (m < fixed_charge_) {
6104 expr_->SetMax(0);
6105 }
6106 }
6107
6108 std::string name() const override {
6109 return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6110 expr_->name(), fixed_charge_);
6111 }
6112
6113 std::string DebugString() const override {
6114 return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6115 expr_->DebugString(), fixed_charge_);
6116 }
6117
6118 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6119
6120 void Accept(ModelVisitor* const visitor) const override {
6121 visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6122 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6123 expr_);
6124 visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6125 fixed_charge_);
6126 visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 0);
6127 visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6128 }
6129
6130 private:
6131 IntExpr* const expr_;
6132 const int64_t fixed_charge_;
6133};
6134
6135// This constraints links an expression and the variable it is casted into
6136class LinkExprAndVar : public CastConstraint {
6137 public:
6138 LinkExprAndVar(Solver* const s, IntExpr* const expr, IntVar* const var)
6139 : CastConstraint(s, var), expr_(expr) {}
6140
6141 ~LinkExprAndVar() override {}
6142
6143 void Post() override {
6144 Solver* const s = solver();
6145 Demon* d = s->MakeConstraintInitialPropagateCallback(this);
6146 expr_->WhenRange(d);
6147 target_var_->WhenRange(d);
6148 }
6149
6150 void InitialPropagate() override {
6151 expr_->SetRange(target_var_->Min(), target_var_->Max());
6152 int64_t l, u;
6153 expr_->Range(&l, &u);
6154 target_var_->SetRange(l, u);
6155 }
6156
6157 std::string DebugString() const override {
6158 return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6159 target_var_->DebugString());
6160 }
6161
6162 void Accept(ModelVisitor* const visitor) const override {
6163 visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6164 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6165 expr_);
6166 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6167 target_var_);
6168 visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6169 }
6170
6171 private:
6172 IntExpr* const expr_;
6173};
6174
6175// ----- Conditional Expression -----
6176
6177class ExprWithEscapeValue : public BaseIntExpr {
6178 public:
6179 ExprWithEscapeValue(Solver* const s, IntVar* const c, IntExpr* const e,
6180 int64_t unperformed_value)
6181 : BaseIntExpr(s),
6182 condition_(c),
6183 expression_(e),
6184 unperformed_value_(unperformed_value) {}
6185
6186 ~ExprWithEscapeValue() override {}
6187
6188 int64_t Min() const override {
6189 if (condition_->Min() == 1) {
6190 return expression_->Min();
6191 } else if (condition_->Max() == 1) {
6192 return std::min(unperformed_value_, expression_->Min());
6193 } else {
6194 return unperformed_value_;
6195 }
6196 }
6197
6198 void SetMin(int64_t m) override {
6199 if (m > unperformed_value_) {
6200 condition_->SetValue(1);
6201 expression_->SetMin(m);
6202 } else if (condition_->Min() == 1) {
6203 expression_->SetMin(m);
6204 } else if (m > expression_->Max()) {
6205 condition_->SetValue(0);
6206 }
6207 }
6208
6209 int64_t Max() const override {
6210 if (condition_->Min() == 1) {
6211 return expression_->Max();
6212 } else if (condition_->Max() == 1) {
6213 return std::max(unperformed_value_, expression_->Max());
6214 } else {
6215 return unperformed_value_;
6216 }
6217 }
6218
6219 void SetMax(int64_t m) override {
6220 if (m < unperformed_value_) {
6221 condition_->SetValue(1);
6222 expression_->SetMax(m);
6223 } else if (condition_->Min() == 1) {
6224 expression_->SetMax(m);
6225 } else if (m < expression_->Min()) {
6226 condition_->SetValue(0);
6227 }
6228 }
6229
6230 void SetRange(int64_t mi, int64_t ma) override {
6231 if (ma < unperformed_value_ || mi > unperformed_value_) {
6232 condition_->SetValue(1);
6233 expression_->SetRange(mi, ma);
6234 } else if (condition_->Min() == 1) {
6235 expression_->SetRange(mi, ma);
6236 } else if (ma < expression_->Min() || mi > expression_->Max()) {
6237 condition_->SetValue(0);
6238 }
6239 }
6240
6241 void SetValue(int64_t v) override {
6242 if (v != unperformed_value_) {
6243 condition_->SetValue(1);
6244 expression_->SetValue(v);
6245 } else if (condition_->Min() == 1) {
6246 expression_->SetValue(v);
6247 } else if (v < expression_->Min() || v > expression_->Max()) {
6248 condition_->SetValue(0);
6249 }
6250 }
6251
6252 bool Bound() const override {
6253 return condition_->Max() == 0 || expression_->Bound();
6254 }
6255
6256 void WhenRange(Demon* d) override {
6257 expression_->WhenRange(d);
6258 condition_->WhenBound(d);
6259 }
6260
6261 std::string DebugString() const override {
6262 return absl::StrFormat("ConditionExpr(%s, %s, %d)",
6263 condition_->DebugString(),
6264 expression_->DebugString(), unperformed_value_);
6265 }
6266
6267 void Accept(ModelVisitor* const visitor) const override {
6268 visitor->BeginVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6269 visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
6270 condition_);
6271 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6272 expression_);
6273 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument,
6274 unperformed_value_);
6275 visitor->EndVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6276 }
6277
6278 private:
6279 IntVar* const condition_;
6280 IntExpr* const expression_;
6281 const int64_t unperformed_value_;
6282 DISALLOW_COPY_AND_ASSIGN(ExprWithEscapeValue);
6283};
6284
6285// ----- This is a specialized case when the variable exact type is known -----
6286class LinkExprAndDomainIntVar : public CastConstraint {
6287 public:
6288 LinkExprAndDomainIntVar(Solver* const s, IntExpr* const expr,
6289 DomainIntVar* const var)
6290 : CastConstraint(s, var),
6291 expr_(expr),
6292 cached_min_(std::numeric_limits<int64_t>::min()),
6293 cached_max_(std::numeric_limits<int64_t>::max()),
6294 fail_stamp_(uint64_t{0}) {}
6295
6296 ~LinkExprAndDomainIntVar() override {}
6297
6298 DomainIntVar* var() const {
6299 return reinterpret_cast<DomainIntVar*>(target_var_);
6300 }
6301
6302 void Post() override {
6303 Solver* const s = solver();
6304 Demon* const d = s->MakeConstraintInitialPropagateCallback(this);
6305 expr_->WhenRange(d);
6306 Demon* const target_var_demon = MakeConstraintDemon0(
6307 solver(), this, &LinkExprAndDomainIntVar::Propagate, "Propagate");
6308 target_var_->WhenRange(target_var_demon);
6309 }
6310
6311 void InitialPropagate() override {
6312 expr_->SetRange(var()->min_.Value(), var()->max_.Value());
6313 expr_->Range(&cached_min_, &cached_max_);
6314 var()->DomainIntVar::SetRange(cached_min_, cached_max_);
6315 }
6316
6317 void Propagate() {
6318 if (var()->min_.Value() > cached_min_ ||
6319 var()->max_.Value() < cached_max_ ||
6320 solver()->fail_stamp() != fail_stamp_) {
6321 InitialPropagate();
6322 fail_stamp_ = solver()->fail_stamp();
6323 }
6324 }
6325
6326 std::string DebugString() const override {
6327 return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6328 target_var_->DebugString());
6329 }
6330
6331 void Accept(ModelVisitor* const visitor) const override {
6332 visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6333 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6334 expr_);
6335 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6336 target_var_);
6337 visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6338 }
6339
6340 private:
6341 IntExpr* const expr_;
6342 int64_t cached_min_;
6343 int64_t cached_max_;
6344 uint64_t fail_stamp_;
6345};
6346} // namespace
6347
6348// ----- Misc -----
6349
6350IntVarIterator* BooleanVar::MakeHoleIterator(bool reversible) const {
6351 return COND_REV_ALLOC(reversible, new EmptyIterator());
6352}
6353IntVarIterator* BooleanVar::MakeDomainIterator(bool reversible) const {
6354 return COND_REV_ALLOC(reversible, new RangeIterator(this));
6355}
6356
6357// ----- API -----
6358
6360 DCHECK_EQ(DOMAIN_INT_VAR, var->VarType());
6361 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6362 dvar->CleanInProcess();
6363}
6364
6365Constraint* SetIsEqual(IntVar* const var, const std::vector<int64_t>& values,
6366 const std::vector<IntVar*>& vars) {
6367 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6368 CHECK(dvar != nullptr);
6369 return dvar->SetIsEqual(values, vars);
6370}
6371
6373 const std::vector<int64_t>& values,
6374 const std::vector<IntVar*>& vars) {
6375 DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6376 CHECK(dvar != nullptr);
6377 return dvar->SetIsGreaterOrEqual(values, vars);
6378}
6379
6381 DCHECK_EQ(BOOLEAN_VAR, var->VarType());
6382 BooleanVar* const boolean_var = reinterpret_cast<BooleanVar*>(var);
6383 boolean_var->RestoreValue();
6384}
6385
6386// ----- API -----
6387
6388IntVar* Solver::MakeIntVar(int64_t min, int64_t max, const std::string& name) {
6389 if (min == max) {
6390 return MakeIntConst(min, name);
6391 }
6392 if (min == 0 && max == 1) {
6393 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6394 } else if (CapSub(max, min) == 1) {
6395 const std::string inner_name = "inner_" + name;
6396 return RegisterIntVar(
6397 MakeSum(RevAlloc(new ConcreteBooleanVar(this, inner_name)), min)
6398 ->VarWithName(name));
6399 } else {
6400 return RegisterIntVar(RevAlloc(new DomainIntVar(this, min, max, name)));
6401 }
6402}
6403
6404IntVar* Solver::MakeIntVar(int64_t min, int64_t max) {
6405 return MakeIntVar(min, max, "");
6406}
6407
6408IntVar* Solver::MakeBoolVar(const std::string& name) {
6409 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6410}
6411
6412IntVar* Solver::MakeBoolVar() {
6413 return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, "")));
6414}
6415
6416IntVar* Solver::MakeIntVar(const std::vector<int64_t>& values,
6417 const std::string& name) {
6418 DCHECK(!values.empty());
6419 // Fast-track the case where we have a single value.
6420 if (values.size() == 1) return MakeIntConst(values[0], name);
6421 // Sort and remove duplicates.
6422 std::vector<int64_t> unique_sorted_values = values;
6423 gtl::STLSortAndRemoveDuplicates(&unique_sorted_values);
6424 // Case when we have a single value, after clean-up.
6425 if (unique_sorted_values.size() == 1) return MakeIntConst(values[0], name);
6426 // Case when the values are a dense interval of integers.
6427 if (unique_sorted_values.size() ==
6428 unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6429 return MakeIntVar(unique_sorted_values.front(), unique_sorted_values.back(),
6430 name);
6431 }
6432 // Compute the GCD: if it's not 1, we can express the variable's domain as
6433 // the product of the GCD and of a domain with smaller values.
6434 int64_t gcd = 0;
6435 for (const int64_t v : unique_sorted_values) {
6436 if (gcd == 0) {
6437 gcd = std::abs(v);
6438 } else {
6439 gcd = MathUtil::GCD64(gcd, std::abs(v)); // Supports v==0.
6440 }
6441 if (gcd == 1) {
6442 // If it's 1, though, we can't do anything special, so we
6443 // immediately return a new DomainIntVar.
6444 return RegisterIntVar(
6445 RevAlloc(new DomainIntVar(this, unique_sorted_values, name)));
6446 }
6447 }
6448 DCHECK_GT(gcd, 1);
6449 for (int64_t& v : unique_sorted_values) {
6450 DCHECK_EQ(0, v % gcd);
6451 v /= gcd;
6452 }
6453 const std::string new_name = name.empty() ? "" : "inner_" + name;
6454 // Catch the case where the divided values are a dense set of integers.
6455 IntVar* inner_intvar = nullptr;
6456 if (unique_sorted_values.size() ==
6457 unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6458 inner_intvar = MakeIntVar(unique_sorted_values.front(),
6459 unique_sorted_values.back(), new_name);
6460 } else {
6461 inner_intvar = RegisterIntVar(
6462 RevAlloc(new DomainIntVar(this, unique_sorted_values, new_name)));
6463 }
6464 return MakeProd(inner_intvar, gcd)->Var();
6465}
6466
6467IntVar* Solver::MakeIntVar(const std::vector<int64_t>& values) {
6468 return MakeIntVar(values, "");
6469}
6470
6471IntVar* Solver::MakeIntVar(const std::vector<int>& values,
6472 const std::string& name) {
6473 return MakeIntVar(ToInt64Vector(values), name);
6474}
6475
6476IntVar* Solver::MakeIntVar(const std::vector<int>& values) {
6477 return MakeIntVar(values, "");
6478}
6479
6480IntVar* Solver::MakeIntConst(int64_t val, const std::string& name) {
6481 // If IntConst is going to be named after its creation,
6482 // cp_share_int_consts should be set to false otherwise names can potentially
6483 // be overwritten.
6484 if (absl::GetFlag(FLAGS_cp_share_int_consts) && name.empty() &&
6485 val >= MIN_CACHED_INT_CONST && val <= MAX_CACHED_INT_CONST) {
6486 return cached_constants_[val - MIN_CACHED_INT_CONST];
6487 }
6488 return RevAlloc(new IntConst(this, val, name));
6489}
6490
6491IntVar* Solver::MakeIntConst(int64_t val) { return MakeIntConst(val, ""); }
6492
6493// ----- Int Var and associated methods -----
6494
6495namespace {
6496std::string IndexedName(const std::string& prefix, int index, int max_index) {
6497#if 0
6498#if defined(_MSC_VER)
6499 const int digits = max_index > 0 ?
6500 static_cast<int>(log(1.0L * max_index) / log(10.0L)) + 1 :
6501 1;
6502#else
6503 const int digits = max_index > 0 ? static_cast<int>(log10(max_index)) + 1: 1;
6504#endif
6505 return absl::StrFormat("%s%0*d", prefix, digits, index);
6506#else
6507 return absl::StrCat(prefix, index);
6508#endif
6509}
6510} // namespace
6511
6512void Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6513 const std::string& name,
6514 std::vector<IntVar*>* vars) {
6515 for (int i = 0; i < var_count; ++i) {
6516 vars->push_back(MakeIntVar(vmin, vmax, IndexedName(name, i, var_count)));
6517 }
6518}
6519
6520void Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6521 std::vector<IntVar*>* vars) {
6522 for (int i = 0; i < var_count; ++i) {
6523 vars->push_back(MakeIntVar(vmin, vmax));
6524 }
6525}
6526
6527IntVar** Solver::MakeIntVarArray(int var_count, int64_t vmin, int64_t vmax,
6528 const std::string& name) {
6529 IntVar** vars = new IntVar*[var_count];
6530 for (int i = 0; i < var_count; ++i) {
6531 vars[i] = MakeIntVar(vmin, vmax, IndexedName(name, i, var_count));
6532 }
6533 return vars;
6534}
6535
6536void Solver::MakeBoolVarArray(int var_count, const std::string& name,
6537 std::vector<IntVar*>* vars) {
6538 for (int i = 0; i < var_count; ++i) {
6539 vars->push_back(MakeBoolVar(IndexedName(name, i, var_count)));
6540 }
6541}
6542
6543void Solver::MakeBoolVarArray(int var_count, std::vector<IntVar*>* vars) {
6544 for (int i = 0; i < var_count; ++i) {
6545 vars->push_back(MakeBoolVar());
6546 }
6547}
6548
6549IntVar** Solver::MakeBoolVarArray(int var_count, const std::string& name) {
6550 IntVar** vars = new IntVar*[var_count];
6551 for (int i = 0; i < var_count; ++i) {
6552 vars[i] = MakeBoolVar(IndexedName(name, i, var_count));
6553 }
6554 return vars;
6555}
6556
6557void Solver::InitCachedIntConstants() {
6558 for (int i = MIN_CACHED_INT_CONST; i <= MAX_CACHED_INT_CONST; ++i) {
6559 cached_constants_[i - MIN_CACHED_INT_CONST] =
6560 RevAlloc(new IntConst(this, i, "")); // note the empty name
6561 }
6562}
6563
6564IntExpr* Solver::MakeSum(IntExpr* const left, IntExpr* const right) {
6565 CHECK_EQ(this, left->solver());
6566 CHECK_EQ(this, right->solver());
6567 if (right->Bound()) {
6568 return MakeSum(left, right->Min());
6569 }
6570 if (left->Bound()) {
6571 return MakeSum(right, left->Min());
6572 }
6573 if (left == right) {
6574 return MakeProd(left, 2);
6575 }
6576 IntExpr* cache = model_cache_->FindExprExprExpression(
6577 left, right, ModelCache::EXPR_EXPR_SUM);
6578 if (cache == nullptr) {
6579 cache = model_cache_->FindExprExprExpression(right, left,
6580 ModelCache::EXPR_EXPR_SUM);
6581 }
6582 if (cache != nullptr) {
6583 return cache;
6584 } else {
6585 IntExpr* const result =
6586 AddOverflows(left->Max(), right->Max()) ||
6587 AddOverflows(left->Min(), right->Min())
6588 ? RegisterIntExpr(RevAlloc(new SafePlusIntExpr(this, left, right)))
6589 : RegisterIntExpr(RevAlloc(new PlusIntExpr(this, left, right)));
6590 model_cache_->InsertExprExprExpression(result, left, right,
6591 ModelCache::EXPR_EXPR_SUM);
6592 return result;
6593 }
6594}
6595
6596IntExpr* Solver::MakeSum(IntExpr* const expr, int64_t value) {
6597 CHECK_EQ(this, expr->solver());
6598 if (expr->Bound()) {
6599 return MakeIntConst(expr->Min() + value);
6600 }
6601 if (value == 0) {
6602 return expr;
6603 }
6604 IntExpr* result = Cache()->FindExprConstantExpression(
6605 expr, value, ModelCache::EXPR_CONSTANT_SUM);
6606 if (result == nullptr) {
6607 if (expr->IsVar() && !AddOverflows(value, expr->Max()) &&
6608 !AddOverflows(value, expr->Min())) {
6609 IntVar* const var = expr->Var();
6610 switch (var->VarType()) {
6611 case DOMAIN_INT_VAR: {
6612 result = RegisterIntExpr(RevAlloc(new PlusCstDomainIntVar(
6613 this, reinterpret_cast<DomainIntVar*>(var), value)));
6614 break;
6615 }
6616 case CONST_VAR: {
6617 result = RegisterIntExpr(MakeIntConst(var->Min() + value));
6618 break;
6619 }
6620 case VAR_ADD_CST: {
6621 PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6622 IntVar* const sub_var = add_var->SubVar();
6623 const int64_t new_constant = value + add_var->Constant();
6624 if (new_constant == 0) {
6625 result = sub_var;
6626 } else {
6627 if (sub_var->VarType() == DOMAIN_INT_VAR) {
6628 DomainIntVar* const dvar =
6629 reinterpret_cast<DomainIntVar*>(sub_var);
6630 result = RegisterIntExpr(
6631 RevAlloc(new PlusCstDomainIntVar(this, dvar, new_constant)));
6632 } else {
6633 result = RegisterIntExpr(
6634 RevAlloc(new PlusCstIntVar(this, sub_var, new_constant)));
6635 }
6636 }
6637 break;
6638 }
6639 case CST_SUB_VAR: {
6640 SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6641 IntVar* const sub_var = add_var->SubVar();
6642 const int64_t new_constant = value + add_var->Constant();
6643 result = RegisterIntExpr(
6644 RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6645 break;
6646 }
6647 case OPP_VAR: {
6648 OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6649 IntVar* const sub_var = add_var->SubVar();
6650 result =
6651 RegisterIntExpr(RevAlloc(new SubCstIntVar(this, sub_var, value)));
6652 break;
6653 }
6654 default:
6655 result =
6656 RegisterIntExpr(RevAlloc(new PlusCstIntVar(this, var, value)));
6657 }
6658 } else {
6659 result = RegisterIntExpr(RevAlloc(new PlusIntCstExpr(this, expr, value)));
6660 }
6661 Cache()->InsertExprConstantExpression(result, expr, value,
6662 ModelCache::EXPR_CONSTANT_SUM);
6663 }
6664 return result;
6665}
6666
6667IntExpr* Solver::MakeDifference(IntExpr* const left, IntExpr* const right) {
6668 CHECK_EQ(this, left->solver());
6669 CHECK_EQ(this, right->solver());
6670 if (left->Bound()) {
6671 return MakeDifference(left->Min(), right);
6672 }
6673 if (right->Bound()) {
6674 return MakeSum(left, -right->Min());
6675 }
6676 IntExpr* sub_left = nullptr;
6677 IntExpr* sub_right = nullptr;
6678 int64_t left_coef = 1;
6679 int64_t right_coef = 1;
6680 if (IsProduct(left, &sub_left, &left_coef) &&
6681 IsProduct(right, &sub_right, &right_coef)) {
6682 const int64_t abs_gcd =
6683 MathUtil::GCD64(std::abs(left_coef), std::abs(right_coef));
6684 if (abs_gcd != 0 && abs_gcd != 1) {
6685 return MakeProd(MakeDifference(MakeProd(sub_left, left_coef / abs_gcd),
6686 MakeProd(sub_right, right_coef / abs_gcd)),
6687 abs_gcd);
6688 }
6689 }
6690
6691 IntExpr* result = Cache()->FindExprExprExpression(
6692 left, right, ModelCache::EXPR_EXPR_DIFFERENCE);
6693 if (result == nullptr) {
6694 if (!SubOverflows(left->Min(), right->Max()) &&
6695 !SubOverflows(left->Max(), right->Min())) {
6696 result = RegisterIntExpr(RevAlloc(new SubIntExpr(this, left, right)));
6697 } else {
6698 result = RegisterIntExpr(RevAlloc(new SafeSubIntExpr(this, left, right)));
6699 }
6700 Cache()->InsertExprExprExpression(result, left, right,
6701 ModelCache::EXPR_EXPR_DIFFERENCE);
6702 }
6703 return result;
6704}
6705
6706// warning: this is 'value - expr'.
6707IntExpr* Solver::MakeDifference(int64_t value, IntExpr* const expr) {
6708 CHECK_EQ(this, expr->solver());
6709 if (expr->Bound()) {
6710 return MakeIntConst(value - expr->Min());
6711 }
6712 if (value == 0) {
6713 return MakeOpposite(expr);
6714 }
6715 IntExpr* result = Cache()->FindExprConstantExpression(
6716 expr, value, ModelCache::EXPR_CONSTANT_DIFFERENCE);
6717 if (result == nullptr) {
6718 if (expr->IsVar() && expr->Min() != std::numeric_limits<int64_t>::min() &&
6719 !SubOverflows(value, expr->Min()) &&
6720 !SubOverflows(value, expr->Max())) {
6721 IntVar* const var = expr->Var();
6722 switch (var->VarType()) {
6723 case VAR_ADD_CST: {
6724 PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6725 IntVar* const sub_var = add_var->SubVar();
6726 const int64_t new_constant = value - add_var->Constant();
6727 if (new_constant == 0) {
6728 result = sub_var;
6729 } else {
6730 result = RegisterIntExpr(
6731 RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6732 }
6733 break;
6734 }
6735 case CST_SUB_VAR: {
6736 SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6737 IntVar* const sub_var = add_var->SubVar();
6738 const int64_t new_constant = value - add_var->Constant();
6739 result = MakeSum(sub_var, new_constant);
6740 break;
6741 }
6742 case OPP_VAR: {
6743 OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6744 IntVar* const sub_var = add_var->SubVar();
6745 result = MakeSum(sub_var, value);
6746 break;
6747 }
6748 default:
6749 result =
6750 RegisterIntExpr(RevAlloc(new SubCstIntVar(this, var, value)));
6751 }
6752 } else {
6753 result = RegisterIntExpr(RevAlloc(new SubIntCstExpr(this, expr, value)));
6754 }
6755 Cache()->InsertExprConstantExpression(result, expr, value,
6756 ModelCache::EXPR_CONSTANT_DIFFERENCE);
6757 }
6758 return result;
6759}
6760
6761IntExpr* Solver::MakeOpposite(IntExpr* const expr) {
6762 CHECK_EQ(this, expr->solver());
6763 if (expr->Bound()) {
6764 return MakeIntConst(-expr->Min());
6765 }
6766 IntExpr* result =
6767 Cache()->FindExprExpression(expr, ModelCache::EXPR_OPPOSITE);
6768 if (result == nullptr) {
6769 if (expr->IsVar()) {
6770 result = RegisterIntVar(RevAlloc(new OppIntExpr(this, expr))->Var());
6771 } else {
6772 result = RegisterIntExpr(RevAlloc(new OppIntExpr(this, expr)));
6773 }
6774 Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_OPPOSITE);
6775 }
6776 return result;
6777}
6778
6779IntExpr* Solver::MakeProd(IntExpr* const expr, int64_t value) {
6780 CHECK_EQ(this, expr->solver());
6781 IntExpr* result = Cache()->FindExprConstantExpression(
6782 expr, value, ModelCache::EXPR_CONSTANT_PROD);
6783 if (result != nullptr) {
6784 return result;
6785 } else {
6786 IntExpr* m_expr = nullptr;
6787 int64_t coefficient = 1;
6788 if (IsProduct(expr, &m_expr, &coefficient)) {
6789 coefficient *= value;
6790 } else {
6791 m_expr = expr;
6793 }
6794 if (m_expr->Bound()) {
6795 return MakeIntConst(coefficient * m_expr->Min());
6796 } else if (coefficient == 1) {
6797 return m_expr;
6798 } else if (coefficient == -1) {
6799 return MakeOpposite(m_expr);
6800 } else if (coefficient > 0) {
6803 result = RegisterIntExpr(
6804 RevAlloc(new SafeTimesPosIntCstExpr(this, m_expr, coefficient)));
6805 } else {
6806 result = RegisterIntExpr(
6807 RevAlloc(new TimesPosIntCstExpr(this, m_expr, coefficient)));
6808 }
6809 } else if (coefficient == 0) {
6810 result = MakeIntConst(0);
6811 } else { // coefficient < 0.
6812 result = RegisterIntExpr(
6813 RevAlloc(new TimesIntNegCstExpr(this, m_expr, coefficient)));
6814 }
6815 if (m_expr->IsVar() &&
6816 !absl::GetFlag(FLAGS_cp_disable_expression_optimization)) {
6817 result = result->Var();
6818 }
6819 Cache()->InsertExprConstantExpression(result, expr, value,
6820 ModelCache::EXPR_CONSTANT_PROD);
6821 return result;
6822 }
6823}
6824
6825namespace {
6826void ExtractPower(IntExpr** const expr, int64_t* const exponant) {
6827 if (dynamic_cast<BasePower*>(*expr) != nullptr) {
6828 BasePower* const power = dynamic_cast<BasePower*>(*expr);
6829 *expr = power->expr();
6830 *exponant = power->exponant();
6831 }
6832 if (dynamic_cast<IntSquare*>(*expr) != nullptr) {
6833 IntSquare* const power = dynamic_cast<IntSquare*>(*expr);
6834 *expr = power->expr();
6835 *exponant = 2;
6836 }
6837 if ((*expr)->IsVar()) {
6838 IntVar* const var = (*expr)->Var();
6839 IntExpr* const sub = var->solver()->CastExpression(var);
6840 if (sub != nullptr && dynamic_cast<BasePower*>(sub) != nullptr) {
6841 BasePower* const power = dynamic_cast<BasePower*>(sub);
6842 *expr = power->expr();
6843 *exponant = power->exponant();
6844 }
6845 if (sub != nullptr && dynamic_cast<IntSquare*>(sub) != nullptr) {
6846 IntSquare* const power = dynamic_cast<IntSquare*>(sub);
6847 *expr = power->expr();
6848 *exponant = 2;
6849 }
6850 }
6851}
6852
6853void ExtractProduct(IntExpr** const expr, int64_t* const coefficient,
6854 bool* modified) {
6855 if (dynamic_cast<TimesCstIntVar*>(*expr) != nullptr) {
6856 TimesCstIntVar* const left_prod = dynamic_cast<TimesCstIntVar*>(*expr);
6857 *coefficient *= left_prod->Constant();
6858 *expr = left_prod->SubVar();
6859 *modified = true;
6860 } else if (dynamic_cast<TimesIntCstExpr*>(*expr) != nullptr) {
6861 TimesIntCstExpr* const left_prod = dynamic_cast<TimesIntCstExpr*>(*expr);
6862 *coefficient *= left_prod->Constant();
6863 *expr = left_prod->Expr();
6864 *modified = true;
6865 }
6866}
6867} // namespace
6868
6869IntExpr* Solver::MakeProd(IntExpr* const left, IntExpr* const right) {
6870 if (left->Bound()) {
6871 return MakeProd(right, left->Min());
6872 }
6873
6874 if (right->Bound()) {
6875 return MakeProd(left, right->Min());
6876 }
6877
6878 // ----- Discover squares and powers -----
6879
6880 IntExpr* m_left = left;
6881 IntExpr* m_right = right;
6882 int64_t left_exponant = 1;
6883 int64_t right_exponant = 1;
6884 ExtractPower(&m_left, &left_exponant);
6885 ExtractPower(&m_right, &right_exponant);
6886
6887 if (m_left == m_right) {
6888 return MakePower(m_left, left_exponant + right_exponant);
6889 }
6890
6891 // ----- Discover nested products -----
6892
6893 m_left = left;
6894 m_right = right;
6895 int64_t coefficient = 1;
6896 bool modified = false;
6897
6898 ExtractProduct(&m_left, &coefficient, &modified);
6899 ExtractProduct(&m_right, &coefficient, &modified);
6900 if (modified) {
6901 return MakeProd(MakeProd(m_left, m_right), coefficient);
6902 }
6903
6904 // ----- Standard build -----
6905
6906 CHECK_EQ(this, left->solver());
6907 CHECK_EQ(this, right->solver());
6908 IntExpr* result = model_cache_->FindExprExprExpression(
6909 left, right, ModelCache::EXPR_EXPR_PROD);
6910 if (result == nullptr) {
6911 result = model_cache_->FindExprExprExpression(right, left,
6912 ModelCache::EXPR_EXPR_PROD);
6913 }
6914 if (result != nullptr) {
6915 return result;
6916 }
6917 if (left->IsVar() && left->Var()->VarType() == BOOLEAN_VAR) {
6918 if (right->Min() >= 0) {
6919 result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6920 this, reinterpret_cast<BooleanVar*>(left), right)));
6921 } else {
6922 result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6923 this, reinterpret_cast<BooleanVar*>(left), right)));
6924 }
6925 } else if (right->IsVar() &&
6926 reinterpret_cast<IntVar*>(right)->VarType() == BOOLEAN_VAR) {
6927 if (left->Min() >= 0) {
6928 result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6929 this, reinterpret_cast<BooleanVar*>(right), left)));
6930 } else {
6931 result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6932 this, reinterpret_cast<BooleanVar*>(right), left)));
6933 }
6934 } else if (left->Min() >= 0 && right->Min() >= 0) {
6935 if (CapProd(left->Max(), right->Max()) ==
6936 std::numeric_limits<int64_t>::max()) { // Potential overflow.
6937 result =
6938 RegisterIntExpr(RevAlloc(new SafeTimesPosIntExpr(this, left, right)));
6939 } else {
6940 result =
6941 RegisterIntExpr(RevAlloc(new TimesPosIntExpr(this, left, right)));
6942 }
6943 } else {
6944 result = RegisterIntExpr(RevAlloc(new TimesIntExpr(this, left, right)));
6945 }
6946 model_cache_->InsertExprExprExpression(result, left, right,
6947 ModelCache::EXPR_EXPR_PROD);
6948 return result;
6949}
6950
6951IntExpr* Solver::MakeDiv(IntExpr* const numerator, IntExpr* const denominator) {
6952 CHECK(numerator != nullptr);
6953 CHECK(denominator != nullptr);
6954 if (denominator->Bound()) {
6955 return MakeDiv(numerator, denominator->Min());
6956 }
6957 IntExpr* result = model_cache_->FindExprExprExpression(
6958 numerator, denominator, ModelCache::EXPR_EXPR_DIV);
6959 if (result != nullptr) {
6960 return result;
6961 }
6962
6963 if (denominator->Min() <= 0 && denominator->Max() >= 0) {
6964 AddConstraint(MakeNonEquality(denominator, 0));
6965 }
6966
6967 if (denominator->Min() >= 0) {
6968 if (numerator->Min() >= 0) {
6969 result = RevAlloc(new DivPosPosIntExpr(this, numerator, denominator));
6970 } else {
6971 result = RevAlloc(new DivPosIntExpr(this, numerator, denominator));
6972 }
6973 } else if (denominator->Max() <= 0) {
6974 if (numerator->Max() <= 0) {
6975 result = RevAlloc(new DivPosPosIntExpr(this, MakeOpposite(numerator),
6976 MakeOpposite(denominator)));
6977 } else {
6978 result = MakeOpposite(RevAlloc(
6979 new DivPosIntExpr(this, numerator, MakeOpposite(denominator))));
6980 }
6981 } else {
6982 result = RevAlloc(new DivIntExpr(this, numerator, denominator));
6983 }
6984 model_cache_->InsertExprExprExpression(result, numerator, denominator,
6985 ModelCache::EXPR_EXPR_DIV);
6986 return result;
6987}
6988
6989IntExpr* Solver::MakeDiv(IntExpr* const expr, int64_t value) {
6990 CHECK(expr != nullptr);
6991 CHECK_EQ(this, expr->solver());
6992 if (expr->Bound()) {
6993 return MakeIntConst(expr->Min() / value);
6994 } else if (value == 1) {
6995 return expr;
6996 } else if (value == -1) {
6997 return MakeOpposite(expr);
6998 } else if (value > 0) {
6999 return RegisterIntExpr(RevAlloc(new DivPosIntCstExpr(this, expr, value)));
7000 } else if (value == 0) {
7001 LOG(FATAL) << "Cannot divide by 0";
7002 return nullptr;
7003 } else {
7004 return RegisterIntExpr(
7005 MakeOpposite(RevAlloc(new DivPosIntCstExpr(this, expr, -value))));
7006 // TODO(user) : implement special case.
7007 }
7008}
7009
7010Constraint* Solver::MakeAbsEquality(IntVar* const var, IntVar* const abs_var) {
7011 if (Cache()->FindExprExpression(var, ModelCache::EXPR_ABS) == nullptr) {
7012 Cache()->InsertExprExpression(abs_var, var, ModelCache::EXPR_ABS);
7013 }
7014 return RevAlloc(new IntAbsConstraint(this, var, abs_var));
7015}
7016
7017IntExpr* Solver::MakeAbs(IntExpr* const e) {
7018 CHECK_EQ(this, e->solver());
7019 if (e->Min() >= 0) {
7020 return e;
7021 } else if (e->Max() <= 0) {
7022 return MakeOpposite(e);
7023 }
7024 IntExpr* result = Cache()->FindExprExpression(e, ModelCache::EXPR_ABS);
7025 if (result == nullptr) {
7026 int64_t coefficient = 1;
7027 IntExpr* expr = nullptr;
7028 if (IsProduct(e, &expr, &coefficient)) {
7029 result = MakeProd(MakeAbs(expr), std::abs(coefficient));
7030 } else {
7031 result = RegisterIntExpr(RevAlloc(new IntAbs(this, e)));
7032 }
7033 Cache()->InsertExprExpression(result, e, ModelCache::EXPR_ABS);
7034 }
7035 return result;
7036}
7037
7038IntExpr* Solver::MakeSquare(IntExpr* const expr) {
7039 CHECK_EQ(this, expr->solver());
7040 if (expr->Bound()) {
7041 const int64_t v = expr->Min();
7042 return MakeIntConst(v * v);
7043 }
7044 IntExpr* result = Cache()->FindExprExpression(expr, ModelCache::EXPR_SQUARE);
7045 if (result == nullptr) {
7046 if (expr->Min() >= 0) {
7047 result = RegisterIntExpr(RevAlloc(new PosIntSquare(this, expr)));
7048 } else {
7049 result = RegisterIntExpr(RevAlloc(new IntSquare(this, expr)));
7050 }
7051 Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_SQUARE);
7052 }
7053 return result;
7054}
7055
7056IntExpr* Solver::MakePower(IntExpr* const expr, int64_t n) {
7057 CHECK_EQ(this, expr->solver());
7058 CHECK_GE(n, 0);
7059 if (expr->Bound()) {
7060 const int64_t v = expr->Min();
7061 if (v >= OverflowLimit(n)) { // Overflow.
7062 return MakeIntConst(std::numeric_limits<int64_t>::max());
7063 }
7064 return MakeIntConst(IntPower(v, n));
7065 }
7066 switch (n) {
7067 case 0:
7068 return MakeIntConst(1);
7069 case 1:
7070 return expr;
7071 case 2:
7072 return MakeSquare(expr);
7073 default: {
7074 IntExpr* result = nullptr;
7075 if (n % 2 == 0) { // even.
7076 if (expr->Min() >= 0) {
7077 result =
7078 RegisterIntExpr(RevAlloc(new PosIntEvenPower(this, expr, n)));
7079 } else {
7080 result = RegisterIntExpr(RevAlloc(new IntEvenPower(this, expr, n)));
7081 }
7082 } else {
7083 result = RegisterIntExpr(RevAlloc(new IntOddPower(this, expr, n)));
7084 }
7085 return result;
7086 }
7087 }
7088}
7089
7090IntExpr* Solver::MakeMin(IntExpr* const left, IntExpr* const right) {
7091 CHECK_EQ(this, left->solver());
7092 CHECK_EQ(this, right->solver());
7093 if (left->Bound()) {
7094 return MakeMin(right, left->Min());
7095 }
7096 if (right->Bound()) {
7097 return MakeMin(left, right->Min());
7098 }
7099 if (left->Min() >= right->Max()) {
7100 return right;
7101 }
7102 if (right->Min() >= left->Max()) {
7103 return left;
7104 }
7105 return RegisterIntExpr(RevAlloc(new MinIntExpr(this, left, right)));
7106}
7107
7108IntExpr* Solver::MakeMin(IntExpr* const expr, int64_t value) {
7109 CHECK_EQ(this, expr->solver());
7110 if (value <= expr->Min()) {
7111 return MakeIntConst(value);
7112 }
7113 if (expr->Bound()) {
7114 return MakeIntConst(std::min(expr->Min(), value));
7115 }
7116 if (expr->Max() <= value) {
7117 return expr;
7118 }
7119 return RegisterIntExpr(RevAlloc(new MinCstIntExpr(this, expr, value)));
7120}
7121
7122IntExpr* Solver::MakeMin(IntExpr* const expr, int value) {
7123 return MakeMin(expr, static_cast<int64_t>(value));
7124}
7125
7126IntExpr* Solver::MakeMax(IntExpr* const left, IntExpr* const right) {
7127 CHECK_EQ(this, left->solver());
7128 CHECK_EQ(this, right->solver());
7129 if (left->Bound()) {
7130 return MakeMax(right, left->Min());
7131 }
7132 if (right->Bound()) {
7133 return MakeMax(left, right->Min());
7134 }
7135 if (left->Min() >= right->Max()) {
7136 return left;
7137 }
7138 if (right->Min() >= left->Max()) {
7139 return right;
7140 }
7141 return RegisterIntExpr(RevAlloc(new MaxIntExpr(this, left, right)));
7142}
7143
7144IntExpr* Solver::MakeMax(IntExpr* const expr, int64_t value) {
7145 CHECK_EQ(this, expr->solver());
7146 if (expr->Bound()) {
7147 return MakeIntConst(std::max(expr->Min(), value));
7148 }
7149 if (value <= expr->Min()) {
7150 return expr;
7151 }
7152 if (expr->Max() <= value) {
7153 return MakeIntConst(value);
7154 }
7155 return RegisterIntExpr(RevAlloc(new MaxCstIntExpr(this, expr, value)));
7156}
7157
7158IntExpr* Solver::MakeMax(IntExpr* const expr, int value) {
7159 return MakeMax(expr, static_cast<int64_t>(value));
7160}
7161
7162IntExpr* Solver::MakeConvexPiecewiseExpr(IntExpr* expr, int64_t early_cost,
7163 int64_t early_date, int64_t late_date,
7164 int64_t late_cost) {
7165 return RegisterIntExpr(RevAlloc(new SimpleConvexPiecewiseExpr(
7166 this, expr, early_cost, early_date, late_date, late_cost)));
7167}
7168
7169IntExpr* Solver::MakeSemiContinuousExpr(IntExpr* const expr,
7170 int64_t fixed_charge, int64_t step) {
7171 if (step == 0) {
7172 if (fixed_charge == 0) {
7173 return MakeIntConst(int64_t{0});
7174 } else {
7175 return RegisterIntExpr(
7176 RevAlloc(new SemiContinuousStepZeroExpr(this, expr, fixed_charge)));
7177 }
7178 } else if (step == 1) {
7179 return RegisterIntExpr(
7180 RevAlloc(new SemiContinuousStepOneExpr(this, expr, fixed_charge)));
7181 } else {
7182 return RegisterIntExpr(
7183 RevAlloc(new SemiContinuousExpr(this, expr, fixed_charge, step)));
7184 }
7185 // TODO(user) : benchmark with virtualization of
7186 // PosIntDivDown and PosIntDivUp - or function pointers.
7187}
7188
7189// ----- Piecewise Linear -----
7190
7192 public:
7194 const PiecewiseLinearFunction& f)
7195 : BaseIntExpr(solver), expr_(expr), f_(f) {}
7197 int64_t Min() const override {
7198 return f_.GetMinimum(expr_->Min(), expr_->Max());
7199 }
7200 void SetMin(int64_t m) override {
7201 const auto& range =
7202 f_.GetSmallestRangeGreaterThanValue(expr_->Min(), expr_->Max(), m);
7203 expr_->SetRange(range.first, range.second);
7204 }
7205
7206 int64_t Max() const override {
7207 return f_.GetMaximum(expr_->Min(), expr_->Max());
7208 }
7209
7210 void SetMax(int64_t m) override {
7211 const auto& range =
7212 f_.GetSmallestRangeLessThanValue(expr_->Min(), expr_->Max(), m);
7213 expr_->SetRange(range.first, range.second);
7214 }
7215
7216 void SetRange(int64_t l, int64_t u) override {
7217 const auto& range =
7218 f_.GetSmallestRangeInValueRange(expr_->Min(), expr_->Max(), l, u);
7219 expr_->SetRange(range.first, range.second);
7220 }
7221 std::string name() const override {
7222 return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->name(),
7223 f_.DebugString());
7224 }
7225
7226 std::string DebugString() const override {
7227 return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->DebugString(),
7228 f_.DebugString());
7229 }
7230
7231 void WhenRange(Demon* d) override { expr_->WhenRange(d); }
7232
7233 void Accept(ModelVisitor* const visitor) const override {
7234 // TODO(user): Implement visitor.
7235 }
7236
7237 private:
7238 IntExpr* const expr_;
7239 const PiecewiseLinearFunction f_;
7240};
7241
7242IntExpr* Solver::MakePiecewiseLinearExpr(IntExpr* expr,
7243 const PiecewiseLinearFunction& f) {
7244 return RegisterIntExpr(RevAlloc(new PiecewiseLinearExpr(this, expr, f)));
7245}
7246
7247// ----- Conditional Expression -----
7248
7249IntExpr* Solver::MakeConditionalExpression(IntVar* const condition,
7250 IntExpr* const expr,
7251 int64_t unperformed_value) {
7252 if (condition->Min() == 1) {
7253 return expr;
7254 } else if (condition->Max() == 0) {
7255 return MakeIntConst(unperformed_value);
7256 } else {
7257 IntExpr* cache = Cache()->FindExprExprConstantExpression(
7258 condition, expr, unperformed_value,
7259 ModelCache::EXPR_EXPR_CONSTANT_CONDITIONAL);
7260 if (cache == nullptr) {
7261 cache = RevAlloc(
7262 new ExprWithEscapeValue(this, condition, expr, unperformed_value));
7263 Cache()->InsertExprExprConstantExpression(
7264 cache, condition, expr, unperformed_value,
7265 ModelCache::EXPR_EXPR_CONSTANT_CONDITIONAL);
7266 }
7267 return cache;
7268 }
7269}
7270
7271// ----- Modulo -----
7272
7273IntExpr* Solver::MakeModulo(IntExpr* const x, int64_t mod) {
7274 IntVar* const result =
7275 MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7276 if (mod >= 0) {
7277 AddConstraint(MakeBetweenCt(result, 0, mod - 1));
7278 } else {
7279 AddConstraint(MakeBetweenCt(result, mod + 1, 0));
7280 }
7281 return result;
7282}
7283
7284IntExpr* Solver::MakeModulo(IntExpr* const x, IntExpr* const mod) {
7285 if (mod->Bound()) {
7286 return MakeModulo(x, mod->Min());
7287 }
7288 IntVar* const result =
7289 MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7290 AddConstraint(MakeLess(result, MakeAbs(mod)));
7291 AddConstraint(MakeGreater(result, MakeOpposite(MakeAbs(mod))));
7292 return result;
7293}
7294
7295// --------- IntVar ---------
7296
7297int IntVar::VarType() const { return UNSPECIFIED; }
7298
7299void IntVar::RemoveValues(const std::vector<int64_t>& values) {
7300 // TODO(user): Check and maybe inline this code.
7301 const int size = values.size();
7302 DCHECK_GE(size, 0);
7303 switch (size) {
7304 case 0: {
7305 return;
7306 }
7307 case 1: {
7308 RemoveValue(values[0]);
7309 return;
7310 }
7311 case 2: {
7312 RemoveValue(values[0]);
7313 RemoveValue(values[1]);
7314 return;
7315 }
7316 case 3: {
7317 RemoveValue(values[0]);
7318 RemoveValue(values[1]);
7319 RemoveValue(values[2]);
7320 return;
7321 }
7322 default: {
7323 // 4 values, let's start doing some more clever things.
7324 // TODO(user) : Sort values!
7325 int start_index = 0;
7326 int64_t new_min = Min();
7327 if (values[start_index] <= new_min) {
7328 while (start_index < size - 1 &&
7329 values[start_index + 1] == values[start_index] + 1) {
7330 new_min = values[start_index + 1] + 1;
7331 start_index++;
7332 }
7333 }
7334 int end_index = size - 1;
7335 int64_t new_max = Max();
7336 if (values[end_index] >= new_max) {
7337 while (end_index > start_index + 1 &&
7338 values[end_index - 1] == values[end_index] - 1) {
7339 new_max = values[end_index - 1] - 1;
7340 end_index--;
7341 }
7342 }
7343 SetRange(new_min, new_max);
7344 for (int i = start_index; i <= end_index; ++i) {
7345 RemoveValue(values[i]);
7346 }
7347 }
7348 }
7349}
7350
7351void IntVar::Accept(ModelVisitor* const visitor) const {
7352 IntExpr* const casted = solver()->CastExpression(this);
7353 visitor->VisitIntegerVariable(this, casted);
7354}
7355
7356void IntVar::SetValues(const std::vector<int64_t>& values) {
7357 switch (values.size()) {
7358 case 0: {
7359 solver()->Fail();
7360 break;
7361 }
7362 case 1: {
7363 SetValue(values.back());
7364 break;
7365 }
7366 case 2: {
7367 if (Contains(values[0])) {
7368 if (Contains(values[1])) {
7369 const int64_t l = std::min(values[0], values[1]);
7370 const int64_t u = std::max(values[0], values[1]);
7371 SetRange(l, u);
7372 if (u > l + 1) {
7373 RemoveInterval(l + 1, u - 1);
7374 }
7375 } else {
7376 SetValue(values[0]);
7377 }
7378 } else {
7379 SetValue(values[1]);
7380 }
7381 break;
7382 }
7383 default: {
7384 // TODO(user): use a clean and safe SortedUniqueCopy() class
7385 // that uses a global, static shared (and locked) storage.
7386 // TODO(user): [optional] consider porting
7387 // STLSortAndRemoveDuplicates from ortools/base/stl_util.h to the
7388 // existing open_source/base/stl_util.h and using it here.
7389 // TODO(user): We could filter out values not in the var.
7390 std::vector<int64_t>& tmp = solver()->tmp_vector_;
7391 tmp.clear();
7392 tmp.insert(tmp.end(), values.begin(), values.end());
7393 std::sort(tmp.begin(), tmp.end());
7394 tmp.erase(std::unique(tmp.begin(), tmp.end()), tmp.end());
7395 const int size = tmp.size();
7396 const int64_t vmin = Min();
7397 const int64_t vmax = Max();
7398 int first = 0;
7399 int last = size - 1;
7400 if (tmp.front() > vmax || tmp.back() < vmin) {
7401 solver()->Fail();
7402 }
7403 // TODO(user) : We could find the first position >= vmin by dichotomy.
7404 while (tmp[first] < vmin || !Contains(tmp[first])) {
7405 ++first;
7406 if (first > last || tmp[first] > vmax) {
7407 solver()->Fail();
7408 }
7409 }
7410 while (last > first && (tmp[last] > vmax || !Contains(tmp[last]))) {
7411 // Note that last >= first implies tmp[last] >= vmin.
7412 --last;
7413 }
7414 DCHECK_GE(last, first);
7415 SetRange(tmp[first], tmp[last]);
7416 while (first < last) {
7417 const int64_t start = tmp[first] + 1;
7418 const int64_t end = tmp[first + 1] - 1;
7419 if (start <= end) {
7420 RemoveInterval(start, end);
7421 }
7422 first++;
7423 }
7424 }
7425 }
7426}
7427// ---------- BaseIntExpr ---------
7428
7429void LinkVarExpr(Solver* const s, IntExpr* const expr, IntVar* const var) {
7430 if (!var->Bound()) {
7431 if (var->VarType() == DOMAIN_INT_VAR) {
7432 DomainIntVar* dvar = reinterpret_cast<DomainIntVar*>(var);
7434 s->RevAlloc(new LinkExprAndDomainIntVar(s, expr, dvar)), dvar, expr);
7435 } else {
7436 s->AddCastConstraint(s->RevAlloc(new LinkExprAndVar(s, expr, var)), var,
7437 expr);
7438 }
7439 }
7440}
7441
7442IntVar* BaseIntExpr::Var() {
7443 if (var_ == nullptr) {
7444 solver()->SaveValue(reinterpret_cast<void**>(&var_));
7445 var_ = CastToVar();
7446 }
7447 return var_;
7448}
7449
7450IntVar* BaseIntExpr::CastToVar() {
7451 int64_t vmin, vmax;
7452 Range(&vmin, &vmax);
7453 IntVar* const var = solver()->MakeIntVar(vmin, vmax);
7454 LinkVarExpr(solver(), this, var);
7455 return var;
7456}
7457
7458// Discovery methods
7459bool Solver::IsADifference(IntExpr* expr, IntExpr** const left,
7460 IntExpr** const right) {
7461 if (expr->IsVar()) {
7462 IntVar* const expr_var = expr->Var();
7463 expr = CastExpression(expr_var);
7464 }
7465 // This is a dynamic cast to check the type of expr.
7466 // It returns nullptr is expr is not a subclass of SubIntExpr.
7467 SubIntExpr* const sub_expr = dynamic_cast<SubIntExpr*>(expr);
7468 if (sub_expr != nullptr) {
7469 *left = sub_expr->left();
7470 *right = sub_expr->right();
7471 return true;
7472 }
7473 return false;
7474}
7475
7476bool Solver::IsBooleanVar(IntExpr* const expr, IntVar** inner_var,
7477 bool* is_negated) const {
7478 if (expr->IsVar() && expr->Var()->VarType() == BOOLEAN_VAR) {
7479 *inner_var = expr->Var();
7480 *is_negated = false;
7481 return true;
7482 } else if (expr->IsVar() && expr->Var()->VarType() == CST_SUB_VAR) {
7483 SubCstIntVar* const sub_var = reinterpret_cast<SubCstIntVar*>(expr);
7484 if (sub_var != nullptr && sub_var->Constant() == 1 &&
7485 sub_var->SubVar()->VarType() == BOOLEAN_VAR) {
7486 *is_negated = true;
7487 *inner_var = sub_var->SubVar();
7488 return true;
7489 }
7490 }
7491 return false;
7492}
7493
7494bool Solver::IsProduct(IntExpr* const expr, IntExpr** inner_expr,
7495 int64_t* coefficient) {
7496 if (dynamic_cast<TimesCstIntVar*>(expr) != nullptr) {
7497 TimesCstIntVar* const var = dynamic_cast<TimesCstIntVar*>(expr);
7498 *coefficient = var->Constant();
7499 *inner_expr = var->SubVar();
7500 return true;
7501 } else if (dynamic_cast<TimesIntCstExpr*>(expr) != nullptr) {
7502 TimesIntCstExpr* const prod = dynamic_cast<TimesIntCstExpr*>(expr);
7503 *coefficient = prod->Constant();
7504 *inner_expr = prod->Expr();
7505 return true;
7506 }
7507 *inner_expr = expr;
7508 *coefficient = 1;
7509 return false;
7510}
7511
7512#undef COND_REV_ALLOC
7513
7514} // namespace operations_research
int64_t max
Definition: alldiff_cst.cc:140
int64_t min
Definition: alldiff_cst.cc:139
#define CHECK(condition)
Definition: base/logging.h:495
#define DCHECK_LE(val1, val2)
Definition: base/logging.h:892
#define DCHECK_NE(val1, val2)
Definition: base/logging.h:891
#define CHECK_LT(val1, val2)
Definition: base/logging.h:705
#define CHECK_EQ(val1, val2)
Definition: base/logging.h:702
#define CHECK_GE(val1, val2)
Definition: base/logging.h:706
#define CHECK_GT(val1, val2)
Definition: base/logging.h:707
#define DCHECK_GE(val1, val2)
Definition: base/logging.h:894
#define CHECK_NE(val1, val2)
Definition: base/logging.h:703
#define DCHECK_GT(val1, val2)
Definition: base/logging.h:895
#define DCHECK_LT(val1, val2)
Definition: base/logging.h:893
#define LOG(severity)
Definition: base/logging.h:420
#define DCHECK(condition)
Definition: base/logging.h:889
#define DCHECK_EQ(val1, val2)
Definition: base/logging.h:890
A BaseObject is the root of all reversibly allocated objects.
void WhenBound(Demon *d) override
This method attaches a demon that will be awakened when the variable is bound.
Definition: expressions.cc:116
IntVar * IsLessOrEqual(int64_t constant) override
Definition: expressions.cc:166
uint64_t Size() const override
This method returns the number of values in the domain of the variable.
Definition: expressions.cc:126
void SetRange(int64_t mi, int64_t ma) override
This method sets both the min and the max of the expression.
Definition: expressions.cc:82
SimpleRevFIFO< Demon * > delayed_bound_demons_
bool Contains(int64_t v) const override
This method returns whether the value 'v' is in the domain of the variable.
Definition: expressions.cc:130
void RemoveValue(int64_t v) override
This method removes the value 'v' from the domain of the variable.
Definition: expressions.cc:93
IntVar * IsEqual(int64_t constant) override
IsEqual.
Definition: expressions.cc:134
IntVar * IsGreaterOrEqual(int64_t constant) override
Definition: expressions.cc:156
void SetMax(int64_t m) override
Definition: expressions.cc:76
SimpleRevFIFO< Demon * > bound_demons_
void RemoveInterval(int64_t l, int64_t u) override
This method removes the interval 'l' .
Definition: expressions.cc:105
void SetMin(int64_t m) override
Definition: expressions.cc:70
IntVar * IsDifferent(int64_t constant) override
Definition: expressions.cc:145
std::string DebugString() const override
Definition: expressions.cc:176
A constraint is the main modeling object.
A Demon is the base element of a propagation queue.
virtual Solver::DemonPriority priority() const
This method returns the priority of the demon.
The class IntExpr is the base of all integer expressions in constraint programming.
virtual bool Bound() const
Returns true if the min and the max of the expression are equal.
virtual void SetValue(int64_t v)
This method sets the value of the expression.
virtual bool IsVar() const
Returns true if the expression is indeed a variable.
virtual int64_t Min() const =0
virtual IntVar * Var()=0
Creates a variable from the expression.
IntVar * VarWithName(const std::string &name)
Creates a variable from the expression and set the name of the resulting var.
Definition: expressions.cc:51
virtual int64_t Max() const =0
The class IntVar is a subset of IntExpr.
IntVar(Solver *const s)
Definition: expressions.cc:59
IntVar * Var() override
Creates a variable from the expression.
virtual int VarType() const
The class Iterator has two direct subclasses.
virtual void VisitIntegerVariable(const IntVar *const variable, IntExpr *const delegate)
PiecewiseLinearExpr(Solver *solver, IntExpr *expr, const PiecewiseLinearFunction &f)
void WhenRange(Demon *d) override
Attach a demon that will watch the min or the max of the expression.
void SetRange(int64_t l, int64_t u) override
This method sets both the min and the max of the expression.
void Accept(ModelVisitor *const visitor) const override
Accepts the given visitor.
std::string name() const override
Object naming.
std::string DebugString() const override
virtual std::string name() const
Object naming.
void SetValue(Solver *const s, const T &val)
DemonPriority
This enum represents the three possible priorities for a demon in the Solver queue.
@ VAR_PRIORITY
VAR_PRIORITY is between DELAYED_PRIORITY and NORMAL_PRIORITY.
@ DELAYED_PRIORITY
DELAYED_PRIORITY is the lowest priority: Demons will be processed after VAR_PRIORITY and NORMAL_PRIOR...
@ OUTSIDE_SEARCH
Before search, after search.
IntExpr * MakeDifference(IntExpr *const left, IntExpr *const right)
left - right
void AddCastConstraint(CastConstraint *const constraint, IntVar *const target_var, IntExpr *const expr)
Adds 'constraint' to the solver and marks it as a cast constraint, that is, a constraint created call...
IntVar * MakeIntConst(int64_t val, const std::string &name)
IntConst will create a constant expression.
void Fail()
Abandon the current branch in the search tree. A backtrack will follow.
T * RevAlloc(T *object)
Registers the given object as being reversible.
int64_t b
std::vector< IntVarIterator * > holes_
int64_t a
const std::string name
const Constraint * ct
int64_t value
IntVar *const expr_
Definition: element.cc:87
IntVar * var
Definition: expr_array.cc:1874
#define COND_REV_ALLOC(rev, alloc)
const int64_t limit_
Solver *const solver_
Definition: expressions.cc:279
const int64_t pow_
ABSL_FLAG(bool, cp_disable_expression_optimization, false, "Disable special optimization when creating expressions.")
const int64_t cst_
IntVarIterator *const iterator_
Handler handler_
Definition: interval.cc:429
const int64_t offset_
Definition: interval.cc:2108
bool in_process_
Definition: interval.cc:428
const int FATAL
Definition: log_severity.h:32
#define DISALLOW_COPY_AND_ASSIGN(TypeName)
Definition: macros.h:29
int RemoveAt(RepeatedType *array, const IndexContainer &indices)
Definition: protobuf_util.h:50
const Collection::value_type::second_type FindPtrOrNull(const Collection &collection, const typename Collection::value_type::first_type &key)
Definition: map_util.h:89
void STLSortAndRemoveDuplicates(T *v, const LessFunc &less_func)
Definition: stl_util.h:58
std::function< int64_t(const Model &)> Value(IntegerVariable v)
Definition: integer.h:1673
Collection of objects used to extend the Constraint Solver library.
int64_t SubOverflows(int64_t x, int64_t y)
static const uint64_t kAllBits64
Definition: bitset.h:33
void InternalSaveBooleanVarValue(Solver *const solver, IntVar *const var)
int64_t CapAdd(int64_t x, int64_t y)
void CleanVariableOnFail(IntVar *const var)
int64_t CapSub(int64_t x, int64_t y)
int64_t UnsafeMostSignificantBitPosition64(const uint64_t *const bitset, uint64_t start, uint64_t end)
uint64_t BitCountRange64(const uint64_t *const bitset, uint64_t start, uint64_t end)
Constraint * SetIsEqual(IntVar *const var, const std::vector< int64_t > &values, const std::vector< IntVar * > &vars)
int64_t UnsafeLeastSignificantBitPosition64(const uint64_t *const bitset, uint64_t start, uint64_t end)
bool AddOverflows(int64_t x, int64_t y)
void RegisterDemon(Solver *const solver, Demon *const demon, DemonProfiler *const monitor)
void RestoreBoolValue(IntVar *const var)
Constraint * SetIsGreaterOrEqual(IntVar *const var, const std::vector< int64_t > &values, const std::vector< IntVar * > &vars)
Demon * MakeConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
int64_t CapProd(int64_t x, int64_t y)
uint64_t OneRange64(uint64_t s, uint64_t e)
Definition: bitset.h:285
uint32_t BitPos64(uint64_t pos)
Definition: bitset.h:330
uint64_t BitCount64(uint64_t n)
Definition: bitset.h:42
std::vector< int64_t > ToInt64Vector(const std::vector< int > &input)
Definition: utilities.cc:828
void LinkVarExpr(Solver *const s, IntExpr *const expr, IntVar *const var)
bool IsBitSet64(const uint64_t *const bitset, uint64_t pos)
Definition: bitset.h:346
uint64_t OneBit64(int pos)
Definition: bitset.h:38
uint64_t BitOffset64(uint64_t pos)
Definition: bitset.h:334
int64_t PosIntDivDown(int64_t e, int64_t v)
uint64_t BitLength64(uint64_t size)
Definition: bitset.h:338
int LeastSignificantBitPosition64(uint64_t n)
Definition: bitset.h:127
int64_t CapOpp(int64_t v)
int MostSignificantBitPosition64(uint64_t n)
Definition: bitset.h:231
int64_t PosIntDivUp(int64_t e, int64_t v)
STL namespace.
int index
Definition: pack.cc:509
int64_t coefficient
IntervalVar *const target_var_
int64_t step_
Definition: search.cc:3018
int64_t current_
Definition: search.cc:3019
const int64_t stamp_
Definition: search.cc:3105
std::optional< int64_t > end
int64_t start
const double constant