OR-Tools  8.0
expressions.cc
Go to the documentation of this file.
1 // Copyright 2010-2018 Google LLC
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 #include <algorithm>
15 #include <cmath>
16 #include <memory>
17 #include <string>
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
26 #include "ortools/base/logging.h"
27 #include "ortools/base/map_util.h"
28 #include "ortools/base/mathutil.h"
29 #include "ortools/base/stl_util.h"
32 #include "ortools/util/bitset.h"
35 
36 DEFINE_bool(cp_disable_expression_optimization, false,
37  "Disable special optimization when creating expressions.");
38 DEFINE_bool(cp_share_int_consts, true, "Share IntConst's with the same value.");
39 
40 #if defined(_MSC_VER)
41 #pragma warning(disable : 4351 4355)
42 #endif
43 
44 namespace operations_research {
45 
46 // ---------- IntExpr ----------
47 
48 IntVar* IntExpr::VarWithName(const std::string& name) {
49  IntVar* const var = Var();
50  var->set_name(name);
51  return var;
52 }
53 
54 // ---------- IntVar ----------
55 
56 IntVar::IntVar(Solver* const s) : IntExpr(s), index_(s->GetNewIntVarIndex()) {}
57 
58 IntVar::IntVar(Solver* const s, const std::string& name)
59  : IntExpr(s), index_(s->GetNewIntVarIndex()) {
60  set_name(name);
61 }
62 
63 // ----- Boolean variable -----
64 
65 const int BooleanVar::kUnboundBooleanVarValue = 2;
66 
67 void BooleanVar::SetMin(int64 m) {
68  if (m <= 0) return;
69  if (m > 1) solver()->Fail();
70  SetValue(1);
71 }
72 
73 void BooleanVar::SetMax(int64 m) {
74  if (m >= 1) return;
75  if (m < 0) solver()->Fail();
76  SetValue(0);
77 }
78 
79 void BooleanVar::SetRange(int64 mi, int64 ma) {
80  if (mi > 1 || ma < 0 || mi > ma) {
81  solver()->Fail();
82  }
83  if (mi == 1) {
84  SetValue(1);
85  } else if (ma == 0) {
86  SetValue(0);
87  }
88 }
89 
90 void BooleanVar::RemoveValue(int64 v) {
91  if (value_ == kUnboundBooleanVarValue) {
92  if (v == 0) {
93  SetValue(1);
94  } else if (v == 1) {
95  SetValue(0);
96  }
97  } else if (v == value_) {
98  solver()->Fail();
99  }
100 }
101 
102 void BooleanVar::RemoveInterval(int64 l, int64 u) {
103  if (u < l) return;
104  if (l <= 0 && u >= 1) {
105  solver()->Fail();
106  } else if (l == 1) {
107  SetValue(0);
108  } else if (u == 0) {
109  SetValue(1);
110  }
111 }
112 
113 void BooleanVar::WhenBound(Demon* d) {
114  if (value_ == kUnboundBooleanVarValue) {
115  if (d->priority() == Solver::DELAYED_PRIORITY) {
116  delayed_bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
117  } else {
118  bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
119  }
120  }
121 }
122 
123 uint64 BooleanVar::Size() const {
124  return (1 + (value_ == kUnboundBooleanVarValue));
125 }
126 
127 bool BooleanVar::Contains(int64 v) const {
128  return ((v == 0 && value_ != 1) || (v == 1 && value_ != 0));
129 }
130 
131 IntVar* BooleanVar::IsEqual(int64 constant) {
132  if (constant > 1 || constant < 0) {
133  return solver()->MakeIntConst(0);
134  }
135  if (constant == 1) {
136  return this;
137  } else { // constant == 0.
138  return solver()->MakeDifference(1, this)->Var();
139  }
140 }
141 
142 IntVar* BooleanVar::IsDifferent(int64 constant) {
143  if (constant > 1 || constant < 0) {
144  return solver()->MakeIntConst(1);
145  }
146  if (constant == 1) {
147  return solver()->MakeDifference(1, this)->Var();
148  } else { // constant == 0.
149  return this;
150  }
151 }
152 
153 IntVar* BooleanVar::IsGreaterOrEqual(int64 constant) {
154  if (constant > 1) {
155  return solver()->MakeIntConst(0);
156  } else if (constant <= 0) {
157  return solver()->MakeIntConst(1);
158  } else {
159  return this;
160  }
161 }
162 
163 IntVar* BooleanVar::IsLessOrEqual(int64 constant) {
164  if (constant < 0) {
165  return solver()->MakeIntConst(0);
166  } else if (constant >= 1) {
167  return solver()->MakeIntConst(1);
168  } else {
169  return IsEqual(0);
170  }
171 }
172 
173 std::string BooleanVar::DebugString() const {
174  std::string out;
175  const std::string& var_name = name();
176  if (!var_name.empty()) {
177  out = var_name + "(";
178  } else {
179  out = "BooleanVar(";
180  }
181  switch (value_) {
182  case 0:
183  out += "0";
184  break;
185  case 1:
186  out += "1";
187  break;
188  case kUnboundBooleanVarValue:
189  out += "0 .. 1";
190  break;
191  }
192  out += ")";
193  return out;
194 }
195 
196 namespace {
197 // ---------- Subclasses of IntVar ----------
198 
199 // ----- Domain Int Var: base class for variables -----
200 // It Contains bounds and a bitset representation of possible values.
201 class DomainIntVar : public IntVar {
202  public:
203  // Utility classes
204  class BitSetIterator : public BaseObject {
205  public:
206  BitSetIterator(uint64* const bitset, int64 omin)
207  : bitset_(bitset), omin_(omin), max_(kint64min), current_(kint64max) {}
208 
209  ~BitSetIterator() override {}
210 
211  void Init(int64 min, int64 max) {
212  max_ = max;
213  current_ = min;
214  }
215 
216  bool Ok() const { return current_ <= max_; }
217 
218  int64 Value() const { return current_; }
219 
220  void Next() {
221  if (++current_ <= max_) {
223  bitset_, current_ - omin_, max_ - omin_) +
224  omin_;
225  }
226  }
227 
228  std::string DebugString() const override { return "BitSetIterator"; }
229 
230  private:
231  uint64* const bitset_;
232  const int64 omin_;
233  int64 max_;
234  int64 current_;
235  };
236 
237  class BitSet : public BaseObject {
238  public:
239  explicit BitSet(Solver* const s) : solver_(s), holes_stamp_(0) {}
240  ~BitSet() override {}
241 
242  virtual int64 ComputeNewMin(int64 nmin, int64 cmin, int64 cmax) = 0;
243  virtual int64 ComputeNewMax(int64 nmax, int64 cmin, int64 cmax) = 0;
244  virtual bool Contains(int64 val) const = 0;
245  virtual bool SetValue(int64 val) = 0;
246  virtual bool RemoveValue(int64 val) = 0;
247  virtual uint64 Size() const = 0;
248  virtual void DelayRemoveValue(int64 val) = 0;
249  virtual void ApplyRemovedValues(DomainIntVar* var) = 0;
250  virtual void ClearRemovedValues() = 0;
251  virtual std::string pretty_DebugString(int64 min, int64 max) const = 0;
252  virtual BitSetIterator* MakeIterator() = 0;
253 
254  void InitHoles() {
255  const uint64 current_stamp = solver_->stamp();
256  if (holes_stamp_ < current_stamp) {
257  holes_.clear();
258  holes_stamp_ = current_stamp;
259  }
260  }
261 
262  virtual void ClearHoles() { holes_.clear(); }
263 
264  const std::vector<int64>& Holes() { return holes_; }
265 
266  void AddHole(int64 value) { holes_.push_back(value); }
267 
268  int NumHoles() const {
269  return holes_stamp_ < solver_->stamp() ? 0 : holes_.size();
270  }
271 
272  protected:
273  Solver* const solver_;
274 
275  private:
276  std::vector<int64> holes_;
277  uint64 holes_stamp_;
278  };
279 
280  class QueueHandler : public Demon {
281  public:
282  explicit QueueHandler(DomainIntVar* const var) : var_(var) {}
283  ~QueueHandler() override {}
284  void Run(Solver* const s) override {
285  s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
286  var_->Process();
287  s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
288  }
289  Solver::DemonPriority priority() const override {
290  return Solver::VAR_PRIORITY;
291  }
292  std::string DebugString() const override {
293  return absl::StrFormat("Handler(%s)", var_->DebugString());
294  }
295 
296  private:
297  DomainIntVar* const var_;
298  };
299 
300  // Bounds and Value watchers
301 
302  // This class stores the watchers variables attached to values. It is
303  // reversible and it helps maintaining the set of 'active' watchers
304  // (variables not bound to a single value).
305  template <class T>
306  class RevIntPtrMap {
307  public:
308  RevIntPtrMap(Solver* const solver, int64 rmin, int64 rmax)
309  : solver_(solver), range_min_(rmin), start_(0) {}
310 
311  ~RevIntPtrMap() {}
312 
313  bool Empty() const { return start_.Value() == elements_.size(); }
314 
315  void SortActive() { std::sort(elements_.begin(), elements_.end()); }
316 
317  // Access with value API.
318 
319  // Add the pointer to the map attached to the given value.
320  void UnsafeRevInsert(int64 value, T* elem) {
321  elements_.push_back(std::make_pair(value, elem));
322  if (solver_->state() != Solver::OUTSIDE_SEARCH) {
323  solver_->AddBacktrackAction(
324  [this, value](Solver* s) { Uninsert(value); }, false);
325  }
326  }
327 
328  T* FindPtrOrNull(int64 value, int* position) {
329  for (int pos = start_.Value(); pos < elements_.size(); ++pos) {
330  if (elements_[pos].first == value) {
331  if (position != nullptr) *position = pos;
332  return At(pos).second;
333  }
334  }
335  return nullptr;
336  }
337 
338  // Access map through the underlying vector.
339  void RemoveAt(int position) {
340  const int start = start_.Value();
341  DCHECK_GE(position, start);
342  DCHECK_LT(position, elements_.size());
343  if (position > start) {
344  // Swap the current element with the one at the start position, and
345  // increase start.
346  const std::pair<int64, T*> copy = elements_[start];
347  elements_[start] = elements_[position];
348  elements_[position] = copy;
349  }
350  start_.Incr(solver_);
351  }
352 
353  const std::pair<int64, T*>& At(int position) const {
354  DCHECK_GE(position, start_.Value());
355  DCHECK_LT(position, elements_.size());
356  return elements_[position];
357  }
358 
359  void RemoveAll() { start_.SetValue(solver_, elements_.size()); }
360 
361  int start() const { return start_.Value(); }
362  int end() const { return elements_.size(); }
363  // Number of active elements.
364  int Size() const { return elements_.size() - start_.Value(); }
365 
366  // Removes the object permanently from the map.
367  void Uninsert(int64 value) {
368  for (int pos = 0; pos < elements_.size(); ++pos) {
369  if (elements_[pos].first == value) {
370  DCHECK_GE(pos, start_.Value());
371  const int last = elements_.size() - 1;
372  if (pos != last) { // Swap the current with the last.
373  elements_[pos] = elements_.back();
374  }
375  elements_.pop_back();
376  return;
377  }
378  }
379  LOG(FATAL) << "The element should have been removed";
380  }
381 
382  private:
383  Solver* const solver_;
384  const int64 range_min_;
385  NumericalRev<int> start_;
386  std::vector<std::pair<int64, T*>> elements_;
387  };
388 
389  // Base class for value watchers
390  class BaseValueWatcher : public Constraint {
391  public:
392  explicit BaseValueWatcher(Solver* const solver) : Constraint(solver) {}
393 
394  ~BaseValueWatcher() override {}
395 
396  virtual IntVar* GetOrMakeValueWatcher(int64 value) = 0;
397 
398  virtual void SetValueWatcher(IntVar* const boolvar, int64 value) = 0;
399  };
400 
401  // This class monitors the domain of the variable and updates the
402  // IsEqual/IsDifferent boolean variables accordingly.
403  class ValueWatcher : public BaseValueWatcher {
404  public:
405  class WatchDemon : public Demon {
406  public:
407  WatchDemon(ValueWatcher* const watcher, int64 value, IntVar* var)
408  : value_watcher_(watcher), value_(value), var_(var) {}
409  ~WatchDemon() override {}
410 
411  void Run(Solver* const solver) override {
412  value_watcher_->ProcessValueWatcher(value_, var_);
413  }
414 
415  private:
416  ValueWatcher* const value_watcher_;
417  const int64 value_;
418  IntVar* const var_;
419  };
420 
421  class VarDemon : public Demon {
422  public:
423  explicit VarDemon(ValueWatcher* const watcher)
424  : value_watcher_(watcher) {}
425 
426  ~VarDemon() override {}
427 
428  void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
429 
430  private:
431  ValueWatcher* const value_watcher_;
432  };
433 
434  ValueWatcher(Solver* const solver, DomainIntVar* const variable)
435  : BaseValueWatcher(solver),
436  variable_(variable),
437  hole_iterator_(variable_->MakeHoleIterator(true)),
438  var_demon_(nullptr),
439  watchers_(solver, variable->Min(), variable->Max()) {}
440 
441  ~ValueWatcher() override {}
442 
443  IntVar* GetOrMakeValueWatcher(int64 value) override {
444  IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
445  if (watcher != nullptr) return watcher;
446  if (variable_->Contains(value)) {
447  if (variable_->Bound()) {
448  return solver()->MakeIntConst(1);
449  } else {
450  const std::string vname = variable_->HasName()
451  ? variable_->name()
452  : variable_->DebugString();
453  const std::string bname =
454  absl::StrFormat("Watch<%s == %d>", vname, value);
455  IntVar* const boolvar = solver()->MakeBoolVar(bname);
456  watchers_.UnsafeRevInsert(value, boolvar);
457  if (posted_.Switched()) {
458  boolvar->WhenBound(
459  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
460  var_demon_->desinhibit(solver());
461  }
462  return boolvar;
463  }
464  } else {
465  return variable_->solver()->MakeIntConst(0);
466  }
467  }
468 
469  void SetValueWatcher(IntVar* const boolvar, int64 value) override {
470  CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
471  if (!boolvar->Bound()) {
472  watchers_.UnsafeRevInsert(value, boolvar);
473  if (posted_.Switched() && !boolvar->Bound()) {
474  boolvar->WhenBound(
475  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
476  var_demon_->desinhibit(solver());
477  }
478  }
479  }
480 
481  void Post() override {
482  var_demon_ = solver()->RevAlloc(new VarDemon(this));
483  variable_->WhenDomain(var_demon_);
484  for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
485  const std::pair<int64, IntVar*>& w = watchers_.At(pos);
486  const int64 value = w.first;
487  IntVar* const boolvar = w.second;
488  if (!boolvar->Bound() && variable_->Contains(value)) {
489  boolvar->WhenBound(
490  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
491  }
492  }
493  posted_.Switch(solver());
494  }
495 
496  void InitialPropagate() override {
497  if (variable_->Bound()) {
498  VariableBound();
499  } else {
500  for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
501  const std::pair<int64, IntVar*>& w = watchers_.At(pos);
502  const int64 value = w.first;
503  IntVar* const boolvar = w.second;
504  if (!variable_->Contains(value)) {
505  boolvar->SetValue(0);
506  watchers_.RemoveAt(pos);
507  } else {
508  if (boolvar->Bound()) {
509  ProcessValueWatcher(value, boolvar);
510  watchers_.RemoveAt(pos);
511  }
512  }
513  }
514  CheckInhibit();
515  }
516  }
517 
518  void ProcessValueWatcher(int64 value, IntVar* boolvar) {
519  if (boolvar->Min() == 0) {
520  if (variable_->Size() < 0xFFFFFF) {
521  variable_->RemoveValue(value);
522  } else {
523  // Delay removal.
524  solver()->AddConstraint(solver()->MakeNonEquality(variable_, value));
525  }
526  } else {
527  variable_->SetValue(value);
528  }
529  }
530 
531  void ProcessVar() {
532  const int kSmallList = 16;
533  if (variable_->Bound()) {
534  VariableBound();
535  } else if (watchers_.Size() <= kSmallList ||
536  variable_->Min() != variable_->OldMin() ||
537  variable_->Max() != variable_->OldMax()) {
538  // Brute force loop for small numbers of watchers, or if the bounds have
539  // changed, which would have required a sort (n log(n)) anyway to take
540  // advantage of.
541  ScanWatchers();
542  CheckInhibit();
543  } else {
544  // If there is no bitset, then there are no holes.
545  // In that case, the two loops above should have performed all
546  // propagation. Otherwise, scan the remaining watchers.
547  BitSet* const bitset = variable_->bitset();
548  if (bitset != nullptr && !watchers_.Empty()) {
549  if (bitset->NumHoles() * 2 < watchers_.Size()) {
550  for (const int64 hole : InitAndGetValues(hole_iterator_)) {
551  int pos = 0;
552  IntVar* const boolvar = watchers_.FindPtrOrNull(hole, &pos);
553  if (boolvar != nullptr) {
554  boolvar->SetValue(0);
555  watchers_.RemoveAt(pos);
556  }
557  }
558  } else {
559  ScanWatchers();
560  }
561  }
562  CheckInhibit();
563  }
564  }
565 
566  // Optimized case if the variable is bound.
567  void VariableBound() {
568  DCHECK(variable_->Bound());
569  const int64 value = variable_->Min();
570  for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
571  const std::pair<int64, IntVar*>& w = watchers_.At(pos);
572  w.second->SetValue(w.first == value);
573  }
574  watchers_.RemoveAll();
575  var_demon_->inhibit(solver());
576  }
577 
578  // Scans all the watchers to check and assign them.
579  void ScanWatchers() {
580  for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
581  const std::pair<int64, IntVar*>& w = watchers_.At(pos);
582  if (!variable_->Contains(w.first)) {
583  IntVar* const boolvar = w.second;
584  boolvar->SetValue(0);
585  watchers_.RemoveAt(pos);
586  }
587  }
588  }
589 
590  // If the set of active watchers is empty, we can inhibit the demon on the
591  // main variable.
592  void CheckInhibit() {
593  if (watchers_.Empty()) {
594  var_demon_->inhibit(solver());
595  }
596  }
597 
598  void Accept(ModelVisitor* const visitor) const override {
599  visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
600  visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
601  variable_);
602  std::vector<int64> all_coefficients;
603  std::vector<IntVar*> all_bool_vars;
604  for (int position = watchers_.start(); position < watchers_.end();
605  ++position) {
606  const std::pair<int64, IntVar*>& w = watchers_.At(position);
607  all_coefficients.push_back(w.first);
608  all_bool_vars.push_back(w.second);
609  }
610  visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
611  all_bool_vars);
612  visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
613  all_coefficients);
614  visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
615  }
616 
617  std::string DebugString() const override {
618  return absl::StrFormat("ValueWatcher(%s)", variable_->DebugString());
619  }
620 
621  private:
622  DomainIntVar* const variable_;
623  IntVarIterator* const hole_iterator_;
624  RevSwitch posted_;
625  Demon* var_demon_;
626  RevIntPtrMap<IntVar> watchers_;
627  };
628 
629  // Optimized case for small maps.
630  class DenseValueWatcher : public BaseValueWatcher {
631  public:
632  class WatchDemon : public Demon {
633  public:
634  WatchDemon(DenseValueWatcher* const watcher, int64 value, IntVar* var)
635  : value_watcher_(watcher), value_(value), var_(var) {}
636  ~WatchDemon() override {}
637 
638  void Run(Solver* const solver) override {
639  value_watcher_->ProcessValueWatcher(value_, var_);
640  }
641 
642  private:
643  DenseValueWatcher* const value_watcher_;
644  const int64 value_;
645  IntVar* const var_;
646  };
647 
648  class VarDemon : public Demon {
649  public:
650  explicit VarDemon(DenseValueWatcher* const watcher)
651  : value_watcher_(watcher) {}
652 
653  ~VarDemon() override {}
654 
655  void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
656 
657  private:
658  DenseValueWatcher* const value_watcher_;
659  };
660 
661  DenseValueWatcher(Solver* const solver, DomainIntVar* const variable)
662  : BaseValueWatcher(solver),
663  variable_(variable),
664  hole_iterator_(variable_->MakeHoleIterator(true)),
665  var_demon_(nullptr),
666  offset_(variable->Min()),
667  watchers_(variable->Max() - variable->Min() + 1, nullptr),
668  active_watchers_(0) {}
669 
670  ~DenseValueWatcher() override {}
671 
672  IntVar* GetOrMakeValueWatcher(int64 value) override {
673  const int64 var_max = offset_ + watchers_.size() - 1; // Bad cast.
674  if (value < offset_ || value > var_max) {
675  return solver()->MakeIntConst(0);
676  }
677  const int index = value - offset_;
678  IntVar* const watcher = watchers_[index];
679  if (watcher != nullptr) return watcher;
680  if (variable_->Contains(value)) {
681  if (variable_->Bound()) {
682  return solver()->MakeIntConst(1);
683  } else {
684  const std::string vname = variable_->HasName()
685  ? variable_->name()
686  : variable_->DebugString();
687  const std::string bname =
688  absl::StrFormat("Watch<%s == %d>", vname, value);
689  IntVar* const boolvar = solver()->MakeBoolVar(bname);
690  RevInsert(index, boolvar);
691  if (posted_.Switched()) {
692  boolvar->WhenBound(
693  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
694  var_demon_->desinhibit(solver());
695  }
696  return boolvar;
697  }
698  } else {
699  return variable_->solver()->MakeIntConst(0);
700  }
701  }
702 
703  void SetValueWatcher(IntVar* const boolvar, int64 value) override {
704  const int index = value - offset_;
705  CHECK(watchers_[index] == nullptr);
706  if (!boolvar->Bound()) {
707  RevInsert(index, boolvar);
708  if (posted_.Switched() && !boolvar->Bound()) {
709  boolvar->WhenBound(
710  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
711  var_demon_->desinhibit(solver());
712  }
713  }
714  }
715 
716  void Post() override {
717  var_demon_ = solver()->RevAlloc(new VarDemon(this));
718  variable_->WhenDomain(var_demon_);
719  for (int pos = 0; pos < watchers_.size(); ++pos) {
720  const int64 value = pos + offset_;
721  IntVar* const boolvar = watchers_[pos];
722  if (boolvar != nullptr && !boolvar->Bound() &&
723  variable_->Contains(value)) {
724  boolvar->WhenBound(
725  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
726  }
727  }
728  posted_.Switch(solver());
729  }
730 
731  void InitialPropagate() override {
732  if (variable_->Bound()) {
733  VariableBound();
734  } else {
735  for (int pos = 0; pos < watchers_.size(); ++pos) {
736  IntVar* const boolvar = watchers_[pos];
737  if (boolvar == nullptr) continue;
738  const int64 value = pos + offset_;
739  if (!variable_->Contains(value)) {
740  boolvar->SetValue(0);
741  RevRemove(pos);
742  } else if (boolvar->Bound()) {
743  ProcessValueWatcher(value, boolvar);
744  RevRemove(pos);
745  }
746  }
747  if (active_watchers_.Value() == 0) {
748  var_demon_->inhibit(solver());
749  }
750  }
751  }
752 
753  void ProcessValueWatcher(int64 value, IntVar* boolvar) {
754  if (boolvar->Min() == 0) {
755  variable_->RemoveValue(value);
756  } else {
757  variable_->SetValue(value);
758  }
759  }
760 
761  void ProcessVar() {
762  if (variable_->Bound()) {
763  VariableBound();
764  } else {
765  // Brute force loop for small numbers of watchers.
766  ScanWatchers();
767  if (active_watchers_.Value() == 0) {
768  var_demon_->inhibit(solver());
769  }
770  }
771  }
772 
773  // Optimized case if the variable is bound.
774  void VariableBound() {
775  DCHECK(variable_->Bound());
776  const int64 value = variable_->Min();
777  for (int pos = 0; pos < watchers_.size(); ++pos) {
778  IntVar* const boolvar = watchers_[pos];
779  if (boolvar != nullptr) {
780  boolvar->SetValue(pos + offset_ == value);
781  RevRemove(pos);
782  }
783  }
784  var_demon_->inhibit(solver());
785  }
786 
787  // Scans all the watchers to check and assign them.
788  void ScanWatchers() {
789  const int64 old_min_index = variable_->OldMin() - offset_;
790  const int64 old_max_index = variable_->OldMax() - offset_;
791  const int64 min_index = variable_->Min() - offset_;
792  const int64 max_index = variable_->Max() - offset_;
793  for (int pos = old_min_index; pos < min_index; ++pos) {
794  IntVar* const boolvar = watchers_[pos];
795  if (boolvar != nullptr) {
796  boolvar->SetValue(0);
797  RevRemove(pos);
798  }
799  }
800  for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
801  IntVar* const boolvar = watchers_[pos];
802  if (boolvar != nullptr) {
803  boolvar->SetValue(0);
804  RevRemove(pos);
805  }
806  }
807  BitSet* const bitset = variable_->bitset();
808  if (bitset != nullptr) {
809  if (bitset->NumHoles() * 2 < active_watchers_.Value()) {
810  for (const int64 hole : InitAndGetValues(hole_iterator_)) {
811  IntVar* const boolvar = watchers_[hole - offset_];
812  if (boolvar != nullptr) {
813  boolvar->SetValue(0);
814  RevRemove(hole - offset_);
815  }
816  }
817  } else {
818  for (int pos = min_index + 1; pos < max_index; ++pos) {
819  IntVar* const boolvar = watchers_[pos];
820  if (boolvar != nullptr && !variable_->Contains(offset_ + pos)) {
821  boolvar->SetValue(0);
822  RevRemove(pos);
823  }
824  }
825  }
826  }
827  }
828 
829  void RevRemove(int pos) {
830  solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
831  watchers_[pos] = nullptr;
832  active_watchers_.Decr(solver());
833  }
834 
835  void RevInsert(int pos, IntVar* boolvar) {
836  solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
837  watchers_[pos] = boolvar;
838  active_watchers_.Incr(solver());
839  }
840 
841  void Accept(ModelVisitor* const visitor) const override {
842  visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
843  visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
844  variable_);
845  std::vector<int64> all_coefficients;
846  std::vector<IntVar*> all_bool_vars;
847  for (int position = 0; position < watchers_.size(); ++position) {
848  if (watchers_[position] != nullptr) {
849  all_coefficients.push_back(position + offset_);
850  all_bool_vars.push_back(watchers_[position]);
851  }
852  }
853  visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
854  all_bool_vars);
855  visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
856  all_coefficients);
857  visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
858  }
859 
860  std::string DebugString() const override {
861  return absl::StrFormat("DenseValueWatcher(%s)", variable_->DebugString());
862  }
863 
864  private:
865  DomainIntVar* const variable_;
866  IntVarIterator* const hole_iterator_;
867  RevSwitch posted_;
868  Demon* var_demon_;
869  const int64 offset_;
870  std::vector<IntVar*> watchers_;
871  NumericalRev<int> active_watchers_;
872  };
873 
874  class BaseUpperBoundWatcher : public Constraint {
875  public:
876  explicit BaseUpperBoundWatcher(Solver* const solver) : Constraint(solver) {}
877 
878  ~BaseUpperBoundWatcher() override {}
879 
880  virtual IntVar* GetOrMakeUpperBoundWatcher(int64 value) = 0;
881 
882  virtual void SetUpperBoundWatcher(IntVar* const boolvar, int64 value) = 0;
883  };
884 
885  // This class watches the bounds of the variable and updates the
886  // IsGreater/IsGreaterOrEqual/IsLess/IsLessOrEqual demons
887  // accordingly.
888  class UpperBoundWatcher : public BaseUpperBoundWatcher {
889  public:
890  class WatchDemon : public Demon {
891  public:
892  WatchDemon(UpperBoundWatcher* const watcher, int64 index,
893  IntVar* const var)
894  : value_watcher_(watcher), index_(index), var_(var) {}
895  ~WatchDemon() override {}
896 
897  void Run(Solver* const solver) override {
898  value_watcher_->ProcessUpperBoundWatcher(index_, var_);
899  }
900 
901  private:
902  UpperBoundWatcher* const value_watcher_;
903  const int64 index_;
904  IntVar* const var_;
905  };
906 
907  class VarDemon : public Demon {
908  public:
909  explicit VarDemon(UpperBoundWatcher* const watcher)
910  : value_watcher_(watcher) {}
911  ~VarDemon() override {}
912 
913  void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
914 
915  private:
916  UpperBoundWatcher* const value_watcher_;
917  };
918 
919  UpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
920  : BaseUpperBoundWatcher(solver),
921  variable_(variable),
922  var_demon_(nullptr),
923  watchers_(solver, variable->Min(), variable->Max()),
924  start_(0),
925  end_(0),
926  sorted_(false) {}
927 
928  ~UpperBoundWatcher() override {}
929 
930  IntVar* GetOrMakeUpperBoundWatcher(int64 value) override {
931  IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
932  if (watcher != nullptr) {
933  return watcher;
934  }
935  if (variable_->Max() >= value) {
936  if (variable_->Min() >= value) {
937  return solver()->MakeIntConst(1);
938  } else {
939  const std::string vname = variable_->HasName()
940  ? variable_->name()
941  : variable_->DebugString();
942  const std::string bname =
943  absl::StrFormat("Watch<%s >= %d>", vname, value);
944  IntVar* const boolvar = solver()->MakeBoolVar(bname);
945  watchers_.UnsafeRevInsert(value, boolvar);
946  if (posted_.Switched()) {
947  boolvar->WhenBound(
948  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
949  var_demon_->desinhibit(solver());
950  sorted_ = false;
951  }
952  return boolvar;
953  }
954  } else {
955  return variable_->solver()->MakeIntConst(0);
956  }
957  }
958 
959  void SetUpperBoundWatcher(IntVar* const boolvar, int64 value) override {
960  CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
961  watchers_.UnsafeRevInsert(value, boolvar);
962  if (posted_.Switched() && !boolvar->Bound()) {
963  boolvar->WhenBound(
964  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
965  var_demon_->desinhibit(solver());
966  sorted_ = false;
967  }
968  }
969 
970  void Post() override {
971  const int kTooSmallToSort = 8;
972  var_demon_ = solver()->RevAlloc(new VarDemon(this));
973  variable_->WhenRange(var_demon_);
974 
975  if (watchers_.Size() > kTooSmallToSort) {
976  watchers_.SortActive();
977  sorted_ = true;
978  start_.SetValue(solver(), watchers_.start());
979  end_.SetValue(solver(), watchers_.end() - 1);
980  }
981 
982  for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
983  const std::pair<int64, IntVar*>& w = watchers_.At(pos);
984  IntVar* const boolvar = w.second;
985  const int64 value = w.first;
986  if (!boolvar->Bound() && value > variable_->Min() &&
987  value <= variable_->Max()) {
988  boolvar->WhenBound(
989  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
990  }
991  }
992  posted_.Switch(solver());
993  }
994 
995  void InitialPropagate() override {
996  const int64 var_min = variable_->Min();
997  const int64 var_max = variable_->Max();
998  if (sorted_) {
999  while (start_.Value() <= end_.Value()) {
1000  const std::pair<int64, IntVar*>& w = watchers_.At(start_.Value());
1001  if (w.first <= var_min) {
1002  w.second->SetValue(1);
1003  start_.Incr(solver());
1004  } else {
1005  break;
1006  }
1007  }
1008  while (end_.Value() >= start_.Value()) {
1009  const std::pair<int64, IntVar*>& w = watchers_.At(end_.Value());
1010  if (w.first > var_max) {
1011  w.second->SetValue(0);
1012  end_.Decr(solver());
1013  } else {
1014  break;
1015  }
1016  }
1017  for (int i = start_.Value(); i <= end_.Value(); ++i) {
1018  const std::pair<int64, IntVar*>& w = watchers_.At(i);
1019  if (w.second->Bound()) {
1020  ProcessUpperBoundWatcher(w.first, w.second);
1021  }
1022  }
1023  if (start_.Value() > end_.Value()) {
1024  var_demon_->inhibit(solver());
1025  }
1026  } else {
1027  for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1028  const std::pair<int64, IntVar*>& w = watchers_.At(pos);
1029  const int64 value = w.first;
1030  IntVar* const boolvar = w.second;
1031 
1032  if (value <= var_min) {
1033  boolvar->SetValue(1);
1034  watchers_.RemoveAt(pos);
1035  } else if (value > var_max) {
1036  boolvar->SetValue(0);
1037  watchers_.RemoveAt(pos);
1038  } else if (boolvar->Bound()) {
1039  ProcessUpperBoundWatcher(value, boolvar);
1040  watchers_.RemoveAt(pos);
1041  }
1042  }
1043  }
1044  }
1045 
1046  void Accept(ModelVisitor* const visitor) const override {
1047  visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1048  visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1049  variable_);
1050  std::vector<int64> all_coefficients;
1051  std::vector<IntVar*> all_bool_vars;
1052  for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1053  const std::pair<int64, IntVar*>& w = watchers_.At(pos);
1054  all_coefficients.push_back(w.first);
1055  all_bool_vars.push_back(w.second);
1056  }
1057  visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1058  all_bool_vars);
1059  visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1060  all_coefficients);
1061  visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1062  }
1063 
1064  std::string DebugString() const override {
1065  return absl::StrFormat("UpperBoundWatcher(%s)", variable_->DebugString());
1066  }
1067 
1068  private:
1069  void ProcessUpperBoundWatcher(int64 value, IntVar* const boolvar) {
1070  if (boolvar->Min() == 0) {
1071  variable_->SetMax(value - 1);
1072  } else {
1073  variable_->SetMin(value);
1074  }
1075  }
1076 
1077  void ProcessVar() {
1078  const int64 var_min = variable_->Min();
1079  const int64 var_max = variable_->Max();
1080  if (sorted_) {
1081  while (start_.Value() <= end_.Value()) {
1082  const std::pair<int64, IntVar*>& w = watchers_.At(start_.Value());
1083  if (w.first <= var_min) {
1084  w.second->SetValue(1);
1085  start_.Incr(solver());
1086  } else {
1087  break;
1088  }
1089  }
1090  while (end_.Value() >= start_.Value()) {
1091  const std::pair<int64, IntVar*>& w = watchers_.At(end_.Value());
1092  if (w.first > var_max) {
1093  w.second->SetValue(0);
1094  end_.Decr(solver());
1095  } else {
1096  break;
1097  }
1098  }
1099  if (start_.Value() > end_.Value()) {
1100  var_demon_->inhibit(solver());
1101  }
1102  } else {
1103  for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1104  const std::pair<int64, IntVar*>& w = watchers_.At(pos);
1105  const int64 value = w.first;
1106  IntVar* const boolvar = w.second;
1107 
1108  if (value <= var_min) {
1109  boolvar->SetValue(1);
1110  watchers_.RemoveAt(pos);
1111  } else if (value > var_max) {
1112  boolvar->SetValue(0);
1113  watchers_.RemoveAt(pos);
1114  }
1115  }
1116  if (watchers_.Empty()) {
1117  var_demon_->inhibit(solver());
1118  }
1119  }
1120  }
1121 
1122  DomainIntVar* const variable_;
1123  RevSwitch posted_;
1124  Demon* var_demon_;
1125  RevIntPtrMap<IntVar> watchers_;
1126  NumericalRev<int> start_;
1127  NumericalRev<int> end_;
1128  bool sorted_;
1129  };
1130 
1131  // Optimized case for small maps.
1132  class DenseUpperBoundWatcher : public BaseUpperBoundWatcher {
1133  public:
1134  class WatchDemon : public Demon {
1135  public:
1136  WatchDemon(DenseUpperBoundWatcher* const watcher, int64 value,
1137  IntVar* var)
1138  : value_watcher_(watcher), value_(value), var_(var) {}
1139  ~WatchDemon() override {}
1140 
1141  void Run(Solver* const solver) override {
1142  value_watcher_->ProcessUpperBoundWatcher(value_, var_);
1143  }
1144 
1145  private:
1146  DenseUpperBoundWatcher* const value_watcher_;
1147  const int64 value_;
1148  IntVar* const var_;
1149  };
1150 
1151  class VarDemon : public Demon {
1152  public:
1153  explicit VarDemon(DenseUpperBoundWatcher* const watcher)
1154  : value_watcher_(watcher) {}
1155 
1156  ~VarDemon() override {}
1157 
1158  void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
1159 
1160  private:
1161  DenseUpperBoundWatcher* const value_watcher_;
1162  };
1163 
1164  DenseUpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
1165  : BaseUpperBoundWatcher(solver),
1166  variable_(variable),
1167  var_demon_(nullptr),
1168  offset_(variable->Min()),
1169  watchers_(variable->Max() - variable->Min() + 1, nullptr),
1170  active_watchers_(0) {}
1171 
1172  ~DenseUpperBoundWatcher() override {}
1173 
1174  IntVar* GetOrMakeUpperBoundWatcher(int64 value) override {
1175  if (variable_->Max() >= value) {
1176  if (variable_->Min() >= value) {
1177  return solver()->MakeIntConst(1);
1178  } else {
1179  const std::string vname = variable_->HasName()
1180  ? variable_->name()
1181  : variable_->DebugString();
1182  const std::string bname =
1183  absl::StrFormat("Watch<%s >= %d>", vname, value);
1184  IntVar* const boolvar = solver()->MakeBoolVar(bname);
1185  RevInsert(value - offset_, boolvar);
1186  if (posted_.Switched()) {
1187  boolvar->WhenBound(
1188  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1189  var_demon_->desinhibit(solver());
1190  }
1191  return boolvar;
1192  }
1193  } else {
1194  return variable_->solver()->MakeIntConst(0);
1195  }
1196  }
1197 
1198  void SetUpperBoundWatcher(IntVar* const boolvar, int64 value) override {
1199  const int index = value - offset_;
1200  CHECK(watchers_[index] == nullptr);
1201  if (!boolvar->Bound()) {
1202  RevInsert(index, boolvar);
1203  if (posted_.Switched() && !boolvar->Bound()) {
1204  boolvar->WhenBound(
1205  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1206  var_demon_->desinhibit(solver());
1207  }
1208  }
1209  }
1210 
1211  void Post() override {
1212  var_demon_ = solver()->RevAlloc(new VarDemon(this));
1213  variable_->WhenRange(var_demon_);
1214  for (int pos = 0; pos < watchers_.size(); ++pos) {
1215  const int64 value = pos + offset_;
1216  IntVar* const boolvar = watchers_[pos];
1217  if (boolvar != nullptr && !boolvar->Bound() &&
1218  value > variable_->Min() && value <= variable_->Max()) {
1219  boolvar->WhenBound(
1220  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1221  }
1222  }
1223  posted_.Switch(solver());
1224  }
1225 
1226  void InitialPropagate() override {
1227  for (int pos = 0; pos < watchers_.size(); ++pos) {
1228  IntVar* const boolvar = watchers_[pos];
1229  if (boolvar == nullptr) continue;
1230  const int64 value = pos + offset_;
1231  if (value <= variable_->Min()) {
1232  boolvar->SetValue(1);
1233  RevRemove(pos);
1234  } else if (value > variable_->Max()) {
1235  boolvar->SetValue(0);
1236  RevRemove(pos);
1237  } else if (boolvar->Bound()) {
1238  ProcessUpperBoundWatcher(value, boolvar);
1239  RevRemove(pos);
1240  }
1241  }
1242  if (active_watchers_.Value() == 0) {
1243  var_demon_->inhibit(solver());
1244  }
1245  }
1246 
1247  void ProcessUpperBoundWatcher(int64 value, IntVar* boolvar) {
1248  if (boolvar->Min() == 0) {
1249  variable_->SetMax(value - 1);
1250  } else {
1251  variable_->SetMin(value);
1252  }
1253  }
1254 
1255  void ProcessVar() {
1256  const int64 old_min_index = variable_->OldMin() - offset_;
1257  const int64 old_max_index = variable_->OldMax() - offset_;
1258  const int64 min_index = variable_->Min() - offset_;
1259  const int64 max_index = variable_->Max() - offset_;
1260  for (int pos = old_min_index; pos <= min_index; ++pos) {
1261  IntVar* const boolvar = watchers_[pos];
1262  if (boolvar != nullptr) {
1263  boolvar->SetValue(1);
1264  RevRemove(pos);
1265  }
1266  }
1267 
1268  for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
1269  IntVar* const boolvar = watchers_[pos];
1270  if (boolvar != nullptr) {
1271  boolvar->SetValue(0);
1272  RevRemove(pos);
1273  }
1274  }
1275  if (active_watchers_.Value() == 0) {
1276  var_demon_->inhibit(solver());
1277  }
1278  }
1279 
1280  void RevRemove(int pos) {
1281  solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1282  watchers_[pos] = nullptr;
1283  active_watchers_.Decr(solver());
1284  }
1285 
1286  void RevInsert(int pos, IntVar* boolvar) {
1287  solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1288  watchers_[pos] = boolvar;
1289  active_watchers_.Incr(solver());
1290  }
1291 
1292  void Accept(ModelVisitor* const visitor) const override {
1293  visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1294  visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1295  variable_);
1296  std::vector<int64> all_coefficients;
1297  std::vector<IntVar*> all_bool_vars;
1298  for (int position = 0; position < watchers_.size(); ++position) {
1299  if (watchers_[position] != nullptr) {
1300  all_coefficients.push_back(position + offset_);
1301  all_bool_vars.push_back(watchers_[position]);
1302  }
1303  }
1304  visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1305  all_bool_vars);
1306  visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1307  all_coefficients);
1308  visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1309  }
1310 
1311  std::string DebugString() const override {
1312  return absl::StrFormat("DenseUpperBoundWatcher(%s)",
1313  variable_->DebugString());
1314  }
1315 
1316  private:
1317  DomainIntVar* const variable_;
1318  RevSwitch posted_;
1319  Demon* var_demon_;
1320  const int64 offset_;
1321  std::vector<IntVar*> watchers_;
1322  NumericalRev<int> active_watchers_;
1323  };
1324 
1325  // ----- Main Class -----
1326  DomainIntVar(Solver* const s, int64 vmin, int64 vmax,
1327  const std::string& name);
1328  DomainIntVar(Solver* const s, const std::vector<int64>& sorted_values,
1329  const std::string& name);
1330  ~DomainIntVar() override;
1331 
1332  int64 Min() const override { return min_.Value(); }
1333  void SetMin(int64 m) override;
1334  int64 Max() const override { return max_.Value(); }
1335  void SetMax(int64 m) override;
1336  void SetRange(int64 mi, int64 ma) override;
1337  void SetValue(int64 v) override;
1338  bool Bound() const override { return (min_.Value() == max_.Value()); }
1339  int64 Value() const override {
1340  CHECK_EQ(min_.Value(), max_.Value())
1341  << " variable " << DebugString() << " is not bound.";
1342  return min_.Value();
1343  }
1344  void RemoveValue(int64 v) override;
1345  void RemoveInterval(int64 l, int64 u) override;
1346  void CreateBits();
1347  void WhenBound(Demon* d) override {
1348  if (min_.Value() != max_.Value()) {
1349  if (d->priority() == Solver::DELAYED_PRIORITY) {
1350  delayed_bound_demons_.PushIfNotTop(solver(),
1351  solver()->RegisterDemon(d));
1352  } else {
1353  bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1354  }
1355  }
1356  }
1357  void WhenRange(Demon* d) override {
1358  if (min_.Value() != max_.Value()) {
1359  if (d->priority() == Solver::DELAYED_PRIORITY) {
1360  delayed_range_demons_.PushIfNotTop(solver(),
1361  solver()->RegisterDemon(d));
1362  } else {
1363  range_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1364  }
1365  }
1366  }
1367  void WhenDomain(Demon* d) override {
1368  if (min_.Value() != max_.Value()) {
1369  if (d->priority() == Solver::DELAYED_PRIORITY) {
1370  delayed_domain_demons_.PushIfNotTop(solver(),
1371  solver()->RegisterDemon(d));
1372  } else {
1373  domain_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1374  }
1375  }
1376  }
1377 
1378  IntVar* IsEqual(int64 constant) override {
1379  Solver* const s = solver();
1380  if (constant == min_.Value() && value_watcher_ == nullptr) {
1381  return s->MakeIsLessOrEqualCstVar(this, constant);
1382  }
1383  if (constant == max_.Value() && value_watcher_ == nullptr) {
1384  return s->MakeIsGreaterOrEqualCstVar(this, constant);
1385  }
1386  if (!Contains(constant)) {
1387  return s->MakeIntConst(int64{0});
1388  }
1389  if (Bound() && min_.Value() == constant) {
1390  return s->MakeIntConst(int64{1});
1391  }
1392  IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1393  this, constant, ModelCache::EXPR_CONSTANT_IS_EQUAL);
1394  if (cache != nullptr) {
1395  return cache->Var();
1396  } else {
1397  if (value_watcher_ == nullptr) {
1398  if (CapSub(Max(), Min()) <= 256) {
1399  solver()->SaveAndSetValue(
1400  reinterpret_cast<void**>(&value_watcher_),
1401  reinterpret_cast<void*>(
1402  solver()->RevAlloc(new DenseValueWatcher(solver(), this))));
1403 
1404  } else {
1405  solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1406  reinterpret_cast<void*>(solver()->RevAlloc(
1407  new ValueWatcher(solver(), this))));
1408  }
1409  solver()->AddConstraint(value_watcher_);
1410  }
1411  IntVar* const boolvar = value_watcher_->GetOrMakeValueWatcher(constant);
1412  s->Cache()->InsertExprConstantExpression(
1413  boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_EQUAL);
1414  return boolvar;
1415  }
1416  }
1417 
1418  Constraint* SetIsEqual(const std::vector<int64>& values,
1419  const std::vector<IntVar*>& vars) {
1420  if (value_watcher_ == nullptr) {
1421  solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1422  reinterpret_cast<void*>(solver()->RevAlloc(
1423  new ValueWatcher(solver(), this))));
1424  for (int i = 0; i < vars.size(); ++i) {
1425  value_watcher_->SetValueWatcher(vars[i], values[i]);
1426  }
1427  }
1428  return value_watcher_;
1429  }
1430 
1431  IntVar* IsDifferent(int64 constant) override {
1432  Solver* const s = solver();
1433  if (constant == min_.Value() && value_watcher_ == nullptr) {
1434  return s->MakeIsGreaterOrEqualCstVar(this, constant + 1);
1435  }
1436  if (constant == max_.Value() && value_watcher_ == nullptr) {
1437  return s->MakeIsLessOrEqualCstVar(this, constant - 1);
1438  }
1439  if (!Contains(constant)) {
1440  return s->MakeIntConst(int64{1});
1441  }
1442  if (Bound() && min_.Value() == constant) {
1443  return s->MakeIntConst(int64{0});
1444  }
1445  IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1446  this, constant, ModelCache::EXPR_CONSTANT_IS_NOT_EQUAL);
1447  if (cache != nullptr) {
1448  return cache->Var();
1449  } else {
1450  IntVar* const boolvar = s->MakeDifference(1, IsEqual(constant))->Var();
1451  s->Cache()->InsertExprConstantExpression(
1452  boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_NOT_EQUAL);
1453  return boolvar;
1454  }
1455  }
1456 
1457  IntVar* IsGreaterOrEqual(int64 constant) override {
1458  Solver* const s = solver();
1459  if (max_.Value() < constant) {
1460  return s->MakeIntConst(int64{0});
1461  }
1462  if (min_.Value() >= constant) {
1463  return s->MakeIntConst(int64{1});
1464  }
1465  IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1466  this, constant, ModelCache::EXPR_CONSTANT_IS_GREATER_OR_EQUAL);
1467  if (cache != nullptr) {
1468  return cache->Var();
1469  } else {
1470  if (bound_watcher_ == nullptr) {
1471  if (CapSub(Max(), Min()) <= 256) {
1472  solver()->SaveAndSetValue(
1473  reinterpret_cast<void**>(&bound_watcher_),
1474  reinterpret_cast<void*>(solver()->RevAlloc(
1475  new DenseUpperBoundWatcher(solver(), this))));
1476  solver()->AddConstraint(bound_watcher_);
1477  } else {
1478  solver()->SaveAndSetValue(
1479  reinterpret_cast<void**>(&bound_watcher_),
1480  reinterpret_cast<void*>(
1481  solver()->RevAlloc(new UpperBoundWatcher(solver(), this))));
1482  solver()->AddConstraint(bound_watcher_);
1483  }
1484  }
1485  IntVar* const boolvar =
1486  bound_watcher_->GetOrMakeUpperBoundWatcher(constant);
1487  s->Cache()->InsertExprConstantExpression(
1488  boolvar, this, constant,
1489  ModelCache::EXPR_CONSTANT_IS_GREATER_OR_EQUAL);
1490  return boolvar;
1491  }
1492  }
1493 
1494  Constraint* SetIsGreaterOrEqual(const std::vector<int64>& values,
1495  const std::vector<IntVar*>& vars) {
1496  if (bound_watcher_ == nullptr) {
1497  if (CapSub(Max(), Min()) <= 256) {
1498  solver()->SaveAndSetValue(
1499  reinterpret_cast<void**>(&bound_watcher_),
1500  reinterpret_cast<void*>(solver()->RevAlloc(
1501  new DenseUpperBoundWatcher(solver(), this))));
1502  solver()->AddConstraint(bound_watcher_);
1503  } else {
1504  solver()->SaveAndSetValue(reinterpret_cast<void**>(&bound_watcher_),
1505  reinterpret_cast<void*>(solver()->RevAlloc(
1506  new UpperBoundWatcher(solver(), this))));
1507  solver()->AddConstraint(bound_watcher_);
1508  }
1509  for (int i = 0; i < values.size(); ++i) {
1510  bound_watcher_->SetUpperBoundWatcher(vars[i], values[i]);
1511  }
1512  }
1513  return bound_watcher_;
1514  }
1515 
1516  IntVar* IsLessOrEqual(int64 constant) override {
1517  Solver* const s = solver();
1518  IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1519  this, constant, ModelCache::EXPR_CONSTANT_IS_LESS_OR_EQUAL);
1520  if (cache != nullptr) {
1521  return cache->Var();
1522  } else {
1523  IntVar* const boolvar =
1524  s->MakeDifference(1, IsGreaterOrEqual(constant + 1))->Var();
1525  s->Cache()->InsertExprConstantExpression(
1526  boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_LESS_OR_EQUAL);
1527  return boolvar;
1528  }
1529  }
1530 
1531  void Process();
1532  void Push();
1533  void CleanInProcess();
1534  uint64 Size() const override {
1535  if (bits_ != nullptr) return bits_->Size();
1536  return (static_cast<uint64>(max_.Value()) -
1537  static_cast<uint64>(min_.Value()) + 1);
1538  }
1539  bool Contains(int64 v) const override {
1540  if (v < min_.Value() || v > max_.Value()) return false;
1541  return (bits_ == nullptr ? true : bits_->Contains(v));
1542  }
1543  IntVarIterator* MakeHoleIterator(bool reversible) const override;
1544  IntVarIterator* MakeDomainIterator(bool reversible) const override;
1545  int64 OldMin() const override { return std::min(old_min_, min_.Value()); }
1546  int64 OldMax() const override { return std::max(old_max_, max_.Value()); }
1547 
1548  std::string DebugString() const override;
1549  BitSet* bitset() const { return bits_; }
1550  int VarType() const override { return DOMAIN_INT_VAR; }
1551  std::string BaseName() const override { return "IntegerVar"; }
1552 
1553  friend class PlusCstDomainIntVar;
1554  friend class LinkExprAndDomainIntVar;
1555 
1556  private:
1557  void CheckOldMin() {
1558  if (old_min_ > min_.Value()) {
1559  old_min_ = min_.Value();
1560  }
1561  }
1562  void CheckOldMax() {
1563  if (old_max_ < max_.Value()) {
1564  old_max_ = max_.Value();
1565  }
1566  }
1567  Rev<int64> min_;
1568  Rev<int64> max_;
1569  int64 old_min_;
1570  int64 old_max_;
1571  int64 new_min_;
1572  int64 new_max_;
1573  SimpleRevFIFO<Demon*> bound_demons_;
1574  SimpleRevFIFO<Demon*> range_demons_;
1575  SimpleRevFIFO<Demon*> domain_demons_;
1576  SimpleRevFIFO<Demon*> delayed_bound_demons_;
1577  SimpleRevFIFO<Demon*> delayed_range_demons_;
1578  SimpleRevFIFO<Demon*> delayed_domain_demons_;
1579  QueueHandler handler_;
1580  bool in_process_;
1581  BitSet* bits_;
1582  BaseValueWatcher* value_watcher_;
1583  BaseUpperBoundWatcher* bound_watcher_;
1584 };
1585 
1586 // ----- BitSet -----
1587 
1588 // Return whether an integer interval [a..b] (inclusive) contains at most
1589 // K values, i.e. b - a < K, in a way that's robust to overflows.
1590 // For performance reasons, in opt mode it doesn't check that [a, b] is a
1591 // valid interval, nor that K is nonnegative.
1592 inline bool ClosedIntervalNoLargerThan(int64 a, int64 b, int64 K) {
1593  DCHECK_LE(a, b);
1594  DCHECK_GE(K, 0);
1595  if (a > 0) {
1596  return a > b - K;
1597  } else {
1598  return a + K > b;
1599  }
1600 }
1601 
1602 class SimpleBitSet : public DomainIntVar::BitSet {
1603  public:
1604  SimpleBitSet(Solver* const s, int64 vmin, int64 vmax)
1605  : BitSet(s),
1606  bits_(nullptr),
1607  stamps_(nullptr),
1608  omin_(vmin),
1609  omax_(vmax),
1610  size_(vmax - vmin + 1),
1611  bsize_(BitLength64(size_.Value())) {
1612  CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1613  << "Bitset too large: [" << vmin << ", " << vmax << "]";
1614  bits_ = new uint64[bsize_];
1615  stamps_ = new uint64[bsize_];
1616  for (int i = 0; i < bsize_; ++i) {
1617  const int bs =
1618  (i == size_.Value() - 1) ? 63 - BitPos64(size_.Value()) : 0;
1619  bits_[i] = kAllBits64 >> bs;
1620  stamps_[i] = s->stamp() - 1;
1621  }
1622  }
1623 
1624  SimpleBitSet(Solver* const s, const std::vector<int64>& sorted_values,
1625  int64 vmin, int64 vmax)
1626  : BitSet(s),
1627  bits_(nullptr),
1628  stamps_(nullptr),
1629  omin_(vmin),
1630  omax_(vmax),
1631  size_(sorted_values.size()),
1632  bsize_(BitLength64(vmax - vmin + 1)) {
1633  CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1634  << "Bitset too large: [" << vmin << ", " << vmax << "]";
1635  bits_ = new uint64[bsize_];
1636  stamps_ = new uint64[bsize_];
1637  for (int i = 0; i < bsize_; ++i) {
1638  bits_[i] = GG_ULONGLONG(0);
1639  stamps_[i] = s->stamp() - 1;
1640  }
1641  for (int i = 0; i < sorted_values.size(); ++i) {
1642  const int64 val = sorted_values[i];
1643  DCHECK(!bit(val));
1644  const int offset = BitOffset64(val - omin_);
1645  const int pos = BitPos64(val - omin_);
1646  bits_[offset] |= OneBit64(pos);
1647  }
1648  }
1649 
1650  ~SimpleBitSet() override {
1651  delete[] bits_;
1652  delete[] stamps_;
1653  }
1654 
1655  bool bit(int64 val) const { return IsBitSet64(bits_, val - omin_); }
1656 
1657  int64 ComputeNewMin(int64 nmin, int64 cmin, int64 cmax) override {
1658  DCHECK_GE(nmin, cmin);
1659  DCHECK_LE(nmin, cmax);
1660  DCHECK_LE(cmin, cmax);
1661  DCHECK_GE(cmin, omin_);
1662  DCHECK_LE(cmax, omax_);
1663  const int64 new_min =
1664  UnsafeLeastSignificantBitPosition64(bits_, nmin - omin_, cmax - omin_) +
1665  omin_;
1666  const uint64 removed_bits =
1667  BitCountRange64(bits_, cmin - omin_, new_min - omin_ - 1);
1668  size_.Add(solver_, -removed_bits);
1669  return new_min;
1670  }
1671 
1672  int64 ComputeNewMax(int64 nmax, int64 cmin, int64 cmax) override {
1673  DCHECK_GE(nmax, cmin);
1674  DCHECK_LE(nmax, cmax);
1675  DCHECK_LE(cmin, cmax);
1676  DCHECK_GE(cmin, omin_);
1677  DCHECK_LE(cmax, omax_);
1678  const int64 new_max =
1679  UnsafeMostSignificantBitPosition64(bits_, cmin - omin_, nmax - omin_) +
1680  omin_;
1681  const uint64 removed_bits =
1682  BitCountRange64(bits_, new_max - omin_ + 1, cmax - omin_);
1683  size_.Add(solver_, -removed_bits);
1684  return new_max;
1685  }
1686 
1687  bool SetValue(int64 val) override {
1688  DCHECK_GE(val, omin_);
1689  DCHECK_LE(val, omax_);
1690  if (bit(val)) {
1691  size_.SetValue(solver_, 1);
1692  return true;
1693  }
1694  return false;
1695  }
1696 
1697  bool Contains(int64 val) const override {
1698  DCHECK_GE(val, omin_);
1699  DCHECK_LE(val, omax_);
1700  return bit(val);
1701  }
1702 
1703  bool RemoveValue(int64 val) override {
1704  if (val < omin_ || val > omax_ || !bit(val)) {
1705  return false;
1706  }
1707  // Bitset.
1708  const int64 val_offset = val - omin_;
1709  const int offset = BitOffset64(val_offset);
1710  const uint64 current_stamp = solver_->stamp();
1711  if (stamps_[offset] < current_stamp) {
1712  stamps_[offset] = current_stamp;
1713  solver_->SaveValue(&bits_[offset]);
1714  }
1715  const int pos = BitPos64(val_offset);
1716  bits_[offset] &= ~OneBit64(pos);
1717  // Size.
1718  size_.Decr(solver_);
1719  // Holes.
1720  InitHoles();
1721  AddHole(val);
1722  return true;
1723  }
1724  uint64 Size() const override { return size_.Value(); }
1725 
1726  std::string DebugString() const override {
1727  std::string out;
1728  absl::StrAppendFormat(&out, "SimpleBitSet(%d..%d : ", omin_, omax_);
1729  for (int i = 0; i < bsize_; ++i) {
1730  absl::StrAppendFormat(&out, "%x", bits_[i]);
1731  }
1732  out += ")";
1733  return out;
1734  }
1735 
1736  void DelayRemoveValue(int64 val) override { removed_.push_back(val); }
1737 
1738  void ApplyRemovedValues(DomainIntVar* var) override {
1739  std::sort(removed_.begin(), removed_.end());
1740  for (std::vector<int64>::iterator it = removed_.begin();
1741  it != removed_.end(); ++it) {
1742  var->RemoveValue(*it);
1743  }
1744  }
1745 
1746  void ClearRemovedValues() override { removed_.clear(); }
1747 
1748  std::string pretty_DebugString(int64 min, int64 max) const override {
1749  std::string out;
1750  DCHECK(bit(min));
1751  DCHECK(bit(max));
1752  if (max != min) {
1753  int cumul = true;
1754  int64 start_cumul = min;
1755  for (int64 v = min + 1; v < max; ++v) {
1756  if (bit(v)) {
1757  if (!cumul) {
1758  cumul = true;
1759  start_cumul = v;
1760  }
1761  } else {
1762  if (cumul) {
1763  if (v == start_cumul + 1) {
1764  absl::StrAppendFormat(&out, "%d ", start_cumul);
1765  } else if (v == start_cumul + 2) {
1766  absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1767  } else {
1768  absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1769  }
1770  cumul = false;
1771  }
1772  }
1773  }
1774  if (cumul) {
1775  if (max == start_cumul + 1) {
1776  absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1777  } else {
1778  absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1779  }
1780  } else {
1781  absl::StrAppendFormat(&out, "%d", max);
1782  }
1783  } else {
1784  absl::StrAppendFormat(&out, "%d", min);
1785  }
1786  return out;
1787  }
1788 
1789  DomainIntVar::BitSetIterator* MakeIterator() override {
1790  return new DomainIntVar::BitSetIterator(bits_, omin_);
1791  }
1792 
1793  private:
1794  uint64* bits_;
1795  uint64* stamps_;
1796  const int64 omin_;
1797  const int64 omax_;
1798  NumericalRev<int64> size_;
1799  const int bsize_;
1800  std::vector<int64> removed_;
1801 };
1802 
1803 // This is a special case where the bitset fits into one 64 bit integer.
1804 // In that case, there are no offset to compute.
1805 // Overflows are caught by the robust ClosedIntervalNoLargerThan() method.
1806 class SmallBitSet : public DomainIntVar::BitSet {
1807  public:
1808  SmallBitSet(Solver* const s, int64 vmin, int64 vmax)
1809  : BitSet(s),
1810  bits_(GG_ULONGLONG(0)),
1811  stamp_(s->stamp() - 1),
1812  omin_(vmin),
1813  omax_(vmax),
1814  size_(vmax - vmin + 1) {
1815  CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1816  bits_ = OneRange64(0, size_.Value() - 1);
1817  }
1818 
1819  SmallBitSet(Solver* const s, const std::vector<int64>& sorted_values,
1820  int64 vmin, int64 vmax)
1821  : BitSet(s),
1822  bits_(GG_ULONGLONG(0)),
1823  stamp_(s->stamp() - 1),
1824  omin_(vmin),
1825  omax_(vmax),
1826  size_(sorted_values.size()) {
1827  CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1828  // We know the array is sorted and does not contains duplicate values.
1829  for (int i = 0; i < sorted_values.size(); ++i) {
1830  const int64 val = sorted_values[i];
1831  DCHECK_GE(val, vmin);
1832  DCHECK_LE(val, vmax);
1833  DCHECK(!IsBitSet64(&bits_, val - omin_));
1834  bits_ |= OneBit64(val - omin_);
1835  }
1836  }
1837 
1838  ~SmallBitSet() override {}
1839 
1840  bool bit(int64 val) const {
1841  DCHECK_GE(val, omin_);
1842  DCHECK_LE(val, omax_);
1843  return (bits_ & OneBit64(val - omin_)) != 0;
1844  }
1845 
1846  int64 ComputeNewMin(int64 nmin, int64 cmin, int64 cmax) override {
1847  DCHECK_GE(nmin, cmin);
1848  DCHECK_LE(nmin, cmax);
1849  DCHECK_LE(cmin, cmax);
1850  DCHECK_GE(cmin, omin_);
1851  DCHECK_LE(cmax, omax_);
1852  // We do not clean the bits between cmin and nmin.
1853  // But we use mask to look only at 'active' bits.
1854 
1855  // Create the mask and compute new bits
1856  const uint64 new_bits = bits_ & OneRange64(nmin - omin_, cmax - omin_);
1857  if (new_bits != GG_ULONGLONG(0)) {
1858  // Compute new size and new min
1859  size_.SetValue(solver_, BitCount64(new_bits));
1860  if (bit(nmin)) { // Common case, the new min is inside the bitset
1861  return nmin;
1862  }
1863  return LeastSignificantBitPosition64(new_bits) + omin_;
1864  } else { // == 0 -> Fail()
1865  solver_->Fail();
1866  return kint64max;
1867  }
1868  }
1869 
1870  int64 ComputeNewMax(int64 nmax, int64 cmin, int64 cmax) override {
1871  DCHECK_GE(nmax, cmin);
1872  DCHECK_LE(nmax, cmax);
1873  DCHECK_LE(cmin, cmax);
1874  DCHECK_GE(cmin, omin_);
1875  DCHECK_LE(cmax, omax_);
1876  // We do not clean the bits between nmax and cmax.
1877  // But we use mask to look only at 'active' bits.
1878 
1879  // Create the mask and compute new_bits
1880  const uint64 new_bits = bits_ & OneRange64(cmin - omin_, nmax - omin_);
1881  if (new_bits != GG_ULONGLONG(0)) {
1882  // Compute new size and new min
1883  size_.SetValue(solver_, BitCount64(new_bits));
1884  if (bit(nmax)) { // Common case, the new max is inside the bitset
1885  return nmax;
1886  }
1887  return MostSignificantBitPosition64(new_bits) + omin_;
1888  } else { // == 0 -> Fail()
1889  solver_->Fail();
1890  return kint64min;
1891  }
1892  }
1893 
1894  bool SetValue(int64 val) override {
1895  DCHECK_GE(val, omin_);
1896  DCHECK_LE(val, omax_);
1897  // We do not clean the bits. We will use masks to ignore the bits
1898  // that should have been cleaned.
1899  if (bit(val)) {
1900  size_.SetValue(solver_, 1);
1901  return true;
1902  }
1903  return false;
1904  }
1905 
1906  bool Contains(int64 val) const override {
1907  DCHECK_GE(val, omin_);
1908  DCHECK_LE(val, omax_);
1909  return bit(val);
1910  }
1911 
1912  bool RemoveValue(int64 val) override {
1913  DCHECK_GE(val, omin_);
1914  DCHECK_LE(val, omax_);
1915  if (bit(val)) {
1916  // Bitset.
1917  const uint64 current_stamp = solver_->stamp();
1918  if (stamp_ < current_stamp) {
1919  stamp_ = current_stamp;
1920  solver_->SaveValue(&bits_);
1921  }
1922  bits_ &= ~OneBit64(val - omin_);
1923  DCHECK(!bit(val));
1924  // Size.
1925  size_.Decr(solver_);
1926  // Holes.
1927  InitHoles();
1928  AddHole(val);
1929  return true;
1930  } else {
1931  return false;
1932  }
1933  }
1934 
1935  uint64 Size() const override { return size_.Value(); }
1936 
1937  std::string DebugString() const override {
1938  return absl::StrFormat("SmallBitSet(%d..%d : %llx)", omin_, omax_, bits_);
1939  }
1940 
1941  void DelayRemoveValue(int64 val) override {
1942  DCHECK_GE(val, omin_);
1943  DCHECK_LE(val, omax_);
1944  removed_.push_back(val);
1945  }
1946 
1947  void ApplyRemovedValues(DomainIntVar* var) override {
1948  std::sort(removed_.begin(), removed_.end());
1949  for (std::vector<int64>::iterator it = removed_.begin();
1950  it != removed_.end(); ++it) {
1951  var->RemoveValue(*it);
1952  }
1953  }
1954 
1955  void ClearRemovedValues() override { removed_.clear(); }
1956 
1957  std::string pretty_DebugString(int64 min, int64 max) const override {
1958  std::string out;
1959  DCHECK(bit(min));
1960  DCHECK(bit(max));
1961  if (max != min) {
1962  int cumul = true;
1963  int64 start_cumul = min;
1964  for (int64 v = min + 1; v < max; ++v) {
1965  if (bit(v)) {
1966  if (!cumul) {
1967  cumul = true;
1968  start_cumul = v;
1969  }
1970  } else {
1971  if (cumul) {
1972  if (v == start_cumul + 1) {
1973  absl::StrAppendFormat(&out, "%d ", start_cumul);
1974  } else if (v == start_cumul + 2) {
1975  absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1976  } else {
1977  absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1978  }
1979  cumul = false;
1980  }
1981  }
1982  }
1983  if (cumul) {
1984  if (max == start_cumul + 1) {
1985  absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1986  } else {
1987  absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1988  }
1989  } else {
1990  absl::StrAppendFormat(&out, "%d", max);
1991  }
1992  } else {
1993  absl::StrAppendFormat(&out, "%d", min);
1994  }
1995  return out;
1996  }
1997 
1998  DomainIntVar::BitSetIterator* MakeIterator() override {
1999  return new DomainIntVar::BitSetIterator(&bits_, omin_);
2000  }
2001 
2002  private:
2003  uint64 bits_;
2004  uint64 stamp_;
2005  const int64 omin_;
2006  const int64 omax_;
2007  NumericalRev<int64> size_;
2008  std::vector<int64> removed_;
2009 };
2010 
2011 class EmptyIterator : public IntVarIterator {
2012  public:
2013  ~EmptyIterator() override {}
2014  void Init() override {}
2015  bool Ok() const override { return false; }
2016  int64 Value() const override {
2017  LOG(FATAL) << "Should not be called";
2018  return 0LL;
2019  }
2020  void Next() override {}
2021 };
2022 
2023 class RangeIterator : public IntVarIterator {
2024  public:
2025  explicit RangeIterator(const IntVar* const var)
2026  : var_(var), min_(kint64max), max_(kint64min), current_(-1) {}
2027 
2028  ~RangeIterator() override {}
2029 
2030  void Init() override {
2031  min_ = var_->Min();
2032  max_ = var_->Max();
2033  current_ = min_;
2034  }
2035 
2036  bool Ok() const override { return current_ <= max_; }
2037 
2038  int64 Value() const override { return current_; }
2039 
2040  void Next() override { current_++; }
2041 
2042  private:
2043  const IntVar* const var_;
2044  int64 min_;
2045  int64 max_;
2046  int64 current_;
2047 };
2048 
2049 class DomainIntVarHoleIterator : public IntVarIterator {
2050  public:
2051  explicit DomainIntVarHoleIterator(const DomainIntVar* const v)
2052  : var_(v), bits_(nullptr), values_(nullptr), size_(0), index_(0) {}
2053 
2054  ~DomainIntVarHoleIterator() override {}
2055 
2056  void Init() override {
2057  bits_ = var_->bitset();
2058  if (bits_ != nullptr) {
2059  bits_->InitHoles();
2060  values_ = bits_->Holes().data();
2061  size_ = bits_->Holes().size();
2062  } else {
2063  values_ = nullptr;
2064  size_ = 0;
2065  }
2066  index_ = 0;
2067  }
2068 
2069  bool Ok() const override { return index_ < size_; }
2070 
2071  int64 Value() const override {
2072  DCHECK(bits_ != nullptr);
2073  DCHECK(index_ < size_);
2074  return values_[index_];
2075  }
2076 
2077  void Next() override { index_++; }
2078 
2079  private:
2080  const DomainIntVar* const var_;
2081  DomainIntVar::BitSet* bits_;
2082  const int64* values_;
2083  int size_;
2084  int index_;
2085 };
2086 
2087 class DomainIntVarDomainIterator : public IntVarIterator {
2088  public:
2089  explicit DomainIntVarDomainIterator(const DomainIntVar* const v,
2090  bool reversible)
2091  : var_(v),
2092  bitset_iterator_(nullptr),
2093  min_(kint64max),
2094  max_(kint64min),
2095  current_(-1),
2096  reversible_(reversible) {}
2097 
2098  ~DomainIntVarDomainIterator() override {
2099  if (!reversible_ && bitset_iterator_) {
2100  delete bitset_iterator_;
2101  }
2102  }
2103 
2104  void Init() override {
2105  if (var_->bitset() != nullptr && !var_->Bound()) {
2106  if (reversible_) {
2107  if (!bitset_iterator_) {
2108  Solver* const solver = var_->solver();
2109  solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2110  bitset_iterator_ = solver->RevAlloc(var_->bitset()->MakeIterator());
2111  }
2112  } else {
2113  if (bitset_iterator_) {
2114  delete bitset_iterator_;
2115  }
2116  bitset_iterator_ = var_->bitset()->MakeIterator();
2117  }
2118  bitset_iterator_->Init(var_->Min(), var_->Max());
2119  } else {
2120  if (bitset_iterator_) {
2121  if (reversible_) {
2122  Solver* const solver = var_->solver();
2123  solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2124  } else {
2125  delete bitset_iterator_;
2126  }
2127  bitset_iterator_ = nullptr;
2128  }
2129  min_ = var_->Min();
2130  max_ = var_->Max();
2131  current_ = min_;
2132  }
2133  }
2134 
2135  bool Ok() const override {
2136  return bitset_iterator_ ? bitset_iterator_->Ok() : (current_ <= max_);
2137  }
2138 
2139  int64 Value() const override {
2140  return bitset_iterator_ ? bitset_iterator_->Value() : current_;
2141  }
2142 
2143  void Next() override {
2144  if (bitset_iterator_) {
2145  bitset_iterator_->Next();
2146  } else {
2147  current_++;
2148  }
2149  }
2150 
2151  private:
2152  const DomainIntVar* const var_;
2153  DomainIntVar::BitSetIterator* bitset_iterator_;
2154  int64 min_;
2155  int64 max_;
2156  int64 current_;
2157  const bool reversible_;
2158 };
2159 
2160 class UnaryIterator : public IntVarIterator {
2161  public:
2162  UnaryIterator(const IntVar* const v, bool hole, bool reversible)
2163  : iterator_(hole ? v->MakeHoleIterator(reversible)
2164  : v->MakeDomainIterator(reversible)),
2165  reversible_(reversible) {}
2166 
2167  ~UnaryIterator() override {
2168  if (!reversible_) {
2169  delete iterator_;
2170  }
2171  }
2172 
2173  void Init() override { iterator_->Init(); }
2174 
2175  bool Ok() const override { return iterator_->Ok(); }
2176 
2177  void Next() override { iterator_->Next(); }
2178 
2179  protected:
2180  IntVarIterator* const iterator_;
2181  const bool reversible_;
2182 };
2183 
2184 DomainIntVar::DomainIntVar(Solver* const s, int64 vmin, int64 vmax,
2185  const std::string& name)
2186  : IntVar(s, name),
2187  min_(vmin),
2188  max_(vmax),
2189  old_min_(vmin),
2190  old_max_(vmax),
2191  new_min_(vmin),
2192  new_max_(vmax),
2193  handler_(this),
2194  in_process_(false),
2195  bits_(nullptr),
2196  value_watcher_(nullptr),
2197  bound_watcher_(nullptr) {}
2198 
2199 DomainIntVar::DomainIntVar(Solver* const s,
2200  const std::vector<int64>& sorted_values,
2201  const std::string& name)
2202  : IntVar(s, name),
2203  min_(kint64max),
2204  max_(kint64min),
2205  old_min_(kint64max),
2206  old_max_(kint64min),
2207  new_min_(kint64max),
2208  new_max_(kint64min),
2209  handler_(this),
2210  in_process_(false),
2211  bits_(nullptr),
2212  value_watcher_(nullptr),
2213  bound_watcher_(nullptr) {
2214  CHECK_GE(sorted_values.size(), 1);
2215  // We know that the vector is sorted and does not have duplicate values.
2216  const int64 vmin = sorted_values.front();
2217  const int64 vmax = sorted_values.back();
2218  const bool contiguous = vmax - vmin + 1 == sorted_values.size();
2219 
2220  min_.SetValue(solver(), vmin);
2221  old_min_ = vmin;
2222  new_min_ = vmin;
2223  max_.SetValue(solver(), vmax);
2224  old_max_ = vmax;
2225  new_max_ = vmax;
2226 
2227  if (!contiguous) {
2228  if (vmax - vmin + 1 < 65) {
2229  bits_ = solver()->RevAlloc(
2230  new SmallBitSet(solver(), sorted_values, vmin, vmax));
2231  } else {
2232  bits_ = solver()->RevAlloc(
2233  new SimpleBitSet(solver(), sorted_values, vmin, vmax));
2234  }
2235  }
2236 }
2237 
2238 DomainIntVar::~DomainIntVar() {}
2239 
2240 void DomainIntVar::SetMin(int64 m) {
2241  if (m <= min_.Value()) return;
2242  if (m > max_.Value()) solver()->Fail();
2243  if (in_process_) {
2244  if (m > new_min_) {
2245  new_min_ = m;
2246  if (new_min_ > new_max_) {
2247  solver()->Fail();
2248  }
2249  }
2250  } else {
2251  CheckOldMin();
2252  const int64 new_min =
2253  (bits_ == nullptr
2254  ? m
2255  : bits_->ComputeNewMin(m, min_.Value(), max_.Value()));
2256  min_.SetValue(solver(), new_min);
2257  if (min_.Value() > max_.Value()) {
2258  solver()->Fail();
2259  }
2260  Push();
2261  }
2262 }
2263 
2264 void DomainIntVar::SetMax(int64 m) {
2265  if (m >= max_.Value()) return;
2266  if (m < min_.Value()) solver()->Fail();
2267  if (in_process_) {
2268  if (m < new_max_) {
2269  new_max_ = m;
2270  if (new_max_ < new_min_) {
2271  solver()->Fail();
2272  }
2273  }
2274  } else {
2275  CheckOldMax();
2276  const int64 new_max =
2277  (bits_ == nullptr
2278  ? m
2279  : bits_->ComputeNewMax(m, min_.Value(), max_.Value()));
2280  max_.SetValue(solver(), new_max);
2281  if (min_.Value() > max_.Value()) {
2282  solver()->Fail();
2283  }
2284  Push();
2285  }
2286 }
2287 
2288 void DomainIntVar::SetRange(int64 mi, int64 ma) {
2289  if (mi == ma) {
2290  SetValue(mi);
2291  } else {
2292  if (mi > ma || mi > max_.Value() || ma < min_.Value()) solver()->Fail();
2293  if (mi <= min_.Value() && ma >= max_.Value()) return;
2294  if (in_process_) {
2295  if (ma < new_max_) {
2296  new_max_ = ma;
2297  }
2298  if (mi > new_min_) {
2299  new_min_ = mi;
2300  }
2301  if (new_min_ > new_max_) {
2302  solver()->Fail();
2303  }
2304  } else {
2305  if (mi > min_.Value()) {
2306  CheckOldMin();
2307  const int64 new_min =
2308  (bits_ == nullptr
2309  ? mi
2310  : bits_->ComputeNewMin(mi, min_.Value(), max_.Value()));
2311  min_.SetValue(solver(), new_min);
2312  }
2313  if (min_.Value() > ma) {
2314  solver()->Fail();
2315  }
2316  if (ma < max_.Value()) {
2317  CheckOldMax();
2318  const int64 new_max =
2319  (bits_ == nullptr
2320  ? ma
2321  : bits_->ComputeNewMax(ma, min_.Value(), max_.Value()));
2322  max_.SetValue(solver(), new_max);
2323  }
2324  if (min_.Value() > max_.Value()) {
2325  solver()->Fail();
2326  }
2327  Push();
2328  }
2329  }
2330 }
2331 
2332 void DomainIntVar::SetValue(int64 v) {
2333  if (v != min_.Value() || v != max_.Value()) {
2334  if (v < min_.Value() || v > max_.Value()) {
2335  solver()->Fail();
2336  }
2337  if (in_process_) {
2338  if (v > new_max_ || v < new_min_) {
2339  solver()->Fail();
2340  }
2341  new_min_ = v;
2342  new_max_ = v;
2343  } else {
2344  if (bits_ && !bits_->SetValue(v)) {
2345  solver()->Fail();
2346  }
2347  CheckOldMin();
2348  CheckOldMax();
2349  min_.SetValue(solver(), v);
2350  max_.SetValue(solver(), v);
2351  Push();
2352  }
2353  }
2354 }
2355 
2356 void DomainIntVar::RemoveValue(int64 v) {
2357  if (v < min_.Value() || v > max_.Value()) return;
2358  if (v == min_.Value()) {
2359  SetMin(v + 1);
2360  } else if (v == max_.Value()) {
2361  SetMax(v - 1);
2362  } else {
2363  if (bits_ == nullptr) {
2364  CreateBits();
2365  }
2366  if (in_process_) {
2367  if (v >= new_min_ && v <= new_max_ && bits_->Contains(v)) {
2368  bits_->DelayRemoveValue(v);
2369  }
2370  } else {
2371  if (bits_->RemoveValue(v)) {
2372  Push();
2373  }
2374  }
2375  }
2376 }
2377 
2378 void DomainIntVar::RemoveInterval(int64 l, int64 u) {
2379  if (l <= min_.Value()) {
2380  SetMin(u + 1);
2381  } else if (u >= max_.Value()) {
2382  SetMax(l - 1);
2383  } else {
2384  for (int64 v = l; v <= u; ++v) {
2385  RemoveValue(v);
2386  }
2387  }
2388 }
2389 
2390 void DomainIntVar::CreateBits() {
2391  solver()->SaveValue(reinterpret_cast<void**>(&bits_));
2392  if (max_.Value() - min_.Value() < 64) {
2393  bits_ = solver()->RevAlloc(
2394  new SmallBitSet(solver(), min_.Value(), max_.Value()));
2395  } else {
2396  bits_ = solver()->RevAlloc(
2397  new SimpleBitSet(solver(), min_.Value(), max_.Value()));
2398  }
2399 }
2400 
2401 void DomainIntVar::CleanInProcess() {
2402  in_process_ = false;
2403  if (bits_ != nullptr) {
2404  bits_->ClearHoles();
2405  }
2406 }
2407 
2408 void DomainIntVar::Push() {
2409  const bool in_process = in_process_;
2410  EnqueueVar(&handler_);
2411  CHECK_EQ(in_process, in_process_);
2412 }
2413 
2414 void DomainIntVar::Process() {
2415  CHECK(!in_process_);
2416  in_process_ = true;
2417  if (bits_ != nullptr) {
2418  bits_->ClearRemovedValues();
2419  }
2420  set_variable_to_clean_on_fail(this);
2421  new_min_ = min_.Value();
2422  new_max_ = max_.Value();
2423  const bool is_bound = min_.Value() == max_.Value();
2424  const bool range_changed =
2425  min_.Value() != OldMin() || max_.Value() != OldMax();
2426  // Process immediate demons.
2427  if (is_bound) {
2428  ExecuteAll(bound_demons_);
2429  }
2430  if (range_changed) {
2431  ExecuteAll(range_demons_);
2432  }
2433  ExecuteAll(domain_demons_);
2434 
2435  // Process delayed demons.
2436  if (is_bound) {
2437  EnqueueAll(delayed_bound_demons_);
2438  }
2439  if (range_changed) {
2440  EnqueueAll(delayed_range_demons_);
2441  }
2442  EnqueueAll(delayed_domain_demons_);
2443 
2444  // Everything went well if we arrive here. Let's clean the variable.
2445  set_variable_to_clean_on_fail(nullptr);
2446  CleanInProcess();
2447  old_min_ = min_.Value();
2448  old_max_ = max_.Value();
2449  if (min_.Value() < new_min_) {
2450  SetMin(new_min_);
2451  }
2452  if (max_.Value() > new_max_) {
2453  SetMax(new_max_);
2454  }
2455  if (bits_ != nullptr) {
2456  bits_->ApplyRemovedValues(this);
2457  }
2458 }
2459 
2460 #define COND_REV_ALLOC(rev, alloc) rev ? solver()->RevAlloc(alloc) : alloc;
2462 IntVarIterator* DomainIntVar::MakeHoleIterator(bool reversible) const {
2463  return COND_REV_ALLOC(reversible, new DomainIntVarHoleIterator(this));
2464 }
2465 
2466 IntVarIterator* DomainIntVar::MakeDomainIterator(bool reversible) const {
2467  return COND_REV_ALLOC(reversible,
2468  new DomainIntVarDomainIterator(this, reversible));
2469 }
2470 
2471 std::string DomainIntVar::DebugString() const {
2472  std::string out;
2473  const std::string& var_name = name();
2474  if (!var_name.empty()) {
2475  out = var_name + "(";
2476  } else {
2477  out = "DomainIntVar(";
2478  }
2479  if (min_.Value() == max_.Value()) {
2480  absl::StrAppendFormat(&out, "%d", min_.Value());
2481  } else if (bits_ != nullptr) {
2482  out.append(bits_->pretty_DebugString(min_.Value(), max_.Value()));
2483  } else {
2484  absl::StrAppendFormat(&out, "%d..%d", min_.Value(), max_.Value());
2485  }
2486  out += ")";
2487  return out;
2488 }
2489 
2490 // ----- Real Boolean Var -----
2491 
2492 class ConcreteBooleanVar : public BooleanVar {
2493  public:
2494  // Utility classes
2495  class Handler : public Demon {
2496  public:
2497  explicit Handler(ConcreteBooleanVar* const var) : Demon(), var_(var) {}
2498  ~Handler() override {}
2499  void Run(Solver* const s) override {
2500  s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
2501  var_->Process();
2502  s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
2503  }
2504  Solver::DemonPriority priority() const override {
2505  return Solver::VAR_PRIORITY;
2506  }
2507  std::string DebugString() const override {
2508  return absl::StrFormat("Handler(%s)", var_->DebugString());
2509  }
2510 
2511  private:
2512  ConcreteBooleanVar* const var_;
2513  };
2514 
2515  ConcreteBooleanVar(Solver* const s, const std::string& name)
2516  : BooleanVar(s, name), handler_(this) {}
2517 
2518  ~ConcreteBooleanVar() override {}
2519 
2520  void SetValue(int64 v) override {
2521  if (value_ == kUnboundBooleanVarValue) {
2522  if ((v & 0xfffffffffffffffe) == 0) {
2523  InternalSaveBooleanVarValue(solver(), this);
2524  value_ = static_cast<int>(v);
2525  EnqueueVar(&handler_);
2526  return;
2527  }
2528  } else if (v == value_) {
2529  return;
2530  }
2531  solver()->Fail();
2532  }
2533 
2534  void Process() {
2535  DCHECK_NE(value_, kUnboundBooleanVarValue);
2536  ExecuteAll(bound_demons_);
2537  for (SimpleRevFIFO<Demon*>::Iterator it(&delayed_bound_demons_); it.ok();
2538  ++it) {
2539  EnqueueDelayedDemon(*it);
2540  }
2541  }
2542 
2543  int64 OldMin() const override { return 0LL; }
2544  int64 OldMax() const override { return 1LL; }
2545  void RestoreValue() override { value_ = kUnboundBooleanVarValue; }
2546 
2547  private:
2548  Handler handler_;
2549 };
2550 
2551 // ----- IntConst -----
2552 
2553 class IntConst : public IntVar {
2554  public:
2555  IntConst(Solver* const s, int64 value, const std::string& name = "")
2556  : IntVar(s, name), value_(value) {}
2557  ~IntConst() override {}
2558 
2559  int64 Min() const override { return value_; }
2560  void SetMin(int64 m) override {
2561  if (m > value_) {
2562  solver()->Fail();
2563  }
2564  }
2565  int64 Max() const override { return value_; }
2566  void SetMax(int64 m) override {
2567  if (m < value_) {
2568  solver()->Fail();
2569  }
2570  }
2571  void SetRange(int64 l, int64 u) override {
2572  if (l > value_ || u < value_) {
2573  solver()->Fail();
2574  }
2575  }
2576  void SetValue(int64 v) override {
2577  if (v != value_) {
2578  solver()->Fail();
2579  }
2580  }
2581  bool Bound() const override { return true; }
2582  int64 Value() const override { return value_; }
2583  void RemoveValue(int64 v) override {
2584  if (v == value_) {
2585  solver()->Fail();
2586  }
2587  }
2588  void RemoveInterval(int64 l, int64 u) override {
2589  if (l <= value_ && value_ <= u) {
2590  solver()->Fail();
2591  }
2592  }
2593  void WhenBound(Demon* d) override {}
2594  void WhenRange(Demon* d) override {}
2595  void WhenDomain(Demon* d) override {}
2596  uint64 Size() const override { return 1; }
2597  bool Contains(int64 v) const override { return (v == value_); }
2598  IntVarIterator* MakeHoleIterator(bool reversible) const override {
2599  return COND_REV_ALLOC(reversible, new EmptyIterator());
2600  }
2601  IntVarIterator* MakeDomainIterator(bool reversible) const override {
2602  return COND_REV_ALLOC(reversible, new RangeIterator(this));
2603  }
2604  int64 OldMin() const override { return value_; }
2605  int64 OldMax() const override { return value_; }
2606  std::string DebugString() const override {
2607  std::string out;
2608  if (solver()->HasName(this)) {
2609  const std::string& var_name = name();
2610  absl::StrAppendFormat(&out, "%s(%d)", var_name, value_);
2611  } else {
2612  absl::StrAppendFormat(&out, "IntConst(%d)", value_);
2613  }
2614  return out;
2615  }
2616 
2617  int VarType() const override { return CONST_VAR; }
2618 
2619  IntVar* IsEqual(int64 constant) override {
2620  if (constant == value_) {
2621  return solver()->MakeIntConst(1);
2622  } else {
2623  return solver()->MakeIntConst(0);
2624  }
2625  }
2626 
2627  IntVar* IsDifferent(int64 constant) override {
2628  if (constant == value_) {
2629  return solver()->MakeIntConst(0);
2630  } else {
2631  return solver()->MakeIntConst(1);
2632  }
2633  }
2634 
2635  IntVar* IsGreaterOrEqual(int64 constant) override {
2636  return solver()->MakeIntConst(value_ >= constant);
2637  }
2638 
2639  IntVar* IsLessOrEqual(int64 constant) override {
2640  return solver()->MakeIntConst(value_ <= constant);
2641  }
2642 
2643  std::string name() const override {
2644  if (solver()->HasName(this)) {
2645  return PropagationBaseObject::name();
2646  } else {
2647  return absl::StrCat(value_);
2648  }
2649  }
2650 
2651  private:
2652  int64 value_;
2653 };
2654 
2655 // ----- x + c variable, optimized case -----
2656 
2657 class PlusCstVar : public IntVar {
2658  public:
2659  PlusCstVar(Solver* const s, IntVar* v, int64 c)
2660  : IntVar(s), var_(v), cst_(c) {}
2661 
2662  ~PlusCstVar() override {}
2663 
2664  void WhenRange(Demon* d) override { var_->WhenRange(d); }
2665 
2666  void WhenBound(Demon* d) override { var_->WhenBound(d); }
2667 
2668  void WhenDomain(Demon* d) override { var_->WhenDomain(d); }
2669 
2670  int64 OldMin() const override { return CapAdd(var_->OldMin(), cst_); }
2671 
2672  int64 OldMax() const override { return CapAdd(var_->OldMax(), cst_); }
2673 
2674  std::string DebugString() const override {
2675  if (HasName()) {
2676  return absl::StrFormat("%s(%s + %d)", name(), var_->DebugString(), cst_);
2677  } else {
2678  return absl::StrFormat("(%s + %d)", var_->DebugString(), cst_);
2679  }
2680  }
2681 
2682  int VarType() const override { return VAR_ADD_CST; }
2683 
2684  void Accept(ModelVisitor* const visitor) const override {
2685  visitor->VisitIntegerVariable(this, ModelVisitor::kSumOperation, cst_,
2686  var_);
2687  }
2688 
2689  IntVar* IsEqual(int64 constant) override {
2690  return var_->IsEqual(constant - cst_);
2691  }
2692 
2693  IntVar* IsDifferent(int64 constant) override {
2694  return var_->IsDifferent(constant - cst_);
2695  }
2696 
2697  IntVar* IsGreaterOrEqual(int64 constant) override {
2698  return var_->IsGreaterOrEqual(constant - cst_);
2699  }
2700 
2701  IntVar* IsLessOrEqual(int64 constant) override {
2702  return var_->IsLessOrEqual(constant - cst_);
2703  }
2704 
2705  IntVar* SubVar() const { return var_; }
2706 
2707  int64 Constant() const { return cst_; }
2708 
2709  protected:
2710  IntVar* const var_;
2711  const int64 cst_;
2712 };
2713 
2714 class PlusCstIntVar : public PlusCstVar {
2715  public:
2716  class PlusCstIntVarIterator : public UnaryIterator {
2717  public:
2718  PlusCstIntVarIterator(const IntVar* const v, int64 c, bool hole, bool rev)
2719  : UnaryIterator(v, hole, rev), cst_(c) {}
2720 
2721  ~PlusCstIntVarIterator() override {}
2722 
2723  int64 Value() const override { return iterator_->Value() + cst_; }
2724 
2725  private:
2726  const int64 cst_;
2727  };
2728 
2729  PlusCstIntVar(Solver* const s, IntVar* v, int64 c) : PlusCstVar(s, v, c) {}
2730 
2731  ~PlusCstIntVar() override {}
2732 
2733  int64 Min() const override { return var_->Min() + cst_; }
2734 
2735  void SetMin(int64 m) override { var_->SetMin(CapSub(m, cst_)); }
2736 
2737  int64 Max() const override { return var_->Max() + cst_; }
2738 
2739  void SetMax(int64 m) override { var_->SetMax(CapSub(m, cst_)); }
2740 
2741  void SetRange(int64 l, int64 u) override {
2742  var_->SetRange(CapSub(l, cst_), CapSub(u, cst_));
2743  }
2744 
2745  void SetValue(int64 v) override { var_->SetValue(v - cst_); }
2746 
2747  int64 Value() const override { return var_->Value() + cst_; }
2748 
2749  bool Bound() const override { return var_->Bound(); }
2750 
2751  void RemoveValue(int64 v) override { var_->RemoveValue(v - cst_); }
2752 
2753  void RemoveInterval(int64 l, int64 u) override {
2754  var_->RemoveInterval(l - cst_, u - cst_);
2755  }
2756 
2757  uint64 Size() const override { return var_->Size(); }
2758 
2759  bool Contains(int64 v) const override { return var_->Contains(v - cst_); }
2760 
2761  IntVarIterator* MakeHoleIterator(bool reversible) const override {
2762  return COND_REV_ALLOC(
2763  reversible, new PlusCstIntVarIterator(var_, cst_, true, reversible));
2764  }
2765  IntVarIterator* MakeDomainIterator(bool reversible) const override {
2766  return COND_REV_ALLOC(
2767  reversible, new PlusCstIntVarIterator(var_, cst_, false, reversible));
2768  }
2769 };
2770 
2771 class PlusCstDomainIntVar : public PlusCstVar {
2772  public:
2773  class PlusCstDomainIntVarIterator : public UnaryIterator {
2774  public:
2775  PlusCstDomainIntVarIterator(const IntVar* const v, int64 c, bool hole,
2776  bool reversible)
2777  : UnaryIterator(v, hole, reversible), cst_(c) {}
2778 
2779  ~PlusCstDomainIntVarIterator() override {}
2780 
2781  int64 Value() const override { return iterator_->Value() + cst_; }
2782 
2783  private:
2784  const int64 cst_;
2785  };
2786 
2787  PlusCstDomainIntVar(Solver* const s, DomainIntVar* v, int64 c)
2788  : PlusCstVar(s, v, c) {}
2789 
2790  ~PlusCstDomainIntVar() override {}
2791 
2792  int64 Min() const override;
2793  void SetMin(int64 m) override;
2794  int64 Max() const override;
2795  void SetMax(int64 m) override;
2796  void SetRange(int64 l, int64 u) override;
2797  void SetValue(int64 v) override;
2798  bool Bound() const override;
2799  int64 Value() const override;
2800  void RemoveValue(int64 v) override;
2801  void RemoveInterval(int64 l, int64 u) override;
2802  uint64 Size() const override;
2803  bool Contains(int64 v) const override;
2804 
2805  DomainIntVar* domain_int_var() const {
2806  return reinterpret_cast<DomainIntVar*>(var_);
2807  }
2808 
2809  IntVarIterator* MakeHoleIterator(bool reversible) const override {
2810  return COND_REV_ALLOC(reversible, new PlusCstDomainIntVarIterator(
2811  var_, cst_, true, reversible));
2812  }
2813  IntVarIterator* MakeDomainIterator(bool reversible) const override {
2814  return COND_REV_ALLOC(reversible, new PlusCstDomainIntVarIterator(
2815  var_, cst_, false, reversible));
2816  }
2817 };
2818 
2819 int64 PlusCstDomainIntVar::Min() const {
2820  return domain_int_var()->min_.Value() + cst_;
2821 }
2822 
2823 void PlusCstDomainIntVar::SetMin(int64 m) {
2824  domain_int_var()->DomainIntVar::SetMin(m - cst_);
2825 }
2826 
2827 int64 PlusCstDomainIntVar::Max() const {
2828  return domain_int_var()->max_.Value() + cst_;
2829 }
2830 
2831 void PlusCstDomainIntVar::SetMax(int64 m) {
2832  domain_int_var()->DomainIntVar::SetMax(m - cst_);
2833 }
2834 
2835 void PlusCstDomainIntVar::SetRange(int64 l, int64 u) {
2836  domain_int_var()->DomainIntVar::SetRange(l - cst_, u - cst_);
2837 }
2838 
2839 void PlusCstDomainIntVar::SetValue(int64 v) {
2840  domain_int_var()->DomainIntVar::SetValue(v - cst_);
2841 }
2842 
2843 bool PlusCstDomainIntVar::Bound() const {
2844  return domain_int_var()->min_.Value() == domain_int_var()->max_.Value();
2845 }
2846 
2848  CHECK_EQ(domain_int_var()->min_.Value(), domain_int_var()->max_.Value())
2849  << " variable is not bound";
2850  return domain_int_var()->min_.Value() + cst_;
2851 }
2852 
2853 void PlusCstDomainIntVar::RemoveValue(int64 v) {
2854  domain_int_var()->DomainIntVar::RemoveValue(v - cst_);
2855 }
2856 
2857 void PlusCstDomainIntVar::RemoveInterval(int64 l, int64 u) {
2858  domain_int_var()->DomainIntVar::RemoveInterval(l - cst_, u - cst_);
2859 }
2860 
2861 uint64 PlusCstDomainIntVar::Size() const {
2862  return domain_int_var()->DomainIntVar::Size();
2863 }
2864 
2865 bool PlusCstDomainIntVar::Contains(int64 v) const {
2866  return domain_int_var()->DomainIntVar::Contains(v - cst_);
2867 }
2868 
2869 // c - x variable, optimized case
2870 
2871 class SubCstIntVar : public IntVar {
2872  public:
2873  class SubCstIntVarIterator : public UnaryIterator {
2874  public:
2875  SubCstIntVarIterator(const IntVar* const v, int64 c, bool hole, bool rev)
2876  : UnaryIterator(v, hole, rev), cst_(c) {}
2877  ~SubCstIntVarIterator() override {}
2878 
2879  int64 Value() const override { return cst_ - iterator_->Value(); }
2880 
2881  private:
2882  const int64 cst_;
2883  };
2884 
2885  SubCstIntVar(Solver* const s, IntVar* v, int64 c);
2886  ~SubCstIntVar() override;
2887 
2888  int64 Min() const override;
2889  void SetMin(int64 m) override;
2890  int64 Max() const override;
2891  void SetMax(int64 m) override;
2892  void SetRange(int64 l, int64 u) override;
2893  void SetValue(int64 v) override;
2894  bool Bound() const override;
2895  int64 Value() const override;
2896  void RemoveValue(int64 v) override;
2897  void RemoveInterval(int64 l, int64 u) override;
2898  uint64 Size() const override;
2899  bool Contains(int64 v) const override;
2900  void WhenRange(Demon* d) override;
2901  void WhenBound(Demon* d) override;
2902  void WhenDomain(Demon* d) override;
2903  IntVarIterator* MakeHoleIterator(bool reversible) const override {
2904  return COND_REV_ALLOC(
2905  reversible, new SubCstIntVarIterator(var_, cst_, true, reversible));
2906  }
2907  IntVarIterator* MakeDomainIterator(bool reversible) const override {
2908  return COND_REV_ALLOC(
2909  reversible, new SubCstIntVarIterator(var_, cst_, false, reversible));
2910  }
2911  int64 OldMin() const override { return CapSub(cst_, var_->OldMax()); }
2912  int64 OldMax() const override { return CapSub(cst_, var_->OldMin()); }
2913  std::string DebugString() const override;
2914  std::string name() const override;
2915  int VarType() const override { return CST_SUB_VAR; }
2916 
2917  void Accept(ModelVisitor* const visitor) const override {
2918  visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation,
2919  cst_, var_);
2920  }
2921 
2922  IntVar* IsEqual(int64 constant) override {
2923  return var_->IsEqual(cst_ - constant);
2924  }
2925 
2926  IntVar* IsDifferent(int64 constant) override {
2927  return var_->IsDifferent(cst_ - constant);
2928  }
2929 
2930  IntVar* IsGreaterOrEqual(int64 constant) override {
2931  return var_->IsLessOrEqual(cst_ - constant);
2932  }
2933 
2934  IntVar* IsLessOrEqual(int64 constant) override {
2935  return var_->IsGreaterOrEqual(cst_ - constant);
2936  }
2937 
2938  IntVar* SubVar() const { return var_; }
2939  int64 Constant() const { return cst_; }
2940 
2941  private:
2942  IntVar* const var_;
2943  const int64 cst_;
2944 };
2945 
2946 SubCstIntVar::SubCstIntVar(Solver* const s, IntVar* v, int64 c)
2947  : IntVar(s), var_(v), cst_(c) {}
2948 
2949 SubCstIntVar::~SubCstIntVar() {}
2950 
2951 int64 SubCstIntVar::Min() const { return cst_ - var_->Max(); }
2952 
2953 void SubCstIntVar::SetMin(int64 m) { var_->SetMax(CapSub(cst_, m)); }
2954 
2955 int64 SubCstIntVar::Max() const { return cst_ - var_->Min(); }
2956 
2957 void SubCstIntVar::SetMax(int64 m) { var_->SetMin(CapSub(cst_, m)); }
2958 
2959 void SubCstIntVar::SetRange(int64 l, int64 u) {
2960  var_->SetRange(CapSub(cst_, u), CapSub(cst_, l));
2961 }
2962 
2963 void SubCstIntVar::SetValue(int64 v) { var_->SetValue(cst_ - v); }
2964 
2965 bool SubCstIntVar::Bound() const { return var_->Bound(); }
2966 
2967 void SubCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
2968 
2969 int64 SubCstIntVar::Value() const { return cst_ - var_->Value(); }
2970 
2971 void SubCstIntVar::RemoveValue(int64 v) { var_->RemoveValue(cst_ - v); }
2972 
2973 void SubCstIntVar::RemoveInterval(int64 l, int64 u) {
2974  var_->RemoveInterval(cst_ - u, cst_ - l);
2975 }
2976 
2977 void SubCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
2978 
2979 void SubCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
2980 
2981 uint64 SubCstIntVar::Size() const { return var_->Size(); }
2982 
2983 bool SubCstIntVar::Contains(int64 v) const { return var_->Contains(cst_ - v); }
2984 
2985 std::string SubCstIntVar::DebugString() const {
2986  if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
2987  return absl::StrFormat("Not(%s)", var_->DebugString());
2988  } else {
2989  return absl::StrFormat("(%d - %s)", cst_, var_->DebugString());
2990  }
2991 }
2992 
2993 std::string SubCstIntVar::name() const {
2994  if (solver()->HasName(this)) {
2995  return PropagationBaseObject::name();
2996  } else if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
2997  return absl::StrFormat("Not(%s)", var_->name());
2998  } else {
2999  return absl::StrFormat("(%d - %s)", cst_, var_->name());
3000  }
3001 }
3002 
3003 // -x variable, optimized case
3004 
3005 class OppIntVar : public IntVar {
3006  public:
3007  class OppIntVarIterator : public UnaryIterator {
3008  public:
3009  OppIntVarIterator(const IntVar* const v, bool hole, bool reversible)
3010  : UnaryIterator(v, hole, reversible) {}
3011  ~OppIntVarIterator() override {}
3012 
3013  int64 Value() const override { return -iterator_->Value(); }
3014  };
3015 
3016  OppIntVar(Solver* const s, IntVar* v);
3017  ~OppIntVar() override;
3018 
3019  int64 Min() const override;
3020  void SetMin(int64 m) override;
3021  int64 Max() const override;
3022  void SetMax(int64 m) override;
3023  void SetRange(int64 l, int64 u) override;
3024  void SetValue(int64 v) override;
3025  bool Bound() const override;
3026  int64 Value() const override;
3027  void RemoveValue(int64 v) override;
3028  void RemoveInterval(int64 l, int64 u) override;
3029  uint64 Size() const override;
3030  bool Contains(int64 v) const override;
3031  void WhenRange(Demon* d) override;
3032  void WhenBound(Demon* d) override;
3033  void WhenDomain(Demon* d) override;
3034  IntVarIterator* MakeHoleIterator(bool reversible) const override {
3035  return COND_REV_ALLOC(reversible,
3036  new OppIntVarIterator(var_, true, reversible));
3037  }
3038  IntVarIterator* MakeDomainIterator(bool reversible) const override {
3039  return COND_REV_ALLOC(reversible,
3040  new OppIntVarIterator(var_, false, reversible));
3041  }
3042  int64 OldMin() const override { return CapOpp(var_->OldMax()); }
3043  int64 OldMax() const override { return CapOpp(var_->OldMin()); }
3044  std::string DebugString() const override;
3045  int VarType() const override { return OPP_VAR; }
3046 
3047  void Accept(ModelVisitor* const visitor) const override {
3048  visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation, 0,
3049  var_);
3050  }
3051 
3052  IntVar* IsEqual(int64 constant) override { return var_->IsEqual(-constant); }
3053 
3054  IntVar* IsDifferent(int64 constant) override {
3055  return var_->IsDifferent(-constant);
3056  }
3057 
3058  IntVar* IsGreaterOrEqual(int64 constant) override {
3059  return var_->IsLessOrEqual(-constant);
3060  }
3061 
3062  IntVar* IsLessOrEqual(int64 constant) override {
3063  return var_->IsGreaterOrEqual(-constant);
3064  }
3065 
3066  IntVar* SubVar() const { return var_; }
3067 
3068  private:
3069  IntVar* const var_;
3070 };
3071 
3072 OppIntVar::OppIntVar(Solver* const s, IntVar* v) : IntVar(s), var_(v) {}
3073 
3074 OppIntVar::~OppIntVar() {}
3075 
3076 int64 OppIntVar::Min() const { return -var_->Max(); }
3077 
3078 void OppIntVar::SetMin(int64 m) { var_->SetMax(CapOpp(m)); }
3079 
3080 int64 OppIntVar::Max() const { return -var_->Min(); }
3081 
3082 void OppIntVar::SetMax(int64 m) { var_->SetMin(CapOpp(m)); }
3083 
3084 void OppIntVar::SetRange(int64 l, int64 u) {
3085  var_->SetRange(CapOpp(u), CapOpp(l));
3086 }
3087 
3088 void OppIntVar::SetValue(int64 v) { var_->SetValue(CapOpp(v)); }
3089 
3090 bool OppIntVar::Bound() const { return var_->Bound(); }
3091 
3092 void OppIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3093 
3094 int64 OppIntVar::Value() const { return -var_->Value(); }
3095 
3096 void OppIntVar::RemoveValue(int64 v) { var_->RemoveValue(-v); }
3097 
3098 void OppIntVar::RemoveInterval(int64 l, int64 u) {
3099  var_->RemoveInterval(-u, -l);
3100 }
3101 
3102 void OppIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3103 
3104 void OppIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3105 
3106 uint64 OppIntVar::Size() const { return var_->Size(); }
3107 
3108 bool OppIntVar::Contains(int64 v) const { return var_->Contains(-v); }
3109 
3110 std::string OppIntVar::DebugString() const {
3111  return absl::StrFormat("-(%s)", var_->DebugString());
3112 }
3113 
3114 // ----- Utility functions -----
3115 
3116 // x * c variable, optimized case
3117 
3118 class TimesCstIntVar : public IntVar {
3119  public:
3120  TimesCstIntVar(Solver* const s, IntVar* v, int64 c)
3121  : IntVar(s), var_(v), cst_(c) {}
3122  ~TimesCstIntVar() override {}
3123 
3124  IntVar* SubVar() const { return var_; }
3125  int64 Constant() const { return cst_; }
3126 
3127  void Accept(ModelVisitor* const visitor) const override {
3128  visitor->VisitIntegerVariable(this, ModelVisitor::kProductOperation, cst_,
3129  var_);
3130  }
3131 
3132  IntVar* IsEqual(int64 constant) override {
3133  if (constant % cst_ == 0) {
3134  return var_->IsEqual(constant / cst_);
3135  } else {
3136  return solver()->MakeIntConst(0);
3137  }
3138  }
3139 
3140  IntVar* IsDifferent(int64 constant) override {
3141  if (constant % cst_ == 0) {
3142  return var_->IsDifferent(constant / cst_);
3143  } else {
3144  return solver()->MakeIntConst(1);
3145  }
3146  }
3147 
3148  IntVar* IsGreaterOrEqual(int64 constant) override {
3149  if (cst_ > 0) {
3150  return var_->IsGreaterOrEqual(PosIntDivUp(constant, cst_));
3151  } else {
3152  return var_->IsLessOrEqual(PosIntDivDown(-constant, -cst_));
3153  }
3154  }
3155 
3156  IntVar* IsLessOrEqual(int64 constant) override {
3157  if (cst_ > 0) {
3158  return var_->IsLessOrEqual(PosIntDivDown(constant, cst_));
3159  } else {
3160  return var_->IsGreaterOrEqual(PosIntDivUp(-constant, -cst_));
3161  }
3162  }
3163 
3164  std::string DebugString() const override {
3165  return absl::StrFormat("(%s * %d)", var_->DebugString(), cst_);
3166  }
3167 
3168  int VarType() const override { return VAR_TIMES_CST; }
3169 
3170  protected:
3171  IntVar* const var_;
3172  const int64 cst_;
3173 };
3174 
3175 class TimesPosCstIntVar : public TimesCstIntVar {
3176  public:
3177  class TimesPosCstIntVarIterator : public UnaryIterator {
3178  public:
3179  TimesPosCstIntVarIterator(const IntVar* const v, int64 c, bool hole,
3180  bool reversible)
3181  : UnaryIterator(v, hole, reversible), cst_(c) {}
3182  ~TimesPosCstIntVarIterator() override {}
3183 
3184  int64 Value() const override { return iterator_->Value() * cst_; }
3185 
3186  private:
3187  const int64 cst_;
3188  };
3189 
3190  TimesPosCstIntVar(Solver* const s, IntVar* v, int64 c);
3191  ~TimesPosCstIntVar() override;
3192 
3193  int64 Min() const override;
3194  void SetMin(int64 m) override;
3195  int64 Max() const override;
3196  void SetMax(int64 m) override;
3197  void SetRange(int64 l, int64 u) override;
3198  void SetValue(int64 v) override;
3199  bool Bound() const override;
3200  int64 Value() const override;
3201  void RemoveValue(int64 v) override;
3202  void RemoveInterval(int64 l, int64 u) override;
3203  uint64 Size() const override;
3204  bool Contains(int64 v) const override;
3205  void WhenRange(Demon* d) override;
3206  void WhenBound(Demon* d) override;
3207  void WhenDomain(Demon* d) override;
3208  IntVarIterator* MakeHoleIterator(bool reversible) const override {
3209  return COND_REV_ALLOC(reversible, new TimesPosCstIntVarIterator(
3210  var_, cst_, true, reversible));
3211  }
3212  IntVarIterator* MakeDomainIterator(bool reversible) const override {
3213  return COND_REV_ALLOC(reversible, new TimesPosCstIntVarIterator(
3214  var_, cst_, false, reversible));
3215  }
3216  int64 OldMin() const override { return CapProd(var_->OldMin(), cst_); }
3217  int64 OldMax() const override { return CapProd(var_->OldMax(), cst_); }
3218 };
3219 
3220 // ----- TimesPosCstIntVar -----
3221 
3222 TimesPosCstIntVar::TimesPosCstIntVar(Solver* const s, IntVar* v, int64 c)
3223  : TimesCstIntVar(s, v, c) {}
3224 
3225 TimesPosCstIntVar::~TimesPosCstIntVar() {}
3226 
3227 int64 TimesPosCstIntVar::Min() const { return CapProd(var_->Min(), cst_); }
3228 
3229 void TimesPosCstIntVar::SetMin(int64 m) {
3230  if (m != kint64min) {
3231  var_->SetMin(PosIntDivUp(m, cst_));
3232  }
3233 }
3234 
3235 int64 TimesPosCstIntVar::Max() const { return CapProd(var_->Max(), cst_); }
3236 
3237 void TimesPosCstIntVar::SetMax(int64 m) {
3238  if (m != kint64max) {
3239  var_->SetMax(PosIntDivDown(m, cst_));
3240  }
3241 }
3242 
3243 void TimesPosCstIntVar::SetRange(int64 l, int64 u) {
3244  var_->SetRange(PosIntDivUp(l, cst_), PosIntDivDown(u, cst_));
3245 }
3246 
3247 void TimesPosCstIntVar::SetValue(int64 v) {
3248  if (v % cst_ != 0) {
3249  solver()->Fail();
3250  }
3251  var_->SetValue(v / cst_);
3252 }
3253 
3254 bool TimesPosCstIntVar::Bound() const { return var_->Bound(); }
3255 
3256 void TimesPosCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3257 
3258 int64 TimesPosCstIntVar::Value() const { return CapProd(var_->Value(), cst_); }
3259 
3260 void TimesPosCstIntVar::RemoveValue(int64 v) {
3261  if (v % cst_ == 0) {
3262  var_->RemoveValue(v / cst_);
3263  }
3264 }
3265 
3266 void TimesPosCstIntVar::RemoveInterval(int64 l, int64 u) {
3267  for (int64 v = l; v <= u; ++v) {
3268  RemoveValue(v);
3269  }
3270  // TODO(user) : Improve me
3271 }
3272 
3273 void TimesPosCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3274 
3275 void TimesPosCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3276 
3277 uint64 TimesPosCstIntVar::Size() const { return var_->Size(); }
3278 
3279 bool TimesPosCstIntVar::Contains(int64 v) const {
3280  return (v % cst_ == 0 && var_->Contains(v / cst_));
3281 }
3282 
3283 // b * c variable, optimized case
3284 
3285 class TimesPosCstBoolVar : public TimesCstIntVar {
3286  public:
3287  class TimesPosCstBoolVarIterator : public UnaryIterator {
3288  public:
3289  // TODO(user) : optimize this.
3290  TimesPosCstBoolVarIterator(const IntVar* const v, int64 c, bool hole,
3291  bool reversible)
3292  : UnaryIterator(v, hole, reversible), cst_(c) {}
3293  ~TimesPosCstBoolVarIterator() override {}
3294 
3295  int64 Value() const override { return iterator_->Value() * cst_; }
3296 
3297  private:
3298  const int64 cst_;
3299  };
3300 
3301  TimesPosCstBoolVar(Solver* const s, BooleanVar* v, int64 c);
3302  ~TimesPosCstBoolVar() override;
3303 
3304  int64 Min() const override;
3305  void SetMin(int64 m) override;
3306  int64 Max() const override;
3307  void SetMax(int64 m) override;
3308  void SetRange(int64 l, int64 u) override;
3309  void SetValue(int64 v) override;
3310  bool Bound() const override;
3311  int64 Value() const override;
3312  void RemoveValue(int64 v) override;
3313  void RemoveInterval(int64 l, int64 u) override;
3314  uint64 Size() const override;
3315  bool Contains(int64 v) const override;
3316  void WhenRange(Demon* d) override;
3317  void WhenBound(Demon* d) override;
3318  void WhenDomain(Demon* d) override;
3319  IntVarIterator* MakeHoleIterator(bool reversible) const override {
3320  return COND_REV_ALLOC(reversible, new EmptyIterator());
3321  }
3322  IntVarIterator* MakeDomainIterator(bool reversible) const override {
3323  return COND_REV_ALLOC(
3324  reversible,
3325  new TimesPosCstBoolVarIterator(boolean_var(), cst_, false, reversible));
3326  }
3327  int64 OldMin() const override { return 0; }
3328  int64 OldMax() const override { return cst_; }
3329 
3330  BooleanVar* boolean_var() const {
3331  return reinterpret_cast<BooleanVar*>(var_);
3332  }
3333 };
3334 
3335 // ----- TimesPosCstBoolVar -----
3336 
3337 TimesPosCstBoolVar::TimesPosCstBoolVar(Solver* const s, BooleanVar* v, int64 c)
3338  : TimesCstIntVar(s, v, c) {}
3339 
3340 TimesPosCstBoolVar::~TimesPosCstBoolVar() {}
3341 
3342 int64 TimesPosCstBoolVar::Min() const {
3343  return (boolean_var()->RawValue() == 1) * cst_;
3344 }
3345 
3346 void TimesPosCstBoolVar::SetMin(int64 m) {
3347  if (m > cst_) {
3348  solver()->Fail();
3349  } else if (m > 0) {
3350  boolean_var()->SetMin(1);
3351  }
3352 }
3353 
3354 int64 TimesPosCstBoolVar::Max() const {
3355  return (boolean_var()->RawValue() != 0) * cst_;
3356 }
3357 
3358 void TimesPosCstBoolVar::SetMax(int64 m) {
3359  if (m < 0) {
3360  solver()->Fail();
3361  } else if (m < cst_) {
3362  boolean_var()->SetMax(0);
3363  }
3364 }
3365 
3366 void TimesPosCstBoolVar::SetRange(int64 l, int64 u) {
3367  if (u < 0 || l > cst_ || l > u) {
3368  solver()->Fail();
3369  }
3370  if (l > 0) {
3371  boolean_var()->SetMin(1);
3372  } else if (u < cst_) {
3373  boolean_var()->SetMax(0);
3374  }
3375 }
3376 
3377 void TimesPosCstBoolVar::SetValue(int64 v) {
3378  if (v == 0) {
3379  boolean_var()->SetValue(0);
3380  } else if (v == cst_) {
3381  boolean_var()->SetValue(1);
3382  } else {
3383  solver()->Fail();
3384  }
3385 }
3386 
3387 bool TimesPosCstBoolVar::Bound() const {
3388  return boolean_var()->RawValue() != BooleanVar::kUnboundBooleanVarValue;
3389 }
3390 
3391 void TimesPosCstBoolVar::WhenRange(Demon* d) { boolean_var()->WhenRange(d); }
3392 
3394  CHECK_NE(boolean_var()->RawValue(), BooleanVar::kUnboundBooleanVarValue)
3395  << " variable is not bound";
3396  return boolean_var()->RawValue() * cst_;
3397 }
3398 
3399 void TimesPosCstBoolVar::RemoveValue(int64 v) {
3400  if (v == 0) {
3401  boolean_var()->RemoveValue(0);
3402  } else if (v == cst_) {
3403  boolean_var()->RemoveValue(1);
3404  }
3405 }
3406 
3407 void TimesPosCstBoolVar::RemoveInterval(int64 l, int64 u) {
3408  if (l <= 0 && u >= 0) {
3409  boolean_var()->RemoveValue(0);
3410  }
3411  if (l <= cst_ && u >= cst_) {
3412  boolean_var()->RemoveValue(1);
3413  }
3414 }
3415 
3416 void TimesPosCstBoolVar::WhenBound(Demon* d) { boolean_var()->WhenBound(d); }
3417 
3418 void TimesPosCstBoolVar::WhenDomain(Demon* d) { boolean_var()->WhenDomain(d); }
3419 
3420 uint64 TimesPosCstBoolVar::Size() const {
3421  return (1 +
3422  (boolean_var()->RawValue() == BooleanVar::kUnboundBooleanVarValue));
3423 }
3424 
3425 bool TimesPosCstBoolVar::Contains(int64 v) const {
3426  if (v == 0) {
3427  return boolean_var()->RawValue() != 1;
3428  } else if (v == cst_) {
3429  return boolean_var()->RawValue() != 0;
3430  }
3431  return false;
3432 }
3433 
3434 // TimesNegCstIntVar
3435 
3436 class TimesNegCstIntVar : public TimesCstIntVar {
3437  public:
3438  class TimesNegCstIntVarIterator : public UnaryIterator {
3439  public:
3440  TimesNegCstIntVarIterator(const IntVar* const v, int64 c, bool hole,
3441  bool reversible)
3442  : UnaryIterator(v, hole, reversible), cst_(c) {}
3443  ~TimesNegCstIntVarIterator() override {}
3444 
3445  int64 Value() const override { return iterator_->Value() * cst_; }
3446 
3447  private:
3448  const int64 cst_;
3449  };
3450 
3451  TimesNegCstIntVar(Solver* const s, IntVar* v, int64 c);
3452  ~TimesNegCstIntVar() override;
3453 
3454  int64 Min() const override;
3455  void SetMin(int64 m) override;
3456  int64 Max() const override;
3457  void SetMax(int64 m) override;
3458  void SetRange(int64 l, int64 u) override;
3459  void SetValue(int64 v) override;
3460  bool Bound() const override;
3461  int64 Value() const override;
3462  void RemoveValue(int64 v) override;
3463  void RemoveInterval(int64 l, int64 u) override;
3464  uint64 Size() const override;
3465  bool Contains(int64 v) const override;
3466  void WhenRange(Demon* d) override;
3467  void WhenBound(Demon* d) override;
3468  void WhenDomain(Demon* d) override;
3469  IntVarIterator* MakeHoleIterator(bool reversible) const override {
3470  return COND_REV_ALLOC(reversible, new TimesNegCstIntVarIterator(
3471  var_, cst_, true, reversible));
3472  }
3473  IntVarIterator* MakeDomainIterator(bool reversible) const override {
3474  return COND_REV_ALLOC(reversible, new TimesNegCstIntVarIterator(
3475  var_, cst_, false, reversible));
3476  }
3477  int64 OldMin() const override { return CapProd(var_->OldMax(), cst_); }
3478  int64 OldMax() const override { return CapProd(var_->OldMin(), cst_); }
3479 };
3480 
3481 // ----- TimesNegCstIntVar -----
3482 
3483 TimesNegCstIntVar::TimesNegCstIntVar(Solver* const s, IntVar* v, int64 c)
3484  : TimesCstIntVar(s, v, c) {}
3485 
3486 TimesNegCstIntVar::~TimesNegCstIntVar() {}
3487 
3488 int64 TimesNegCstIntVar::Min() const { return CapProd(var_->Max(), cst_); }
3489 
3490 void TimesNegCstIntVar::SetMin(int64 m) {
3491  if (m != kint64min) {
3492  var_->SetMax(PosIntDivDown(-m, -cst_));
3493  }
3494 }
3495 
3496 int64 TimesNegCstIntVar::Max() const { return CapProd(var_->Min(), cst_); }
3497 
3498 void TimesNegCstIntVar::SetMax(int64 m) {
3499  if (m != kint64max) {
3500  var_->SetMin(PosIntDivUp(-m, -cst_));
3501  }
3502 }
3503 
3504 void TimesNegCstIntVar::SetRange(int64 l, int64 u) {
3505  var_->SetRange(PosIntDivUp(-u, -cst_), PosIntDivDown(-l, -cst_));
3506 }
3507 
3508 void TimesNegCstIntVar::SetValue(int64 v) {
3509  if (v % cst_ != 0) {
3510  solver()->Fail();
3511  }
3512  var_->SetValue(v / cst_);
3513 }
3514 
3515 bool TimesNegCstIntVar::Bound() const { return var_->Bound(); }
3516 
3517 void TimesNegCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3518 
3519 int64 TimesNegCstIntVar::Value() const { return CapProd(var_->Value(), cst_); }
3520 
3521 void TimesNegCstIntVar::RemoveValue(int64 v) {
3522  if (v % cst_ == 0) {
3523  var_->RemoveValue(v / cst_);
3524  }
3525 }
3526 
3527 void TimesNegCstIntVar::RemoveInterval(int64 l, int64 u) {
3528  for (int64 v = l; v <= u; ++v) {
3529  RemoveValue(v);
3530  }
3531  // TODO(user) : Improve me
3532 }
3533 
3534 void TimesNegCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3535 
3536 void TimesNegCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3537 
3538 uint64 TimesNegCstIntVar::Size() const { return var_->Size(); }
3539 
3540 bool TimesNegCstIntVar::Contains(int64 v) const {
3541  return (v % cst_ == 0 && var_->Contains(v / cst_));
3542 }
3543 
3544 // ---------- arithmetic expressions ----------
3545 
3546 // ----- PlusIntExpr -----
3547 
3548 class PlusIntExpr : public BaseIntExpr {
3549  public:
3550  PlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3551  : BaseIntExpr(s), left_(l), right_(r) {}
3552 
3553  ~PlusIntExpr() override {}
3554 
3555  int64 Min() const override { return left_->Min() + right_->Min(); }
3556 
3557  void SetMin(int64 m) override {
3558  if (m > left_->Min() + right_->Min()) {
3559  left_->SetMin(m - right_->Max());
3560  right_->SetMin(m - left_->Max());
3561  }
3562  }
3563 
3564  void SetRange(int64 l, int64 u) override {
3565  const int64 left_min = left_->Min();
3566  const int64 right_min = right_->Min();
3567  const int64 left_max = left_->Max();
3568  const int64 right_max = right_->Max();
3569  if (l > left_min + right_min) {
3570  left_->SetMin(l - right_max);
3571  right_->SetMin(l - left_max);
3572  }
3573  if (u < left_max + right_max) {
3574  left_->SetMax(u - right_min);
3575  right_->SetMax(u - left_min);
3576  }
3577  }
3578 
3579  int64 Max() const override { return left_->Max() + right_->Max(); }
3580 
3581  void SetMax(int64 m) override {
3582  if (m < left_->Max() + right_->Max()) {
3583  left_->SetMax(m - right_->Min());
3584  right_->SetMax(m - left_->Min());
3585  }
3586  }
3587 
3588  bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3589 
3590  void Range(int64* const mi, int64* const ma) override {
3591  *mi = left_->Min() + right_->Min();
3592  *ma = left_->Max() + right_->Max();
3593  }
3594 
3595  std::string name() const override {
3596  return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3597  }
3598 
3599  std::string DebugString() const override {
3600  return absl::StrFormat("(%s + %s)", left_->DebugString(),
3601  right_->DebugString());
3602  }
3603 
3604  void WhenRange(Demon* d) override {
3605  left_->WhenRange(d);
3606  right_->WhenRange(d);
3607  }
3608 
3609  void ExpandPlusIntExpr(IntExpr* const expr, std::vector<IntExpr*>* subs) {
3610  PlusIntExpr* const casted = dynamic_cast<PlusIntExpr*>(expr);
3611  if (casted != nullptr) {
3612  ExpandPlusIntExpr(casted->left_, subs);
3613  ExpandPlusIntExpr(casted->right_, subs);
3614  } else {
3615  subs->push_back(expr);
3616  }
3617  }
3618 
3619  IntVar* CastToVar() override {
3620  if (dynamic_cast<PlusIntExpr*>(left_) != nullptr ||
3621  dynamic_cast<PlusIntExpr*>(right_) != nullptr) {
3622  std::vector<IntExpr*> sub_exprs;
3623  ExpandPlusIntExpr(left_, &sub_exprs);
3624  ExpandPlusIntExpr(right_, &sub_exprs);
3625  if (sub_exprs.size() >= 3) {
3626  std::vector<IntVar*> sub_vars(sub_exprs.size());
3627  for (int i = 0; i < sub_exprs.size(); ++i) {
3628  sub_vars[i] = sub_exprs[i]->Var();
3629  }
3630  return solver()->MakeSum(sub_vars)->Var();
3631  }
3632  }
3633  return BaseIntExpr::CastToVar();
3634  }
3635 
3636  void Accept(ModelVisitor* const visitor) const override {
3637  visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3638  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3639  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3640  right_);
3641  visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3642  }
3643 
3644  private:
3645  IntExpr* const left_;
3646  IntExpr* const right_;
3647 };
3648 
3649 class SafePlusIntExpr : public BaseIntExpr {
3650  public:
3651  SafePlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3652  : BaseIntExpr(s), left_(l), right_(r) {}
3653 
3654  ~SafePlusIntExpr() override {}
3655 
3656  int64 Min() const override { return CapAdd(left_->Min(), right_->Min()); }
3657 
3658  void SetMin(int64 m) override {
3659  left_->SetMin(CapSub(m, right_->Max()));
3660  right_->SetMin(CapSub(m, left_->Max()));
3661  }
3662 
3663  void SetRange(int64 l, int64 u) override {
3664  const int64 left_min = left_->Min();
3665  const int64 right_min = right_->Min();
3666  const int64 left_max = left_->Max();
3667  const int64 right_max = right_->Max();
3668  if (l > CapAdd(left_min, right_min)) {
3669  left_->SetMin(CapSub(l, right_max));
3670  right_->SetMin(CapSub(l, left_max));
3671  }
3672  if (u < CapAdd(left_max, right_max)) {
3673  left_->SetMax(CapSub(u, right_min));
3674  right_->SetMax(CapSub(u, left_min));
3675  }
3676  }
3677 
3678  int64 Max() const override { return CapAdd(left_->Max(), right_->Max()); }
3679 
3680  void SetMax(int64 m) override {
3681  left_->SetMax(CapSub(m, right_->Min()));
3682  right_->SetMax(CapSub(m, left_->Min()));
3683  }
3684 
3685  bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3686 
3687  std::string name() const override {
3688  return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3689  }
3690 
3691  std::string DebugString() const override {
3692  return absl::StrFormat("(%s + %s)", left_->DebugString(),
3693  right_->DebugString());
3694  }
3695 
3696  void WhenRange(Demon* d) override {
3697  left_->WhenRange(d);
3698  right_->WhenRange(d);
3699  }
3700 
3701  void Accept(ModelVisitor* const visitor) const override {
3702  visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3703  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3704  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3705  right_);
3706  visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3707  }
3708 
3709  private:
3710  IntExpr* const left_;
3711  IntExpr* const right_;
3712 };
3713 
3714 // ----- PlusIntCstExpr -----
3715 
3716 class PlusIntCstExpr : public BaseIntExpr {
3717  public:
3718  PlusIntCstExpr(Solver* const s, IntExpr* const e, int64 v)
3719  : BaseIntExpr(s), expr_(e), value_(v) {}
3720  ~PlusIntCstExpr() override {}
3721  int64 Min() const override { return CapAdd(expr_->Min(), value_); }
3722  void SetMin(int64 m) override { expr_->SetMin(CapSub(m, value_)); }
3723  int64 Max() const override { return CapAdd(expr_->Max(), value_); }
3724  void SetMax(int64 m) override { expr_->SetMax(CapSub(m, value_)); }
3725  bool Bound() const override { return (expr_->Bound()); }
3726  std::string name() const override {
3727  return absl::StrFormat("(%s + %d)", expr_->name(), value_);
3728  }
3729  std::string DebugString() const override {
3730  return absl::StrFormat("(%s + %d)", expr_->DebugString(), value_);
3731  }
3732  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3733  IntVar* CastToVar() override;
3734  void Accept(ModelVisitor* const visitor) const override {
3735  visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3736  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3737  expr_);
3738  visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3739  visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3740  }
3741 
3742  private:
3743  IntExpr* const expr_;
3744  const int64 value_;
3745 };
3746 
3747 IntVar* PlusIntCstExpr::CastToVar() {
3748  Solver* const s = solver();
3749  IntVar* const var = expr_->Var();
3750  IntVar* cast = nullptr;
3751  if (AddOverflows(value_, expr_->Max()) ||
3752  AddOverflows(value_, expr_->Min())) {
3753  return BaseIntExpr::CastToVar();
3754  }
3755  switch (var->VarType()) {
3756  case DOMAIN_INT_VAR:
3757  cast = s->RegisterIntVar(s->RevAlloc(new PlusCstDomainIntVar(
3758  s, reinterpret_cast<DomainIntVar*>(var), value_)));
3759  // FIXME: Break was inserted during fallthrough cleanup. Please check.
3760  break;
3761  default:
3762  cast = s->RegisterIntVar(s->RevAlloc(new PlusCstIntVar(s, var, value_)));
3763  break;
3764  }
3765  return cast;
3766 }
3767 
3768 // ----- SubIntExpr -----
3769 
3770 class SubIntExpr : public BaseIntExpr {
3771  public:
3772  SubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3773  : BaseIntExpr(s), left_(l), right_(r) {}
3774 
3775  ~SubIntExpr() override {}
3776 
3777  int64 Min() const override { return left_->Min() - right_->Max(); }
3778 
3779  void SetMin(int64 m) override {
3780  left_->SetMin(CapAdd(m, right_->Min()));
3781  right_->SetMax(CapSub(left_->Max(), m));
3782  }
3783 
3784  int64 Max() const override { return left_->Max() - right_->Min(); }
3785 
3786  void SetMax(int64 m) override {
3787  left_->SetMax(CapAdd(m, right_->Max()));
3788  right_->SetMin(CapSub(left_->Min(), m));
3789  }
3790 
3791  void Range(int64* mi, int64* ma) override {
3792  *mi = left_->Min() - right_->Max();
3793  *ma = left_->Max() - right_->Min();
3794  }
3795 
3796  void SetRange(int64 l, int64 u) override {
3797  const int64 left_min = left_->Min();
3798  const int64 right_min = right_->Min();
3799  const int64 left_max = left_->Max();
3800  const int64 right_max = right_->Max();
3801  if (l > left_min - right_max) {
3802  left_->SetMin(CapAdd(l, right_min));
3803  right_->SetMax(CapSub(left_max, l));
3804  }
3805  if (u < left_max - right_min) {
3806  left_->SetMax(CapAdd(u, right_max));
3807  right_->SetMin(CapSub(left_min, u));
3808  }
3809  }
3810 
3811  bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3812 
3813  std::string name() const override {
3814  return absl::StrFormat("(%s - %s)", left_->name(), right_->name());
3815  }
3816 
3817  std::string DebugString() const override {
3818  return absl::StrFormat("(%s - %s)", left_->DebugString(),
3819  right_->DebugString());
3820  }
3821 
3822  void WhenRange(Demon* d) override {
3823  left_->WhenRange(d);
3824  right_->WhenRange(d);
3825  }
3826 
3827  void Accept(ModelVisitor* const visitor) const override {
3828  visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3829  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3830  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3831  right_);
3832  visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3833  }
3834 
3835  IntExpr* left() const { return left_; }
3836  IntExpr* right() const { return right_; }
3837 
3838  protected:
3839  IntExpr* const left_;
3840  IntExpr* const right_;
3841 };
3842 
3843 class SafeSubIntExpr : public SubIntExpr {
3844  public:
3845  SafeSubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3846  : SubIntExpr(s, l, r) {}
3847 
3848  ~SafeSubIntExpr() override {}
3849 
3850  int64 Min() const override { return CapSub(left_->Min(), right_->Max()); }
3851 
3852  void SetMin(int64 m) override {
3853  left_->SetMin(CapAdd(m, right_->Min()));
3854  right_->SetMax(CapSub(left_->Max(), m));
3855  }
3856 
3857  void SetRange(int64 l, int64 u) override {
3858  const int64 left_min = left_->Min();
3859  const int64 right_min = right_->Min();
3860  const int64 left_max = left_->Max();
3861  const int64 right_max = right_->Max();
3862  if (l > CapSub(left_min, right_max)) {
3863  left_->SetMin(CapAdd(l, right_min));
3864  right_->SetMax(CapSub(left_max, l));
3865  }
3866  if (u < CapSub(left_max, right_min)) {
3867  left_->SetMax(CapAdd(u, right_max));
3868  right_->SetMin(CapSub(left_min, u));
3869  }
3870  }
3871 
3872  void Range(int64* mi, int64* ma) override {
3873  *mi = CapSub(left_->Min(), right_->Max());
3874  *ma = CapSub(left_->Max(), right_->Min());
3875  }
3876 
3877  int64 Max() const override { return CapSub(left_->Max(), right_->Min()); }
3878 
3879  void SetMax(int64 m) override {
3880  left_->SetMax(CapAdd(m, right_->Max()));
3881  right_->SetMin(CapSub(left_->Min(), m));
3882  }
3883 };
3884 
3885 // l - r
3886 
3887 // ----- SubIntCstExpr -----
3888 
3889 class SubIntCstExpr : public BaseIntExpr {
3890  public:
3891  SubIntCstExpr(Solver* const s, IntExpr* const e, int64 v)
3892  : BaseIntExpr(s), expr_(e), value_(v) {}
3893  ~SubIntCstExpr() override {}
3894  int64 Min() const override { return CapSub(value_, expr_->Max()); }
3895  void SetMin(int64 m) override { expr_->SetMax(CapSub(value_, m)); }
3896  int64 Max() const override { return CapSub(value_, expr_->Min()); }
3897  void SetMax(int64 m) override { expr_->SetMin(CapSub(value_, m)); }
3898  bool Bound() const override { return (expr_->Bound()); }
3899  std::string name() const override {
3900  return absl::StrFormat("(%d - %s)", value_, expr_->name());
3901  }
3902  std::string DebugString() const override {
3903  return absl::StrFormat("(%d - %s)", value_, expr_->DebugString());
3904  }
3905  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3906  IntVar* CastToVar() override;
3907 
3908  void Accept(ModelVisitor* const visitor) const override {
3909  visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3910  visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3911  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3912  expr_);
3913  visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3914  }
3915 
3916  private:
3917  IntExpr* const expr_;
3918  const int64 value_;
3919 };
3920 
3921 IntVar* SubIntCstExpr::CastToVar() {
3922  if (SubOverflows(value_, expr_->Min()) ||
3923  SubOverflows(value_, expr_->Max())) {
3924  return BaseIntExpr::CastToVar();
3925  }
3926  Solver* const s = solver();
3927  IntVar* const var =
3928  s->RegisterIntVar(s->RevAlloc(new SubCstIntVar(s, expr_->Var(), value_)));
3929  return var;
3930 }
3931 
3932 // ----- OppIntExpr -----
3933 
3934 class OppIntExpr : public BaseIntExpr {
3935  public:
3936  OppIntExpr(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
3937  ~OppIntExpr() override {}
3938  int64 Min() const override { return (-expr_->Max()); }
3939  void SetMin(int64 m) override { expr_->SetMax(-m); }
3940  int64 Max() const override { return (-expr_->Min()); }
3941  void SetMax(int64 m) override { expr_->SetMin(-m); }
3942  bool Bound() const override { return (expr_->Bound()); }
3943  std::string name() const override {
3944  return absl::StrFormat("(-%s)", expr_->name());
3945  }
3946  std::string DebugString() const override {
3947  return absl::StrFormat("(-%s)", expr_->DebugString());
3948  }
3949  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3950  IntVar* CastToVar() override;
3951 
3952  void Accept(ModelVisitor* const visitor) const override {
3953  visitor->BeginVisitIntegerExpression(ModelVisitor::kOpposite, this);
3954  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3955  expr_);
3956  visitor->EndVisitIntegerExpression(ModelVisitor::kOpposite, this);
3957  }
3958 
3959  private:
3960  IntExpr* const expr_;
3961 };
3962 
3963 IntVar* OppIntExpr::CastToVar() {
3964  Solver* const s = solver();
3965  IntVar* const var =
3966  s->RegisterIntVar(s->RevAlloc(new OppIntVar(s, expr_->Var())));
3967  return var;
3968 }
3969 
3970 // ----- TimesIntCstExpr -----
3971 
3972 class TimesIntCstExpr : public BaseIntExpr {
3973  public:
3974  TimesIntCstExpr(Solver* const s, IntExpr* const e, int64 v)
3975  : BaseIntExpr(s), expr_(e), value_(v) {}
3976 
3977  ~TimesIntCstExpr() override {}
3978 
3979  bool Bound() const override { return (expr_->Bound()); }
3980 
3981  std::string name() const override {
3982  return absl::StrFormat("(%s * %d)", expr_->name(), value_);
3983  }
3984 
3985  std::string DebugString() const override {
3986  return absl::StrFormat("(%s * %d)", expr_->DebugString(), value_);
3987  }
3988 
3989  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3990 
3991  IntExpr* Expr() const { return expr_; }
3992 
3993  int64 Constant() const { return value_; }
3994 
3995  void Accept(ModelVisitor* const visitor) const override {
3996  visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
3997  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3998  expr_);
3999  visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4000  visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4001  }
4002 
4003  protected:
4004  IntExpr* const expr_;
4005  const int64 value_;
4006 };
4007 
4008 // ----- TimesPosIntCstExpr -----
4009 
4010 class TimesPosIntCstExpr : public TimesIntCstExpr {
4011  public:
4012  TimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64 v)
4013  : TimesIntCstExpr(s, e, v) {
4014  CHECK_GT(v, 0);
4015  }
4016 
4017  ~TimesPosIntCstExpr() override {}
4018 
4019  int64 Min() const override { return expr_->Min() * value_; }
4020 
4021  void SetMin(int64 m) override { expr_->SetMin(PosIntDivUp(m, value_)); }
4022 
4023  int64 Max() const override { return expr_->Max() * value_; }
4024 
4025  void SetMax(int64 m) override { expr_->SetMax(PosIntDivDown(m, value_)); }
4026 
4027  IntVar* CastToVar() override {
4028  Solver* const s = solver();
4029  IntVar* var = nullptr;
4030  if (expr_->IsVar() &&
4031  reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4032  var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4033  s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4034  } else {
4035  var = s->RegisterIntVar(
4036  s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4037  }
4038  return var;
4039  }
4040 };
4041 
4042 // This expressions adds safe arithmetic (w.r.t. overflows) compared
4043 // to the previous one.
4044 class SafeTimesPosIntCstExpr : public TimesIntCstExpr {
4045  public:
4046  SafeTimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64 v)
4047  : TimesIntCstExpr(s, e, v) {
4048  CHECK_GT(v, 0);
4049  }
4050 
4051  ~SafeTimesPosIntCstExpr() override {}
4052 
4053  int64 Min() const override { return CapProd(expr_->Min(), value_); }
4054 
4055  void SetMin(int64 m) override {
4056  if (m != kint64min) {
4057  expr_->SetMin(PosIntDivUp(m, value_));
4058  }
4059  }
4060 
4061  int64 Max() const override { return CapProd(expr_->Max(), value_); }
4062 
4063  void SetMax(int64 m) override {
4064  if (m != kint64max) {
4065  expr_->SetMax(PosIntDivDown(m, value_));
4066  }
4067  }
4068 
4069  IntVar* CastToVar() override {
4070  Solver* const s = solver();
4071  IntVar* var = nullptr;
4072  if (expr_->IsVar() &&
4073  reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4074  var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4075  s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4076  } else {
4077  // TODO(user): Check overflows.
4078  var = s->RegisterIntVar(
4079  s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4080  }
4081  return var;
4082  }
4083 };
4084 
4085 // ----- TimesIntNegCstExpr -----
4086 
4087 class TimesIntNegCstExpr : public TimesIntCstExpr {
4088  public:
4089  TimesIntNegCstExpr(Solver* const s, IntExpr* const e, int64 v)
4090  : TimesIntCstExpr(s, e, v) {
4091  CHECK_LT(v, 0);
4092  }
4093 
4094  ~TimesIntNegCstExpr() override {}
4095 
4096  int64 Min() const override { return CapProd(expr_->Max(), value_); }
4097 
4098  void SetMin(int64 m) override {
4099  if (m != kint64min) {
4100  expr_->SetMax(PosIntDivDown(-m, -value_));
4101  }
4102  }
4103 
4104  int64 Max() const override { return CapProd(expr_->Min(), value_); }
4105 
4106  void SetMax(int64 m) override {
4107  if (m != kint64max) {
4108  expr_->SetMin(PosIntDivUp(-m, -value_));
4109  }
4110  }
4111 
4112  IntVar* CastToVar() override {
4113  Solver* const s = solver();
4114  IntVar* var = nullptr;
4115  var = s->RegisterIntVar(
4116  s->RevAlloc(new TimesNegCstIntVar(s, expr_->Var(), value_)));
4117  return var;
4118  }
4119 };
4120 
4121 // ----- Utilities for product expression -----
4122 
4123 // Propagates set_min on left * right, left and right >= 0.
4124 void SetPosPosMinExpr(IntExpr* const left, IntExpr* const right, int64 m) {
4125  DCHECK_GE(left->Min(), 0);
4126  DCHECK_GE(right->Min(), 0);
4127  const int64 lmax = left->Max();
4128  const int64 rmax = right->Max();
4129  if (m > CapProd(lmax, rmax)) {
4130  left->solver()->Fail();
4131  }
4132  if (m > CapProd(left->Min(), right->Min())) {
4133  // Ok for m == 0 due to left and right being positive
4134  if (0 != rmax) {
4135  left->SetMin(PosIntDivUp(m, rmax));
4136  }
4137  if (0 != lmax) {
4138  right->SetMin(PosIntDivUp(m, lmax));
4139  }
4140  }
4141 }
4142 
4143 // Propagates set_max on left * right, left and right >= 0.
4144 void SetPosPosMaxExpr(IntExpr* const left, IntExpr* const right, int64 m) {
4145  DCHECK_GE(left->Min(), 0);
4146  DCHECK_GE(right->Min(), 0);
4147  const int64 lmin = left->Min();
4148  const int64 rmin = right->Min();
4149  if (m < CapProd(lmin, rmin)) {
4150  left->solver()->Fail();
4151  }
4152  if (m < CapProd(left->Max(), right->Max())) {
4153  if (0 != lmin) {
4154  right->SetMax(PosIntDivDown(m, lmin));
4155  }
4156  if (0 != rmin) {
4157  left->SetMax(PosIntDivDown(m, rmin));
4158  }
4159  // else do nothing: 0 is supporting any value from other expr.
4160  }
4161 }
4162 
4163 // Propagates set_min on left * right, left >= 0, right across 0.
4164 void SetPosGenMinExpr(IntExpr* const left, IntExpr* const right, int64 m) {
4165  DCHECK_GE(left->Min(), 0);
4166  DCHECK_GT(right->Max(), 0);
4167  DCHECK_LT(right->Min(), 0);
4168  const int64 lmax = left->Max();
4169  const int64 rmax = right->Max();
4170  if (m > CapProd(lmax, rmax)) {
4171  left->solver()->Fail();
4172  }
4173  if (left->Max() == 0) { // left is bound to 0, product is bound to 0.
4174  DCHECK_EQ(0, left->Min());
4175  DCHECK_LE(m, 0);
4176  } else {
4177  if (m > 0) { // We deduce right > 0.
4178  left->SetMin(PosIntDivUp(m, rmax));
4179  right->SetMin(PosIntDivUp(m, lmax));
4180  } else if (m == 0) {
4181  const int64 lmin = left->Min();
4182  if (lmin > 0) {
4183  right->SetMin(0);
4184  }
4185  } else { // m < 0
4186  const int64 lmin = left->Min();
4187  if (0 != lmin) { // We cannot deduce anything if 0 is in the domain.
4188  right->SetMin(-PosIntDivDown(-m, lmin));
4189  }
4190  }
4191  }
4192 }
4193 
4194 // Propagates set_min on left * right, left and right across 0.
4195 void SetGenGenMinExpr(IntExpr* const left, IntExpr* const right, int64 m) {
4196  DCHECK_LT(left->Min(), 0);
4197  DCHECK_GT(left->Max(), 0);
4198  DCHECK_GT(right->Max(), 0);
4199  DCHECK_LT(right->Min(), 0);
4200  const int64 lmin = left->Min();
4201  const int64 lmax = left->Max();
4202  const int64 rmin = right->Min();
4203  const int64 rmax = right->Max();
4204  if (m > std::max(CapProd(lmin, rmin), CapProd(lmax, rmax))) {
4205  left->solver()->Fail();
4206  }
4207  if (m > lmin * rmin) { // Must be positive section * positive section.
4208  left->SetMin(PosIntDivUp(m, rmax));
4209  right->SetMin(PosIntDivUp(m, lmax));
4210  } else if (m > CapProd(lmax, rmax)) { // Negative section * negative section.
4211  left->SetMax(-PosIntDivUp(m, -rmin));
4212  right->SetMax(-PosIntDivUp(m, -lmin));
4213  }
4214 }
4215 
4216 void TimesSetMin(IntExpr* const left, IntExpr* const right,
4217  IntExpr* const minus_left, IntExpr* const minus_right,
4218  int64 m) {
4219  if (left->Min() >= 0) {
4220  if (right->Min() >= 0) {
4221  SetPosPosMinExpr(left, right, m);
4222  } else if (right->Max() <= 0) {
4223  SetPosPosMaxExpr(left, minus_right, -m);
4224  } else { // right->Min() < 0 && right->Max() > 0
4225  SetPosGenMinExpr(left, right, m);
4226  }
4227  } else if (left->Max() <= 0) {
4228  if (right->Min() >= 0) {
4229  SetPosPosMaxExpr(right, minus_left, -m);
4230  } else if (right->Max() <= 0) {
4231  SetPosPosMinExpr(minus_left, minus_right, m);
4232  } else { // right->Min() < 0 && right->Max() > 0
4233  SetPosGenMinExpr(minus_left, minus_right, m);
4234  }
4235  } else if (right->Min() >= 0) { // left->Min() < 0 && left->Max() > 0
4236  SetPosGenMinExpr(right, left, m);
4237  } else if (right->Max() <= 0) { // left->Min() < 0 && left->Max() > 0
4238  SetPosGenMinExpr(minus_right, minus_left, m);
4239  } else { // left->Min() < 0 && left->Max() > 0 &&
4240  // right->Min() < 0 && right->Max() > 0
4241  SetGenGenMinExpr(left, right, m);
4242  }
4243 }
4244 
4245 class TimesIntExpr : public BaseIntExpr {
4246  public:
4247  TimesIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4248  : BaseIntExpr(s),
4249  left_(l),
4250  right_(r),
4251  minus_left_(s->MakeOpposite(left_)),
4252  minus_right_(s->MakeOpposite(right_)) {}
4253  ~TimesIntExpr() override {}
4254  int64 Min() const override {
4255  const int64 lmin = left_->Min();
4256  const int64 lmax = left_->Max();
4257  const int64 rmin = right_->Min();
4258  const int64 rmax = right_->Max();
4259  return std::min(std::min(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4260  std::min(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4261  }
4262  void SetMin(int64 m) override;
4263  int64 Max() const override {
4264  const int64 lmin = left_->Min();
4265  const int64 lmax = left_->Max();
4266  const int64 rmin = right_->Min();
4267  const int64 rmax = right_->Max();
4268  return std::max(std::max(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4269  std::max(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4270  }
4271  void SetMax(int64 m) override;
4272  bool Bound() const override;
4273  std::string name() const override {
4274  return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4275  }
4276  std::string DebugString() const override {
4277  return absl::StrFormat("(%s * %s)", left_->DebugString(),
4278  right_->DebugString());
4279  }
4280  void WhenRange(Demon* d) override {
4281  left_->WhenRange(d);
4282  right_->WhenRange(d);
4283  }
4284 
4285  void Accept(ModelVisitor* const visitor) const override {
4286  visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4287  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4288  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4289  right_);
4290  visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4291  }
4292 
4293  private:
4294  IntExpr* const left_;
4295  IntExpr* const right_;
4296  IntExpr* const minus_left_;
4297  IntExpr* const minus_right_;
4298 };
4299 
4300 void TimesIntExpr::SetMin(int64 m) {
4301  if (m != kint64min) {
4302  TimesSetMin(left_, right_, minus_left_, minus_right_, m);
4303  }
4304 }
4305 
4306 void TimesIntExpr::SetMax(int64 m) {
4307  if (m != kint64max) {
4308  TimesSetMin(left_, minus_right_, minus_left_, right_, -m);
4309  }
4310 }
4311 
4312 bool TimesIntExpr::Bound() const {
4313  const bool left_bound = left_->Bound();
4314  const bool right_bound = right_->Bound();
4315  return ((left_bound && left_->Max() == 0) ||
4316  (right_bound && right_->Max() == 0) || (left_bound && right_bound));
4317 }
4318 
4319 // ----- TimesPosIntExpr -----
4320 
4321 class TimesPosIntExpr : public BaseIntExpr {
4322  public:
4323  TimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4324  : BaseIntExpr(s), left_(l), right_(r) {}
4325  ~TimesPosIntExpr() override {}
4326  int64 Min() const override { return (left_->Min() * right_->Min()); }
4327  void SetMin(int64 m) override;
4328  int64 Max() const override { return (left_->Max() * right_->Max()); }
4329  void SetMax(int64 m) override;
4330  bool Bound() const override;
4331  std::string name() const override {
4332  return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4333  }
4334  std::string DebugString() const override {
4335  return absl::StrFormat("(%s * %s)", left_->DebugString(),
4336  right_->DebugString());
4337  }
4338  void WhenRange(Demon* d) override {
4339  left_->WhenRange(d);
4340  right_->WhenRange(d);
4341  }
4342 
4343  void Accept(ModelVisitor* const visitor) const override {
4344  visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4345  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4346  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4347  right_);
4348  visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4349  }
4350 
4351  private:
4352  IntExpr* const left_;
4353  IntExpr* const right_;
4354 };
4355 
4356 void TimesPosIntExpr::SetMin(int64 m) { SetPosPosMinExpr(left_, right_, m); }
4357 
4358 void TimesPosIntExpr::SetMax(int64 m) { SetPosPosMaxExpr(left_, right_, m); }
4359 
4360 bool TimesPosIntExpr::Bound() const {
4361  return (left_->Max() == 0 || right_->Max() == 0 ||
4362  (left_->Bound() && right_->Bound()));
4363 }
4364 
4365 // ----- SafeTimesPosIntExpr -----
4366 
4367 class SafeTimesPosIntExpr : public BaseIntExpr {
4368  public:
4369  SafeTimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4370  : BaseIntExpr(s), left_(l), right_(r) {}
4371  ~SafeTimesPosIntExpr() override {}
4372  int64 Min() const override { return CapProd(left_->Min(), right_->Min()); }
4373  void SetMin(int64 m) override {
4374  if (m != kint64min) {
4375  SetPosPosMinExpr(left_, right_, m);
4376  }
4377  }
4378  int64 Max() const override { return CapProd(left_->Max(), right_->Max()); }
4379  void SetMax(int64 m) override {
4380  if (m != kint64max) {
4381  SetPosPosMaxExpr(left_, right_, m);
4382  }
4383  }
4384  bool Bound() const override {
4385  return (left_->Max() == 0 || right_->Max() == 0 ||
4386  (left_->Bound() && right_->Bound()));
4387  }
4388  std::string name() const override {
4389  return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4390  }
4391  std::string DebugString() const override {
4392  return absl::StrFormat("(%s * %s)", left_->DebugString(),
4393  right_->DebugString());
4394  }
4395  void WhenRange(Demon* d) override {
4396  left_->WhenRange(d);
4397  right_->WhenRange(d);
4398  }
4399 
4400  void Accept(ModelVisitor* const visitor) const override {
4401  visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4402  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4403  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4404  right_);
4405  visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4406  }
4407 
4408  private:
4409  IntExpr* const left_;
4410  IntExpr* const right_;
4411 };
4412 
4413 // ----- TimesBooleanPosIntExpr -----
4414 
4415 class TimesBooleanPosIntExpr : public BaseIntExpr {
4416  public:
4417  TimesBooleanPosIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4418  : BaseIntExpr(s), boolvar_(b), expr_(e) {}
4419  ~TimesBooleanPosIntExpr() override {}
4420  int64 Min() const override {
4421  return (boolvar_->RawValue() == 1 ? expr_->Min() : 0);
4422  }
4423  void SetMin(int64 m) override;
4424  int64 Max() const override {
4425  return (boolvar_->RawValue() == 0 ? 0 : expr_->Max());
4426  }
4427  void SetMax(int64 m) override;
4428  void Range(int64* mi, int64* ma) override;
4429  void SetRange(int64 mi, int64 ma) override;
4430  bool Bound() const override;
4431  std::string name() const override {
4432  return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4433  }
4434  std::string DebugString() const override {
4435  return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4436  expr_->DebugString());
4437  }
4438  void WhenRange(Demon* d) override {
4439  boolvar_->WhenRange(d);
4440  expr_->WhenRange(d);
4441  }
4442 
4443  void Accept(ModelVisitor* const visitor) const override {
4444  visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4445  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4446  boolvar_);
4447  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4448  expr_);
4449  visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4450  }
4451 
4452  private:
4453  BooleanVar* const boolvar_;
4454  IntExpr* const expr_;
4455 };
4456 
4457 void TimesBooleanPosIntExpr::SetMin(int64 m) {
4458  if (m > 0) {
4459  boolvar_->SetValue(1);
4460  expr_->SetMin(m);
4461  }
4462 }
4463 
4464 void TimesBooleanPosIntExpr::SetMax(int64 m) {
4465  if (m < 0) {
4466  solver()->Fail();
4467  }
4468  if (m < expr_->Min()) {
4469  boolvar_->SetValue(0);
4470  }
4471  if (boolvar_->RawValue() == 1) {
4472  expr_->SetMax(m);
4473  }
4474 }
4475 
4476 void TimesBooleanPosIntExpr::Range(int64* mi, int64* ma) {
4477  const int value = boolvar_->RawValue();
4478  if (value == 0) {
4479  *mi = 0;
4480  *ma = 0;
4481  } else if (value == 1) {
4482  expr_->Range(mi, ma);
4483  } else {
4484  *mi = 0;
4485  *ma = expr_->Max();
4486  }
4487 }
4488 
4489 void TimesBooleanPosIntExpr::SetRange(int64 mi, int64 ma) {
4490  if (ma < 0 || mi > ma) {
4491  solver()->Fail();
4492  }
4493  if (mi > 0) {
4494  boolvar_->SetValue(1);
4495  expr_->SetMin(mi);
4496  }
4497  if (ma < expr_->Min()) {
4498  boolvar_->SetValue(0);
4499  }
4500  if (boolvar_->RawValue() == 1) {
4501  expr_->SetMax(ma);
4502  }
4503 }
4504 
4505 bool TimesBooleanPosIntExpr::Bound() const {
4506  return (boolvar_->RawValue() == 0 || expr_->Max() == 0 ||
4507  (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue &&
4508  expr_->Bound()));
4509 }
4510 
4511 // ----- TimesBooleanIntExpr -----
4512 
4513 class TimesBooleanIntExpr : public BaseIntExpr {
4514  public:
4515  TimesBooleanIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4516  : BaseIntExpr(s), boolvar_(b), expr_(e) {}
4517  ~TimesBooleanIntExpr() override {}
4518  int64 Min() const override {
4519  switch (boolvar_->RawValue()) {
4520  case 0: {
4521  return 0LL;
4522  }
4523  case 1: {
4524  return expr_->Min();
4525  }
4526  default: {
4527  DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4528  return std::min(int64{0}, expr_->Min());
4529  }
4530  }
4531  }
4532  void SetMin(int64 m) override;
4533  int64 Max() const override {
4534  switch (boolvar_->RawValue()) {
4535  case 0: {
4536  return 0LL;
4537  }
4538  case 1: {
4539  return expr_->Max();
4540  }
4541  default: {
4542  DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4543  return std::max(int64{0}, expr_->Max());
4544  }
4545  }
4546  }
4547  void SetMax(int64 m) override;
4548  void Range(int64* mi, int64* ma) override;
4549  void SetRange(int64 mi, int64 ma) override;
4550  bool Bound() const override;
4551  std::string name() const override {
4552  return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4553  }
4554  std::string DebugString() const override {
4555  return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4556  expr_->DebugString());
4557  }
4558  void WhenRange(Demon* d) override {
4559  boolvar_->WhenRange(d);
4560  expr_->WhenRange(d);
4561  }
4562 
4563  void Accept(ModelVisitor* const visitor) const override {
4564  visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4565  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4566  boolvar_);
4567  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4568  expr_);
4569  visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4570  }
4571 
4572  private:
4573  BooleanVar* const boolvar_;
4574  IntExpr* const expr_;
4575 };
4576 
4577 void TimesBooleanIntExpr::SetMin(int64 m) {
4578  switch (boolvar_->RawValue()) {
4579  case 0: {
4580  if (m > 0) {
4581  solver()->Fail();
4582  }
4583  break;
4584  }
4585  case 1: {
4586  expr_->SetMin(m);
4587  break;
4588  }
4589  default: {
4590  DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4591  if (m > 0) { // 0 is no longer possible for boolvar because min > 0.
4592  boolvar_->SetValue(1);
4593  expr_->SetMin(m);
4594  } else if (m <= 0 && expr_->Max() < m) {
4595  boolvar_->SetValue(0);
4596  }
4597  }
4598  }
4599 }
4600 
4601 void TimesBooleanIntExpr::SetMax(int64 m) {
4602  switch (boolvar_->RawValue()) {
4603  case 0: {
4604  if (m < 0) {
4605  solver()->Fail();
4606  }
4607  break;
4608  }
4609  case 1: {
4610  expr_->SetMax(m);
4611  break;
4612  }
4613  default: {
4614  DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4615  if (m < 0) { // 0 is no longer possible for boolvar because max < 0.
4616  boolvar_->SetValue(1);
4617  expr_->SetMax(m);
4618  } else if (m >= 0 && expr_->Min() > m) {
4619  boolvar_->SetValue(0);
4620  }
4621  }
4622  }
4623 }
4624 
4625 void TimesBooleanIntExpr::Range(int64* mi, int64* ma) {
4626  switch (boolvar_->RawValue()) {
4627  case 0: {
4628  *mi = 0;
4629  *ma = 0;
4630  break;
4631  }
4632  case 1: {
4633  *mi = expr_->Min();
4634  *ma = expr_->Max();
4635  break;
4636  }
4637  default: {
4638  DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4639  *mi = std::min(int64{0}, expr_->Min());
4640  *ma = std::max(int64{0}, expr_->Max());
4641  break;
4642  }
4643  }
4644 }
4645 
4646 void TimesBooleanIntExpr::SetRange(int64 mi, int64 ma) {
4647  if (mi > ma) {
4648  solver()->Fail();
4649  }
4650  switch (boolvar_->RawValue()) {
4651  case 0: {
4652  if (mi > 0 || ma < 0) {
4653  solver()->Fail();
4654  }
4655  break;
4656  }
4657  case 1: {
4658  expr_->SetRange(mi, ma);
4659  break;
4660  }
4661  default: {
4662  DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4663  if (mi > 0) {
4664  boolvar_->SetValue(1);
4665  expr_->SetMin(mi);
4666  } else if (mi == 0 && expr_->Max() < 0) {
4667  boolvar_->SetValue(0);
4668  }
4669  if (ma < 0) {
4670  boolvar_->SetValue(1);
4671  expr_->SetMax(ma);
4672  } else if (ma == 0 && expr_->Min() > 0) {
4673  boolvar_->SetValue(0);
4674  }
4675  break;
4676  }
4677  }
4678 }
4679 
4680 bool TimesBooleanIntExpr::Bound() const {
4681  return (boolvar_->RawValue() == 0 ||
4682  (expr_->Bound() &&
4683  (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue ||
4684  expr_->Max() == 0)));
4685 }
4686 
4687 // ----- DivPosIntCstExpr -----
4688 
4689 class DivPosIntCstExpr : public BaseIntExpr {
4690  public:
4691  DivPosIntCstExpr(Solver* const s, IntExpr* const e, int64 v)
4692  : BaseIntExpr(s), expr_(e), value_(v) {
4693  CHECK_GE(v, 0);
4694  }
4695  ~DivPosIntCstExpr() override {}
4696 
4697  int64 Min() const override { return expr_->Min() / value_; }
4698 
4699  void SetMin(int64 m) override {
4700  if (m > 0) {
4701  expr_->SetMin(m * value_);
4702  } else {
4703  expr_->SetMin((m - 1) * value_ + 1);
4704  }
4705  }
4706  int64 Max() const override { return expr_->Max() / value_; }
4707 
4708  void SetMax(int64 m) override {
4709  if (m >= 0) {
4710  expr_->SetMax((m + 1) * value_ - 1);
4711  } else {
4712  expr_->SetMax(m * value_);
4713  }
4714  }
4715 
4716  std::string name() const override {
4717  return absl::StrFormat("(%s div %d)", expr_->name(), value_);
4718  }
4719 
4720  std::string DebugString() const override {
4721  return absl::StrFormat("(%s div %d)", expr_->DebugString(), value_);
4722  }
4723 
4724  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
4725 
4726  void Accept(ModelVisitor* const visitor) const override {
4727  visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4728  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
4729  expr_);
4730  visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4731  visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4732  }
4733 
4734  private:
4735  IntExpr* const expr_;
4736  const int64 value_;
4737 };
4738 
4739 // DivPosIntExpr
4740 
4741 class DivPosIntExpr : public BaseIntExpr {
4742  public:
4743  DivPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4744  : BaseIntExpr(s),
4745  num_(num),
4746  denom_(denom),
4747  opp_num_(s->MakeOpposite(num)) {}
4748 
4749  ~DivPosIntExpr() override {}
4750 
4751  int64 Min() const override {
4752  return num_->Min() >= 0
4753  ? num_->Min() / denom_->Max()
4754  : (denom_->Min() == 0 ? num_->Min()
4755  : num_->Min() / denom_->Min());
4756  }
4757 
4758  int64 Max() const override {
4759  return num_->Max() >= 0 ? (denom_->Min() == 0 ? num_->Max()
4760  : num_->Max() / denom_->Min())
4761  : num_->Max() / denom_->Max();
4762  }
4763 
4764  static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64 m) {
4765  num->SetMin(m * denom->Min());
4766  denom->SetMax(num->Max() / m);
4767  }
4768 
4769  static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64 m) {
4770  num->SetMax((m + 1) * denom->Max() - 1);
4771  denom->SetMin(num->Min() / (m + 1) + 1);
4772  }
4773 
4774  void SetMin(int64 m) override {
4775  if (m > 0) {
4776  SetPosMin(num_, denom_, m);
4777  } else {
4778  SetPosMax(opp_num_, denom_, -m);
4779  }
4780  }
4781 
4782  void SetMax(int64 m) override {
4783  if (m >= 0) {
4784  SetPosMax(num_, denom_, m);
4785  } else {
4786  SetPosMin(opp_num_, denom_, -m);
4787  }
4788  }
4789 
4790  std::string name() const override {
4791  return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4792  }
4793  std::string DebugString() const override {
4794  return absl::StrFormat("(%s div %s)", num_->DebugString(),
4795  denom_->DebugString());
4796  }
4797  void WhenRange(Demon* d) override {
4798  num_->WhenRange(d);
4799  denom_->WhenRange(d);
4800  }
4801 
4802  void Accept(ModelVisitor* const visitor) const override {
4803  visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4804  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4805  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4806  denom_);
4807  visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4808  }
4809 
4810  private:
4811  IntExpr* const num_;
4812  IntExpr* const denom_;
4813  IntExpr* const opp_num_;
4814 };
4815 
4816 class DivPosPosIntExpr : public BaseIntExpr {
4817  public:
4818  DivPosPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4819  : BaseIntExpr(s), num_(num), denom_(denom) {}
4820 
4821  ~DivPosPosIntExpr() override {}
4822 
4823  int64 Min() const override {
4824  if (denom_->Max() == 0) {
4825  solver()->Fail();
4826  }
4827  return num_->Min() / denom_->Max();
4828  }
4829 
4830  int64 Max() const override {
4831  if (denom_->Min() == 0) {
4832  return num_->Max();
4833  } else {
4834  return num_->Max() / denom_->Min();
4835  }
4836  }
4837 
4838  void SetMin(int64 m) override {
4839  if (m > 0) {
4840  num_->SetMin(m * denom_->Min());
4841  denom_->SetMax(num_->Max() / m);
4842  }
4843  }
4844 
4845  void SetMax(int64 m) override {
4846  if (m >= 0) {
4847  num_->SetMax((m + 1) * denom_->Max() - 1);
4848  denom_->SetMin(num_->Min() / (m + 1) + 1);
4849  } else {
4850  solver()->Fail();
4851  }
4852  }
4853 
4854  std::string name() const override {
4855  return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4856  }
4857 
4858  std::string DebugString() const override {
4859  return absl::StrFormat("(%s div %s)", num_->DebugString(),
4860  denom_->DebugString());
4861  }
4862 
4863  void WhenRange(Demon* d) override {
4864  num_->WhenRange(d);
4865  denom_->WhenRange(d);
4866  }
4867 
4868  void Accept(ModelVisitor* const visitor) const override {
4869  visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4870  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4871  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4872  denom_);
4873  visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4874  }
4875 
4876  private:
4877  IntExpr* const num_;
4878  IntExpr* const denom_;
4879 };
4880 
4881 // DivIntExpr
4882 
4883 class DivIntExpr : public BaseIntExpr {
4884  public:
4885  DivIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4886  : BaseIntExpr(s),
4887  num_(num),
4888  denom_(denom),
4889  opp_num_(s->MakeOpposite(num)) {}
4890 
4891  ~DivIntExpr() override {}
4892 
4893  int64 Min() const override {
4894  const int64 num_min = num_->Min();
4895  const int64 num_max = num_->Max();
4896  const int64 denom_min = denom_->Min();
4897  const int64 denom_max = denom_->Max();
4898 
4899  if (denom_min == 0 && denom_max == 0) {
4900  return kint64max; // TODO(user): Check this convention.
4901  }
4902 
4903  if (denom_min >= 0) { // Denominator strictly positive.
4904  DCHECK_GT(denom_max, 0);
4905  const int64 adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4906  return num_min >= 0 ? num_min / denom_max : num_min / adjusted_denom_min;
4907  } else if (denom_max <= 0) { // Denominator strictly negative.
4908  DCHECK_LT(denom_min, 0);
4909  const int64 adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4910  return num_max >= 0 ? num_max / adjusted_denom_max : num_max / denom_min;
4911  } else { // Denominator across 0.
4912  return std::min(num_min, -num_max);
4913  }
4914  }
4915 
4916  int64 Max() const override {
4917  const int64 num_min = num_->Min();
4918  const int64 num_max = num_->Max();
4919  const int64 denom_min = denom_->Min();
4920  const int64 denom_max = denom_->Max();
4921 
4922  if (denom_min == 0 && denom_max == 0) {
4923  return kint64min; // TODO(user): Check this convention.
4924  }
4925 
4926  if (denom_min >= 0) { // Denominator strictly positive.
4927  DCHECK_GT(denom_max, 0);
4928  const int64 adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4929  return num_max >= 0 ? num_max / adjusted_denom_min : num_max / denom_max;
4930  } else if (denom_max <= 0) { // Denominator strictly negative.
4931  DCHECK_LT(denom_min, 0);
4932  const int64 adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4933  return num_min >= 0 ? num_min / denom_min
4934  : -num_min / -adjusted_denom_max;
4935  } else { // Denominator across 0.
4936  return std::max(num_max, -num_min);
4937  }
4938  }
4939 
4940  void AdjustDenominator() {
4941  if (denom_->Min() == 0) {
4942  denom_->SetMin(1);
4943  } else if (denom_->Max() == 0) {
4944  denom_->SetMax(-1);
4945  }
4946  }
4947 
4948  // m > 0.
4949  static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64 m) {
4950  DCHECK_GT(m, 0);
4951  const int64 num_min = num->Min();
4952  const int64 num_max = num->Max();
4953  const int64 denom_min = denom->Min();
4954  const int64 denom_max = denom->Max();
4955  DCHECK_NE(denom_min, 0);
4956  DCHECK_NE(denom_max, 0);
4957  if (denom_min > 0) { // Denominator strictly positive.
4958  num->SetMin(m * denom_min);
4959  denom->SetMax(num_max / m);
4960  } else if (denom_max < 0) { // Denominator strictly negative.
4961  num->SetMax(m * denom_max);
4962  denom->SetMin(num_min / m);
4963  } else { // Denominator across 0.
4964  if (num_min >= 0) {
4965  num->SetMin(m);
4966  denom->SetRange(1, num_max / m);
4967  } else if (num_max <= 0) {
4968  num->SetMax(-m);
4969  denom->SetRange(num_min / m, -1);
4970  } else {
4971  if (m > -num_min) { // Denominator is forced positive.
4972  num->SetMin(m);
4973  denom->SetRange(1, num_max / m);
4974  } else if (m > num_max) { // Denominator is forced negative.
4975  num->SetMax(-m);
4976  denom->SetRange(num_min / m, -1);
4977  } else {
4978  denom->SetRange(num_min / m, num_max / m);
4979  }
4980  }
4981  }
4982  }
4983 
4984  // m >= 0.
4985  static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64 m) {
4986  DCHECK_GE(m, 0);
4987  const int64 num_min = num->Min();
4988  const int64 num_max = num->Max();
4989  const int64 denom_min = denom->Min();
4990  const int64 denom_max = denom->Max();
4991  DCHECK_NE(denom_min, 0);
4992  DCHECK_NE(denom_max, 0);
4993  if (denom_min > 0) { // Denominator strictly positive.
4994  num->SetMax((m + 1) * denom_max - 1);
4995  denom->SetMin((num_min / (m + 1)) + 1);
4996  } else if (denom_max < 0) {
4997  num->SetMin((m + 1) * denom_min + 1);
4998  denom->SetMax(num_max / (m + 1) - 1);
4999  } else if (num_min > (m + 1) * denom_max - 1) {
5000  denom->SetMax(-1);
5001  } else if (num_max < (m + 1) * denom_min + 1) {
5002  denom->SetMin(1);
5003  }
5004  }
5005 
5006  void SetMin(int64 m) override {
5007  AdjustDenominator();
5008  if (m > 0) {
5009  SetPosMin(num_, denom_, m);
5010  } else {
5011  SetPosMax(opp_num_, denom_, -m);
5012  }
5013  }
5014 
5015  void SetMax(int64 m) override {
5016  AdjustDenominator();
5017  if (m >= 0) {
5018  SetPosMax(num_, denom_, m);
5019  } else {
5020  SetPosMin(opp_num_, denom_, -m);
5021  }
5022  }
5023 
5024  std::string name() const override {
5025  return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
5026  }
5027  std::string DebugString() const override {
5028  return absl::StrFormat("(%s div %s)", num_->DebugString(),
5029  denom_->DebugString());
5030  }
5031  void WhenRange(Demon* d) override {
5032  num_->WhenRange(d);
5033  denom_->WhenRange(d);
5034  }
5035 
5036  void Accept(ModelVisitor* const visitor) const override {
5037  visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
5038  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
5039  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5040  denom_);
5041  visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
5042  }
5043 
5044  private:
5045  IntExpr* const num_;
5046  IntExpr* const denom_;
5047  IntExpr* const opp_num_;
5048 };
5049 
5050 // ----- IntAbs And IntAbsConstraint ------
5051 
5052 class IntAbsConstraint : public CastConstraint {
5053  public:
5054  IntAbsConstraint(Solver* const s, IntVar* const sub, IntVar* const target)
5055  : CastConstraint(s, target), sub_(sub) {}
5056 
5057  ~IntAbsConstraint() override {}
5058 
5059  void Post() override {
5060  Demon* const sub_demon = MakeConstraintDemon0(
5061  solver(), this, &IntAbsConstraint::PropagateSub, "PropagateSub");
5062  sub_->WhenRange(sub_demon);
5063  Demon* const target_demon = MakeConstraintDemon0(
5064  solver(), this, &IntAbsConstraint::PropagateTarget, "PropagateTarget");
5065  target_var_->WhenRange(target_demon);
5066  }
5067 
5068  void InitialPropagate() override {
5069  PropagateSub();
5070  PropagateTarget();
5071  }
5072 
5073  void PropagateSub() {
5074  const int64 smin = sub_->Min();
5075  const int64 smax = sub_->Max();
5076  if (smax <= 0) {
5077  target_var_->SetRange(-smax, -smin);
5078  } else if (smin >= 0) {
5079  target_var_->SetRange(smin, smax);
5080  } else {
5081  target_var_->SetRange(0, std::max(-smin, smax));
5082  }
5083  }
5084 
5085  void PropagateTarget() {
5086  const int64 target_max = target_var_->Max();
5087  sub_->SetRange(-target_max, target_max);
5088  const int64 target_min = target_var_->Min();
5089  if (target_min > 0) {
5090  if (sub_->Min() > -target_min) {
5091  sub_->SetMin(target_min);
5092  } else if (sub_->Max() < target_min) {
5093  sub_->SetMax(-target_min);
5094  }
5095  }
5096  }
5097 
5098  std::string DebugString() const override {
5099  return absl::StrFormat("IntAbsConstraint(%s, %s)", sub_->DebugString(),
5100  target_var_->DebugString());
5101  }
5102 
5103  void Accept(ModelVisitor* const visitor) const override {
5104  visitor->BeginVisitConstraint(ModelVisitor::kAbsEqual, this);
5105  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5106  sub_);
5107  visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
5108  target_var_);
5109  visitor->EndVisitConstraint(ModelVisitor::kAbsEqual, this);
5110  }
5111 
5112  private:
5113  IntVar* const sub_;
5114 };
5115 
5116 class IntAbs : public BaseIntExpr {
5117  public:
5118  IntAbs(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
5119 
5120  ~IntAbs() override {}
5121 
5122  int64 Min() const override {
5123  int64 emin = 0;
5124  int64 emax = 0;
5125  expr_->Range(&emin, &emax);
5126  if (emin >= 0) {
5127  return emin;
5128  }
5129  if (emax <= 0) {
5130  return -emax;
5131  }
5132  return 0;
5133  }
5134 
5135  void SetMin(int64 m) override {
5136  if (m > 0) {
5137  int64 emin = 0;
5138  int64 emax = 0;
5139  expr_->Range(&emin, &emax);
5140  if (emin > -m) {
5141  expr_->SetMin(m);
5142  } else if (emax < m) {
5143  expr_->SetMax(-m);
5144  }
5145  }
5146  }
5147 
5148  int64 Max() const override {
5149  int64 emin = 0;
5150  int64 emax = 0;
5151  expr_->Range(&emin, &emax);
5152  return std::max(-emin, emax);
5153  }
5154 
5155  void SetMax(int64 m) override { expr_->SetRange(-m, m); }
5156 
5157  void SetRange(int64 mi, int64 ma) override {
5158  expr_->SetRange(-ma, ma);
5159  if (mi > 0) {
5160  int64 emin = 0;
5161  int64 emax = 0;
5162  expr_->Range(&emin, &emax);
5163  if (emin > -mi) {
5164  expr_->SetMin(mi);
5165  } else if (emax < mi) {
5166  expr_->SetMax(-mi);
5167  }
5168  }
5169  }
5170 
5171  void Range(int64* mi, int64* ma) override {
5172  int64 emin = 0;
5173  int64 emax = 0;
5174  expr_->Range(&emin, &emax);
5175  if (emin >= 0) {
5176  *mi = emin;
5177  *ma = emax;
5178  } else if (emax <= 0) {
5179  *mi = -emax;
5180  *ma = -emin;
5181  } else {
5182  *mi = 0;
5183  *ma = std::max(-emin, emax);
5184  }
5185  }
5186 
5187  bool Bound() const override { return expr_->Bound(); }
5188 
5189  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5190 
5191  std::string name() const override {
5192  return absl::StrFormat("IntAbs(%s)", expr_->name());
5193  }
5194 
5195  std::string DebugString() const override {
5196  return absl::StrFormat("IntAbs(%s)", expr_->DebugString());
5197  }
5198 
5199  void Accept(ModelVisitor* const visitor) const override {
5200  visitor->BeginVisitIntegerExpression(ModelVisitor::kAbs, this);
5201  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5202  expr_);
5203  visitor->EndVisitIntegerExpression(ModelVisitor::kAbs, this);
5204  }
5205 
5206  IntVar* CastToVar() override {
5207  int64 min_value = 0;
5208  int64 max_value = 0;
5209  Range(&min_value, &max_value);
5210  Solver* const s = solver();
5211  const std::string name = absl::StrFormat("AbsVar(%s)", expr_->name());
5212  IntVar* const target = s->MakeIntVar(min_value, max_value, name);
5213  CastConstraint* const ct =
5214  s->RevAlloc(new IntAbsConstraint(s, expr_->Var(), target));
5215  s->AddCastConstraint(ct, target, this);
5216  return target;
5217  }
5218 
5219  private:
5220  IntExpr* const expr_;
5221 };
5222 
5223 // ----- Square -----
5224 
5225 // TODO(user): shouldn't we compare to kint32max^2 instead of kint64max?
5226 class IntSquare : public BaseIntExpr {
5227  public:
5228  IntSquare(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
5229  ~IntSquare() override {}
5230 
5231  int64 Min() const override {
5232  const int64 emin = expr_->Min();
5233  if (emin >= 0) {
5234  return emin >= kint32max ? kint64max : emin * emin;
5235  }
5236  const int64 emax = expr_->Max();
5237  if (emax < 0) {
5238  return emax <= -kint32max ? kint64max : emax * emax;
5239  }
5240  return 0LL;
5241  }
5242  void SetMin(int64 m) override {
5243  if (m <= 0) {
5244  return;
5245  }
5246  // TODO(user): What happens if m is kint64max?
5247  const int64 emin = expr_->Min();
5248  const int64 emax = expr_->Max();
5249  const int64 root = static_cast<int64>(ceil(sqrt(static_cast<double>(m))));
5250  if (emin >= 0) {
5251  expr_->SetMin(root);
5252  } else if (emax <= 0) {
5253  expr_->SetMax(-root);
5254  } else if (expr_->IsVar()) {
5255  reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5256  }
5257  }
5258  int64 Max() const override {
5259  const int64 emax = expr_->Max();
5260  const int64 emin = expr_->Min();
5261  if (emax >= kint32max || emin <= -kint32max) {
5262  return kint64max;
5263  }
5264  return std::max(emin * emin, emax * emax);
5265  }
5266  void SetMax(int64 m) override {
5267  if (m < 0) {
5268  solver()->Fail();
5269  }
5270  if (m == kint64max) {
5271  return;
5272  }
5273  const int64 root = static_cast<int64>(floor(sqrt(static_cast<double>(m))));
5274  expr_->SetRange(-root, root);
5275  }
5276  bool Bound() const override { return expr_->Bound(); }
5277  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5278  std::string name() const override {
5279  return absl::StrFormat("IntSquare(%s)", expr_->name());
5280  }
5281  std::string DebugString() const override {
5282  return absl::StrFormat("IntSquare(%s)", expr_->DebugString());
5283  }
5284 
5285  void Accept(ModelVisitor* const visitor) const override {
5286  visitor->BeginVisitIntegerExpression(ModelVisitor::kSquare, this);
5287  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5288  expr_);
5289  visitor->EndVisitIntegerExpression(ModelVisitor::kSquare, this);
5290  }
5291 
5292  IntExpr* expr() const { return expr_; }
5293 
5294  protected:
5295  IntExpr* const expr_;
5296 };
5297 
5298 class PosIntSquare : public IntSquare {
5299  public:
5300  PosIntSquare(Solver* const s, IntExpr* const e) : IntSquare(s, e) {}
5301  ~PosIntSquare() override {}
5302 
5303  int64 Min() const override {
5304  const int64 emin = expr_->Min();
5305  return emin >= kint32max ? kint64max : emin * emin;
5306  }
5307  void SetMin(int64 m) override {
5308  if (m <= 0) {
5309  return;
5310  }
5311  const int64 root = static_cast<int64>(ceil(sqrt(static_cast<double>(m))));
5312  expr_->SetMin(root);
5313  }
5314  int64 Max() const override {
5315  const int64 emax = expr_->Max();
5316  return emax >= kint32max ? kint64max : emax * emax;
5317  }
5318  void SetMax(int64 m) override {
5319  if (m < 0) {
5320  solver()->Fail();
5321  }
5322  if (m == kint64max) {
5323  return;
5324  }
5325  const int64 root = static_cast<int64>(floor(sqrt(static_cast<double>(m))));
5326  expr_->SetMax(root);
5327  }
5328 };
5329 
5330 // ----- EvenPower -----
5331 
5332 int64 IntPower(int64 value, int64 power) {
5333  int64 result = value;
5334  // TODO(user): Speed that up.
5335  for (int i = 1; i < power; ++i) {
5336  result *= value;
5337  }
5338  return result;
5339 }
5340 
5341 int64 OverflowLimit(int64 power) {
5342  return static_cast<int64>(
5343  floor(exp(log(static_cast<double>(kint64max)) / power)));
5344 }
5345 
5346 class BasePower : public BaseIntExpr {
5347  public:
5348  BasePower(Solver* const s, IntExpr* const e, int64 n)
5349  : BaseIntExpr(s), expr_(e), pow_(n), limit_(OverflowLimit(n)) {
5350  CHECK_GT(n, 0);
5351  }
5352 
5353  ~BasePower() override {}
5354 
5355  bool Bound() const override { return expr_->Bound(); }
5356 
5357  IntExpr* expr() const { return expr_; }
5358 
5359  int64 exponant() const { return pow_; }
5360 
5361  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5362 
5363  std::string name() const override {
5364  return absl::StrFormat("IntPower(%s, %d)", expr_->name(), pow_);
5365  }
5366 
5367  std::string DebugString() const override {
5368  return absl::StrFormat("IntPower(%s, %d)", expr_->DebugString(), pow_);
5369  }
5370 
5371  void Accept(ModelVisitor* const visitor) const override {
5372  visitor->BeginVisitIntegerExpression(ModelVisitor::kPower, this);
5373  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5374  expr_);
5375  visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, pow_);
5376  visitor->EndVisitIntegerExpression(ModelVisitor::kPower, this);
5377  }
5378 
5379  protected:
5380  int64 Pown(int64 value) const {
5381  if (value >= limit_) {
5382  return kint64max;
5383  }
5384  if (value <= -limit_) {
5385  if (pow_ % 2 == 0) {
5386  return kint64max;
5387  } else {
5388  return kint64min;
5389  }
5390  }
5391  return IntPower(value, pow_);
5392  }
5393 
5394  int64 SqrnDown(int64 value) const {
5395  if (value == kint64min) {
5396  return kint64min;
5397  }
5398  if (value == kint64max) {
5399  return kint64max;
5400  }
5401  int64 res = 0;
5402  const double d_value = static_cast<double>(value);
5403  if (value >= 0) {
5404  const double sq = exp(log(d_value) / pow_);
5405  res = static_cast<int64>(floor(sq));
5406  } else {
5407  CHECK_EQ(1, pow_ % 2);
5408  const double sq = exp(log(-d_value) / pow_);
5409  res = -static_cast<int64>(ceil(sq));
5410  }
5411  const int64 pow_res = Pown(res + 1);
5412  if (pow_res <= value) {
5413  return res + 1;
5414  } else {
5415  return res;
5416  }
5417  }
5418 
5419  int64 SqrnUp(int64 value) const {
5420  if (value == kint64min) {
5421  return kint64min;
5422  }
5423  if (value == kint64max) {
5424  return kint64max;
5425  }
5426  int64 res = 0;
5427  const double d_value = static_cast<double>(value);
5428  if (value >= 0) {
5429  const double sq = exp(log(d_value) / pow_);
5430  res = static_cast<int64>(ceil(sq));
5431  } else {
5432  CHECK_EQ(1, pow_ % 2);
5433  const double sq = exp(log(-d_value) / pow_);
5434  res = -static_cast<int64>(floor(sq));
5435  }
5436  const int64 pow_res = Pown(res - 1);
5437  if (pow_res >= value) {
5438  return res - 1;
5439  } else {
5440  return res;
5441  }
5442  }
5443 
5444  IntExpr* const expr_;
5445  const int64 pow_;
5446  const int64 limit_;
5447 };
5448 
5449 class IntEvenPower : public BasePower {
5450  public:
5451  IntEvenPower(Solver* const s, IntExpr* const e, int64 n)
5452  : BasePower(s, e, n) {
5453  CHECK_EQ(0, n % 2);
5454  }
5455 
5456  ~IntEvenPower() override {}
5457 
5458  int64 Min() const override {
5459  int64 emin = 0;
5460  int64 emax = 0;
5461  expr_->Range(&emin, &emax);
5462  if (emin >= 0) {
5463  return Pown(emin);
5464  }
5465  if (emax < 0) {
5466  return Pown(emax);
5467  }
5468  return 0LL;
5469  }
5470  void SetMin(int64 m) override {
5471  if (m <= 0) {
5472  return;
5473  }
5474  int64 emin = 0;
5475  int64 emax = 0;
5476  expr_->Range(&emin, &emax);
5477  const int64 root = SqrnUp(m);
5478  if (emin > -root) {
5479  expr_->SetMin(root);
5480  } else if (emax < root) {
5481  expr_->SetMax(-root);
5482  } else if (expr_->IsVar()) {
5483  reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5484  }
5485  }
5486 
5487  int64 Max() const override {
5488  return std::max(Pown(expr_->Min()), Pown(expr_->Max()));
5489  }
5490 
5491  void SetMax(int64 m) override {
5492  if (m < 0) {
5493  solver()->Fail();
5494  }
5495  if (m == kint64max) {
5496  return;
5497  }
5498  const int64 root = SqrnDown(m);
5499  expr_->SetRange(-root, root);
5500  }
5501 };
5502 
5503 class PosIntEvenPower : public BasePower {
5504  public:
5505  PosIntEvenPower(Solver* const s, IntExpr* const e, int64 pow)
5506  : BasePower(s, e, pow) {
5507  CHECK_EQ(0, pow % 2);
5508  }
5509 
5510  ~PosIntEvenPower() override {}
5511 
5512  int64 Min() const override { return Pown(expr_->Min()); }
5513 
5514  void SetMin(int64 m) override {
5515  if (m <= 0) {
5516  return;
5517  }
5518  expr_->SetMin(SqrnUp(m));
5519  }
5520  int64 Max() const override { return Pown(expr_->Max()); }
5521 
5522  void SetMax(int64 m) override {
5523  if (m < 0) {
5524  solver()->Fail();
5525  }
5526  if (m == kint64max) {
5527  return;
5528  }
5529  expr_->SetMax(SqrnDown(m));
5530  }
5531 };
5532 
5533 class IntOddPower : public BasePower {
5534  public:
5535  IntOddPower(Solver* const s, IntExpr* const e, int64 n) : BasePower(s, e, n) {
5536  CHECK_EQ(1, n % 2);
5537  }
5538 
5539  ~IntOddPower() override {}
5540 
5541  int64 Min() const override { return Pown(expr_->Min()); }
5542 
5543  void SetMin(int64 m) override { expr_->SetMin(SqrnUp(m)); }
5544 
5545  int64 Max() const override { return Pown(expr_->Max()); }
5546 
5547  void SetMax(int64 m) override { expr_->SetMax(SqrnDown(m)); }
5548 };
5549 
5550 // ----- Min(expr, expr) -----
5551 
5552 class MinIntExpr : public BaseIntExpr {
5553  public:
5554  MinIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5555  : BaseIntExpr(s), left_(l), right_(r) {}
5556  ~MinIntExpr() override {}
5557  int64 Min() const override {
5558  const int64 lmin = left_->Min();
5559  const int64 rmin = right_->Min();
5560  return std::min(lmin, rmin);
5561  }
5562  void SetMin(int64 m) override {
5563  left_->SetMin(m);
5564  right_->SetMin(m);
5565  }
5566  int64 Max() const override {
5567  const int64 lmax = left_->Max();
5568  const int64 rmax = right_->Max();
5569  return std::min(lmax, rmax);
5570  }
5571  void SetMax(int64 m) override {
5572  if (left_->Min() > m) {
5573  right_->SetMax(m);
5574  }
5575  if (right_->Min() > m) {
5576  left_->SetMax(m);
5577  }
5578  }
5579  std::string name() const override {
5580  return absl::StrFormat("MinIntExpr(%s, %s)", left_->name(), right_->name());
5581  }
5582  std::string DebugString() const override {
5583  return absl::StrFormat("MinIntExpr(%s, %s)", left_->DebugString(),
5584  right_->DebugString());
5585  }
5586  void WhenRange(Demon* d) override {
5587  left_->WhenRange(d);
5588  right_->WhenRange(d);
5589  }
5590 
5591  void Accept(ModelVisitor* const visitor) const override {
5592  visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5593  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5594  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5595  right_);
5596  visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5597  }
5598 
5599  private:
5600  IntExpr* const left_;
5601  IntExpr* const right_;
5602 };
5603 
5604 // ----- Min(expr, constant) -----
5605 
5606 class MinCstIntExpr : public BaseIntExpr {
5607  public:
5608  MinCstIntExpr(Solver* const s, IntExpr* const e, int64 v)
5609  : BaseIntExpr(s), expr_(e), value_(v) {}
5610 
5611  ~MinCstIntExpr() override {}
5612 
5613  int64 Min() const override { return std::min(expr_->Min(), value_); }
5614 
5615  void SetMin(int64 m) override {
5616  if (m > value_) {
5617  solver()->Fail();
5618  }
5619  expr_->SetMin(m);
5620  }
5621 
5622  int64 Max() const override { return std::min(expr_->Max(), value_); }
5623 
5624  void SetMax(int64 m) override {
5625  if (value_ > m) {
5626  expr_->SetMax(m);
5627  }
5628  }
5629 
5630  bool Bound() const override {
5631  return (expr_->Bound() || expr_->Min() >= value_);
5632  }
5633 
5634  std::string name() const override {
5635  return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->name(), value_);
5636  }
5637 
5638  std::string DebugString() const override {
5639  return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->DebugString(),
5640  value_);
5641  }
5642 
5643  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5644 
5645  void Accept(ModelVisitor* const visitor) const override {
5646  visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5647  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5648  expr_);
5649  visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5650  visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5651  }
5652 
5653  private:
5654  IntExpr* const expr_;
5655  const int64 value_;
5656 };
5657 
5658 // ----- Max(expr, expr) -----
5659 
5660 class MaxIntExpr : public BaseIntExpr {
5661  public:
5662  MaxIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5663  : BaseIntExpr(s), left_(l), right_(r) {}
5664 
5665  ~MaxIntExpr() override {}
5666 
5667  int64 Min() const override { return std::max(left_->Min(), right_->Min()); }
5668 
5669  void SetMin(int64 m) override {
5670  if (left_->Max() < m) {
5671  right_->SetMin(m);
5672  } else {
5673  if (right_->Max() < m) {
5674  left_->SetMin(m);
5675  }
5676  }
5677  }
5678 
5679  int64 Max() const override { return std::max(left_->Max(), right_->Max()); }
5680 
5681  void SetMax(int64 m) override {
5682  left_->SetMax(m);
5683  right_->SetMax(m);
5684  }
5685 
5686  std::string name() const override {
5687  return absl::StrFormat("MaxIntExpr(%s, %s)", left_->name(), right_->name());
5688  }
5689 
5690  std::string DebugString() const override {
5691  return absl::StrFormat("MaxIntExpr(%s, %s)", left_->DebugString(),
5692  right_->DebugString());
5693  }
5694 
5695  void WhenRange(Demon* d) override {
5696  left_->WhenRange(d);
5697  right_->WhenRange(d);
5698  }
5699 
5700  void Accept(ModelVisitor* const visitor) const override {
5701  visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5702  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5703  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5704  right_);
5705  visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5706  }
5707 
5708  private:
5709  IntExpr* const left_;
5710  IntExpr* const right_;
5711 };
5712 
5713 // ----- Max(expr, constant) -----
5714 
5715 class MaxCstIntExpr : public BaseIntExpr {
5716  public:
5717  MaxCstIntExpr(Solver* const s, IntExpr* const e, int64 v)
5718  : BaseIntExpr(s), expr_(e), value_(v) {}
5719 
5720  ~MaxCstIntExpr() override {}
5721 
5722  int64 Min() const override { return std::max(expr_->Min(), value_); }
5723 
5724  void SetMin(int64 m) override {
5725  if (value_ < m) {
5726  expr_->SetMin(m);
5727  }
5728  }
5729 
5730  int64 Max() const override { return std::max(expr_->Max(), value_); }
5731 
5732  void SetMax(int64 m) override {
5733  if (m < value_) {
5734  solver()->Fail();
5735  }
5736  expr_->SetMax(m);
5737  }
5738 
5739  bool Bound() const override {
5740  return (expr_->Bound() || expr_->Max() <= value_);
5741  }
5742 
5743  std::string name() const override {
5744  return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->name(), value_);
5745  }
5746 
5747  std::string DebugString() const override {
5748  return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->DebugString(),
5749  value_);
5750  }
5751 
5752  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5753 
5754  void Accept(ModelVisitor* const visitor) const override {
5755  visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5756  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5757  expr_);
5758  visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5759  visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5760  }
5761 
5762  private:
5763  IntExpr* const expr_;
5764  const int64 value_;
5765 };
5766 
5767 // ----- Convex Piecewise -----
5768 
5769 // This class is a very simple convex piecewise linear function. The
5770 // argument of the function is the expression. Between early_date and
5771 // late_date, the value of the function is 0. Before early date, it
5772 // is affine and the cost is early_cost * (early_date - x). After
5773 // late_date, the cost is late_cost * (x - late_date).
5774 
5775 class SimpleConvexPiecewiseExpr : public BaseIntExpr {
5776  public:
5777  SimpleConvexPiecewiseExpr(Solver* const s, IntExpr* const e, int64 ec,
5778  int64 ed, int64 ld, int64 lc)
5779  : BaseIntExpr(s),
5780  expr_(e),
5781  early_cost_(ec),
5782  early_date_(ec == 0 ? kint64min : ed),
5783  late_date_(lc == 0 ? kint64max : ld),
5784  late_cost_(lc) {
5785  DCHECK_GE(ec, int64{0});
5786  DCHECK_GE(lc, int64{0});
5787  DCHECK_GE(ld, ed);
5788 
5789  // If the penalty is 0, we can push the "confort zone or zone
5790  // of no cost towards infinity.
5791  }
5792 
5793  ~SimpleConvexPiecewiseExpr() override {}
5794 
5795  int64 Min() const override {
5796  const int64 vmin = expr_->Min();
5797  const int64 vmax = expr_->Max();
5798  if (vmin >= late_date_) {
5799  return (vmin - late_date_) * late_cost_;
5800  } else if (vmax <= early_date_) {
5801  return (early_date_ - vmax) * early_cost_;
5802  } else {
5803  return 0LL;
5804  }
5805  }
5806 
5807  void SetMin(int64 m) override {
5808  if (m <= 0) {
5809  return;
5810  }
5811  int64 vmin = 0;
5812  int64 vmax = 0;
5813  expr_->Range(&vmin, &vmax);
5814 
5815  const int64 rb =
5816  (late_cost_ == 0 ? vmax : late_date_ + PosIntDivUp(m, late_cost_) - 1);
5817  const int64 lb =
5818  (early_cost_ == 0 ? vmin
5819  : early_date_ - PosIntDivUp(m, early_cost_) + 1);
5820 
5821  if (expr_->IsVar()) {
5822  expr_->Var()->RemoveInterval(lb, rb);
5823  }
5824  }
5825 
5826  int64 Max() const override {
5827  const int64 vmin = expr_->Min();
5828  const int64 vmax = expr_->Max();
5829  const int64 mr = vmax > late_date_ ? (vmax - late_date_) * late_cost_ : 0;
5830  const int64 ml =
5831  vmin < early_date_ ? (early_date_ - vmin) * early_cost_ : 0;
5832  return std::max(mr, ml);
5833  }
5834 
5835  void SetMax(int64 m) override {
5836  if (m < 0) {
5837  solver()->Fail();
5838  }
5839  if (late_cost_ != 0LL) {
5840  const int64 rb = late_date_ + PosIntDivDown(m, late_cost_);
5841  if (early_cost_ != 0LL) {
5842  const int64 lb = early_date_ - PosIntDivDown(m, early_cost_);
5843  expr_->SetRange(lb, rb);
5844  } else {
5845  expr_->SetMax(rb);
5846  }
5847  } else {
5848  if (early_cost_ != 0LL) {
5849  const int64 lb = early_date_ - PosIntDivDown(m, early_cost_);
5850  expr_->SetMin(lb);
5851  }
5852  }
5853  }
5854 
5855  std::string name() const override {
5856  return absl::StrFormat(
5857  "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5858  expr_->name(), early_cost_, early_date_, late_date_, late_cost_);
5859  }
5860 
5861  std::string DebugString() const override {
5862  return absl::StrFormat(
5863  "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5864  expr_->DebugString(), early_cost_, early_date_, late_date_, late_cost_);
5865  }
5866 
5867  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5868 
5869  void Accept(ModelVisitor* const visitor) const override {
5870  visitor->BeginVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5871  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5872  expr_);
5873  visitor->VisitIntegerArgument(ModelVisitor::kEarlyCostArgument,
5874  early_cost_);
5875  visitor->VisitIntegerArgument(ModelVisitor::kEarlyDateArgument,
5876  early_date_);
5877  visitor->VisitIntegerArgument(ModelVisitor::kLateCostArgument, late_cost_);
5878  visitor->VisitIntegerArgument(ModelVisitor::kLateDateArgument, late_date_);
5879  visitor->EndVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5880  }
5881 
5882  private:
5883  IntExpr* const expr_;
5884  const int64 early_cost_;
5885  const int64 early_date_;
5886  const int64 late_date_;
5887  const int64 late_cost_;
5888 };
5889 
5890 // ----- Semi Continuous -----
5891 
5892 class SemiContinuousExpr : public BaseIntExpr {
5893  public:
5894  SemiContinuousExpr(Solver* const s, IntExpr* const e, int64 fixed_charge,
5895  int64 step)
5896  : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge), step_(step) {
5897  DCHECK_GE(fixed_charge, int64{0});
5898  DCHECK_GT(step, int64{0});
5899  }
5900 
5901  ~SemiContinuousExpr() override {}
5902 
5903  int64 Value(int64 x) const {
5904  if (x <= 0) {
5905  return 0;
5906  } else {
5907  return CapAdd(fixed_charge_, CapProd(x, step_));
5908  }
5909  }
5910 
5911  int64 Min() const override { return Value(expr_->Min()); }
5912 
5913  void SetMin(int64 m) override {
5914  if (m >= CapAdd(fixed_charge_, step_)) {
5915  const int64 y = PosIntDivUp(CapSub(m, fixed_charge_), step_);
5916  expr_->SetMin(y);
5917  } else if (m > 0) {
5918  expr_->SetMin(1);
5919  }
5920  }
5921 
5922  int64 Max() const override { return Value(expr_->Max()); }
5923 
5924  void SetMax(int64 m) override {
5925  if (m < 0) {
5926  solver()->Fail();
5927  }
5928  if (m == kint64max) {
5929  return;
5930  }
5931  if (m < CapAdd(fixed_charge_, step_)) {
5932  expr_->SetMax(0);
5933  } else {
5934  const int64 y = PosIntDivDown(CapSub(m, fixed_charge_), step_);
5935  expr_->SetMax(y);
5936  }
5937  }
5938 
5939  std::string name() const override {
5940  return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
5941  expr_->name(), fixed_charge_, step_);
5942  }
5943 
5944  std::string DebugString() const override {
5945  return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
5946  expr_->DebugString(), fixed_charge_, step_);
5947  }
5948 
5949  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5950 
5951  void Accept(ModelVisitor* const visitor) const override {
5952  visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
5953  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5954  expr_);
5955  visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
5956  fixed_charge_);
5957  visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, step_);
5958  visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
5959  }
5960 
5961  private:
5962  IntExpr* const expr_;
5963  const int64 fixed_charge_;
5964  const int64 step_;
5965 };
5966 
5967 class SemiContinuousStepOneExpr : public BaseIntExpr {
5968  public:
5969  SemiContinuousStepOneExpr(Solver* const s, IntExpr* const e,
5970  int64 fixed_charge)
5971  : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
5972  DCHECK_GE(fixed_charge, int64{0});
5973  }
5974 
5975  ~SemiContinuousStepOneExpr() override {}
5976 
5977  int64 Value(int64 x) const {
5978  if (x <= 0) {
5979  return 0;
5980  } else {
5981  return fixed_charge_ + x;
5982  }
5983  }
5984 
5985  int64 Min() const override { return Value(expr_->Min()); }
5986 
5987  void SetMin(int64 m) override {
5988  if (m >= fixed_charge_ + 1) {
5989  expr_->SetMin(m - fixed_charge_);
5990  } else if (m > 0) {
5991  expr_->SetMin(1);
5992  }
5993  }
5994 
5995  int64 Max() const override { return Value(expr_->Max()); }
5996 
5997  void SetMax(int64 m) override {
5998  if (m < 0) {
5999  solver()->Fail();
6000  }
6001  if (m < fixed_charge_ + 1) {
6002  expr_->SetMax(0);
6003  } else {
6004  expr_->SetMax(m - fixed_charge_);
6005  }
6006  }
6007 
6008  std::string name() const override {
6009  return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6010  expr_->name(), fixed_charge_);
6011  }
6012 
6013  std::string DebugString() const override {
6014  return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6015  expr_->DebugString(), fixed_charge_);
6016  }
6017 
6018  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6019 
6020  void Accept(ModelVisitor* const visitor) const override {
6021  visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6022  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6023  expr_);
6024  visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6025  fixed_charge_);
6026  visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 1);
6027  visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6028  }
6029 
6030  private:
6031  IntExpr* const expr_;
6032  const int64 fixed_charge_;
6033 };
6034 
6035 class SemiContinuousStepZeroExpr : public BaseIntExpr {
6036  public:
6037  SemiContinuousStepZeroExpr(Solver* const s, IntExpr* const e,
6038  int64 fixed_charge)
6039  : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
6040  DCHECK_GT(fixed_charge, int64{0});
6041  }
6042 
6043  ~SemiContinuousStepZeroExpr() override {}
6044 
6045  int64 Value(int64 x) const {
6046  if (x <= 0) {
6047  return 0;
6048  } else {
6049  return fixed_charge_;
6050  }
6051  }
6052 
6053  int64 Min() const override { return Value(expr_->Min()); }
6054 
6055  void SetMin(int64 m) override {
6056  if (m >= fixed_charge_) {
6057  solver()->Fail();
6058  } else if (m > 0) {
6059  expr_->SetMin(1);
6060  }
6061  }
6062 
6063  int64 Max() const override { return Value(expr_->Max()); }
6064 
6065  void SetMax(int64 m) override {
6066  if (m < 0) {
6067  solver()->Fail();
6068  }
6069  if (m < fixed_charge_) {
6070  expr_->SetMax(0);
6071  }
6072  }
6073 
6074  std::string name() const override {
6075  return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6076  expr_->name(), fixed_charge_);
6077  }
6078 
6079  std::string DebugString() const override {
6080  return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6081  expr_->DebugString(), fixed_charge_);
6082  }
6083 
6084  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6085 
6086  void Accept(ModelVisitor* const visitor) const override {
6087  visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6088  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6089  expr_);
6090  visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6091  fixed_charge_);
6092  visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 0);
6093  visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6094  }
6095 
6096  private:
6097  IntExpr* const expr_;
6098  const int64 fixed_charge_;
6099 };
6100 
6101 // This constraints links an expression and the variable it is casted into
6102 class LinkExprAndVar : public CastConstraint {
6103  public:
6104  LinkExprAndVar(Solver* const s, IntExpr* const expr, IntVar* const var)
6105  : CastConstraint(s, var), expr_(expr) {}
6106 
6107  ~LinkExprAndVar() override {}
6108 
6109  void Post() override {
6110  Solver* const s = solver();
6111  Demon* d = s->MakeConstraintInitialPropagateCallback(this);
6112  expr_->WhenRange(d);
6113  target_var_->WhenRange(d);
6114  }
6115 
6116  void InitialPropagate() override {
6117  expr_->SetRange(target_var_->Min(), target_var_->Max());
6118  int64 l, u;
6119  expr_->Range(&l, &u);
6120  target_var_->SetRange(l, u);
6121  }
6122 
6123  std::string DebugString() const override {
6124  return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6125  target_var_->DebugString());
6126  }
6127 
6128  void Accept(ModelVisitor* const visitor) const override {
6129  visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6130  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6131  expr_);
6132  visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6133  target_var_);
6134  visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6135  }
6136 
6137  private:
6138  IntExpr* const expr_;
6139 };
6140 
6141 // ----- Conditional Expression -----
6142 
6143 class ExprWithEscapeValue : public BaseIntExpr {
6144  public:
6145  ExprWithEscapeValue(Solver* const s, IntVar* const c, IntExpr* const e,
6146  int64 unperformed_value)
6147  : BaseIntExpr(s),
6148  condition_(c),
6149  expression_(e),
6150  unperformed_value_(unperformed_value) {}
6151 
6152  ~ExprWithEscapeValue() override {}
6153 
6154  int64 Min() const override {
6155  if (condition_->Min() == 1) {
6156  return expression_->Min();
6157  } else if (condition_->Max() == 1) {
6158  return std::min(unperformed_value_, expression_->Min());
6159  } else {
6160  return unperformed_value_;
6161  }
6162  }
6163 
6164  void SetMin(int64 m) override {
6165  if (m > unperformed_value_) {
6166  condition_->SetValue(1);
6167  expression_->SetMin(m);
6168  } else if (condition_->Min() == 1) {
6169  expression_->SetMin(m);
6170  } else if (m > expression_->Max()) {
6171  condition_->SetValue(0);
6172  }
6173  }
6174 
6175  int64 Max() const override {
6176  if (condition_->Min() == 1) {
6177  return expression_->Max();
6178  } else if (condition_->Max() == 1) {
6179  return std::max(unperformed_value_, expression_->Max());
6180  } else {
6181  return unperformed_value_;
6182  }
6183  }
6184 
6185  void SetMax(int64 m) override {
6186  if (m < unperformed_value_) {
6187  condition_->SetValue(1);
6188  expression_->SetMax(m);
6189  } else if (condition_->Min() == 1) {
6190  expression_->SetMax(m);
6191  } else if (m < expression_->Min()) {
6192  condition_->SetValue(0);
6193  }
6194  }
6195 
6196  void SetRange(int64 mi, int64 ma) override {
6197  if (ma < unperformed_value_ || mi > unperformed_value_) {
6198  condition_->SetValue(1);
6199  expression_->SetRange(mi, ma);
6200  } else if (condition_->Min() == 1) {
6201  expression_->SetRange(mi, ma);
6202  } else if (ma < expression_->Min() || mi > expression_->Max()) {
6203  condition_->SetValue(0);
6204  }
6205  }
6206 
6207  void SetValue(int64 v) override {
6208  if (v != unperformed_value_) {
6209  condition_->SetValue(1);
6210  expression_->SetValue(v);
6211  } else if (condition_->Min() == 1) {
6212  expression_->SetValue(v);
6213  } else if (v < expression_->Min() || v > expression_->Max()) {
6214  condition_->SetValue(0);
6215  }
6216  }
6217 
6218  bool Bound() const override {
6219  return condition_->Max() == 0 || expression_->Bound();
6220  }
6221 
6222  void WhenRange(Demon* d) override {
6223  expression_->WhenRange(d);
6224  condition_->WhenBound(d);
6225  }
6226 
6227  std::string DebugString() const override {
6228  return absl::StrFormat("ConditionExpr(%s, %s, %d)",
6229  condition_->DebugString(),
6230  expression_->DebugString(), unperformed_value_);
6231  }
6232 
6233  void Accept(ModelVisitor* const visitor) const override {
6234  visitor->BeginVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6235  visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
6236  condition_);
6237  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6238  expression_);
6239  visitor->VisitIntegerArgument(ModelVisitor::kValueArgument,
6240  unperformed_value_);
6241  visitor->EndVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6242  }
6243 
6244  private:
6245  IntVar* const condition_;
6246  IntExpr* const expression_;
6247  const int64 unperformed_value_;
6248  DISALLOW_COPY_AND_ASSIGN(ExprWithEscapeValue);
6249 };
6250 
6251 // ----- This is a specialized case when the variable exact type is known -----
6252 class LinkExprAndDomainIntVar : public CastConstraint {
6253  public:
6254  LinkExprAndDomainIntVar(Solver* const s, IntExpr* const expr,
6255  DomainIntVar* const var)
6256  : CastConstraint(s, var),
6257  expr_(expr),
6258  cached_min_(kint64min),
6259  cached_max_(kint64max),
6260  fail_stamp_(GG_ULONGLONG(0)) {}
6261 
6262  ~LinkExprAndDomainIntVar() override {}
6263 
6264  DomainIntVar* var() const {
6265  return reinterpret_cast<DomainIntVar*>(target_var_);
6266  }
6267 
6268  void Post() override {
6269  Solver* const s = solver();
6270  Demon* const d = s->MakeConstraintInitialPropagateCallback(this);
6271  expr_->WhenRange(d);
6272  Demon* const target_var_demon = MakeConstraintDemon0(
6273  solver(), this, &LinkExprAndDomainIntVar::Propagate, "Propagate");
6274  target_var_->WhenRange(target_var_demon);
6275  }
6276 
6277  void InitialPropagate() override {
6278  expr_->SetRange(var()->min_.Value(), var()->max_.Value());
6279  expr_->Range(&cached_min_, &cached_max_);
6280  var()->DomainIntVar::SetRange(cached_min_, cached_max_);
6281  }
6282 
6283  void Propagate() {
6284  if (var()->min_.Value() > cached_min_ ||
6285  var()->max_.Value() < cached_max_ ||
6286  solver()->fail_stamp() != fail_stamp_) {
6287  InitialPropagate();
6288  fail_stamp_ = solver()->fail_stamp();
6289  }
6290  }
6291 
6292  std::string DebugString() const override {
6293  return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6294  target_var_->DebugString());
6295  }
6296 
6297  void Accept(ModelVisitor* const visitor) const override {
6298  visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6299  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6300  expr_);
6301  visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6302  target_var_);
6303  visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6304  }
6305 
6306  private:
6307  IntExpr* const expr_;
6308  int64 cached_min_;
6309  int64 cached_max_;
6310  uint64 fail_stamp_;
6311 };
6312 } // namespace
6313 
6314 // ----- Misc -----
6315 
6316 IntVarIterator* BooleanVar::MakeHoleIterator(bool reversible) const {
6317  return COND_REV_ALLOC(reversible, new EmptyIterator());
6318 }
6319 IntVarIterator* BooleanVar::MakeDomainIterator(bool reversible) const {
6320  return COND_REV_ALLOC(reversible, new RangeIterator(this));
6321 }
6322 
6323 // ----- API -----
6324 
6325 void CleanVariableOnFail(IntVar* const var) {
6326  DCHECK_EQ(DOMAIN_INT_VAR, var->VarType());
6327  DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6328  dvar->CleanInProcess();
6329 }
6330 
6331 Constraint* SetIsEqual(IntVar* const var, const std::vector<int64>& values,
6332  const std::vector<IntVar*>& vars) {
6333  DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6334  CHECK(dvar != nullptr);
6335  return dvar->SetIsEqual(values, vars);
6336 }
6337 
6338 Constraint* SetIsGreaterOrEqual(IntVar* const var,
6339  const std::vector<int64>& values,
6340  const std::vector<IntVar*>& vars) {
6341  DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6342  CHECK(dvar != nullptr);
6343  return dvar->SetIsGreaterOrEqual(values, vars);
6344 }
6345 
6346 void RestoreBoolValue(IntVar* const var) {
6347  DCHECK_EQ(BOOLEAN_VAR, var->VarType());
6348  BooleanVar* const boolean_var = reinterpret_cast<BooleanVar*>(var);
6349  boolean_var->RestoreValue();
6350 }
6351 
6352 // ----- API -----
6353 
6354 IntVar* Solver::MakeIntVar(int64 min, int64 max, const std::string& name) {
6355  if (min == max) {
6356  return MakeIntConst(min, name);
6357  }
6358  if (min == 0 && max == 1) {
6359  return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6360  } else if (CapSub(max, min) == 1) {
6361  const std::string inner_name = "inner_" + name;
6362  return RegisterIntVar(
6363  MakeSum(RevAlloc(new ConcreteBooleanVar(this, inner_name)), min)
6364  ->VarWithName(name));
6365  } else {
6366  return RegisterIntVar(RevAlloc(new DomainIntVar(this, min, max, name)));
6367  }
6368 }
6369 
6370 IntVar* Solver::MakeIntVar(int64 min, int64 max) {
6371  return MakeIntVar(min, max, "");
6372 }
6373 
6374 IntVar* Solver::MakeBoolVar(const std::string& name) {
6375  return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6376 }
6377 
6378 IntVar* Solver::MakeBoolVar() {
6379  return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, "")));
6380 }
6381 
6382 IntVar* Solver::MakeIntVar(const std::vector<int64>& values,
6383  const std::string& name) {
6384  DCHECK(!values.empty());
6385  // Fast-track the case where we have a single value.
6386  if (values.size() == 1) return MakeIntConst(values[0], name);
6387  // Sort and remove duplicates.
6388  std::vector<int64> unique_sorted_values = values;
6389  gtl::STLSortAndRemoveDuplicates(&unique_sorted_values);
6390  // Case when we have a single value, after clean-up.
6391  if (unique_sorted_values.size() == 1) return MakeIntConst(values[0], name);
6392  // Case when the values are a dense interval of integers.
6393  if (unique_sorted_values.size() ==
6394  unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6395  return MakeIntVar(unique_sorted_values.front(), unique_sorted_values.back(),
6396  name);
6397  }
6398  // Compute the GCD: if it's not 1, we can express the variable's domain as
6399  // the product of the GCD and of a domain with smaller values.
6400  int64 gcd = 0;
6401  for (const int64 v : unique_sorted_values) {
6402  if (gcd == 0) {
6403  gcd = std::abs(v);
6404  } else {
6405  gcd = MathUtil::GCD64(gcd, std::abs(v)); // Supports v==0.
6406  }
6407  if (gcd == 1) {
6408  // If it's 1, though, we can't do anything special, so we
6409  // immediately return a new DomainIntVar.
6410  return RegisterIntVar(
6411  RevAlloc(new DomainIntVar(this, unique_sorted_values, name)));
6412  }
6413  }
6414  DCHECK_GT(gcd, 1);
6415  for (int64& v : unique_sorted_values) {
6416  DCHECK_EQ(0, v % gcd);
6417  v /= gcd;
6418  }
6419  const std::string new_name = name.empty() ? "" : "inner_" + name;
6420  // Catch the case where the divided values are a dense set of integers.
6421  IntVar* inner_intvar = nullptr;
6422  if (unique_sorted_values.size() ==
6423  unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6424  inner_intvar = MakeIntVar(unique_sorted_values.front(),
6425  unique_sorted_values.back(), new_name);
6426  } else {
6427  inner_intvar = RegisterIntVar(
6428  RevAlloc(new DomainIntVar(this, unique_sorted_values, new_name)));
6429  }
6430  return MakeProd(inner_intvar, gcd)->Var();
6431 }
6432 
6433 IntVar* Solver::MakeIntVar(const std::vector<int64>& values) {
6434  return MakeIntVar(values, "");
6435 }
6436 
6437 IntVar* Solver::MakeIntVar(const std::vector<int>& values,
6438  const std::string& name) {
6439  return MakeIntVar(ToInt64Vector(values), name);
6440 }
6441 
6442 IntVar* Solver::MakeIntVar(const std::vector<int>& values) {
6443  return MakeIntVar(values, "");
6444 }
6445 
6446 IntVar* Solver::MakeIntConst(int64 val, const std::string& name) {
6447  // If IntConst is going to be named after its creation,
6448  // cp_share_int_consts should be set to false otherwise names can potentially
6449  // be overwritten.
6450  if (FLAGS_cp_share_int_consts && name.empty() &&
6451  val >= MIN_CACHED_INT_CONST && val <= MAX_CACHED_INT_CONST) {
6452  return cached_constants_[val - MIN_CACHED_INT_CONST];
6453  }
6454  return RevAlloc(new IntConst(this, val, name));
6455 }
6456 
6457 IntVar* Solver::MakeIntConst(int64 val) { return MakeIntConst(val, ""); }
6458 
6459 // ----- Int Var and associated methods -----
6460 
6461 namespace {
6462 std::string IndexedName(const std::string& prefix, int index, int max_index) {
6463 #if 0
6464 #if defined(_MSC_VER)
6465  const int digits = max_index > 0 ?
6466  static_cast<int>(log(1.0L * max_index) / log(10.0L)) + 1 :
6467  1;
6468 #else
6469  const int digits = max_index > 0 ? static_cast<int>(log10(max_index)) + 1: 1;
6470 #endif
6471  return absl::StrFormat("%s%0*d", prefix, digits, index);
6472 #else
6473  return absl::StrCat(prefix, index);
6474 #endif
6475 }
6476 } // namespace
6477 
6478 void Solver::MakeIntVarArray(int var_count, int64 vmin, int64 vmax,
6479  const std::string& name,
6480  std::vector<IntVar*>* vars) {
6481  for (int i = 0; i < var_count; ++i) {
6482  vars->push_back(MakeIntVar(vmin, vmax, IndexedName(name, i, var_count)));
6483  }
6484 }
6485 
6486 void Solver::MakeIntVarArray(int var_count, int64 vmin, int64 vmax,
6487  std::vector<IntVar*>* vars) {
6488  for (int i = 0; i < var_count; ++i) {
6489  vars->push_back(MakeIntVar(vmin, vmax));
6490  }
6491 }
6492 
6493 IntVar** Solver::MakeIntVarArray(int var_count, int64 vmin, int64 vmax,
6494  const std::string& name) {
6495  IntVar** vars = new IntVar*[var_count];
6496  for (int i = 0; i < var_count; ++i) {
6497  vars[i] = MakeIntVar(vmin, vmax, IndexedName(name, i, var_count));
6498  }
6499  return vars;
6500 }
6501 
6502 void Solver::MakeBoolVarArray(int var_count, const std::string& name,
6503  std::vector<IntVar*>* vars) {
6504  for (int i = 0; i < var_count; ++i) {
6505  vars->push_back(MakeBoolVar(IndexedName(name, i, var_count)));
6506  }
6507 }
6508 
6509 void Solver::MakeBoolVarArray(int var_count, std::vector<IntVar*>* vars) {
6510  for (int i = 0; i < var_count; ++i) {
6511  vars->push_back(MakeBoolVar());
6512  }
6513 }
6514 
6515 IntVar** Solver::MakeBoolVarArray(int var_count, const std::string& name) {
6516  IntVar** vars = new IntVar*[var_count];
6517  for (int i = 0; i < var_count; ++i) {
6518  vars[i] = MakeBoolVar(IndexedName(name, i, var_count));
6519  }
6520  return vars;
6521 }
6522 
6523 void Solver::InitCachedIntConstants() {
6524  for (int i = MIN_CACHED_INT_CONST; i <= MAX_CACHED_INT_CONST; ++i) {
6525  cached_constants_[i - MIN_CACHED_INT_CONST] =
6526  RevAlloc(new IntConst(this, i, "")); // note the empty name
6527  }
6528 }
6529 
6530 IntExpr* Solver::MakeSum(IntExpr* const left, IntExpr* const right) {
6531  CHECK_EQ(this, left->solver());
6532  CHECK_EQ(this, right->solver());
6533  if (right->Bound()) {
6534  return MakeSum(left, right->Min());
6535  }
6536  if (left->Bound()) {
6537  return MakeSum(right, left->Min());
6538  }
6539  if (left == right) {
6540  return MakeProd(left, 2);
6541  }
6542  IntExpr* cache = model_cache_->FindExprExprExpression(
6543  left, right, ModelCache::EXPR_EXPR_SUM);
6544  if (cache == nullptr) {
6545  cache = model_cache_->FindExprExprExpression(right, left,
6546  ModelCache::EXPR_EXPR_SUM);
6547  }
6548  if (cache != nullptr) {
6549  return cache;
6550  } else {
6551  IntExpr* const result =
6552  AddOverflows(left->Max(), right->Max()) ||
6553  AddOverflows(left->Min(), right->Min())
6554  ? RegisterIntExpr(RevAlloc(new SafePlusIntExpr(this, left, right)))
6555  : RegisterIntExpr(RevAlloc(new PlusIntExpr(this, left, right)));
6556  model_cache_->InsertExprExprExpression(result, left, right,
6557  ModelCache::EXPR_EXPR_SUM);
6558  return result;
6559  }
6560 }
6561 
6562 IntExpr* Solver::MakeSum(IntExpr* const expr, int64 value) {
6563  CHECK_EQ(this, expr->solver());
6564  if (expr->Bound()) {
6565  return MakeIntConst(expr->Min() + value);
6566  }
6567  if (value == 0) {
6568  return expr;
6569  }
6570  IntExpr* result = Cache()->FindExprConstantExpression(
6571  expr, value, ModelCache::EXPR_CONSTANT_SUM);
6572  if (result == nullptr) {
6573  if (expr->IsVar() && !AddOverflows(value, expr->Max()) &&
6574  !AddOverflows(value, expr->Min())) {
6575  IntVar* const var = expr->Var();
6576  switch (var->VarType()) {
6577  case DOMAIN_INT_VAR: {
6578  result = RegisterIntExpr(RevAlloc(new PlusCstDomainIntVar(
6579  this, reinterpret_cast<DomainIntVar*>(var), value)));
6580  break;
6581  }
6582  case CONST_VAR: {
6583  result = RegisterIntExpr(MakeIntConst(var->Min() + value));
6584  break;
6585  }
6586  case VAR_ADD_CST: {
6587  PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6588  IntVar* const sub_var = add_var->SubVar();
6589  const int64 new_constant = value + add_var->Constant();
6590  if (new_constant == 0) {
6591  result = sub_var;
6592  } else {
6593  if (sub_var->VarType() == DOMAIN_INT_VAR) {
6594  DomainIntVar* const dvar =
6595  reinterpret_cast<DomainIntVar*>(sub_var);
6596  result = RegisterIntExpr(
6597  RevAlloc(new PlusCstDomainIntVar(this, dvar, new_constant)));
6598  } else {
6599  result = RegisterIntExpr(
6600  RevAlloc(new PlusCstIntVar(this, sub_var, new_constant)));
6601  }
6602  }
6603  break;
6604  }
6605  case CST_SUB_VAR: {
6606  SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6607  IntVar* const sub_var = add_var->SubVar();
6608  const int64 new_constant = value + add_var->Constant();
6609  result = RegisterIntExpr(
6610  RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6611  break;
6612  }
6613  case OPP_VAR: {
6614  OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6615  IntVar* const sub_var = add_var->SubVar();
6616  result =
6617  RegisterIntExpr(RevAlloc(new SubCstIntVar(this, sub_var, value)));
6618  break;
6619  }
6620  default:
6621  result =
6622  RegisterIntExpr(RevAlloc(new PlusCstIntVar(this, var, value)));
6623  }
6624  } else {
6625  result = RegisterIntExpr(RevAlloc(new PlusIntCstExpr(this, expr, value)));
6626  }
6627  Cache()->InsertExprConstantExpression(result, expr, value,
6628  ModelCache::EXPR_CONSTANT_SUM);
6629  }
6630  return result;
6631 }
6632 
6633 IntExpr* Solver::MakeDifference(IntExpr* const left, IntExpr* const right) {
6634  CHECK_EQ(this, left->solver());
6635  CHECK_EQ(this, right->solver());
6636  if (left->Bound()) {
6637  return MakeDifference(left->Min(), right);
6638  }
6639  if (right->Bound()) {
6640  return MakeSum(left, -right->Min());
6641  }
6642  IntExpr* sub_left = nullptr;
6643  IntExpr* sub_right = nullptr;
6644  int64 left_coef = 1;
6645  int64 right_coef = 1;
6646  if (IsProduct(left, &sub_left, &left_coef) &&
6647  IsProduct(right, &sub_right, &right_coef)) {
6648  const int64 abs_gcd =
6649  MathUtil::GCD64(std::abs(left_coef), std::abs(right_coef));
6650  if (abs_gcd != 0 && abs_gcd != 1) {
6651  return MakeProd(MakeDifference(MakeProd(sub_left, left_coef / abs_gcd),
6652  MakeProd(sub_right, right_coef / abs_gcd)),
6653  abs_gcd);
6654  }
6655  }
6656 
6657  IntExpr* result = Cache()->FindExprExprExpression(
6658  left, right, ModelCache::EXPR_EXPR_DIFFERENCE);
6659  if (result == nullptr) {
6660  if (!SubOverflows(left->Min(), right->Max()) &&
6661  !SubOverflows(left->Max(), right->Min())) {
6662  result = RegisterIntExpr(RevAlloc(new SubIntExpr(this, left, right)));
6663  } else {
6664  result = RegisterIntExpr(RevAlloc(new SafeSubIntExpr(this, left, right)));
6665  }
6666  Cache()->InsertExprExprExpression(result, left, right,
6667  ModelCache::EXPR_EXPR_DIFFERENCE);
6668  }
6669  return result;
6670 }
6671 
6672 // warning: this is 'value - expr'.
6673 IntExpr* Solver::MakeDifference(int64 value, IntExpr* const expr) {
6674  CHECK_EQ(this, expr->solver());
6675  if (expr->Bound()) {
6676  return MakeIntConst(value - expr->Min());
6677  }
6678  if (value == 0) {
6679  return MakeOpposite(expr);
6680  }
6681  IntExpr* result = Cache()->FindExprConstantExpression(
6682  expr, value, ModelCache::EXPR_CONSTANT_DIFFERENCE);
6683  if (result == nullptr) {
6684  if (expr->IsVar() && expr->Min() != kint64min &&
6685  !SubOverflows(value, expr->Min()) &&
6686  !SubOverflows(value, expr->Max())) {
6687  IntVar* const var = expr->Var();
6688  switch (var->VarType()) {
6689  case VAR_ADD_CST: {
6690  PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6691  IntVar* const sub_var = add_var->SubVar();
6692  const int64 new_constant = value - add_var->Constant();
6693  if (new_constant == 0) {
6694  result = sub_var;
6695  } else {
6696  result = RegisterIntExpr(
6697  RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6698  }
6699  break;
6700  }
6701  case CST_SUB_VAR: {
6702  SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6703  IntVar* const sub_var = add_var->SubVar();
6704  const int64 new_constant = value - add_var->Constant();
6705  result = MakeSum(sub_var, new_constant);
6706  break;
6707  }
6708  case OPP_VAR: {
6709  OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6710  IntVar* const sub_var = add_var->SubVar();
6711  result = MakeSum(sub_var, value);
6712  break;
6713  }
6714  default:
6715  result =
6716  RegisterIntExpr(RevAlloc(new SubCstIntVar(this, var, value)));
6717  }
6718  } else {
6719  result = RegisterIntExpr(RevAlloc(new SubIntCstExpr(this, expr, value)));
6720  }
6721  Cache()->InsertExprConstantExpression(result, expr, value,
6722  ModelCache::EXPR_CONSTANT_DIFFERENCE);
6723  }
6724  return result;
6725 }
6726 
6727 IntExpr* Solver::MakeOpposite(IntExpr* const expr) {
6728  CHECK_EQ(this, expr->solver());
6729  if (expr->Bound()) {
6730  return MakeIntConst(-expr->Min());
6731  }
6732  IntExpr* result =
6733  Cache()->FindExprExpression(expr, ModelCache::EXPR_OPPOSITE);
6734  if (result == nullptr) {
6735  if (expr->IsVar()) {
6736  result = RegisterIntVar(RevAlloc(new OppIntExpr(this, expr))->Var());
6737  } else {
6738  result = RegisterIntExpr(RevAlloc(new OppIntExpr(this, expr)));
6739  }
6740  Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_OPPOSITE);
6741  }
6742  return result;
6743 }
6744 
6745 IntExpr* Solver::MakeProd(IntExpr* const expr, int64 value) {
6746  CHECK_EQ(this, expr->solver());
6747  IntExpr* result = Cache()->FindExprConstantExpression(
6748  expr, value, ModelCache::EXPR_CONSTANT_PROD);
6749  if (result != nullptr) {
6750  return result;
6751  } else {
6752  IntExpr* m_expr = nullptr;
6753  int64 coefficient = 1;
6754  if (IsProduct(expr, &m_expr, &coefficient)) {
6755  coefficient *= value;
6756  } else {
6757  m_expr = expr;
6758  coefficient = value;
6759  }
6760  if (m_expr->Bound()) {
6761  return MakeIntConst(coefficient * m_expr->Min());
6762  } else if (coefficient == 1) {
6763  return m_expr;
6764  } else if (coefficient == -1) {
6765  return MakeOpposite(m_expr);
6766  } else if (coefficient > 0) {
6767  if (m_expr->Max() > kint64max / coefficient ||
6768  m_expr->Min() < kint64min / coefficient) {
6769  result = RegisterIntExpr(
6770  RevAlloc(new SafeTimesPosIntCstExpr(this, m_expr, coefficient)));
6771  } else {
6772  result = RegisterIntExpr(
6773  RevAlloc(new TimesPosIntCstExpr(this, m_expr, coefficient)));
6774  }
6775  } else if (coefficient == 0) {
6776  result = MakeIntConst(0);
6777  } else { // coefficient < 0.
6778  result = RegisterIntExpr(
6779  RevAlloc(new TimesIntNegCstExpr(this, m_expr, coefficient)));
6780  }
6781  if (m_expr->IsVar() && !FLAGS_cp_disable_expression_optimization) {
6782  result = result->Var();
6783  }
6784  Cache()->InsertExprConstantExpression(result, expr, value,
6785  ModelCache::EXPR_CONSTANT_PROD);
6786  return result;
6787  }
6788 }
6789 
6790 namespace {
6791 void ExtractPower(IntExpr** const expr, int64* const exponant) {
6792  if (dynamic_cast<BasePower*>(*expr) != nullptr) {
6793  BasePower* const power = dynamic_cast<BasePower*>(*expr);
6794  *expr = power->expr();
6795  *exponant = power->exponant();
6796  }
6797  if (dynamic_cast<IntSquare*>(*expr) != nullptr) {
6798  IntSquare* const power = dynamic_cast<IntSquare*>(*expr);
6799  *expr = power->expr();
6800  *exponant = 2;
6801  }
6802  if ((*expr)->IsVar()) {
6803  IntVar* const var = (*expr)->Var();
6804  IntExpr* const sub = var->solver()->CastExpression(var);
6805  if (sub != nullptr && dynamic_cast<BasePower*>(sub) != nullptr) {
6806  BasePower* const power = dynamic_cast<BasePower*>(sub);
6807  *expr = power->expr();
6808  *exponant = power->exponant();
6809  }
6810  if (sub != nullptr && dynamic_cast<IntSquare*>(sub) != nullptr) {
6811  IntSquare* const power = dynamic_cast<IntSquare*>(sub);
6812  *expr = power->expr();
6813  *exponant = 2;
6814  }
6815  }
6816 }
6817 
6818 void ExtractProduct(IntExpr** const expr, int64* const coefficient,
6819  bool* modified) {
6820  if (dynamic_cast<TimesCstIntVar*>(*expr) != nullptr) {
6821  TimesCstIntVar* const left_prod = dynamic_cast<TimesCstIntVar*>(*expr);
6822  *coefficient *= left_prod->Constant();
6823  *expr = left_prod->SubVar();
6824  *modified = true;
6825  } else if (dynamic_cast<TimesIntCstExpr*>(*expr) != nullptr) {
6826  TimesIntCstExpr* const left_prod = dynamic_cast<TimesIntCstExpr*>(*expr);
6827  *coefficient *= left_prod->Constant();
6828  *expr = left_prod->Expr();
6829  *modified = true;
6830  }
6831 }
6832 } // namespace
6833 
6834 IntExpr* Solver::MakeProd(IntExpr* const left, IntExpr* const right) {
6835  if (left->Bound()) {
6836  return MakeProd(right, left->Min());
6837  }
6838 
6839  if (right->Bound()) {
6840  return MakeProd(left, right->Min());
6841  }
6842 
6843  // ----- Discover squares and powers -----
6844 
6845  IntExpr* m_left = left;
6846  IntExpr* m_right = right;
6847  int64 left_exponant = 1;
6848  int64 right_exponant = 1;
6849  ExtractPower(&m_left, &left_exponant);
6850  ExtractPower(&m_right, &right_exponant);
6851 
6852  if (m_left == m_right) {
6853  return MakePower(m_left, left_exponant + right_exponant);
6854  }
6855 
6856  // ----- Discover nested products -----
6857 
6858  m_left = left;
6859  m_right = right;
6860  int64 coefficient = 1;
6861  bool modified = false;
6862 
6863  ExtractProduct(&m_left, &coefficient, &modified);
6864  ExtractProduct(&m_right, &coefficient, &modified);
6865  if (modified) {
6866  return MakeProd(MakeProd(m_left, m_right), coefficient);
6867  }
6868 
6869  // ----- Standard build -----
6870 
6871  CHECK_EQ(this, left->solver());
6872  CHECK_EQ(this, right->solver());
6873  IntExpr* result = model_cache_->FindExprExprExpression(
6874  left, right, ModelCache::EXPR_EXPR_PROD);
6875  if (result == nullptr) {
6876  result = model_cache_->FindExprExprExpression(right, left,
6877  ModelCache::EXPR_EXPR_PROD);
6878  }
6879  if (result != nullptr) {
6880  return result;
6881  }
6882  if (left->IsVar() && left->Var()->VarType() == BOOLEAN_VAR) {
6883  if (right->Min() >= 0) {
6884  result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6885  this, reinterpret_cast<BooleanVar*>(left), right)));
6886  } else {
6887  result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6888  this, reinterpret_cast<BooleanVar*>(left), right)));
6889  }
6890  } else if (right->IsVar() &&
6891  reinterpret_cast<IntVar*>(right)->VarType() == BOOLEAN_VAR) {
6892  if (left->Min() >= 0) {
6893  result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6894  this, reinterpret_cast<BooleanVar*>(right), left)));
6895  } else {
6896  result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6897  this, reinterpret_cast<BooleanVar*>(right), left)));
6898  }
6899  } else if (left->Min() >= 0 && right->Min() >= 0) {
6900  if (CapProd(left->Max(), right->Max()) ==
6901  kint64max) { // Potential overflow.
6902  result =
6903  RegisterIntExpr(RevAlloc(new SafeTimesPosIntExpr(this, left, right)));
6904  } else {
6905  result =
6906  RegisterIntExpr(RevAlloc(new TimesPosIntExpr(this, left, right)));
6907  }
6908  } else {
6909  result = RegisterIntExpr(RevAlloc(new TimesIntExpr(this, left, right)));
6910  }
6911  model_cache_->InsertExprExprExpression(result, left, right,
6912  ModelCache::EXPR_EXPR_PROD);
6913  return result;
6914 }
6915 
6916 IntExpr* Solver::MakeDiv(IntExpr* const numerator, IntExpr* const denominator) {
6917  CHECK(numerator != nullptr);
6918  CHECK(denominator != nullptr);
6919  if (denominator->Bound()) {
6920  return MakeDiv(numerator, denominator->Min());
6921  }
6922  IntExpr* result = model_cache_->FindExprExprExpression(
6923  numerator, denominator, ModelCache::EXPR_EXPR_DIV);
6924  if (result != nullptr) {
6925  return result;
6926  }
6927 
6928  if (denominator->Min() <= 0 && denominator->Max() >= 0) {
6929  AddConstraint(MakeNonEquality(denominator, 0));
6930  }
6931 
6932  if (denominator->Min() >= 0) {
6933  if (numerator->Min() >= 0) {
6934  result = RevAlloc(new DivPosPosIntExpr(this, numerator, denominator));
6935  } else {
6936  result = RevAlloc(new DivPosIntExpr(this, numerator, denominator));
6937  }
6938  } else if (denominator->Max() <= 0) {
6939  if (numerator->Max() <= 0) {
6940  result = RevAlloc(new DivPosPosIntExpr(this, MakeOpposite(numerator),
6941  MakeOpposite(denominator)));
6942  } else {
6943  result = MakeOpposite(RevAlloc(
6944  new DivPosIntExpr(this, numerator, MakeOpposite(denominator))));
6945  }
6946  } else {
6947  result = RevAlloc(new DivIntExpr(this, numerator, denominator));
6948  }
6949  model_cache_->InsertExprExprExpression(result, numerator, denominator,
6950  ModelCache::EXPR_EXPR_DIV);
6951  return result;
6952 }
6953 
6954 IntExpr* Solver::MakeDiv(IntExpr* const expr, int64 value) {
6955  CHECK(expr != nullptr);
6956  CHECK_EQ(this, expr->solver());
6957  if (expr->Bound()) {
6958  return MakeIntConst(expr->Min() / value);
6959  } else if (value == 1) {
6960  return expr;
6961  } else if (value == -1) {
6962  return MakeOpposite(expr);
6963  } else if (value > 0) {
6964  return RegisterIntExpr(RevAlloc(new DivPosIntCstExpr(this, expr, value)));
6965  } else if (value == 0) {
6966  LOG(FATAL) << "Cannot divide by 0";
6967  return nullptr;
6968  } else {
6969  return RegisterIntExpr(
6970  MakeOpposite(RevAlloc(new DivPosIntCstExpr(this, expr, -value))));
6971  // TODO(user) : implement special case.
6972  }
6973 }
6974 
6975 Constraint* Solver::MakeAbsEquality(IntVar* const var, IntVar* const abs_var) {
6976  if (Cache()->FindExprExpression(var, ModelCache::EXPR_ABS) == nullptr) {
6977  Cache()->InsertExprExpression(abs_var, var, ModelCache::EXPR_ABS);
6978  }
6979  return RevAlloc(new IntAbsConstraint(this, var, abs_var));
6980 }
6981 
6982 IntExpr* Solver::MakeAbs(IntExpr* const e) {
6983  CHECK_EQ(this, e->solver());
6984  if (e->Min() >= 0) {
6985  return e;
6986  } else if (e->Max() <= 0) {
6987  return MakeOpposite(e);
6988  }
6989  IntExpr* result = Cache()->FindExprExpression(e, ModelCache::EXPR_ABS);
6990  if (result == nullptr) {
6991  int64 coefficient = 1;
6992  IntExpr* expr = nullptr;
6993  if (IsProduct(e, &expr, &coefficient)) {
6994  result = MakeProd(MakeAbs(expr), std::abs(coefficient));
6995  } else {
6996  result = RegisterIntExpr(RevAlloc(new IntAbs(this, e)));
6997  }
6998  Cache()->InsertExprExpression(result, e, ModelCache::EXPR_ABS);
6999  }
7000  return result;
7001 }
7002 
7003 IntExpr* Solver::MakeSquare(IntExpr* const expr) {
7004  CHECK_EQ(this, expr->solver());
7005  if (expr->Bound()) {
7006  const int64 v = expr->Min();
7007  return MakeIntConst(v * v);
7008  }
7009  IntExpr* result = Cache()->FindExprExpression(expr, ModelCache::EXPR_SQUARE);
7010  if (result == nullptr) {
7011  if (expr->Min() >= 0) {
7012  result = RegisterIntExpr(RevAlloc(new PosIntSquare(this, expr)));
7013  } else {
7014  result = RegisterIntExpr(RevAlloc(new IntSquare(this, expr)));
7015  }
7016  Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_SQUARE);
7017  }
7018  return result;
7019 }
7020 
7021 IntExpr* Solver::MakePower(IntExpr* const expr, int64 n) {
7022  CHECK_EQ(this, expr->solver());
7023  CHECK_GE(n, 0);
7024  if (expr->Bound()) {
7025  const int64 v = expr->Min();
7026  if (v >= OverflowLimit(n)) { // Overflow.
7027  return MakeIntConst(kint64max);
7028  }
7029  return MakeIntConst(IntPower(v, n));
7030  }
7031  switch (n) {
7032  case 0:
7033  return MakeIntConst(1);
7034  case 1:
7035  return expr;
7036  case 2:
7037  return MakeSquare(expr);
7038  default: {
7039  IntExpr* result = nullptr;
7040  if (n % 2 == 0) { // even.
7041  if (expr->Min() >= 0) {
7042  result =
7043  RegisterIntExpr(RevAlloc(new PosIntEvenPower(this, expr, n)));
7044  } else {
7045  result = RegisterIntExpr(RevAlloc(new IntEvenPower(this, expr, n)));
7046  }
7047  } else {
7048  result = RegisterIntExpr(RevAlloc(new IntOddPower(this, expr, n)));
7049  }
7050  return result;
7051  }
7052  }
7053 }
7054 
7055 IntExpr* Solver::MakeMin(IntExpr* const left, IntExpr* const right) {
7056  CHECK_EQ(this, left->solver());
7057  CHECK_EQ(this, right->solver());
7058  if (left->Bound()) {
7059  return MakeMin(right, left->Min());
7060  }
7061  if (right->Bound()) {
7062  return MakeMin(left, right->Min());
7063  }
7064  if (left->Min() >= right->Max()) {
7065  return right;
7066  }
7067  if (right->Min() >= left->Max()) {
7068  return left;
7069  }
7070  return RegisterIntExpr(RevAlloc(new MinIntExpr(this, left, right)));
7071 }
7072 
7073 IntExpr* Solver::MakeMin(IntExpr* const expr, int64 value) {
7074  CHECK_EQ(this, expr->solver());
7075  if (value <= expr->Min()) {
7076  return MakeIntConst(value);
7077  }
7078  if (expr->Bound()) {
7079  return MakeIntConst(std::min(expr->Min(), value));
7080  }
7081  if (expr->Max() <= value) {
7082  return expr;
7083  }
7084  return RegisterIntExpr(RevAlloc(new MinCstIntExpr(this, expr, value)));
7085 }
7086 
7087 IntExpr* Solver::MakeMin(IntExpr* const expr, int value) {
7088  return MakeMin(expr, static_cast<int64>(value));
7089 }
7090 
7091 IntExpr* Solver::MakeMax(IntExpr* const left, IntExpr* const right) {
7092  CHECK_EQ(this, left->solver());
7093  CHECK_EQ(this, right->solver());
7094  if (left->Bound()) {
7095  return MakeMax(right, left->Min());
7096  }
7097  if (right->Bound()) {
7098  return MakeMax(left, right->Min());
7099  }
7100  if (left->Min() >= right->Max()) {
7101  return left;
7102  }
7103  if (right->Min() >= left->Max()) {
7104  return right;
7105  }
7106  return RegisterIntExpr(RevAlloc(new MaxIntExpr(this, left, right)));
7107 }
7108 
7109 IntExpr* Solver::MakeMax(IntExpr* const expr, int64 value) {
7110  CHECK_EQ(this, expr->solver());
7111  if (expr->Bound()) {
7112  return MakeIntConst(std::max(expr->Min(), value));
7113  }
7114  if (value <= expr->Min()) {
7115  return expr;
7116  }
7117  if (expr->Max() <= value) {
7118  return MakeIntConst(value);
7119  }
7120  return RegisterIntExpr(RevAlloc(new MaxCstIntExpr(this, expr, value)));
7121 }
7122 
7123 IntExpr* Solver::MakeMax(IntExpr* const expr, int value) {
7124  return MakeMax(expr, static_cast<int64>(value));
7125 }
7126 
7127 IntExpr* Solver::MakeConvexPiecewiseExpr(IntExpr* expr, int64 early_cost,
7128  int64 early_date, int64 late_date,
7129  int64 late_cost) {
7130  return RegisterIntExpr(RevAlloc(new SimpleConvexPiecewiseExpr(
7131  this, expr, early_cost, early_date, late_date, late_cost)));
7132 }
7133 
7134 IntExpr* Solver::MakeSemiContinuousExpr(IntExpr* const expr, int64 fixed_charge,
7135  int64 step) {
7136  if (step == 0) {
7137  if (fixed_charge == 0) {
7138  return MakeIntConst(int64{0});
7139  } else {
7140  return RegisterIntExpr(
7141  RevAlloc(new SemiContinuousStepZeroExpr(this, expr, fixed_charge)));
7142  }
7143  } else if (step == 1) {
7144  return RegisterIntExpr(
7145  RevAlloc(new SemiContinuousStepOneExpr(this, expr, fixed_charge)));
7146  } else {
7147  return RegisterIntExpr(
7148  RevAlloc(new SemiContinuousExpr(this, expr, fixed_charge, step)));
7149  }
7150  // TODO(user) : benchmark with virtualization of
7151  // PosIntDivDown and PosIntDivUp - or function pointers.
7152 }
7153 
7154 // ----- Piecewise Linear -----
7155 
7157  public:
7158  PiecewiseLinearExpr(Solver* solver, IntExpr* expr,
7159  const PiecewiseLinearFunction& f)
7160  : BaseIntExpr(solver), expr_(expr), f_(f) {}
7161  ~PiecewiseLinearExpr() override {}
7162  int64 Min() const override {
7163  return f_.GetMinimum(expr_->Min(), expr_->Max());
7164  }
7165  void SetMin(int64 m) override {
7166  const auto& range =
7167  f_.GetSmallestRangeGreaterThanValue(expr_->Min(), expr_->Max(), m);
7168  expr_->SetRange(range.first, range.second);
7169  }
7170 
7171  int64 Max() const override {
7172  return f_.GetMaximum(expr_->Min(), expr_->Max());
7173  }
7174 
7175  void SetMax(int64 m) override {
7176  const auto& range =
7177  f_.GetSmallestRangeLessThanValue(expr_->Min(), expr_->Max(), m);
7178  expr_->SetRange(range.first, range.second);
7179  }
7180 
7181  void SetRange(int64 l, int64 u) override {
7182  const auto& range =
7183  f_.GetSmallestRangeInValueRange(expr_->Min(), expr_->Max(), l, u);
7184  expr_->SetRange(range.first, range.second);
7185  }
7186  std::string name() const override {
7187  return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->name(),
7188  f_.DebugString());
7189  }
7190 
7191  std::string DebugString() const override {
7192  return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->DebugString(),
7193  f_.DebugString());
7194  }
7195 
7196  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
7197 
7198  void Accept(ModelVisitor* const visitor) const override {
7199  // TODO(user): Implement visitor.
7200  }
7201 
7202  private:
7203  IntExpr* const expr_;
7204  const PiecewiseLinearFunction f_;
7205 };
7206 
7207 IntExpr* Solver::MakePiecewiseLinearExpr(IntExpr* expr,
7208  const PiecewiseLinearFunction& f) {
7209  return RegisterIntExpr(RevAlloc(new PiecewiseLinearExpr(this, expr, f)));
7210 }
7211 
7212 // ----- Conditional Expression -----
7213 
7214 IntExpr* Solver::MakeConditionalExpression(IntVar* const condition,
7215  IntExpr* const expr,
7216  int64 unperformed_value) {
7217  if (condition->Min() == 1) {
7218  return expr;
7219  } else if (condition->Max() == 0) {
7220  return MakeIntConst(unperformed_value);
7221  } else {
7222  IntExpr* cache = Cache()->FindExprExprConstantExpression(
7223  condition, expr, unperformed_value,
7224  ModelCache::EXPR_EXPR_CONSTANT_CONDITIONAL);
7225  if (cache == nullptr) {
7226  cache = RevAlloc(
7227  new ExprWithEscapeValue(this, condition, expr, unperformed_value));
7228  Cache()->InsertExprExprConstantExpression(
7229  cache, condition, expr, unperformed_value,
7230  ModelCache::EXPR_EXPR_CONSTANT_CONDITIONAL);
7231  }
7232  return cache;
7233  }
7234 }
7235 
7236 // ----- Modulo -----
7237 
7238 IntExpr* Solver::MakeModulo(IntExpr* const x, int64 mod) {
7239  IntVar* const result =
7240  MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7241  if (mod >= 0) {
7242  AddConstraint(MakeBetweenCt(result, 0, mod - 1));
7243  } else {
7244  AddConstraint(MakeBetweenCt(result, mod + 1, 0));
7245  }
7246  return result;
7247 }
7248 
7249 IntExpr* Solver::MakeModulo(IntExpr* const x, IntExpr* const mod) {
7250  if (mod->Bound()) {
7251  return MakeModulo(x, mod->Min());
7252  }
7253  IntVar* const result =
7254  MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7255  AddConstraint(MakeLess(result, MakeAbs(mod)));
7256  AddConstraint(MakeGreater(result, MakeOpposite(MakeAbs(mod))));
7257  return result;
7258 }
7259 
7260 // --------- IntVar ---------
7261 
7262 int IntVar::VarType() const { return UNSPECIFIED; }
7263 
7264 void IntVar::RemoveValues(const std::vector<int64>& values) {
7265  // TODO(user): Check and maybe inline this code.
7266  const int size = values.size();
7267  DCHECK_GE(size, 0);
7268  switch (size) {
7269  case 0: {
7270  return;
7271  }
7272  case 1: {
7273  RemoveValue(values[0]);
7274  return;
7275  }
7276  case 2: {
7277  RemoveValue(values[0]);
7278  RemoveValue(values[1]);
7279  return;
7280  }
7281  case 3: {
7282  RemoveValue(values[0]);
7283  RemoveValue(values[1]);
7284  RemoveValue(values[2]);
7285  return;
7286  }
7287  default: {
7288  // 4 values, let's start doing some more clever things.
7289  // TODO(user) : Sort values!
7290  int start_index = 0;
7291  int64 new_min = Min();
7292  if (values[start_index] <= new_min) {
7293  while (start_index < size - 1 &&
7294  values[start_index + 1] == values[start_index] + 1) {
7295  new_min = values[start_index + 1] + 1;
7296  start_index++;
7297  }
7298  }
7299  int end_index = size - 1;
7300  int64 new_max = Max();
7301  if (values[end_index] >= new_max) {
7302  while (end_index > start_index + 1 &&
7303  values[end_index - 1] == values[end_index] - 1) {
7304  new_max = values[end_index - 1] - 1;
7305  end_index--;
7306  }
7307  }
7308  SetRange(new_min, new_max);
7309  for (int i = start_index; i <= end_index; ++i) {
7310  RemoveValue(values[i]);
7311  }
7312  }
7313  }
7314 }
7315 
7316 void IntVar::Accept(ModelVisitor* const visitor) const {
7317  IntExpr* const casted = solver()->CastExpression(this);
7318  visitor->VisitIntegerVariable(this, casted);
7319 }
7320 
7321 void IntVar::SetValues(const std::vector<int64>& values) {
7322  switch (values.size()) {
7323  case 0: {
7324  solver()->Fail();
7325  break;
7326  }
7327  case 1: {
7328  SetValue(values.back());
7329  break;
7330  }
7331  case 2: {
7332  if (Contains(values[0])) {
7333  if (Contains(values[1])) {
7334  const int64 l = std::min(values[0], values[1]);
7335  const int64 u = std::max(values[0], values[1]);
7336  SetRange(l, u);
7337  if (u > l + 1) {
7338  RemoveInterval(l + 1, u - 1);
7339  }
7340  } else {
7341  SetValue(values[0]);
7342  }
7343  } else {
7344  SetValue(values[1]);
7345  }
7346  break;
7347  }
7348  default: {
7349  // TODO(user): use a clean and safe SortedUniqueCopy() class
7350  // that uses a global, static shared (and locked) storage.
7351  // TODO(user): [optional] consider porting
7352  // STLSortAndRemoveDuplicates from ortools/base/stl_util.h to the
7353  // existing open_source/base/stl_util.h and using it here.
7354  // TODO(user): We could filter out values not in the var.
7355  std::vector<int64>& tmp = solver()->tmp_vector_;
7356  tmp.clear();
7357  tmp.insert(tmp.end(), values.begin(), values.end());
7358  std::sort(tmp.begin(), tmp.end());
7359  tmp.erase(std::unique(tmp.begin(), tmp.end()), tmp.end());
7360  const int size = tmp.size();
7361  const int64 vmin = Min();
7362  const int64 vmax = Max();
7363  int first = 0;
7364  int last = size - 1;
7365  if (tmp.front() > vmax || tmp.back() < vmin) {
7366  solver()->Fail();
7367  }
7368  // TODO(user) : We could find the first position >= vmin by dichotomy.
7369  while (tmp[first] < vmin || !Contains(tmp[first])) {
7370  ++first;
7371  if (first > last || tmp[first] > vmax) {
7372  solver()->Fail();
7373  }
7374  }
7375  while (last > first && (tmp[last] > vmax || !Contains(tmp[last]))) {
7376  // Note that last >= first implies tmp[last] >= vmin.
7377  --last;
7378  }
7379  DCHECK_GE(last, first);
7380  SetRange(tmp[first], tmp[last]);
7381  while (first < last) {
7382  const int64 start = tmp[first] + 1;
7383  const int64 end = tmp[first + 1] - 1;
7384  if (start <= end) {
7385  RemoveInterval(start, end);
7386  }
7387  first++;
7388  }
7389  }
7390  }
7391 }
7392 // ---------- BaseIntExpr ---------
7393 
7394 void LinkVarExpr(Solver* const s, IntExpr* const expr, IntVar* const var) {
7395  if (!var->Bound()) {
7396  if (var->VarType() == DOMAIN_INT_VAR) {
7397  DomainIntVar* dvar = reinterpret_cast<DomainIntVar*>(var);
7398  s->AddCastConstraint(
7399  s->RevAlloc(new LinkExprAndDomainIntVar(s, expr, dvar)), dvar, expr);
7400  } else {
7401  s->AddCastConstraint(s->RevAlloc(new LinkExprAndVar(s, expr, var)), var,
7402  expr);
7403  }
7404  }
7405 }
7406 
7407 IntVar* BaseIntExpr::Var() {
7408  if (var_ == nullptr) {
7409  solver()->SaveValue(reinterpret_cast<void**>(&var_));
7410  var_ = CastToVar();
7411  }
7412  return var_;
7413 }
7414 
7415 IntVar* BaseIntExpr::CastToVar() {
7416  int64 vmin, vmax;
7417  Range(&vmin, &vmax);
7418  IntVar* const var = solver()->MakeIntVar(vmin, vmax);
7419  LinkVarExpr(solver(), this, var);
7420  return var;
7421 }
7422 
7423 // Discovery methods
7424 bool Solver::IsADifference(IntExpr* expr, IntExpr** const left,
7425  IntExpr** const right) {
7426  if (expr->IsVar()) {
7427  IntVar* const expr_var = expr->Var();
7428  expr = CastExpression(expr_var);
7429  }
7430  // This is a dynamic cast to check the type of expr.
7431  // It returns nullptr is expr is not a subclass of SubIntExpr.
7432  SubIntExpr* const sub_expr = dynamic_cast<SubIntExpr*>(expr);
7433  if (sub_expr != nullptr) {
7434  *left = sub_expr->left();
7435  *right = sub_expr->right();
7436  return true;
7437  }
7438  return false;
7439 }
7440 
7441 bool Solver::IsBooleanVar(IntExpr* const expr, IntVar** inner_var,
7442  bool* is_negated) const {
7443  if (expr->IsVar() && expr->Var()->VarType() == BOOLEAN_VAR) {
7444  *inner_var = expr->Var();
7445  *is_negated = false;
7446  return true;
7447  } else if (expr->IsVar() && expr->Var()->VarType() == CST_SUB_VAR) {
7448  SubCstIntVar* const sub_var = reinterpret_cast<SubCstIntVar*>(expr);
7449  if (sub_var != nullptr && sub_var->Constant() == 1 &&
7450  sub_var->SubVar()->VarType() == BOOLEAN_VAR) {
7451  *is_negated = true;
7452  *inner_var = sub_var->SubVar();
7453  return true;
7454  }
7455  }
7456  return false;
7457 }
7458 
7459 bool Solver::IsProduct(IntExpr* const expr, IntExpr** inner_expr,
7460  int64* coefficient) {
7461  if (dynamic_cast<TimesCstIntVar*>(expr) != nullptr) {
7462  TimesCstIntVar* const var = dynamic_cast<TimesCstIntVar*>(expr);
7463  *coefficient = var->Constant();
7464  *inner_expr = var->SubVar();
7465  return true;
7466  } else if (dynamic_cast<TimesIntCstExpr*>(expr) != nullptr) {
7467  TimesIntCstExpr* const prod = dynamic_cast<TimesIntCstExpr*>(expr);
7468  *coefficient = prod->Constant();
7469  *inner_expr = prod->Expr();
7470  return true;
7471  }
7472  *inner_expr = expr;
7473  *coefficient = 1;
7474  return false;
7475 }
7476 
7477 #undef COND_REV_ALLOC
7478 
7479 } // namespace operations_research
var
IntVar * var
Definition: expr_array.cc:1858
operations_research::internal::BitLength64
uint64 BitLength64(uint64 size)
Definition: bitmap.h:26
operations_research::ToInt64Vector
std::vector< int64 > ToInt64Vector(const std::vector< int > &input)
Definition: utilities.cc:822
min
int64 min
Definition: alldiff_cst.cc:138
integral_types.h
map_util.h
operations_research::OPP_VAR
@ OPP_VAR
Definition: constraint_solveri.h:131
operations_research::SimpleRevFIFO::PushIfNotTop
void PushIfNotTop(Solver *const s, T val)
Pushes the var on top if is not a duplicate of the current top object.
Definition: constraint_solveri.h:190
operations_research::CapSub
int64 CapSub(int64 x, int64 y)
Definition: saturated_arithmetic.h:154
max
int64 max
Definition: alldiff_cst.cc:139
operations_research::LinkVarExpr
void LinkVarExpr(Solver *const s, IntExpr *const expr, IntVar *const var)
Definition: expressions.cc:7394
limit_
const int64 limit_
Definition: expressions.cc:5446
current_
int64 current_
Definition: search.cc:2947
operations_research::PiecewiseLinearExpr::SetMax
void SetMax(int64 m) override
Definition: expressions.cc:7175
operations_research::PiecewiseLinearExpr::WhenRange
void WhenRange(Demon *d) override
Definition: expressions.cc:7196
operations_research::CapProd
int64 CapProd(int64 x, int64 y)
Definition: saturated_arithmetic.h:231
operations_research::CST_SUB_VAR
@ CST_SUB_VAR
Definition: constraint_solveri.h:130
operations_research::PiecewiseLinearExpr::Min
int64 Min() const override
Definition: expressions.cc:7162
operations_research::CapOpp
int64 CapOpp(int64 v)
Definition: saturated_arithmetic.h:163
GG_ULONGLONG
#define GG_ULONGLONG(x)
Definition: integral_types.h:47
logging.h
expr_
IntVar *const expr_
Definition: element.cc:85
operations_research::PiecewiseLinearFunction
Definition: piecewise_linear_function.h:101
stamp_
const int64 stamp_
Definition: search.cc:3033
operations_research::PiecewiseLinearExpr::SetMin
void SetMin(int64 m) override
Definition: expressions.cc:7165
operations_research::CleanVariableOnFail
void CleanVariableOnFail(IntVar *const var)
Definition: expressions.cc:6325
value
int64 value
Definition: demon_profiler.cc:43
operations_research::OneRange64
uint64 OneRange64(uint64 s, uint64 e)
Definition: bitset.h:285
saturated_arithmetic.h
operations_research
The vehicle routing library lets one model and solve generic vehicle routing problems ranging from th...
Definition: dense_doubly_linked_list.h:21
operations_research::PiecewiseLinearExpr::DebugString
std::string DebugString() const override
Definition: expressions.cc:7191
COND_REV_ALLOC
#define COND_REV_ALLOC(rev, alloc)
Definition: expressions.cc:2460
operations_research::internal::BitPos64
uint64 BitPos64(uint64 pos)
Definition: bitmap.h:24
operations_research::PosIntDivUp
int64 PosIntDivUp(int64 e, int64 v)
Definition: constraint_solveri.h:2993
operations_research::InternalSaveBooleanVarValue
void InternalSaveBooleanVarValue(Solver *const solver, IntVar *const var)
Definition: constraint_solver.cc:940
operations_research::BitCountRange64
uint64 BitCountRange64(const uint64 *const bitset, uint64 start, uint64 end)
kint64min
static const int64 kint64min
Definition: integral_types.h:60
holes_
std::vector< IntVarIterator * > holes_
Definition: constraint_solver/table.cc:224
operations_research::VAR_ADD_CST
@ VAR_ADD_CST
Definition: constraint_solveri.h:128
operations_research::SubOverflows
int64 SubOverflows(int64 x, int64 y)
Definition: saturated_arithmetic.h:80
int64
int64_t int64
Definition: integral_types.h:34
constraint_solveri.h
pow_
const int64 pow_
Definition: expressions.cc:5445
index
int index
Definition: pack.cc:508
operations_research::RegisterDemon
void RegisterDemon(Solver *const solver, Demon *const demon, DemonProfiler *const monitor)
Definition: demon_profiler.cc:460
offset_
const int64 offset_
Definition: interval.cc:2076
cst_
const int64 cst_
Definition: expressions.cc:2711
operations_research::PiecewiseLinearExpr::Accept
void Accept(ModelVisitor *const visitor) const override
Definition: expressions.cc:7198
operations_research::BOOLEAN_VAR
@ BOOLEAN_VAR
Definition: constraint_solveri.h:126
operations_research::PiecewiseLinearExpr::SetRange
void SetRange(int64 l, int64 u) override
Definition: expressions.cc:7181
operations_research::RestoreBoolValue
void RestoreBoolValue(IntVar *const var)
Definition: expressions.cc:6346
a
int64 a
Definition: constraint_solver/table.cc:42
constraint_solver.h
handler_
Handler handler_
Definition: interval.cc:420
string_array.h
operations_research::CONST_VAR
@ CONST_VAR
Definition: constraint_solveri.h:127
kint32max
static const int32 kint32max
Definition: integral_types.h:59
operations_research::PiecewiseLinearExpr::~PiecewiseLinearExpr
~PiecewiseLinearExpr() override
Definition: expressions.cc:7161
operations_research::BooleanVar::RestoreValue
virtual void RestoreValue()=0
operations_research::UnsafeLeastSignificantBitPosition64
int64 UnsafeLeastSignificantBitPosition64(const uint64 *const bitset, uint64 start, uint64 end)
operations_research::CapAdd
int64 CapAdd(int64 x, int64 y)
Definition: saturated_arithmetic.h:124
gtl::STLSortAndRemoveDuplicates
void STLSortAndRemoveDuplicates(T *v, const LessFunc &less_func)
Definition: stl_util.h:58
mathutil.h
operations_research::SetIsGreaterOrEqual
Constraint * SetIsGreaterOrEqual(IntVar *const var, const std::vector< int64 > &values, const std::vector< IntVar * > &vars)
Definition: expressions.cc:6338
operations_research::PiecewiseLinearExpr::name
std::string name() const override
Definition: expressions.cc:7186
target_var_
IntervalVar *const target_var_
Definition: sched_constraints.cc:230
operations_research::internal::OneBit64
uint64 OneBit64(int pos)
Definition: bitmap.h:23
ct
const Constraint * ct
Definition: demon_profiler.cc:42
uint64
uint64_t uint64
Definition: integral_types.h:39
operations_research::sat::Value
std::function< int64(const Model &)> Value(IntegerVariable v)
Definition: integer.h:1396
operations_research::UnsafeMostSignificantBitPosition64
int64 UnsafeMostSignificantBitPosition64(const uint64 *const bitset, uint64 start, uint64 end)
operations_research::MostSignificantBitPosition64
int MostSignificantBitPosition64(uint64 n)
Definition: bitset.h:231
iterator_
IntVarIterator *const iterator_
Definition: expressions.cc:2180
operations_research::PiecewiseLinearExpr::Max
int64 Max() const override
Definition: expressions.cc:7171
operations_research::PiecewiseLinearExpr::PiecewiseLinearExpr
PiecewiseLinearExpr(Solver *solver, IntExpr *expr, const PiecewiseLinearFunction &f)
Definition: expressions.cc:7158
operations_research::internal::IsBitSet64
bool IsBitSet64(const uint64 *const bitset, uint64 pos)
Definition: bitmap.h:27
operations_research::BooleanVar
Definition: constraint_solveri.h:1944
DEFINE_bool
DEFINE_bool(cp_disable_expression_optimization, false, "Disable special optimization when creating expressions.")
coefficient
int64 coefficient
Definition: routing_search.cc:973
operations_research::BitCount64
uint64 BitCount64(uint64 n)
Definition: bitset.h:42
operations_research::VAR_TIMES_CST
@ VAR_TIMES_CST
Definition: constraint_solveri.h:129
stl_util.h
operations_research::internal::BitOffset64
uint64 BitOffset64(uint64 pos)
Definition: bitmap.h:25
operations_research::PiecewiseLinearExpr
Definition: expressions.cc:7156
operations_research::UNSPECIFIED
@ UNSPECIFIED
Definition: constraint_solveri.h:124
operations_research::kAllBits64
static const uint64 kAllBits64
Definition: bitset.h:33
in_process_
bool in_process_
Definition: interval.cc:419
google::protobuf::util::RemoveAt
int RemoveAt(RepeatedType *array, const IndexContainer &indices)
Definition: protobuf_util.h:40
b
int64 b
Definition: constraint_solver/table.cc:43
DISALLOW_COPY_AND_ASSIGN
#define DISALLOW_COPY_AND_ASSIGN(TypeName)
Definition: macros.h:29
operations_research::DOMAIN_INT_VAR
@ DOMAIN_INT_VAR
Definition: constraint_solveri.h:125
operations_research::AddOverflows
bool AddOverflows(int64 x, int64 y)
Definition: saturated_arithmetic.h:76
operations_research::PosIntDivDown
int64 PosIntDivDown(int64 e, int64 v)
Definition: constraint_solveri.h:3002
gtl::FindPtrOrNull
const Collection::value_type::second_type FindPtrOrNull(const Collection &collection, const typename Collection::value_type::first_type &key)
Definition: map_util.h:70
operations_research::MakeConstraintDemon0
Demon * MakeConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
Definition: constraint_solveri.h:532
solver_
Solver *const solver_
Definition: expressions.cc:273
operations_research::BaseIntExpr
This is the base class for all expressions that are not variables.
Definition: constraint_solveri.h:109
step_
int64 step_
Definition: search.cc:2946
commandlineflags.h
operations_research::SetIsEqual
Constraint * SetIsEqual(IntVar *const var, const std::vector< int64 > &values, const std::vector< IntVar * > &vars)
Definition: expressions.cc:6331
name
const std::string name
Definition: default_search.cc:807
bitset.h
kint64max
static const int64 kint64max
Definition: integral_types.h:62
operations_research::LeastSignificantBitPosition64
int LeastSignificantBitPosition64(uint64 n)
Definition: bitset.h:127