[CP-SAT] implement proper modulo propagator (with fixed mod)

This commit is contained in:
Laurent Perron
2021-10-27 20:56:28 +02:00
parent 936abd2477
commit bab11deddf
5 changed files with 554 additions and 295 deletions

View File

@@ -610,27 +610,62 @@ void LinMinPropagator::RegisterWith(GenericLiteralWatcher* watcher) {
watcher->RegisterReversibleInt(id, &rev_unique_candidate_);
}
PositiveProductPropagator::PositiveProductPropagator(
AffineExpression a, AffineExpression b, AffineExpression p,
IntegerTrail* integer_trail)
: a_(a), b_(b), p_(p), integer_trail_(integer_trail) {
// Note that we assume this is true at level zero, and so we never include
// that fact in the reasons we compute.
CHECK_GE(integer_trail_->LevelZeroLowerBound(a_), 0);
CHECK_GE(integer_trail_->LevelZeroLowerBound(b_), 0);
ProductPropagator::ProductPropagator(AffineExpression a, AffineExpression b,
AffineExpression p,
IntegerTrail* integer_trail)
: a_(a), b_(b), p_(p), integer_trail_(integer_trail) {}
// We want all affine expression to be either non-negative or across zero.
bool ProductPropagator::CanonicalizeCases() {
if (integer_trail_->UpperBound(a_) <= 0) {
a_ = a_.Negated();
p_ = p_.Negated();
}
if (integer_trail_->UpperBound(b_) <= 0) {
b_ = b_.Negated();
p_ = p_.Negated();
}
// If both a and b positive, p must be too.
if (integer_trail_->LowerBound(a_) >= 0 &&
integer_trail_->LowerBound(b_) >= 0) {
return integer_trail_->UnsafeEnqueue(
p_.GreaterOrEqual(0), {}, {a_.GreaterOrEqual(0), b_.GreaterOrEqual(0)});
}
// Otherwise, make sure p is non-negative or accros zero.
if (integer_trail_->UpperBound(p_) <= 0) {
if (integer_trail_->LowerBound(a_) < 0) {
DCHECK_GT(integer_trail_->UpperBound(a_), 0);
a_ = a_.Negated();
p_ = p_.Negated();
} else {
DCHECK_LT(integer_trail_->LowerBound(b_), 0);
DCHECK_GT(integer_trail_->UpperBound(b_), 0);
b_ = b_.Negated();
p_ = p_.Negated();
}
}
return true;
}
// TODO(user): We can tighten the bounds on p by removing extreme value that
// Note that this propagation is exact, except on the domain of p as this
// involves more complex arithmetic.
//
// TODO(user): We could tighten the bounds on p by removing extreme value that
// do not contains divisor in the domains of a or b. There is an algo in O(
// smallest domain size between a or b).
bool PositiveProductPropagator::Propagate() {
bool ProductPropagator::PropagateWhenAllNonNegative() {
const IntegerValue max_a = integer_trail_->UpperBound(a_);
const IntegerValue max_b = integer_trail_->UpperBound(b_);
const IntegerValue new_max(CapProd(max_a.value(), max_b.value()));
if (new_max < integer_trail_->UpperBound(p_)) {
if (!integer_trail_->Enqueue(p_.LowerOrEqual(new_max), {},
{integer_trail_->UpperBoundAsLiteral(a_),
integer_trail_->UpperBoundAsLiteral(b_)})) {
if (!integer_trail_->Enqueue(
p_.LowerOrEqual(new_max), {},
{integer_trail_->UpperBoundAsLiteral(a_),
integer_trail_->UpperBoundAsLiteral(b_), a_.GreaterOrEqual(0),
b_.GreaterOrEqual(0)})) {
return false;
}
}
@@ -639,9 +674,10 @@ bool PositiveProductPropagator::Propagate() {
const IntegerValue min_b = integer_trail_->LowerBound(b_);
const IntegerValue new_min(CapProd(min_a.value(), min_b.value()));
if (new_min > integer_trail_->LowerBound(p_)) {
if (!integer_trail_->Enqueue(p_.GreaterOrEqual(new_min), {},
{integer_trail_->LowerBoundAsLiteral(a_),
integer_trail_->LowerBoundAsLiteral(b_)})) {
if (!integer_trail_->UnsafeEnqueue(
p_.GreaterOrEqual(new_min), {},
{integer_trail_->LowerBoundAsLiteral(a_),
integer_trail_->LowerBoundAsLiteral(b_)})) {
return false;
}
}
@@ -655,16 +691,18 @@ bool PositiveProductPropagator::Propagate() {
const IntegerValue max_p = integer_trail_->UpperBound(p_);
const IntegerValue prod(CapProd(max_a.value(), min_b.value()));
if (prod > max_p) {
if (!integer_trail_->Enqueue(a.LowerOrEqual(FloorRatio(max_p, min_b)), {},
{integer_trail_->LowerBoundAsLiteral(b),
integer_trail_->UpperBoundAsLiteral(p_)})) {
if (!integer_trail_->UnsafeEnqueue(
a.LowerOrEqual(FloorRatio(max_p, min_b)), {},
{integer_trail_->LowerBoundAsLiteral(b),
integer_trail_->UpperBoundAsLiteral(p_),
p_.GreaterOrEqual(0)})) {
return false;
}
} else if (prod < min_p) {
if (!integer_trail_->Enqueue(b.GreaterOrEqual(CeilRatio(min_p, max_a)),
{},
{integer_trail_->UpperBoundAsLiteral(a),
integer_trail_->LowerBoundAsLiteral(p_)})) {
if (!integer_trail_->UnsafeEnqueue(
b.GreaterOrEqual(CeilRatio(min_p, max_a)), {},
{integer_trail_->UpperBoundAsLiteral(a),
integer_trail_->LowerBoundAsLiteral(p_), a.GreaterOrEqual(0)})) {
return false;
}
}
@@ -673,7 +711,199 @@ bool PositiveProductPropagator::Propagate() {
return true;
}
void PositiveProductPropagator::RegisterWith(GenericLiteralWatcher* watcher) {
// This assumes p > 0, p = a * X, and X can take any value.
// We can propagate max of a by computing a bound on the min b when positive.
// The expression b is just used to detect when there is no solution given the
// upper bound of b.
bool ProductPropagator::PropagateMaxOnPositiveProduct(AffineExpression a,
AffineExpression b,
IntegerValue min_p,
IntegerValue max_p) {
const IntegerValue max_a = integer_trail_->UpperBound(a);
DCHECK_GT(max_a, 0);
DCHECK_GT(min_p, 0);
if (max_a >= min_p) {
if (max_p < max_a) {
if (!integer_trail_->UnsafeEnqueue(
a.LowerOrEqual(max_p), {},
{p_.LowerOrEqual(max_p), p_.GreaterOrEqual(1)})) {
return false;
}
}
return true;
}
const IntegerValue min_pos_b = CeilRatio(min_p, max_a);
if (min_pos_b > integer_trail_->UpperBound(b)) {
if (!integer_trail_->UnsafeEnqueue(
b.LowerOrEqual(0), {},
{integer_trail_->LowerBoundAsLiteral(p_),
integer_trail_->UpperBoundAsLiteral(a),
integer_trail_->UpperBoundAsLiteral(b)})) {
return false;
}
return true;
}
const IntegerValue new_max_a = FloorRatio(max_p, min_pos_b);
if (new_max_a < integer_trail_->UpperBound(a)) {
if (!integer_trail_->UnsafeEnqueue(
a.LowerOrEqual(new_max_a), {},
{integer_trail_->LowerBoundAsLiteral(p_),
integer_trail_->UpperBoundAsLiteral(a),
integer_trail_->UpperBoundAsLiteral(p_)})) {
return false;
}
}
return true;
}
bool ProductPropagator::Propagate() {
if (!CanonicalizeCases()) return false;
// In the most common case, we use better reasons even though the code
// below would propagate the same.
const int64_t min_a = integer_trail_->LowerBound(a_).value();
const int64_t min_b = integer_trail_->LowerBound(b_).value();
if (min_a >= 0 && min_b >= 0) {
// This was done by CanonicalizeCases().
DCHECK_GE(integer_trail_->LowerBound(p_), 0);
return PropagateWhenAllNonNegative();
}
// Lets propagate on p_ first, the max/min is given by one of: max_a * max_b,
// max_a * min_b, min_a * max_b, min_a * min_b. This is true, because any
// product x * y, depending on the sign, is dominated by one of these.
//
// TODO(user): In the reasons, including all 4 bounds is always correct, but
// we might be able to relax some of them.
const int64_t max_a = integer_trail_->UpperBound(a_).value();
const int64_t max_b = integer_trail_->UpperBound(b_).value();
const IntegerValue p1(CapProd(max_a, max_b));
const IntegerValue p2(CapProd(max_a, min_b));
const IntegerValue p3(CapProd(min_a, max_b));
const IntegerValue p4(CapProd(min_a, min_b));
const IntegerValue new_max_p = std::max({p1, p2, p3, p4});
if (new_max_p < integer_trail_->UpperBound(p_)) {
if (!integer_trail_->UnsafeEnqueue(
p_.LowerOrEqual(new_max_p), {},
{integer_trail_->LowerBoundAsLiteral(a_),
integer_trail_->LowerBoundAsLiteral(b_),
integer_trail_->UpperBoundAsLiteral(a_),
integer_trail_->UpperBoundAsLiteral(b_)})) {
return false;
}
}
const IntegerValue new_min_p = std::min({p1, p2, p3, p4});
if (new_min_p > integer_trail_->LowerBound(p_)) {
if (!integer_trail_->UnsafeEnqueue(
p_.GreaterOrEqual(new_min_p), {},
{integer_trail_->LowerBoundAsLiteral(a_),
integer_trail_->LowerBoundAsLiteral(b_),
integer_trail_->UpperBoundAsLiteral(a_),
integer_trail_->UpperBoundAsLiteral(b_)})) {
return false;
}
}
// Lets propagate on a and b.
const IntegerValue min_p = integer_trail_->LowerBound(p_);
const IntegerValue max_p = integer_trail_->UpperBound(p_);
// We need a bit more propagation to avoid bad cases below.
const bool zero_is_possible = min_p <= 0;
if (!zero_is_possible) {
if (integer_trail_->LowerBound(a_) == 0) {
if (!integer_trail_->UnsafeEnqueue(
a_.GreaterOrEqual(1), {},
{p_.GreaterOrEqual(1), a_.GreaterOrEqual(0)})) {
return false;
}
}
if (integer_trail_->LowerBound(b_) == 0) {
if (!integer_trail_->UnsafeEnqueue(
b_.GreaterOrEqual(1), {},
{p_.GreaterOrEqual(1), b_.GreaterOrEqual(0)})) {
return false;
}
}
if (integer_trail_->LowerBound(a_) >= 0 &&
integer_trail_->LowerBound(b_) <= 0) {
return integer_trail_->UnsafeEnqueue(
b_.GreaterOrEqual(1), {},
{a_.GreaterOrEqual(0), p_.GreaterOrEqual(1)});
}
if (integer_trail_->LowerBound(b_) >= 0 &&
integer_trail_->LowerBound(a_) <= 0) {
return integer_trail_->UnsafeEnqueue(
a_.GreaterOrEqual(1), {},
{b_.GreaterOrEqual(0), p_.GreaterOrEqual(1)});
}
}
for (int i = 0; i < 2; ++i) {
// p = a * b, what is the min/max of a?
const AffineExpression a = i == 0 ? a_ : b_;
const AffineExpression b = i == 0 ? b_ : a_;
const IntegerValue max_b = integer_trail_->UpperBound(b);
const IntegerValue min_b = integer_trail_->LowerBound(b);
const IntegerValue max_a = integer_trail_->UpperBound(a);
const IntegerValue min_a = integer_trail_->LowerBound(a);
// If the domain of b contain zero, we can't propagate anything on a.
// Because of CanonicalizeCases(), we just deal with min_b > 0 here.
if (zero_is_possible && min_b <= 0) continue;
// Here both a and b are across zero, but zero is not possible.
if (min_b < 0 && max_b > 0) {
CHECK_GT(min_p, 0); // Because zero is not possible.
// This should be done on the next Propagate() call.
if (min_a >= 0 || max_a <= 0) continue;
PropagateMaxOnPositiveProduct(a, b, min_p, max_p);
PropagateMaxOnPositiveProduct(a.Negated(), b.Negated(), min_p, max_p);
continue;
}
// This shouldn't happen here.
// If it does, we should reach the fixed point on the next iteration.
if (min_b <= 0) continue;
if (min_p >= 0) {
return integer_trail_->UnsafeEnqueue(
a.GreaterOrEqual(0), {}, {p_.GreaterOrEqual(0), b.GreaterOrEqual(1)});
}
if (max_p <= 0) {
return integer_trail_->UnsafeEnqueue(
a.LowerOrEqual(0), {}, {p_.LowerOrEqual(0), b.GreaterOrEqual(1)});
}
// So min_b > 0 and p is across zero: min_p < 0 and max_p > 0.
const IntegerValue new_max_a = FloorRatio(max_p, min_b);
if (new_max_a < integer_trail_->UpperBound(a)) {
if (!integer_trail_->UnsafeEnqueue(
a.LowerOrEqual(new_max_a), {},
{integer_trail_->UpperBoundAsLiteral(p_),
integer_trail_->LowerBoundAsLiteral(b)})) {
return false;
}
}
const IntegerValue new_min_a = CeilRatio(min_p, min_b);
if (new_min_a > integer_trail_->LowerBound(a)) {
if (!integer_trail_->UnsafeEnqueue(
a.GreaterOrEqual(new_min_a), {},
{integer_trail_->LowerBoundAsLiteral(p_),
integer_trail_->LowerBoundAsLiteral(b)})) {
return false;
}
}
}
return true;
}
void ProductPropagator::RegisterWith(GenericLiteralWatcher* watcher) {
const int id = watcher->Register(this);
watcher->WatchAffineExpression(a_, id);
watcher->WatchAffineExpression(b_, id);
@@ -956,7 +1186,9 @@ FixedDivisionPropagator::FixedDivisionPropagator(AffineExpression a,
IntegerValue b,
AffineExpression c,
IntegerTrail* integer_trail)
: a_(a), b_(b), c_(c), integer_trail_(integer_trail) {}
: a_(a), b_(b), c_(c), integer_trail_(integer_trail) {
CHECK_GT(b_, 0);
}
bool FixedDivisionPropagator::Propagate() {
const IntegerValue min_a = integer_trail_->LowerBound(a_);
@@ -964,8 +1196,6 @@ bool FixedDivisionPropagator::Propagate() {
IntegerValue min_c = integer_trail_->LowerBound(c_);
IntegerValue max_c = integer_trail_->UpperBound(c_);
CHECK_GT(b_, 0);
if (max_a / b_ < max_c) {
max_c = max_a / b_;
if (!integer_trail_->UnsafeEnqueue(
@@ -1013,6 +1243,187 @@ void FixedDivisionPropagator::RegisterWith(GenericLiteralWatcher* watcher) {
watcher->WatchAffineExpression(c_, id);
}
FixedModuloPropagator::FixedModuloPropagator(AffineExpression expr,
IntegerValue mod,
AffineExpression target,
IntegerTrail* integer_trail)
: expr_(expr),
mod_(mod),
target_(target),
negated_expr_(expr.Negated()),
negated_target_(target.Negated()),
integer_trail_(integer_trail) {
CHECK_GT(mod_, 0);
}
bool FixedModuloPropagator::Propagate() {
if (!PropagateSignsAndTargetRange()) return false;
if (!PropagateOuterBounds()) return false;
if (integer_trail_->LowerBound(expr_) >= 0) {
if (!PropagatePositiveDomains(expr_, target_)) return false;
} else if (integer_trail_->UpperBound(expr_) <= 0) {
if (!PropagatePositiveDomains(negated_expr_, negated_target_)) return false;
}
return true;
}
bool FixedModuloPropagator::PropagateSignsAndTargetRange() {
// Initial domain reduction on the target.
if (integer_trail_->UpperBound(target_) >= mod_) {
if (!integer_trail_->UnsafeEnqueue(target_.LowerOrEqual(mod_ - 1), {},
{})) {
return false;
}
}
if (integer_trail_->LowerBound(target_) <= -mod_) {
if (!integer_trail_->UnsafeEnqueue(target_.GreaterOrEqual(1 - mod_), {},
{})) {
return false;
}
}
// The sign of target_ is fixed by the sign of expr_.
if (integer_trail_->LowerBound(expr_) >= 0 &&
integer_trail_->LowerBound(target_) < 0) {
if (!integer_trail_->UnsafeEnqueue(target_.GreaterOrEqual(0), {},
{expr_.GreaterOrEqual(0)})) {
return false;
}
}
if (integer_trail_->UpperBound(expr_) <= 0 &&
integer_trail_->UpperBound(target_) > 0) {
if (!integer_trail_->UnsafeEnqueue(target_.LowerOrEqual(0), {},
{expr_.LowerOrEqual(0)})) {
return false;
}
}
return true;
}
bool FixedModuloPropagator::PropagateOuterBounds() {
const IntegerValue min_expr = integer_trail_->LowerBound(expr_);
const IntegerValue max_expr = integer_trail_->UpperBound(expr_);
const IntegerValue min_target = integer_trail_->LowerBound(target_);
const IntegerValue max_target = integer_trail_->UpperBound(target_);
const IntegerValue min_expr_round = (min_expr / mod_) * mod_;
const IntegerValue max_expr_round = (max_expr / mod_) * mod_;
if (max_expr > max_expr_round + max_target) {
if (!integer_trail_->UnsafeEnqueue(
expr_.LowerOrEqual(max_expr_round + max_target), {},
{integer_trail_->UpperBoundAsLiteral(target_),
integer_trail_->UpperBoundAsLiteral(expr_)})) {
return false;
}
}
if (min_expr < min_expr_round + min_target) {
if (!integer_trail_->UnsafeEnqueue(
expr_.GreaterOrEqual(min_expr_round + min_target), {},
{integer_trail_->LowerBoundAsLiteral(expr_),
integer_trail_->LowerBoundAsLiteral(target_)})) {
return false;
}
}
if (min_expr_round == max_expr_round) {
if (min_expr_round + min_target < min_expr) {
if (!integer_trail_->UnsafeEnqueue(
target_.GreaterOrEqual(min_expr - min_expr_round), {},
{integer_trail_->LowerBoundAsLiteral(target_),
integer_trail_->UpperBoundAsLiteral(target_),
integer_trail_->LowerBoundAsLiteral(expr_),
integer_trail_->UpperBoundAsLiteral(expr_)})) {
return false;
}
}
if (max_expr_round + max_target > max_expr) {
if (!integer_trail_->UnsafeEnqueue(
target_.LowerOrEqual(max_expr - max_expr_round), {},
{integer_trail_->LowerBoundAsLiteral(target_),
integer_trail_->UpperBoundAsLiteral(target_),
integer_trail_->LowerBoundAsLiteral(expr_),
integer_trail_->UpperBoundAsLiteral(expr_)})) {
return false;
}
}
} else if (min_expr_round == 0 && min_target < 0) {
// expr == target when expr <= 0.
if (min_target < min_expr) {
if (!integer_trail_->UnsafeEnqueue(
target_.GreaterOrEqual(min_expr), {},
{integer_trail_->LowerBoundAsLiteral(target_),
integer_trail_->LowerBoundAsLiteral(expr_)})) {
return false;
}
}
} else if (max_expr_round == 0 && max_target > 0) {
// expr == target when expr >= 0.
if (max_target > max_expr) {
if (!integer_trail_->UnsafeEnqueue(
target_.LowerOrEqual(max_expr), {},
{integer_trail_->UpperBoundAsLiteral(target_),
integer_trail_->UpperBoundAsLiteral(expr_)})) {
return false;
}
}
}
return true;
}
bool FixedModuloPropagator::PropagatePositiveDomains(AffineExpression expr,
AffineExpression target) {
const IntegerValue min_target = integer_trail_->LowerBound(target);
DCHECK_GE(min_target, 0);
const IntegerValue max_target = integer_trail_->UpperBound(target);
// The propagation rules below will not be triggered if the domain of target
// covers [0..mod_ - 1].
if (min_target == 0 && max_target == mod_ - 1) return true;
const IntegerValue min_expr = integer_trail_->LowerBound(expr);
DCHECK_GE(min_expr, 0);
const IntegerValue max_expr = integer_trail_->UpperBound(expr);
const IntegerValue min_expr_round = (min_expr / mod_) * mod_;
const IntegerValue max_expr_round = (max_expr / mod_) * mod_;
if (max_expr < max_expr_round + min_target) {
if (!integer_trail_->UnsafeEnqueue(
expr.LowerOrEqual(max_expr_round - mod_ + max_target), {},
{expr.GreaterOrEqual(0), integer_trail_->UpperBoundAsLiteral(expr),
integer_trail_->LowerBoundAsLiteral(target),
integer_trail_->UpperBoundAsLiteral(target)})) {
return false;
}
}
if (min_expr > min_expr_round + max_target) {
if (!integer_trail_->UnsafeEnqueue(
expr.GreaterOrEqual(min_expr_round + mod_ + min_target), {},
{integer_trail_->LowerBoundAsLiteral(target),
integer_trail_->UpperBoundAsLiteral(target),
integer_trail_->LowerBoundAsLiteral(expr)})) {
return false;
}
}
return true;
}
void FixedModuloPropagator::RegisterWith(GenericLiteralWatcher* watcher) {
const int id = watcher->Register(this);
watcher->WatchAffineExpression(expr_, id);
watcher->WatchAffineExpression(target_, id);
}
std::function<void(Model*)> IsOneOf(IntegerVariable var,
const std::vector<Literal>& selectors,
const std::vector<IntegerValue>& values) {