diff --git a/src/constraint_solver/expr_cst.cc b/src/constraint_solver/expr_cst.cc index 8044106878..468d23638e 100644 --- a/src/constraint_solver/expr_cst.cc +++ b/src/constraint_solver/expr_cst.cc @@ -1054,17 +1054,29 @@ class NotMemberCt : public Constraint { }; } // namespace -Constraint* Solver::MakeMemberCt(IntExpr* const var, +Constraint* Solver::MakeMemberCt(IntExpr* const expr, const std::vector& values) { + IntExpr* sub = nullptr; + int64 coef = 1; + if (IsProduct(expr, &sub, &coef) && coef != 0 && coef != 1) { + std::vector new_values; + new_values.reserve(values.size()); + for (const int64 value : values) { + if (value % coef == 0) { + new_values.push_back(value / coef); + } + } + return MakeMemberCt(sub, new_values); + } std::vector sorted = SortedNoDuplicates(values); if (IsIncreasingContiguous(sorted)) { - return MakeBetweenCt(var, sorted.front(), sorted.back()); + return MakeBetweenCt(expr, sorted.front(), sorted.back()); } else { // Let's build the reverse vector. - if (var->Max() - var->Min() < 2 * values.size()) { + if (expr->Max() - expr->Min() < 2 * values.size()) { hash_set value_set(values.begin(), values.end()); std::vector remaining; - for (int64 value = var->Min(); value <= var->Max(); ++value) { + for (int64 value = expr->Min(); value <= expr->Max(); ++value) { if (!ContainsKey(value_set, value)) { remaining.push_back(value); } @@ -1072,23 +1084,18 @@ Constraint* Solver::MakeMemberCt(IntExpr* const var, if (remaining.empty()) { return MakeTrueConstraint(); } else if (remaining.size() == 1) { - return MakeNonEquality(var, remaining.back()); + return MakeNonEquality(expr, remaining.back()); } else if (remaining.size() < values.size()) { - return RevAlloc(new NotMemberCt(this, var->Var(), remaining)); + return RevAlloc(new NotMemberCt(this, expr->Var(), remaining)); } } - return RevAlloc(new MemberCt(this, var->Var(), sorted)); + return RevAlloc(new MemberCt(this, expr->Var(), sorted)); } } -Constraint* Solver::MakeMemberCt(IntExpr* const var, +Constraint* Solver::MakeMemberCt(IntExpr* const expr, const std::vector& values) { - std::vector sorted = SortedNoDuplicates(ToInt64Vector(values)); - if (IsIncreasingContiguous(sorted)) { - return MakeBetweenCt(var, sorted.front(), sorted.back()); - } else { - return RevAlloc(new MemberCt(this, var->Var(), sorted)); - } + return MakeMemberCt(expr, ToInt64Vector(values)); } // ----- IsMemberCt -----