OR-Tools  9.3
expr_cst.cc
Go to the documentation of this file.
1// Copyright 2010-2021 Google LLC
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//
15// Expression constraints
16
17#include <cstddef>
18#include <cstdint>
19#include <limits>
20#include <set>
21#include <string>
22#include <vector>
23
24#include "absl/strings/str_format.h"
25#include "absl/strings/str_join.h"
34
35ABSL_FLAG(int, cache_initial_size, 1024,
36 "Initial size of the array of the hash "
37 "table of caches for objects of type Var(x == 3)");
38
39namespace operations_research {
40
41//-----------------------------------------------------------------------------
42// Equality
43
44namespace {
45class EqualityExprCst : public Constraint {
46 public:
47 EqualityExprCst(Solver* const s, IntExpr* const e, int64_t v);
48 ~EqualityExprCst() override {}
49 void Post() override;
50 void InitialPropagate() override;
51 IntVar* Var() override {
52 return solver()->MakeIsEqualCstVar(expr_->Var(), value_);
53 }
54 std::string DebugString() const override;
55
56 void Accept(ModelVisitor* const visitor) const override {
57 visitor->BeginVisitConstraint(ModelVisitor::kEquality, this);
58 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
59 expr_);
60 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
61 visitor->EndVisitConstraint(ModelVisitor::kEquality, this);
62 }
63
64 private:
65 IntExpr* const expr_;
66 int64_t value_;
67};
68
69EqualityExprCst::EqualityExprCst(Solver* const s, IntExpr* const e, int64_t v)
70 : Constraint(s), expr_(e), value_(v) {}
71
72void EqualityExprCst::Post() {
73 if (!expr_->IsVar()) {
74 Demon* d = solver()->MakeConstraintInitialPropagateCallback(this);
75 expr_->WhenRange(d);
76 }
77}
78
79void EqualityExprCst::InitialPropagate() { expr_->SetValue(value_); }
80
81std::string EqualityExprCst::DebugString() const {
82 return absl::StrFormat("(%s == %d)", expr_->DebugString(), value_);
83}
84} // namespace
85
86Constraint* Solver::MakeEquality(IntExpr* const e, int64_t v) {
87 CHECK_EQ(this, e->solver());
88 IntExpr* left = nullptr;
89 IntExpr* right = nullptr;
90 if (IsADifference(e, &left, &right)) {
91 return MakeEquality(left, MakeSum(right, v));
92 } else if (e->IsVar() && !e->Var()->Contains(v)) {
93 return MakeFalseConstraint();
94 } else if (e->Min() == e->Max() && e->Min() == v) {
95 return MakeTrueConstraint();
96 } else {
97 return RevAlloc(new EqualityExprCst(this, e, v));
98 }
99}
100
101Constraint* Solver::MakeEquality(IntExpr* const e, int v) {
102 CHECK_EQ(this, e->solver());
103 IntExpr* left = nullptr;
104 IntExpr* right = nullptr;
105 if (IsADifference(e, &left, &right)) {
106 return MakeEquality(left, MakeSum(right, v));
107 } else if (e->IsVar() && !e->Var()->Contains(v)) {
108 return MakeFalseConstraint();
109 } else if (e->Min() == e->Max() && e->Min() == v) {
110 return MakeTrueConstraint();
111 } else {
112 return RevAlloc(new EqualityExprCst(this, e, v));
113 }
114}
115
116//-----------------------------------------------------------------------------
117// Greater or equal constraint
118
119namespace {
120class GreaterEqExprCst : public Constraint {
121 public:
122 GreaterEqExprCst(Solver* const s, IntExpr* const e, int64_t v);
123 ~GreaterEqExprCst() override {}
124 void Post() override;
125 void InitialPropagate() override;
126 std::string DebugString() const override;
127 IntVar* Var() override {
128 return solver()->MakeIsGreaterOrEqualCstVar(expr_->Var(), value_);
129 }
130
131 void Accept(ModelVisitor* const visitor) const override {
132 visitor->BeginVisitConstraint(ModelVisitor::kGreaterOrEqual, this);
133 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
134 expr_);
135 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
136 visitor->EndVisitConstraint(ModelVisitor::kGreaterOrEqual, this);
137 }
138
139 private:
140 IntExpr* const expr_;
141 int64_t value_;
142 Demon* demon_;
143};
144
145GreaterEqExprCst::GreaterEqExprCst(Solver* const s, IntExpr* const e, int64_t v)
146 : Constraint(s), expr_(e), value_(v), demon_(nullptr) {}
147
148void GreaterEqExprCst::Post() {
149 if (!expr_->IsVar() && expr_->Min() < value_) {
150 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
151 expr_->WhenRange(demon_);
152 } else {
153 // Let's clean the demon in case the constraint is posted during search.
154 demon_ = nullptr;
155 }
156}
157
158void GreaterEqExprCst::InitialPropagate() {
159 expr_->SetMin(value_);
160 if (demon_ != nullptr && expr_->Min() >= value_) {
161 demon_->inhibit(solver());
162 }
163}
164
165std::string GreaterEqExprCst::DebugString() const {
166 return absl::StrFormat("(%s >= %d)", expr_->DebugString(), value_);
167}
168} // namespace
169
171 CHECK_EQ(this, e->solver());
172 if (e->Min() >= v) {
173 return MakeTrueConstraint();
174 } else if (e->Max() < v) {
175 return MakeFalseConstraint();
176 } else {
177 return RevAlloc(new GreaterEqExprCst(this, e, v));
178 }
179}
180
182 CHECK_EQ(this, e->solver());
183 if (e->Min() >= v) {
184 return MakeTrueConstraint();
185 } else if (e->Max() < v) {
186 return MakeFalseConstraint();
187 } else {
188 return RevAlloc(new GreaterEqExprCst(this, e, v));
189 }
190}
191
193 CHECK_EQ(this, e->solver());
194 if (e->Min() > v) {
195 return MakeTrueConstraint();
196 } else if (e->Max() <= v) {
197 return MakeFalseConstraint();
198 } else {
199 return RevAlloc(new GreaterEqExprCst(this, e, v + 1));
200 }
201}
202
204 CHECK_EQ(this, e->solver());
205 if (e->Min() > v) {
206 return MakeTrueConstraint();
207 } else if (e->Max() <= v) {
208 return MakeFalseConstraint();
209 } else {
210 return RevAlloc(new GreaterEqExprCst(this, e, v + 1));
211 }
212}
213
214//-----------------------------------------------------------------------------
215// Less or equal constraint
216
217namespace {
218class LessEqExprCst : public Constraint {
219 public:
220 LessEqExprCst(Solver* const s, IntExpr* const e, int64_t v);
221 ~LessEqExprCst() override {}
222 void Post() override;
223 void InitialPropagate() override;
224 std::string DebugString() const override;
225 IntVar* Var() override {
226 return solver()->MakeIsLessOrEqualCstVar(expr_->Var(), value_);
227 }
228 void Accept(ModelVisitor* const visitor) const override {
229 visitor->BeginVisitConstraint(ModelVisitor::kLessOrEqual, this);
230 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
231 expr_);
232 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
233 visitor->EndVisitConstraint(ModelVisitor::kLessOrEqual, this);
234 }
235
236 private:
237 IntExpr* const expr_;
238 int64_t value_;
239 Demon* demon_;
240};
241
242LessEqExprCst::LessEqExprCst(Solver* const s, IntExpr* const e, int64_t v)
243 : Constraint(s), expr_(e), value_(v), demon_(nullptr) {}
244
245void LessEqExprCst::Post() {
246 if (!expr_->IsVar() && expr_->Max() > value_) {
247 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
248 expr_->WhenRange(demon_);
249 } else {
250 // Let's clean the demon in case the constraint is posted during search.
251 demon_ = nullptr;
252 }
253}
254
255void LessEqExprCst::InitialPropagate() {
256 expr_->SetMax(value_);
257 if (demon_ != nullptr && expr_->Max() <= value_) {
258 demon_->inhibit(solver());
259 }
260}
261
262std::string LessEqExprCst::DebugString() const {
263 return absl::StrFormat("(%s <= %d)", expr_->DebugString(), value_);
264}
265} // namespace
266
268 CHECK_EQ(this, e->solver());
269 if (e->Max() <= v) {
270 return MakeTrueConstraint();
271 } else if (e->Min() > v) {
272 return MakeFalseConstraint();
273 } else {
274 return RevAlloc(new LessEqExprCst(this, e, v));
275 }
276}
277
279 CHECK_EQ(this, e->solver());
280 if (e->Max() <= v) {
281 return MakeTrueConstraint();
282 } else if (e->Min() > v) {
283 return MakeFalseConstraint();
284 } else {
285 return RevAlloc(new LessEqExprCst(this, e, v));
286 }
287}
288
289Constraint* Solver::MakeLess(IntExpr* const e, int64_t v) {
290 CHECK_EQ(this, e->solver());
291 if (e->Max() < v) {
292 return MakeTrueConstraint();
293 } else if (e->Min() >= v) {
294 return MakeFalseConstraint();
295 } else {
296 return RevAlloc(new LessEqExprCst(this, e, v - 1));
297 }
298}
299
301 CHECK_EQ(this, e->solver());
302 if (e->Max() < v) {
303 return MakeTrueConstraint();
304 } else if (e->Min() >= v) {
305 return MakeFalseConstraint();
306 } else {
307 return RevAlloc(new LessEqExprCst(this, e, v - 1));
308 }
309}
310
311//-----------------------------------------------------------------------------
312// Different constraints
313
314namespace {
315class DiffCst : public Constraint {
316 public:
317 DiffCst(Solver* const s, IntVar* const var, int64_t value);
318 ~DiffCst() override {}
319 void Post() override {}
320 void InitialPropagate() override;
321 void BoundPropagate();
322 std::string DebugString() const override;
323 IntVar* Var() override {
324 return solver()->MakeIsDifferentCstVar(var_, value_);
325 }
326 void Accept(ModelVisitor* const visitor) const override {
327 visitor->BeginVisitConstraint(ModelVisitor::kNonEqual, this);
328 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
329 var_);
330 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
331 visitor->EndVisitConstraint(ModelVisitor::kNonEqual, this);
332 }
333
334 private:
335 bool HasLargeDomain(IntVar* var);
336
337 IntVar* const var_;
338 int64_t value_;
339 Demon* demon_;
340};
341
342DiffCst::DiffCst(Solver* const s, IntVar* const var, int64_t value)
343 : Constraint(s), var_(var), value_(value), demon_(nullptr) {}
344
345void DiffCst::InitialPropagate() {
346 if (HasLargeDomain(var_)) {
347 demon_ = MakeConstraintDemon0(solver(), this, &DiffCst::BoundPropagate,
348 "BoundPropagate");
349 var_->WhenRange(demon_);
350 } else {
351 var_->RemoveValue(value_);
352 }
353}
354
355void DiffCst::BoundPropagate() {
356 const int64_t var_min = var_->Min();
357 const int64_t var_max = var_->Max();
358 if (var_min > value_ || var_max < value_) {
359 demon_->inhibit(solver());
360 } else if (var_min == value_) {
361 var_->SetMin(value_ + 1);
362 } else if (var_max == value_) {
363 var_->SetMax(value_ - 1);
364 } else if (!HasLargeDomain(var_)) {
365 demon_->inhibit(solver());
366 var_->RemoveValue(value_);
367 }
368}
369
370std::string DiffCst::DebugString() const {
371 return absl::StrFormat("(%s != %d)", var_->DebugString(), value_);
372}
373
374bool DiffCst::HasLargeDomain(IntVar* var) {
375 return CapSub(var->Max(), var->Min()) > 0xFFFFFF;
376}
377} // namespace
378
380 CHECK_EQ(this, e->solver());
381 IntExpr* left = nullptr;
382 IntExpr* right = nullptr;
383 if (IsADifference(e, &left, &right)) {
384 return MakeNonEquality(left, MakeSum(right, v));
385 } else if (e->IsVar() && !e->Var()->Contains(v)) {
386 return MakeTrueConstraint();
387 } else if (e->Bound() && e->Min() == v) {
388 return MakeFalseConstraint();
389 } else {
390 return RevAlloc(new DiffCst(this, e->Var(), v));
391 }
392}
393
395 CHECK_EQ(this, e->solver());
396 IntExpr* left = nullptr;
397 IntExpr* right = nullptr;
398 if (IsADifference(e, &left, &right)) {
399 return MakeNonEquality(left, MakeSum(right, v));
400 } else if (e->IsVar() && !e->Var()->Contains(v)) {
401 return MakeTrueConstraint();
402 } else if (e->Bound() && e->Min() == v) {
403 return MakeFalseConstraint();
404 } else {
405 return RevAlloc(new DiffCst(this, e->Var(), v));
406 }
407}
408// ----- is_equal_cst Constraint -----
409
410namespace {
411class IsEqualCstCt : public CastConstraint {
412 public:
413 IsEqualCstCt(Solver* const s, IntVar* const v, int64_t c, IntVar* const b)
414 : CastConstraint(s, b), var_(v), cst_(c), demon_(nullptr) {}
415 void Post() override {
416 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
417 var_->WhenDomain(demon_);
418 target_var_->WhenBound(demon_);
419 }
420 void InitialPropagate() override {
421 bool inhibit = var_->Bound();
422 int64_t u = var_->Contains(cst_);
423 int64_t l = inhibit ? u : 0;
424 target_var_->SetRange(l, u);
425 if (target_var_->Bound()) {
426 if (target_var_->Min() == 0) {
427 if (var_->Size() <= 0xFFFFFF) {
428 var_->RemoveValue(cst_);
429 inhibit = true;
430 }
431 } else {
432 var_->SetValue(cst_);
433 inhibit = true;
434 }
435 }
436 if (inhibit) {
437 demon_->inhibit(solver());
438 }
439 }
440 std::string DebugString() const override {
441 return absl::StrFormat("IsEqualCstCt(%s, %d, %s)", var_->DebugString(),
442 cst_, target_var_->DebugString());
443 }
444
445 void Accept(ModelVisitor* const visitor) const override {
446 visitor->BeginVisitConstraint(ModelVisitor::kIsEqual, this);
447 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
448 var_);
449 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
450 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
452 visitor->EndVisitConstraint(ModelVisitor::kIsEqual, this);
453 }
454
455 private:
456 IntVar* const var_;
457 int64_t cst_;
458 Demon* demon_;
459};
460} // namespace
461
463 IntExpr* left = nullptr;
464 IntExpr* right = nullptr;
465 if (IsADifference(var, &left, &right)) {
466 return MakeIsEqualVar(left, MakeSum(right, value));
467 }
468 if (CapSub(var->Max(), var->Min()) == 1) {
469 if (value == var->Min()) {
470 return MakeDifference(value + 1, var)->Var();
471 } else if (value == var->Max()) {
472 return MakeSum(var, -value + 1)->Var();
473 } else {
474 return MakeIntConst(0);
475 }
476 }
477 if (var->IsVar()) {
478 return var->Var()->IsEqual(value);
479 } else {
480 IntVar* const boolvar =
481 MakeBoolVar(absl::StrFormat("Is(%s == %d)", var->DebugString(), value));
483 return boolvar;
484 }
485}
486
488 IntVar* const boolvar) {
489 CHECK_EQ(this, var->solver());
490 CHECK_EQ(this, boolvar->solver());
491 if (value == var->Min()) {
492 if (CapSub(var->Max(), var->Min()) == 1) {
493 return MakeEquality(MakeDifference(value + 1, var), boolvar);
494 }
495 return MakeIsLessOrEqualCstCt(var, value, boolvar);
496 }
497 if (value == var->Max()) {
498 if (CapSub(var->Max(), var->Min()) == 1) {
499 return MakeEquality(MakeSum(var, -value + 1), boolvar);
500 }
501 return MakeIsGreaterOrEqualCstCt(var, value, boolvar);
502 }
503 if (boolvar->Bound()) {
504 if (boolvar->Min() == 0) {
505 return MakeNonEquality(var, value);
506 } else {
507 return MakeEquality(var, value);
508 }
509 }
510 // TODO(user) : what happens if the constraint is not posted?
511 // The cache becomes tainted.
512 model_cache_->InsertExprConstantExpression(
514 IntExpr* left = nullptr;
515 IntExpr* right = nullptr;
516 if (IsADifference(var, &left, &right)) {
517 return MakeIsEqualCt(left, MakeSum(right, value), boolvar);
518 } else {
519 return RevAlloc(new IsEqualCstCt(this, var->Var(), value, boolvar));
520 }
521}
522
523// ----- is_diff_cst Constraint -----
524
525namespace {
526class IsDiffCstCt : public CastConstraint {
527 public:
528 IsDiffCstCt(Solver* const s, IntVar* const v, int64_t c, IntVar* const b)
529 : CastConstraint(s, b), var_(v), cst_(c), demon_(nullptr) {}
530
531 void Post() override {
532 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
533 var_->WhenDomain(demon_);
534 target_var_->WhenBound(demon_);
535 }
536
537 void InitialPropagate() override {
538 bool inhibit = var_->Bound();
539 int64_t l = 1 - var_->Contains(cst_);
540 int64_t u = inhibit ? l : 1;
541 target_var_->SetRange(l, u);
542 if (target_var_->Bound()) {
543 if (target_var_->Min() == 1) {
544 if (var_->Size() <= 0xFFFFFF) {
545 var_->RemoveValue(cst_);
546 inhibit = true;
547 }
548 } else {
549 var_->SetValue(cst_);
550 inhibit = true;
551 }
552 }
553 if (inhibit) {
554 demon_->inhibit(solver());
555 }
556 }
557
558 std::string DebugString() const override {
559 return absl::StrFormat("IsDiffCstCt(%s, %d, %s)", var_->DebugString(), cst_,
560 target_var_->DebugString());
561 }
562
563 void Accept(ModelVisitor* const visitor) const override {
564 visitor->BeginVisitConstraint(ModelVisitor::kIsDifferent, this);
565 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
566 var_);
567 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
568 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
570 visitor->EndVisitConstraint(ModelVisitor::kIsDifferent, this);
571 }
572
573 private:
574 IntVar* const var_;
575 int64_t cst_;
576 Demon* demon_;
577};
578} // namespace
579
581 IntExpr* left = nullptr;
582 IntExpr* right = nullptr;
583 if (IsADifference(var, &left, &right)) {
584 return MakeIsDifferentVar(left, MakeSum(right, value));
585 }
586 return var->Var()->IsDifferent(value);
587}
588
590 IntVar* const boolvar) {
591 CHECK_EQ(this, var->solver());
592 CHECK_EQ(this, boolvar->solver());
593 if (value == var->Min()) {
594 return MakeIsGreaterOrEqualCstCt(var, value + 1, boolvar);
595 }
596 if (value == var->Max()) {
597 return MakeIsLessOrEqualCstCt(var, value - 1, boolvar);
598 }
599 if (var->IsVar() && !var->Var()->Contains(value)) {
600 return MakeEquality(boolvar, int64_t{1});
601 }
602 if (var->Bound() && var->Min() == value) {
603 return MakeEquality(boolvar, Zero());
604 }
605 if (boolvar->Bound()) {
606 if (boolvar->Min() == 0) {
607 return MakeEquality(var, value);
608 } else {
609 return MakeNonEquality(var, value);
610 }
611 }
612 model_cache_->InsertExprConstantExpression(
614 IntExpr* left = nullptr;
615 IntExpr* right = nullptr;
616 if (IsADifference(var, &left, &right)) {
617 return MakeIsDifferentCt(left, MakeSum(right, value), boolvar);
618 } else {
619 return RevAlloc(new IsDiffCstCt(this, var->Var(), value, boolvar));
620 }
621}
622
623// ----- is_greater_equal_cst Constraint -----
624
625namespace {
626class IsGreaterEqualCstCt : public CastConstraint {
627 public:
628 IsGreaterEqualCstCt(Solver* const s, IntExpr* const v, int64_t c,
629 IntVar* const b)
630 : CastConstraint(s, b), expr_(v), cst_(c), demon_(nullptr) {}
631 void Post() override {
632 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
633 expr_->WhenRange(demon_);
634 target_var_->WhenBound(demon_);
635 }
636 void InitialPropagate() override {
637 bool inhibit = false;
638 int64_t u = expr_->Max() >= cst_;
639 int64_t l = expr_->Min() >= cst_;
640 target_var_->SetRange(l, u);
641 if (target_var_->Bound()) {
642 inhibit = true;
643 if (target_var_->Min() == 0) {
644 expr_->SetMax(cst_ - 1);
645 } else {
646 expr_->SetMin(cst_);
647 }
648 }
649 if (inhibit && ((target_var_->Max() == 0 && expr_->Max() < cst_) ||
650 (target_var_->Min() == 1 && expr_->Min() >= cst_))) {
651 // Can we safely inhibit? Sometimes an expression is not
652 // persistent, just monotonic.
653 demon_->inhibit(solver());
654 }
655 }
656 std::string DebugString() const override {
657 return absl::StrFormat("IsGreaterEqualCstCt(%s, %d, %s)",
658 expr_->DebugString(), cst_,
659 target_var_->DebugString());
660 }
661
662 void Accept(ModelVisitor* const visitor) const override {
663 visitor->BeginVisitConstraint(ModelVisitor::kIsGreaterOrEqual, this);
664 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
665 expr_);
666 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
667 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
669 visitor->EndVisitConstraint(ModelVisitor::kIsGreaterOrEqual, this);
670 }
671
672 private:
673 IntExpr* const expr_;
674 int64_t cst_;
675 Demon* demon_;
676};
677} // namespace
678
680 if (var->Min() >= value) {
681 return MakeIntConst(int64_t{1});
682 }
683 if (var->Max() < value) {
684 return MakeIntConst(int64_t{0});
685 }
686 if (var->IsVar()) {
687 return var->Var()->IsGreaterOrEqual(value);
688 } else {
689 IntVar* const boolvar =
690 MakeBoolVar(absl::StrFormat("Is(%s >= %d)", var->DebugString(), value));
692 return boolvar;
693 }
694}
695
698}
699
701 IntVar* const boolvar) {
702 if (boolvar->Bound()) {
703 if (boolvar->Min() == 0) {
704 return MakeLess(var, value);
705 } else {
707 }
708 }
709 CHECK_EQ(this, var->solver());
710 CHECK_EQ(this, boolvar->solver());
711 model_cache_->InsertExprConstantExpression(
713 return RevAlloc(new IsGreaterEqualCstCt(this, var, value, boolvar));
714}
715
717 IntVar* const b) {
718 return MakeIsGreaterOrEqualCstCt(v, c + 1, b);
719}
720
721// ----- is_lesser_equal_cst Constraint -----
722
723namespace {
724class IsLessEqualCstCt : public CastConstraint {
725 public:
726 IsLessEqualCstCt(Solver* const s, IntExpr* const v, int64_t c,
727 IntVar* const b)
728 : CastConstraint(s, b), expr_(v), cst_(c), demon_(nullptr) {}
729
730 void Post() override {
731 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
732 expr_->WhenRange(demon_);
733 target_var_->WhenBound(demon_);
734 }
735
736 void InitialPropagate() override {
737 bool inhibit = false;
738 int64_t u = expr_->Min() <= cst_;
739 int64_t l = expr_->Max() <= cst_;
740 target_var_->SetRange(l, u);
741 if (target_var_->Bound()) {
742 inhibit = true;
743 if (target_var_->Min() == 0) {
744 expr_->SetMin(cst_ + 1);
745 } else {
746 expr_->SetMax(cst_);
747 }
748 }
749 if (inhibit && ((target_var_->Max() == 0 && expr_->Min() > cst_) ||
750 (target_var_->Min() == 1 && expr_->Max() <= cst_))) {
751 // Can we safely inhibit? Sometimes an expression is not
752 // persistent, just monotonic.
753 demon_->inhibit(solver());
754 }
755 }
756
757 std::string DebugString() const override {
758 return absl::StrFormat("IsLessEqualCstCt(%s, %d, %s)", expr_->DebugString(),
759 cst_, target_var_->DebugString());
760 }
761
762 void Accept(ModelVisitor* const visitor) const override {
763 visitor->BeginVisitConstraint(ModelVisitor::kIsLessOrEqual, this);
764 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
765 expr_);
766 visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
767 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
769 visitor->EndVisitConstraint(ModelVisitor::kIsLessOrEqual, this);
770 }
771
772 private:
773 IntExpr* const expr_;
774 int64_t cst_;
775 Demon* demon_;
776};
777} // namespace
778
780 if (var->Max() <= value) {
781 return MakeIntConst(int64_t{1});
782 }
783 if (var->Min() > value) {
784 return MakeIntConst(int64_t{0});
785 }
786 if (var->IsVar()) {
787 return var->Var()->IsLessOrEqual(value);
788 } else {
789 IntVar* const boolvar =
790 MakeBoolVar(absl::StrFormat("Is(%s <= %d)", var->DebugString(), value));
792 return boolvar;
793 }
794}
795
797 return MakeIsLessOrEqualCstVar(var, value - 1);
798}
799
801 IntVar* const boolvar) {
802 if (boolvar->Bound()) {
803 if (boolvar->Min() == 0) {
804 return MakeGreater(var, value);
805 } else {
806 return MakeLessOrEqual(var, value);
807 }
808 }
809 CHECK_EQ(this, var->solver());
810 CHECK_EQ(this, boolvar->solver());
811 model_cache_->InsertExprConstantExpression(
813 return RevAlloc(new IsLessEqualCstCt(this, var, value, boolvar));
814}
815
817 IntVar* const b) {
818 return MakeIsLessOrEqualCstCt(v, c - 1, b);
819}
820
821// ----- BetweenCt -----
822
823namespace {
824class BetweenCt : public Constraint {
825 public:
826 BetweenCt(Solver* const s, IntExpr* const v, int64_t l, int64_t u)
827 : Constraint(s), expr_(v), min_(l), max_(u), demon_(nullptr) {}
828
829 void Post() override {
830 if (!expr_->IsVar()) {
831 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
832 expr_->WhenRange(demon_);
833 }
834 }
835
836 void InitialPropagate() override {
837 expr_->SetRange(min_, max_);
838 int64_t emin = 0;
839 int64_t emax = 0;
840 expr_->Range(&emin, &emax);
841 if (demon_ != nullptr && emin >= min_ && emax <= max_) {
842 demon_->inhibit(solver());
843 }
844 }
845
846 std::string DebugString() const override {
847 return absl::StrFormat("BetweenCt(%s, %d, %d)", expr_->DebugString(), min_,
848 max_);
849 }
850
851 void Accept(ModelVisitor* const visitor) const override {
852 visitor->BeginVisitConstraint(ModelVisitor::kBetween, this);
853 visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, min_);
854 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
855 expr_);
856 visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, max_);
857 visitor->EndVisitConstraint(ModelVisitor::kBetween, this);
858 }
859
860 private:
861 IntExpr* const expr_;
862 int64_t min_;
863 int64_t max_;
864 Demon* demon_;
865};
866
867// ----- NonMember constraint -----
868
869class NotBetweenCt : public Constraint {
870 public:
871 NotBetweenCt(Solver* const s, IntExpr* const v, int64_t l, int64_t u)
872 : Constraint(s), expr_(v), min_(l), max_(u), demon_(nullptr) {}
873
874 void Post() override {
875 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
876 expr_->WhenRange(demon_);
877 }
878
879 void InitialPropagate() override {
880 int64_t emin = 0;
881 int64_t emax = 0;
882 expr_->Range(&emin, &emax);
883 if (emin >= min_) {
884 expr_->SetMin(max_ + 1);
885 } else if (emax <= max_) {
886 expr_->SetMax(min_ - 1);
887 }
888
889 if (!expr_->IsVar() && (emax < min_ || emin > max_)) {
890 demon_->inhibit(solver());
891 }
892 }
893
894 std::string DebugString() const override {
895 return absl::StrFormat("NotBetweenCt(%s, %d, %d)", expr_->DebugString(),
896 min_, max_);
897 }
898
899 void Accept(ModelVisitor* const visitor) const override {
900 visitor->BeginVisitConstraint(ModelVisitor::kNotBetween, this);
901 visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, min_);
902 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
903 expr_);
904 visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, max_);
905 visitor->EndVisitConstraint(ModelVisitor::kBetween, this);
906 }
907
908 private:
909 IntExpr* const expr_;
910 int64_t min_;
911 int64_t max_;
912 Demon* demon_;
913};
914
915int64_t ExtractExprProductCoeff(IntExpr** expr) {
916 int64_t prod = 1;
917 int64_t coeff = 1;
918 while ((*expr)->solver()->IsProduct(*expr, expr, &coeff)) prod *= coeff;
919 return prod;
920}
921} // namespace
922
923Constraint* Solver::MakeBetweenCt(IntExpr* expr, int64_t l, int64_t u) {
924 DCHECK_EQ(this, expr->solver());
925 // Catch empty and singleton intervals.
926 if (l >= u) {
927 if (l > u) return MakeFalseConstraint();
928 return MakeEquality(expr, l);
929 }
930 int64_t emin = 0;
931 int64_t emax = 0;
932 expr->Range(&emin, &emax);
933 // Catch the trivial cases first.
934 if (emax < l || emin > u) return MakeFalseConstraint();
935 if (emin >= l && emax <= u) return MakeTrueConstraint();
936 // Catch one-sided constraints.
937 if (emax <= u) return MakeGreaterOrEqual(expr, l);
938 if (emin >= l) return MakeLessOrEqual(expr, u);
939 // Simplify the common factor, if any.
940 int64_t coeff = ExtractExprProductCoeff(&expr);
941 if (coeff != 1) {
942 CHECK_NE(coeff, 0); // Would have been caught by the trivial cases already.
943 if (coeff < 0) {
944 std::swap(u, l);
945 u = -u;
946 l = -l;
947 coeff = -coeff;
948 }
949 return MakeBetweenCt(expr, PosIntDivUp(l, coeff), PosIntDivDown(u, coeff));
950 } else {
951 // No further reduction is possible.
952 return RevAlloc(new BetweenCt(this, expr, l, u));
953 }
954}
955
956Constraint* Solver::MakeNotBetweenCt(IntExpr* expr, int64_t l, int64_t u) {
957 DCHECK_EQ(this, expr->solver());
958 // Catch empty interval.
959 if (l > u) {
960 return MakeTrueConstraint();
961 }
962
963 int64_t emin = 0;
964 int64_t emax = 0;
965 expr->Range(&emin, &emax);
966 // Catch the trivial cases first.
967 if (emax < l || emin > u) return MakeTrueConstraint();
968 if (emin >= l && emax <= u) return MakeFalseConstraint();
969 // Catch one-sided constraints.
970 if (emin >= l) return MakeGreater(expr, u);
971 if (emax <= u) return MakeLess(expr, l);
972 // TODO(user): Add back simplification code if expr is constant *
973 // other_expr.
974 return RevAlloc(new NotBetweenCt(this, expr, l, u));
975}
976
977// ----- is_between_cst Constraint -----
978
979namespace {
980class IsBetweenCt : public Constraint {
981 public:
982 IsBetweenCt(Solver* const s, IntExpr* const e, int64_t l, int64_t u,
983 IntVar* const b)
984 : Constraint(s),
985 expr_(e),
986 min_(l),
987 max_(u),
988 boolvar_(b),
989 demon_(nullptr) {}
990
991 void Post() override {
992 demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
993 expr_->WhenRange(demon_);
994 boolvar_->WhenBound(demon_);
995 }
996
997 void InitialPropagate() override {
998 bool inhibit = false;
999 int64_t emin = 0;
1000 int64_t emax = 0;
1001 expr_->Range(&emin, &emax);
1002 int64_t u = 1 - (emin > max_ || emax < min_);
1003 int64_t l = emax <= max_ && emin >= min_;
1004 boolvar_->SetRange(l, u);
1005 if (boolvar_->Bound()) {
1006 inhibit = true;
1007 if (boolvar_->Min() == 0) {
1008 if (expr_->IsVar()) {
1009 expr_->Var()->RemoveInterval(min_, max_);
1010 inhibit = true;
1011 } else if (emin > min_) {
1012 expr_->SetMin(max_ + 1);
1013 } else if (emax < max_) {
1014 expr_->SetMax(min_ - 1);
1015 }
1016 } else {
1017 expr_->SetRange(min_, max_);
1018 inhibit = true;
1019 }
1020 if (inhibit && expr_->IsVar()) {
1021 demon_->inhibit(solver());
1022 }
1023 }
1024 }
1025
1026 std::string DebugString() const override {
1027 return absl::StrFormat("IsBetweenCt(%s, %d, %d, %s)", expr_->DebugString(),
1028 min_, max_, boolvar_->DebugString());
1029 }
1030
1031 void Accept(ModelVisitor* const visitor) const override {
1032 visitor->BeginVisitConstraint(ModelVisitor::kIsBetween, this);
1033 visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, min_);
1034 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1035 expr_);
1036 visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, max_);
1037 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1038 boolvar_);
1039 visitor->EndVisitConstraint(ModelVisitor::kIsBetween, this);
1040 }
1041
1042 private:
1043 IntExpr* const expr_;
1044 int64_t min_;
1045 int64_t max_;
1046 IntVar* const boolvar_;
1047 Demon* demon_;
1048};
1049} // namespace
1050
1051Constraint* Solver::MakeIsBetweenCt(IntExpr* expr, int64_t l, int64_t u,
1052 IntVar* const b) {
1053 CHECK_EQ(this, expr->solver());
1054 CHECK_EQ(this, b->solver());
1055 // Catch empty and singleton intervals.
1056 if (l >= u) {
1057 if (l > u) return MakeEquality(b, Zero());
1058 return MakeIsEqualCstCt(expr, l, b);
1059 }
1060 int64_t emin = 0;
1061 int64_t emax = 0;
1062 expr->Range(&emin, &emax);
1063 // Catch the trivial cases first.
1064 if (emax < l || emin > u) return MakeEquality(b, Zero());
1065 if (emin >= l && emax <= u) return MakeEquality(b, 1);
1066 // Catch one-sided constraints.
1067 if (emax <= u) return MakeIsGreaterOrEqualCstCt(expr, l, b);
1068 if (emin >= l) return MakeIsLessOrEqualCstCt(expr, u, b);
1069 // Simplify the common factor, if any.
1070 int64_t coeff = ExtractExprProductCoeff(&expr);
1071 if (coeff != 1) {
1072 CHECK_NE(coeff, 0); // Would have been caught by the trivial cases already.
1073 if (coeff < 0) {
1074 std::swap(u, l);
1075 u = -u;
1076 l = -l;
1077 coeff = -coeff;
1078 }
1079 return MakeIsBetweenCt(expr, PosIntDivUp(l, coeff), PosIntDivDown(u, coeff),
1080 b);
1081 } else {
1082 // No further reduction is possible.
1083 return RevAlloc(new IsBetweenCt(this, expr, l, u, b));
1084 }
1085}
1086
1087IntVar* Solver::MakeIsBetweenVar(IntExpr* const v, int64_t l, int64_t u) {
1088 CHECK_EQ(this, v->solver());
1089 IntVar* const b = MakeBoolVar();
1090 AddConstraint(MakeIsBetweenCt(v, l, u, b));
1091 return b;
1092}
1093
1094// ---------- Member ----------
1095
1096// ----- Member(IntVar, IntSet) -----
1097
1098namespace {
1099// TODO(user): Do not create holes on expressions.
1100class MemberCt : public Constraint {
1101 public:
1102 MemberCt(Solver* const s, IntVar* const v,
1103 const std::vector<int64_t>& sorted_values)
1104 : Constraint(s), var_(v), values_(sorted_values) {
1105 DCHECK(v != nullptr);
1106 DCHECK(s != nullptr);
1107 }
1108
1109 void Post() override {}
1110
1111 void InitialPropagate() override { var_->SetValues(values_); }
1112
1113 std::string DebugString() const override {
1114 return absl::StrFormat("Member(%s, %s)", var_->DebugString(),
1115 absl::StrJoin(values_, ", "));
1116 }
1117
1118 void Accept(ModelVisitor* const visitor) const override {
1119 visitor->BeginVisitConstraint(ModelVisitor::kMember, this);
1120 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1121 var_);
1122 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
1123 visitor->EndVisitConstraint(ModelVisitor::kMember, this);
1124 }
1125
1126 private:
1127 IntVar* const var_;
1128 const std::vector<int64_t> values_;
1129};
1130
1131class NotMemberCt : public Constraint {
1132 public:
1133 NotMemberCt(Solver* const s, IntVar* const v,
1134 const std::vector<int64_t>& sorted_values)
1135 : Constraint(s), var_(v), values_(sorted_values) {
1136 DCHECK(v != nullptr);
1137 DCHECK(s != nullptr);
1138 }
1139
1140 void Post() override {}
1141
1142 void InitialPropagate() override { var_->RemoveValues(values_); }
1143
1144 std::string DebugString() const override {
1145 return absl::StrFormat("NotMember(%s, %s)", var_->DebugString(),
1146 absl::StrJoin(values_, ", "));
1147 }
1148
1149 void Accept(ModelVisitor* const visitor) const override {
1150 visitor->BeginVisitConstraint(ModelVisitor::kMember, this);
1151 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1152 var_);
1153 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
1154 visitor->EndVisitConstraint(ModelVisitor::kMember, this);
1155 }
1156
1157 private:
1158 IntVar* const var_;
1159 const std::vector<int64_t> values_;
1160};
1161} // namespace
1162
1164 const std::vector<int64_t>& values) {
1165 const int64_t coeff = ExtractExprProductCoeff(&expr);
1166 if (coeff == 0) {
1167 return std::find(values.begin(), values.end(), 0) == values.end()
1170 }
1171 std::vector<int64_t> copied_values = values;
1172 // If the expression is a non-trivial product, we filter out the values that
1173 // aren't multiples of "coeff", and divide them.
1174 if (coeff != 1) {
1175 int num_kept = 0;
1176 for (const int64_t v : copied_values) {
1177 if (v % coeff == 0) copied_values[num_kept++] = v / coeff;
1178 }
1179 copied_values.resize(num_kept);
1180 }
1181 // Filter out the values that are outside the [Min, Max] interval.
1182 int num_kept = 0;
1183 int64_t emin;
1184 int64_t emax;
1185 expr->Range(&emin, &emax);
1186 for (const int64_t v : copied_values) {
1187 if (v >= emin && v <= emax) copied_values[num_kept++] = v;
1188 }
1189 copied_values.resize(num_kept);
1190 // Catch empty set.
1191 if (copied_values.empty()) return MakeFalseConstraint();
1192 // Sort and remove duplicates.
1193 gtl::STLSortAndRemoveDuplicates(&copied_values);
1194 // Special case for singleton.
1195 if (copied_values.size() == 1) return MakeEquality(expr, copied_values[0]);
1196 // Catch contiguous intervals.
1197 if (copied_values.size() ==
1198 copied_values.back() - copied_values.front() + 1) {
1199 // Note: MakeBetweenCt() has a fast-track for trivially true constraints.
1200 return MakeBetweenCt(expr, copied_values.front(), copied_values.back());
1201 }
1202 // If the set of values in [expr.Min(), expr.Max()] that are *not* in
1203 // "values" is smaller than "values", then it's more efficient to use
1204 // NotMemberCt. Catch that case here.
1205 if (emax - emin < 2 * copied_values.size()) {
1206 // Convert "copied_values" to list the values *not* allowed.
1207 std::vector<bool> is_among_input_values(emax - emin + 1, false);
1208 for (const int64_t v : copied_values)
1209 is_among_input_values[v - emin] = true;
1210 // We use the zero valued indices of is_among_input_values to build the
1211 // complement of copied_values.
1212 copied_values.clear();
1213 for (int64_t v_off = 0; v_off < is_among_input_values.size(); ++v_off) {
1214 if (!is_among_input_values[v_off]) copied_values.push_back(v_off + emin);
1215 }
1216 // The empty' case (all values in range [expr.Min(), expr.Max()] are in the
1217 // "values" input) was caught earlier, by the "contiguous interval" case.
1218 DCHECK_GE(copied_values.size(), 1);
1219 if (copied_values.size() == 1) {
1220 return MakeNonEquality(expr, copied_values[0]);
1221 }
1222 return RevAlloc(new NotMemberCt(this, expr->Var(), copied_values));
1223 }
1224 // Otherwise, just use MemberCt. No further reduction is possible.
1225 return RevAlloc(new MemberCt(this, expr->Var(), copied_values));
1226}
1227
1229 const std::vector<int>& values) {
1230 return MakeMemberCt(expr, ToInt64Vector(values));
1231}
1232
1234 const std::vector<int64_t>& values) {
1235 const int64_t coeff = ExtractExprProductCoeff(&expr);
1236 if (coeff == 0) {
1237 return std::find(values.begin(), values.end(), 0) == values.end()
1240 }
1241 std::vector<int64_t> copied_values = values;
1242 // If the expression is a non-trivial product, we filter out the values that
1243 // aren't multiples of "coeff", and divide them.
1244 if (coeff != 1) {
1245 int num_kept = 0;
1246 for (const int64_t v : copied_values) {
1247 if (v % coeff == 0) copied_values[num_kept++] = v / coeff;
1248 }
1249 copied_values.resize(num_kept);
1250 }
1251 // Filter out the values that are outside the [Min, Max] interval.
1252 int num_kept = 0;
1253 int64_t emin;
1254 int64_t emax;
1255 expr->Range(&emin, &emax);
1256 for (const int64_t v : copied_values) {
1257 if (v >= emin && v <= emax) copied_values[num_kept++] = v;
1258 }
1259 copied_values.resize(num_kept);
1260 // Catch empty set.
1261 if (copied_values.empty()) return MakeTrueConstraint();
1262 // Sort and remove duplicates.
1263 gtl::STLSortAndRemoveDuplicates(&copied_values);
1264 // Special case for singleton.
1265 if (copied_values.size() == 1) return MakeNonEquality(expr, copied_values[0]);
1266 // Catch contiguous intervals.
1267 if (copied_values.size() ==
1268 copied_values.back() - copied_values.front() + 1) {
1269 return MakeNotBetweenCt(expr, copied_values.front(), copied_values.back());
1270 }
1271 // If the set of values in [expr.Min(), expr.Max()] that are *not* in
1272 // "values" is smaller than "values", then it's more efficient to use
1273 // MemberCt. Catch that case here.
1274 if (emax - emin < 2 * copied_values.size()) {
1275 // Convert "copied_values" to a dense boolean vector.
1276 std::vector<bool> is_among_input_values(emax - emin + 1, false);
1277 for (const int64_t v : copied_values)
1278 is_among_input_values[v - emin] = true;
1279 // Use zero valued indices for is_among_input_values to build the
1280 // complement of copied_values.
1281 copied_values.clear();
1282 for (int64_t v_off = 0; v_off < is_among_input_values.size(); ++v_off) {
1283 if (!is_among_input_values[v_off]) copied_values.push_back(v_off + emin);
1284 }
1285 // The empty' case (all values in range [expr.Min(), expr.Max()] are in the
1286 // "values" input) was caught earlier, by the "contiguous interval" case.
1287 DCHECK_GE(copied_values.size(), 1);
1288 if (copied_values.size() == 1) {
1289 return MakeEquality(expr, copied_values[0]);
1290 }
1291 return RevAlloc(new MemberCt(this, expr->Var(), copied_values));
1292 }
1293 // Otherwise, just use NotMemberCt. No further reduction is possible.
1294 return RevAlloc(new NotMemberCt(this, expr->Var(), copied_values));
1295}
1296
1298 const std::vector<int>& values) {
1299 return MakeNotMemberCt(expr, ToInt64Vector(values));
1300}
1301
1302// ----- IsMemberCt -----
1303
1304namespace {
1305class IsMemberCt : public Constraint {
1306 public:
1307 IsMemberCt(Solver* const s, IntVar* const v,
1308 const std::vector<int64_t>& sorted_values, IntVar* const b)
1309 : Constraint(s),
1310 var_(v),
1311 values_as_set_(sorted_values.begin(), sorted_values.end()),
1312 values_(sorted_values),
1313 boolvar_(b),
1314 support_(0),
1315 demon_(nullptr),
1316 domain_(var_->MakeDomainIterator(true)),
1317 neg_support_(std::numeric_limits<int64_t>::min()) {
1318 DCHECK(v != nullptr);
1319 DCHECK(s != nullptr);
1320 DCHECK(b != nullptr);
1321 while (values_as_set_.contains(neg_support_)) {
1322 neg_support_++;
1323 }
1324 }
1325
1326 void Post() override {
1327 demon_ = MakeConstraintDemon0(solver(), this, &IsMemberCt::VarDomain,
1328 "VarDomain");
1329 if (!var_->Bound()) {
1330 var_->WhenDomain(demon_);
1331 }
1332 if (!boolvar_->Bound()) {
1333 Demon* const bdemon = MakeConstraintDemon0(
1334 solver(), this, &IsMemberCt::TargetBound, "TargetBound");
1335 boolvar_->WhenBound(bdemon);
1336 }
1337 }
1338
1339 void InitialPropagate() override {
1340 boolvar_->SetRange(0, 1);
1341 if (boolvar_->Bound()) {
1342 TargetBound();
1343 } else {
1344 VarDomain();
1345 }
1346 }
1347
1348 std::string DebugString() const override {
1349 return absl::StrFormat("IsMemberCt(%s, %s, %s)", var_->DebugString(),
1350 absl::StrJoin(values_, ", "),
1351 boolvar_->DebugString());
1352 }
1353
1354 void Accept(ModelVisitor* const visitor) const override {
1355 visitor->BeginVisitConstraint(ModelVisitor::kIsMember, this);
1356 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1357 var_);
1358 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);
1359 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
1360 boolvar_);
1361 visitor->EndVisitConstraint(ModelVisitor::kIsMember, this);
1362 }
1363
1364 private:
1365 void VarDomain() {
1366 if (boolvar_->Bound()) {
1367 TargetBound();
1368 } else {
1369 for (int offset = 0; offset < values_.size(); ++offset) {
1370 const int candidate = (support_ + offset) % values_.size();
1371 if (var_->Contains(values_[candidate])) {
1372 support_ = candidate;
1373 if (var_->Bound()) {
1374 demon_->inhibit(solver());
1375 boolvar_->SetValue(1);
1376 return;
1377 }
1378 // We have found a positive support. Let's check the
1379 // negative support.
1380 if (var_->Contains(neg_support_)) {
1381 return;
1382 } else {
1383 // Look for a new negative support.
1384 for (const int64_t value : InitAndGetValues(domain_)) {
1385 if (!values_as_set_.contains(value)) {
1386 neg_support_ = value;
1387 return;
1388 }
1389 }
1390 }
1391 // No negative support, setting boolvar to true.
1392 demon_->inhibit(solver());
1393 boolvar_->SetValue(1);
1394 return;
1395 }
1396 }
1397 // No positive support, setting boolvar to false.
1398 demon_->inhibit(solver());
1399 boolvar_->SetValue(0);
1400 }
1401 }
1402
1403 void TargetBound() {
1404 DCHECK(boolvar_->Bound());
1405 if (boolvar_->Min() == 1LL) {
1406 demon_->inhibit(solver());
1407 var_->SetValues(values_);
1408 } else {
1409 demon_->inhibit(solver());
1410 var_->RemoveValues(values_);
1411 }
1412 }
1413
1414 IntVar* const var_;
1415 absl::flat_hash_set<int64_t> values_as_set_;
1416 std::vector<int64_t> values_;
1417 IntVar* const boolvar_;
1418 int support_;
1419 Demon* demon_;
1420 IntVarIterator* const domain_;
1421 int64_t neg_support_;
1422};
1423
1424template <class T>
1425Constraint* BuildIsMemberCt(Solver* const solver, IntExpr* const expr,
1426 const std::vector<T>& values,
1427 IntVar* const boolvar) {
1428 // TODO(user): optimize this by copying the code from MakeMemberCt.
1429 // Simplify and filter if expr is a product.
1430 IntExpr* sub = nullptr;
1431 int64_t coef = 1;
1432 if (solver->IsProduct(expr, &sub, &coef) && coef != 0 && coef != 1) {
1433 std::vector<int64_t> new_values;
1434 new_values.reserve(values.size());
1435 for (const int64_t value : values) {
1436 if (value % coef == 0) {
1437 new_values.push_back(value / coef);
1438 }
1439 }
1440 return BuildIsMemberCt(solver, sub, new_values, boolvar);
1441 }
1442
1443 std::set<T> set_of_values(values.begin(), values.end());
1444 std::vector<int64_t> filtered_values;
1445 bool all_values = false;
1446 if (expr->IsVar()) {
1447 IntVar* const var = expr->Var();
1448 for (const T value : set_of_values) {
1449 if (var->Contains(value)) {
1450 filtered_values.push_back(value);
1451 }
1452 }
1453 all_values = (filtered_values.size() == var->Size());
1454 } else {
1455 int64_t emin = 0;
1456 int64_t emax = 0;
1457 expr->Range(&emin, &emax);
1458 for (const T value : set_of_values) {
1459 if (value >= emin && value <= emax) {
1460 filtered_values.push_back(value);
1461 }
1462 }
1463 all_values = (filtered_values.size() == emax - emin + 1);
1464 }
1465 if (filtered_values.empty()) {
1466 return solver->MakeEquality(boolvar, Zero());
1467 } else if (all_values) {
1468 return solver->MakeEquality(boolvar, 1);
1469 } else if (filtered_values.size() == 1) {
1470 return solver->MakeIsEqualCstCt(expr, filtered_values.back(), boolvar);
1471 } else if (filtered_values.back() ==
1472 filtered_values.front() + filtered_values.size() - 1) {
1473 // Contiguous
1474 return solver->MakeIsBetweenCt(expr, filtered_values.front(),
1475 filtered_values.back(), boolvar);
1476 } else {
1477 return solver->RevAlloc(
1478 new IsMemberCt(solver, expr->Var(), filtered_values, boolvar));
1479 }
1480}
1481} // namespace
1482
1484 const std::vector<int64_t>& values,
1485 IntVar* const boolvar) {
1486 return BuildIsMemberCt(this, expr, values, boolvar);
1487}
1488
1490 const std::vector<int>& values,
1491 IntVar* const boolvar) {
1492 return BuildIsMemberCt(this, expr, values, boolvar);
1493}
1494
1496 const std::vector<int64_t>& values) {
1497 IntVar* const b = MakeBoolVar();
1498 AddConstraint(MakeIsMemberCt(expr, values, b));
1499 return b;
1500}
1501
1503 const std::vector<int>& values) {
1504 IntVar* const b = MakeBoolVar();
1505 AddConstraint(MakeIsMemberCt(expr, values, b));
1506 return b;
1507}
1508
1509namespace {
1510class SortedDisjointForbiddenIntervalsConstraint : public Constraint {
1511 public:
1512 SortedDisjointForbiddenIntervalsConstraint(
1513 Solver* const solver, IntVar* const var,
1515 : Constraint(solver), var_(var), intervals_(std::move(intervals)) {}
1516
1517 ~SortedDisjointForbiddenIntervalsConstraint() override {}
1518
1519 void Post() override {
1520 Demon* const demon = solver()->MakeConstraintInitialPropagateCallback(this);
1521 var_->WhenRange(demon);
1522 }
1523
1524 void InitialPropagate() override {
1525 const int64_t vmin = var_->Min();
1526 const int64_t vmax = var_->Max();
1527 const auto first_interval_it = intervals_.FirstIntervalGreaterOrEqual(vmin);
1528 if (first_interval_it == intervals_.end()) {
1529 // No interval intersects the variable's range. Nothing to do.
1530 return;
1531 }
1532 const auto last_interval_it = intervals_.LastIntervalLessOrEqual(vmax);
1533 if (last_interval_it == intervals_.end()) {
1534 // No interval intersects the variable's range. Nothing to do.
1535 return;
1536 }
1537 // TODO(user): Quick fail if first_interval_it == last_interval_it, which
1538 // would imply that the interval contains the entire range of the variable?
1539 if (vmin >= first_interval_it->start) {
1540 // The variable's minimum is inside a forbidden interval. Move it to the
1541 // interval's end.
1542 var_->SetMin(CapAdd(first_interval_it->end, 1));
1543 }
1544 if (vmax <= last_interval_it->end) {
1545 // Ditto, on the other side.
1546 var_->SetMax(CapSub(last_interval_it->start, 1));
1547 }
1548 }
1549
1550 std::string DebugString() const override {
1551 return absl::StrFormat("ForbiddenIntervalCt(%s, %s)", var_->DebugString(),
1552 intervals_.DebugString());
1553 }
1554
1555 void Accept(ModelVisitor* const visitor) const override {
1556 visitor->BeginVisitConstraint(ModelVisitor::kNotMember, this);
1557 visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
1558 var_);
1559 std::vector<int64_t> starts;
1560 std::vector<int64_t> ends;
1561 for (auto& interval : intervals_) {
1562 starts.push_back(interval.start);
1563 ends.push_back(interval.end);
1564 }
1565 visitor->VisitIntegerArrayArgument(ModelVisitor::kStartsArgument, starts);
1566 visitor->VisitIntegerArrayArgument(ModelVisitor::kEndsArgument, ends);
1567 visitor->EndVisitConstraint(ModelVisitor::kNotMember, this);
1568 }
1569
1570 private:
1571 IntVar* const var_;
1572 const SortedDisjointIntervalList intervals_;
1573};
1574} // namespace
1575
1577 std::vector<int64_t> starts,
1578 std::vector<int64_t> ends) {
1579 return RevAlloc(new SortedDisjointForbiddenIntervalsConstraint(
1580 this, expr->Var(), {starts, ends}));
1581}
1582
1584 std::vector<int> starts,
1585 std::vector<int> ends) {
1586 return RevAlloc(new SortedDisjointForbiddenIntervalsConstraint(
1587 this, expr->Var(), {starts, ends}));
1588}
1589
1591 SortedDisjointIntervalList intervals) {
1592 return RevAlloc(new SortedDisjointForbiddenIntervalsConstraint(
1593 this, expr->Var(), std::move(intervals)));
1594}
1595} // namespace operations_research
int64_t min
Definition: alldiff_cst.cc:139
#define CHECK_EQ(val1, val2)
Definition: base/logging.h:703
#define DCHECK_GE(val1, val2)
Definition: base/logging.h:895
#define CHECK_NE(val1, val2)
Definition: base/logging.h:704
#define DCHECK(condition)
Definition: base/logging.h:890
#define DCHECK_EQ(val1, val2)
Definition: base/logging.h:891
Cast constraints are special channeling constraints designed to keep a variable in sync with an expre...
A constraint is the main modeling object.
virtual void InitialPropagate()=0
This method performs the initial propagation of the constraint.
virtual void Accept(ModelVisitor *const visitor) const
Accepts the given visitor.
virtual IntVar * Var()
Creates a Boolean variable representing the status of the constraint (false = constraint is violated,...
std::string DebugString() const override
virtual void Post()=0
This method is called when the constraint is processed by the solver.
void inhibit(Solver *const s)
This method inhibits the demon in the search tree below the current position.
The class IntExpr is the base of all integer expressions in constraint programming.
virtual bool Bound() const
Returns true if the min and the max of the expression are equal.
virtual bool IsVar() const
Returns true if the expression is indeed a variable.
virtual int64_t Min() const =0
virtual void SetMax(int64_t m)=0
virtual IntVar * Var()=0
Creates a variable from the expression.
virtual void SetMin(int64_t m)=0
virtual int64_t Max() const =0
virtual void Range(int64_t *l, int64_t *u)
By default calls Min() and Max(), but can be redefined when Min and Max code can be factorized.
virtual void WhenRange(Demon *d)=0
Attach a demon that will watch the min or the max of the expression.
The class IntVar is a subset of IntExpr.
virtual bool Contains(int64_t v) const =0
This method returns whether the value 'v' is in the domain of the variable.
virtual void RemoveValue(int64_t v)=0
This method removes the value 'v' from the domain of the variable.
IntVar * Var() override
Creates a variable from the expression.
bool IsVar() const override
Returns true if the expression is indeed a variable.
virtual uint64_t Size() const =0
This method returns the number of values in the domain of the variable.
Constraint * MakeBetweenCt(IntExpr *const expr, int64_t l, int64_t u)
(l <= expr <= u)
Definition: expr_cst.cc:923
Constraint * MakeIsLessCstCt(IntExpr *const v, int64_t c, IntVar *const b)
b == (v < c)
Definition: expr_cst.cc:816
IntVar * MakeIsGreaterCstVar(IntExpr *const var, int64_t value)
status var of (var > value)
Definition: expr_cst.cc:696
Constraint * MakeLess(IntExpr *const left, IntExpr *const right)
left < right
Definition: range_cst.cc:546
Constraint * MakeFalseConstraint()
This constraint always fails.
Definition: constraints.cc:523
Constraint * MakeEquality(IntExpr *const left, IntExpr *const right)
left == right
Definition: range_cst.cc:512
Constraint * MakeIsDifferentCt(IntExpr *const v1, IntExpr *const v2, IntVar *const b)
b == (v1 != v2)
Definition: range_cst.cc:686
Constraint * MakeLessOrEqual(IntExpr *const left, IntExpr *const right)
left <= right
Definition: range_cst.cc:526
IntVar * MakeIsGreaterOrEqualCstVar(IntExpr *const var, int64_t value)
status var of (var >= value)
Definition: expr_cst.cc:679
Constraint * MakeIsLessOrEqualCstCt(IntExpr *const var, int64_t value, IntVar *const boolvar)
boolvar == (var <= value)
Definition: expr_cst.cc:800
Constraint * MakeNotMemberCt(IntExpr *const expr, const std::vector< int64_t > &values)
expr not in set.
Definition: expr_cst.cc:1233
IntVar * MakeIsDifferentVar(IntExpr *const v1, IntExpr *const v2)
status var of (v1 != v2)
Definition: range_cst.cc:641
IntVar * MakeIsEqualVar(IntExpr *const v1, IntExpr *v2)
status var of (v1 == v2)
Definition: range_cst.cc:577
Constraint * MakeGreater(IntExpr *const left, IntExpr *const right)
left > right
Definition: range_cst.cc:560
IntVar * MakeIsLessCstVar(IntExpr *const var, int64_t value)
status var of (var < value)
Definition: expr_cst.cc:796
Constraint * MakeMemberCt(IntExpr *const expr, const std::vector< int64_t > &values)
expr in set.
Definition: expr_cst.cc:1163
Constraint * MakeNotBetweenCt(IntExpr *const expr, int64_t l, int64_t u)
(expr < l || expr > u) This constraint is lazy as it will not make holes in the domain of variables.
Definition: expr_cst.cc:956
void AddConstraint(Constraint *const c)
Adds the constraint 'c' to the model.
Constraint * MakeIsEqualCstCt(IntExpr *const var, int64_t value, IntVar *const boolvar)
boolvar == (var == value)
Definition: expr_cst.cc:487
Constraint * MakeIsEqualCt(IntExpr *const v1, IntExpr *v2, IntVar *const b)
b == (v1 == v2)
Definition: range_cst.cc:622
Constraint * MakeTrueConstraint()
This constraint always succeeds.
Definition: constraints.cc:518
IntVar * MakeIsBetweenVar(IntExpr *const v, int64_t l, int64_t u)
Definition: expr_cst.cc:1087
IntVar * MakeIsMemberVar(IntExpr *const expr, const std::vector< int64_t > &values)
Definition: expr_cst.cc:1495
IntVar * MakeIsLessOrEqualCstVar(IntExpr *const var, int64_t value)
status var of (var <= value)
Definition: expr_cst.cc:779
IntExpr * MakeDifference(IntExpr *const left, IntExpr *const right)
left - right
Constraint * MakeIsDifferentCstCt(IntExpr *const var, int64_t value, IntVar *const boolvar)
boolvar == (var != value)
Definition: expr_cst.cc:589
IntVar * MakeBoolVar()
MakeBoolVar will create a variable with a {0, 1} domain.
IntVar * MakeIsDifferentCstVar(IntExpr *const var, int64_t value)
status var of (var != value)
Definition: expr_cst.cc:580
Constraint * MakeNonEquality(IntExpr *const left, IntExpr *const right)
left != right
Definition: range_cst.cc:564
Constraint * MakeIsGreaterOrEqualCstCt(IntExpr *const var, int64_t value, IntVar *const boolvar)
boolvar == (var >= value)
Definition: expr_cst.cc:700
Constraint * MakeIsBetweenCt(IntExpr *const expr, int64_t l, int64_t u, IntVar *const b)
b == (l <= expr <= u)
Definition: expr_cst.cc:1051
IntExpr * MakeSum(IntExpr *const left, IntExpr *const right)
left + right.
Constraint * MakeIsGreaterCstCt(IntExpr *const v, int64_t c, IntVar *const b)
b == (v > c)
Definition: expr_cst.cc:716
IntVar * MakeIntConst(int64_t val, const std::string &name)
IntConst will create a constant expression.
Constraint * MakeGreaterOrEqual(IntExpr *const left, IntExpr *const right)
left >= right
Definition: range_cst.cc:542
IntVar * MakeIsEqualCstVar(IntExpr *const var, int64_t value)
status var of (var == value)
Definition: expr_cst.cc:462
Constraint * MakeIsMemberCt(IntExpr *const expr, const std::vector< int64_t > &values, IntVar *const boolvar)
boolvar == (expr in set)
Definition: expr_cst.cc:1483
T * RevAlloc(T *object)
Registers the given object as being reversible.
This class represents a sorted list of disjoint, closed intervals.
int64_t b
int64_t value
IntVar *const expr_
Definition: element.cc:87
IntVar * var
Definition: expr_array.cc:1874
int64_t coef
Definition: expr_array.cc:1875
ABSL_FLAG(int, cache_initial_size, 1024, "Initial size of the array of the hash " "table of caches for objects of type Var(x == 3)")
const int64_t cst_
void STLSortAndRemoveDuplicates(T *v, const LessFunc &less_func)
Definition: stl_util.h:58
void swap(IdMap< K, V > &a, IdMap< K, V > &b)
Definition: id_map.h:262
Collection of objects used to extend the Constraint Solver library.
int64_t CapAdd(int64_t x, int64_t y)
int64_t CapSub(int64_t x, int64_t y)
int64_t Zero()
NOLINT.
Demon * MakeConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
std::vector< int64_t > ToInt64Vector(const std::vector< int > &input)
Definition: utilities.cc:828
int64_t PosIntDivDown(int64_t e, int64_t v)
int64_t PosIntDivUp(int64_t e, int64_t v)
STL namespace.
IntervalVar * interval
Definition: resource.cc:100
IntervalVar *const target_var_
std::optional< int64_t > end
const double coeff