diff --git a/src/constraint_solver/expr_cst.cc b/src/constraint_solver/expr_cst.cc index b0eb7dc34f..0e94a3c745 100644 --- a/src/constraint_solver/expr_cst.cc +++ b/src/constraint_solver/expr_cst.cc @@ -977,6 +977,36 @@ class MemberCt : public Constraint { IntVar* const var_; const std::vector values_; }; + +class NotMemberCt : public Constraint { + public: + NotMemberCt(Solver* const s, IntVar* const v, const std::vector& sorted_values) + : Constraint(s), var_(v), values_(sorted_values) { + DCHECK(v != nullptr); + DCHECK(s != nullptr); + } + + virtual void Post() {} + + virtual void InitialPropagate() { var_->RemoveValues(values_); } + + virtual std::string DebugString() const { + return StringPrintf("NotMember(%s, %s)", var_->DebugString().c_str(), + strings::Join(values_, ", ").c_str()); + } + + virtual void Accept(ModelVisitor* const visitor) const { + visitor->BeginVisitConstraint(ModelVisitor::kMember, this); + visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument, + var_); + visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_); + visitor->EndVisitConstraint(ModelVisitor::kMember, this); + } + + private: + IntVar* const var_; + const std::vector values_; +}; } // namespace Constraint* Solver::MakeMemberCt(IntExpr* const var, @@ -985,6 +1015,23 @@ Constraint* Solver::MakeMemberCt(IntExpr* const var, if (IsIncreasingContiguous(sorted)) { return MakeBetweenCt(var, sorted.front(), sorted.back()); } else { + // Let's build the reverse vector. + if (var->Max() - var->Min() < 2 * values.size()) { + hash_set value_set(values.begin(), values.end()); + std::vector remaining; + for (int64 value = var->Min(); value <= var->Max(); ++value) { + if (!ContainsKey(value_set, value)) { + remaining.push_back(value); + } + } + if (remaining.empty()) { + return MakeTrueConstraint(); + } else if (remaining.size() == 1) { + return MakeNonEquality(var, remaining.back()); + } else if (remaining.size() < values.size()) { + return RevAlloc(new NotMemberCt(this, var->Var(), remaining)); + } + } return RevAlloc(new MemberCt(this, var->Var(), sorted)); } }