From bab11deddf3156b2d14a5524b5dea521945409cb Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Wed, 27 Oct 2021 20:56:28 +0200 Subject: [PATCH] [CP-SAT] implement proper modulo propagator (with fixed mod) --- ortools/sat/cp_model_expand.cc | 259 ++---------------- ortools/sat/cp_model_loader.cc | 15 ++ ortools/sat/integer_expr.cc | 465 +++++++++++++++++++++++++++++++-- ortools/sat/integer_expr.h | 109 +++++--- ortools/sat/lb_tree_search.cc | 1 - 5 files changed, 554 insertions(+), 295 deletions(-) diff --git a/ortools/sat/cp_model_expand.cc b/ortools/sat/cp_model_expand.cc index 2263eb68e8..feb98fe4eb 100644 --- a/ortools/sat/cp_model_expand.cc +++ b/ortools/sat/cp_model_expand.cc @@ -149,8 +149,10 @@ void ExpandReservoir(ConstraintProto* ct, PresolveContext* context) { void ExpandIntMod(ConstraintProto* ct, PresolveContext* context) { const IntegerArgumentProto& int_mod = ct->int_mod(); - const int var = int_mod.vars(0); const int mod_var = int_mod.vars(1); + if (context->IsFixed(mod_var)) return; + + const int var = int_mod.vars(0); const int target_var = int_mod.target(); // We reduce the domain of target_var to avoid later overflow. @@ -177,49 +179,36 @@ void ExpandIntMod(ConstraintProto* ct, PresolveContext* context) { div_proto->add_vars(var); div_proto->add_vars(mod_var); - if (context->IsFixed(mod_var)) { - // var - div_var * mod = target. - LinearConstraintProto* const lin = - new_enforced_constraint()->mutable_linear(); - lin->add_vars(int_mod.vars(0)); - lin->add_coeffs(1); - lin->add_vars(div_var); - lin->add_coeffs(-context->MinOf(mod_var)); - lin->add_vars(target_var); - lin->add_coeffs(-1); - lin->add_domain(0); - lin->add_domain(0); - } else { - // Create prod_var = div_var * mod. - const Domain prod_domain = - context->DomainOf(div_var) - .ContinuousMultiplicationBy(context->DomainOf(mod_var)) - .IntersectionWith(context->DomainOf(var).AdditionWith( - context->DomainOf(target_var).Negation())); - const int prod_var = context->NewIntVar(prod_domain); - IntegerArgumentProto* const int_prod = - new_enforced_constraint()->mutable_int_prod(); - int_prod->set_target(prod_var); - int_prod->add_vars(div_var); - int_prod->add_vars(mod_var); + // Create prod_var = div_var * mod. + const Domain prod_domain = + context->DomainOf(div_var) + .ContinuousMultiplicationBy(context->DomainOf(mod_var)) + .IntersectionWith(context->DomainOf(var).AdditionWith( + context->DomainOf(target_var).Negation())); + const int prod_var = context->NewIntVar(prod_domain); + IntegerArgumentProto* const int_prod = + new_enforced_constraint()->mutable_int_prod(); + int_prod->set_target(prod_var); + int_prod->add_vars(div_var); + int_prod->add_vars(mod_var); - // var - prod_var = target. - LinearConstraintProto* const lin = - new_enforced_constraint()->mutable_linear(); - lin->add_vars(var); - lin->add_coeffs(1); - lin->add_vars(prod_var); - lin->add_coeffs(-1); - lin->add_vars(target_var); - lin->add_coeffs(-1); - lin->add_domain(0); - lin->add_domain(0); - } + // var - prod_var = target. + LinearConstraintProto* const lin = + new_enforced_constraint()->mutable_linear(); + lin->add_vars(var); + lin->add_coeffs(1); + lin->add_vars(prod_var); + lin->add_coeffs(-1); + lin->add_vars(target_var); + lin->add_coeffs(-1); + lin->add_domain(0); + lin->add_domain(0); ct->Clear(); context->UpdateRuleStats("int_mod: expanded"); } +// TODO(user): Move this into the presolve instead? void ExpandIntProdWithBoolean(int bool_ref, int int_ref, int product_ref, PresolveContext* context) { ConstraintProto* const one = context->working_model->add_constraints(); @@ -239,174 +228,6 @@ void ExpandIntProdWithBoolean(int bool_ref, int int_ref, int product_ref, zero->mutable_linear()->add_domain(0); } -// a_ref spans across 0, b_ref does not. -void ExpandIntProdWithOneAcrossZero(int a_ref, int b_ref, int product_ref, - PresolveContext* context) { - DCHECK_LT(context->MinOf(a_ref), 0); - DCHECK_GT(context->MaxOf(a_ref), 0); - DCHECK(context->MinOf(b_ref) >= 0 || context->MaxOf(b_ref) <= 0); - - // Split the domain of a in two, controlled by a new literal. - const int a_is_positive = context->NewBoolVar(); - context->AddImplyInDomain(a_is_positive, a_ref, - {0, std::numeric_limits::max()}); - context->AddImplyInDomain(NegatedRef(a_is_positive), a_ref, - {std::numeric_limits::min(), -1}); - const int pos_a_ref = context->NewIntVar({0, context->MaxOf(a_ref)}); - AddXEqualYOrXEqualZero(a_is_positive, pos_a_ref, a_ref, context); - - const int neg_a_ref = context->NewIntVar({context->MinOf(a_ref), 0}); - AddXEqualYOrXEqualZero(NegatedRef(a_is_positive), neg_a_ref, a_ref, context); - - // Create product with the positive part ofa_ref. - const bool b_is_positive = context->MinOf(b_ref) >= 0; - const Domain pos_a_product_domain = - b_is_positive ? Domain({0, context->MaxOf(product_ref)}) - : Domain({context->MinOf(product_ref), 0}); - const int pos_a_product = context->NewIntVar(pos_a_product_domain); - IntegerArgumentProto* pos_product = - context->working_model->add_constraints()->mutable_int_prod(); - pos_product->set_target(pos_a_product); - pos_product->add_vars(pos_a_ref); - pos_product->add_vars(b_ref); - - // Create product with the negative part of a_ref. - const Domain neg_a_product_domain = - b_is_positive ? Domain({context->MinOf(product_ref), 0}) - : Domain({0, context->MaxOf(product_ref)}); - const int neg_a_product = context->NewIntVar(neg_a_product_domain); - IntegerArgumentProto* neg_product = - context->working_model->add_constraints()->mutable_int_prod(); - neg_product->set_target(neg_a_product); - neg_product->add_vars(neg_a_ref); - neg_product->add_vars(b_ref); - - // Link back to the original product. - LinearConstraintProto* lin = - context->working_model->add_constraints()->mutable_linear(); - lin->add_vars(product_ref); - lin->add_coeffs(-1); - lin->add_vars(pos_a_product); - lin->add_coeffs(1); - lin->add_vars(neg_a_product); - lin->add_coeffs(1); - lin->add_domain(0); - lin->add_domain(0); -} - -void ExpandPositiveIntProdWithTwoAcrossZero(int a_ref, int b_ref, - int product_ref, - PresolveContext* context) { - const int terms_are_positive = context->NewBoolVar(); - - const Domain product_domain = context->DomainOf(product_ref); - - const int64_t max_of_a = context->MaxOf(a_ref); - const int64_t max_of_b = context->MaxOf(b_ref); - const Domain positive_vars_domain = - Domain(0, CapProd(max_of_a, max_of_b)).IntersectionWith(product_domain); - int both_positive_product_ref = std::numeric_limits::min(); - if (!positive_vars_domain.IsEmpty()) { - const int pos_a_ref = context->NewIntVar({0, max_of_a}); - AddXEqualYOrXEqualZero(terms_are_positive, pos_a_ref, a_ref, context); - const int pos_b_ref = context->NewIntVar({0, max_of_b}); - AddXEqualYOrXEqualZero(terms_are_positive, pos_b_ref, b_ref, context); - - // We add 0 to the domain in case this product is not selected. - both_positive_product_ref = - context->NewIntVar(positive_vars_domain.UnionWith(Domain(0))); - IntegerArgumentProto* pos_product = - context->working_model->add_constraints()->mutable_int_prod(); - pos_product->set_target(both_positive_product_ref); - pos_product->add_vars(pos_a_ref); - pos_product->add_vars(pos_b_ref); - } - - const int64_t min_of_a = context->MinOf(a_ref); - const int64_t min_of_b = context->MinOf(b_ref); - const Domain negative_vars_domain = - Domain(0, CapProd(min_of_a, min_of_b)).IntersectionWith(product_domain); - int both_negative_product_ref = std::numeric_limits::min(); - if (!negative_vars_domain.IsEmpty()) { - const int neg_a_ref = context->NewIntVar({min_of_a, 0}); - AddXEqualYOrXEqualZero(NegatedRef(terms_are_positive), neg_a_ref, a_ref, - context); - const int neg_b_ref = context->NewIntVar({min_of_b, 0}); - AddXEqualYOrXEqualZero(NegatedRef(terms_are_positive), neg_b_ref, b_ref, - context); - // We add 0 to the domain in case this product is not selected. - both_negative_product_ref = - context->NewIntVar(negative_vars_domain.UnionWith(Domain(0))); - IntegerArgumentProto* neg_product = - context->working_model->add_constraints()->mutable_int_prod(); - neg_product->set_target(both_negative_product_ref); - neg_product->add_vars(neg_a_ref); - neg_product->add_vars(neg_b_ref); - } - - // Link back to the original product. - LinearConstraintProto* lin = - context->working_model->add_constraints()->mutable_linear(); - lin->add_vars(product_ref); - lin->add_coeffs(-1); - if (both_positive_product_ref != std::numeric_limits::min()) { - lin->add_vars(both_positive_product_ref); - lin->add_coeffs(1); - } - if (both_negative_product_ref != std::numeric_limits::min()) { - lin->add_vars(both_negative_product_ref); - lin->add_coeffs(1); - } - lin->add_domain(0); - lin->add_domain(0); -} - -void ExpandIntProdWithTwoAcrossZero(int a_ref, int b_ref, int product_ref, - PresolveContext* context) { - if (context->MinOf(product_ref) >= 0) { - ExpandPositiveIntProdWithTwoAcrossZero(a_ref, b_ref, product_ref, context); - return; - } else if (context->MaxOf(product_ref) <= 0) { - ExpandPositiveIntProdWithTwoAcrossZero(a_ref, NegatedRef(b_ref), - NegatedRef(product_ref), context); - return; - } - // Split a_ref domain in two, controlled by a new literal. - const int a_is_positive = context->NewBoolVar(); - context->AddImplyInDomain(a_is_positive, a_ref, - {0, std::numeric_limits::max()}); - context->AddImplyInDomain(NegatedRef(a_is_positive), a_ref, - {std::numeric_limits::min(), -1}); - const int64_t min_of_a = context->MinOf(a_ref); - const int64_t max_of_a = context->MaxOf(a_ref); - - const int pos_a_ref = context->NewIntVar({0, max_of_a}); - AddXEqualYOrXEqualZero(a_is_positive, pos_a_ref, a_ref, context); - - const int neg_a_ref = context->NewIntVar({min_of_a, 0}); - AddXEqualYOrXEqualZero(NegatedRef(a_is_positive), neg_a_ref, a_ref, context); - - // Create product with two sub parts of a_ref. - const int pos_product_ref = - context->NewIntVar(context->DomainOf(product_ref)); - ExpandIntProdWithOneAcrossZero(b_ref, pos_a_ref, pos_product_ref, context); - const int neg_product_ref = - context->NewIntVar(context->DomainOf(product_ref)); - ExpandIntProdWithOneAcrossZero(b_ref, neg_a_ref, neg_product_ref, context); - - // Link back to the original product. - LinearConstraintProto* lin = - context->working_model->add_constraints()->mutable_linear(); - lin->add_vars(product_ref); - lin->add_coeffs(-1); - lin->add_vars(pos_product_ref); - lin->add_coeffs(1); - lin->add_vars(neg_product_ref); - lin->add_coeffs(1); - lin->add_domain(0); - lin->add_domain(0); -} - void ExpandIntProd(ConstraintProto* ct, PresolveContext* context) { const IntegerArgumentProto& int_prod = ct->int_prod(); if (int_prod.vars_size() != 2) return; @@ -432,32 +253,6 @@ void ExpandIntProd(ConstraintProto* ct, PresolveContext* context) { context->UpdateRuleStats("int_prod: expanded product with Boolean var"); return; } - - const bool a_span_across_zero = - context->MinOf(a) < 0 && context->MaxOf(a) > 0; - const bool b_span_across_zero = - context->MinOf(b) < 0 && context->MaxOf(b) > 0; - if (a_span_across_zero && !b_span_across_zero) { - ExpandIntProdWithOneAcrossZero(a, b, p, context); - ct->Clear(); - context->UpdateRuleStats( - "int_prod: expanded product with general integer variables"); - return; - } - if (!a_span_across_zero && b_span_across_zero) { - ExpandIntProdWithOneAcrossZero(b, a, p, context); - ct->Clear(); - context->UpdateRuleStats( - "int_prod: expanded product with general integer variables"); - return; - } - if (a_span_across_zero && b_span_across_zero) { - ExpandIntProdWithTwoAcrossZero(a, b, p, context); - ct->Clear(); - context->UpdateRuleStats( - "int_prod: expanded product with general integer variables"); - return; - } } void ExpandInverse(ConstraintProto* ct, PresolveContext* context) { diff --git a/ortools/sat/cp_model_loader.cc b/ortools/sat/cp_model_loader.cc index ef1a580fc6..24b6bfc805 100644 --- a/ortools/sat/cp_model_loader.cc +++ b/ortools/sat/cp_model_loader.cc @@ -1156,6 +1156,18 @@ void LoadIntDivConstraint(const ConstraintProto& ct, Model* m) { } } +void LoadIntModConstraint(const ConstraintProto& ct, Model* m) { + auto* mapping = m->GetOrCreate(); + auto* integer_trail = m->GetOrCreate(); + + const IntegerVariable target = mapping->Integer(ct.int_mod().target()); + const std::vector vars = + mapping->Integers(ct.int_mod().vars()); + CHECK(integer_trail->IsFixed(vars[1])); + const IntegerValue fixed_modulo = integer_trail->FixedValue(vars[1]); + m->Add(FixedModuloConstraint(vars[0], fixed_modulo, target)); +} + void LoadLinMaxConstraint(const ConstraintProto& ct, Model* m) { if (ct.lin_max().exprs().empty()) { m->GetOrCreate()->NotifyThatModelIsUnsat(); @@ -1259,6 +1271,9 @@ bool LoadConstraint(const ConstraintProto& ct, Model* m) { case ConstraintProto::ConstraintProto::kIntDiv: LoadIntDivConstraint(ct, m); return true; + case ConstraintProto::ConstraintProto::kIntMod: + LoadIntModConstraint(ct, m); + return true; case ConstraintProto::ConstraintProto::kLinMax: LoadLinMaxConstraint(ct, m); return true; diff --git a/ortools/sat/integer_expr.cc b/ortools/sat/integer_expr.cc index 6e2a4324c2..7c84b75982 100644 --- a/ortools/sat/integer_expr.cc +++ b/ortools/sat/integer_expr.cc @@ -610,27 +610,62 @@ void LinMinPropagator::RegisterWith(GenericLiteralWatcher* watcher) { watcher->RegisterReversibleInt(id, &rev_unique_candidate_); } -PositiveProductPropagator::PositiveProductPropagator( - AffineExpression a, AffineExpression b, AffineExpression p, - IntegerTrail* integer_trail) - : a_(a), b_(b), p_(p), integer_trail_(integer_trail) { - // Note that we assume this is true at level zero, and so we never include - // that fact in the reasons we compute. - CHECK_GE(integer_trail_->LevelZeroLowerBound(a_), 0); - CHECK_GE(integer_trail_->LevelZeroLowerBound(b_), 0); +ProductPropagator::ProductPropagator(AffineExpression a, AffineExpression b, + AffineExpression p, + IntegerTrail* integer_trail) + : a_(a), b_(b), p_(p), integer_trail_(integer_trail) {} + +// We want all affine expression to be either non-negative or across zero. +bool ProductPropagator::CanonicalizeCases() { + if (integer_trail_->UpperBound(a_) <= 0) { + a_ = a_.Negated(); + p_ = p_.Negated(); + } + if (integer_trail_->UpperBound(b_) <= 0) { + b_ = b_.Negated(); + p_ = p_.Negated(); + } + + // If both a and b positive, p must be too. + if (integer_trail_->LowerBound(a_) >= 0 && + integer_trail_->LowerBound(b_) >= 0) { + return integer_trail_->UnsafeEnqueue( + p_.GreaterOrEqual(0), {}, {a_.GreaterOrEqual(0), b_.GreaterOrEqual(0)}); + } + + // Otherwise, make sure p is non-negative or accros zero. + if (integer_trail_->UpperBound(p_) <= 0) { + if (integer_trail_->LowerBound(a_) < 0) { + DCHECK_GT(integer_trail_->UpperBound(a_), 0); + a_ = a_.Negated(); + p_ = p_.Negated(); + } else { + DCHECK_LT(integer_trail_->LowerBound(b_), 0); + DCHECK_GT(integer_trail_->UpperBound(b_), 0); + b_ = b_.Negated(); + p_ = p_.Negated(); + } + } + + return true; } -// TODO(user): We can tighten the bounds on p by removing extreme value that +// Note that this propagation is exact, except on the domain of p as this +// involves more complex arithmetic. +// +// TODO(user): We could tighten the bounds on p by removing extreme value that // do not contains divisor in the domains of a or b. There is an algo in O( // smallest domain size between a or b). -bool PositiveProductPropagator::Propagate() { +bool ProductPropagator::PropagateWhenAllNonNegative() { const IntegerValue max_a = integer_trail_->UpperBound(a_); const IntegerValue max_b = integer_trail_->UpperBound(b_); const IntegerValue new_max(CapProd(max_a.value(), max_b.value())); if (new_max < integer_trail_->UpperBound(p_)) { - if (!integer_trail_->Enqueue(p_.LowerOrEqual(new_max), {}, - {integer_trail_->UpperBoundAsLiteral(a_), - integer_trail_->UpperBoundAsLiteral(b_)})) { + if (!integer_trail_->Enqueue( + p_.LowerOrEqual(new_max), {}, + {integer_trail_->UpperBoundAsLiteral(a_), + integer_trail_->UpperBoundAsLiteral(b_), a_.GreaterOrEqual(0), + b_.GreaterOrEqual(0)})) { return false; } } @@ -639,9 +674,10 @@ bool PositiveProductPropagator::Propagate() { const IntegerValue min_b = integer_trail_->LowerBound(b_); const IntegerValue new_min(CapProd(min_a.value(), min_b.value())); if (new_min > integer_trail_->LowerBound(p_)) { - if (!integer_trail_->Enqueue(p_.GreaterOrEqual(new_min), {}, - {integer_trail_->LowerBoundAsLiteral(a_), - integer_trail_->LowerBoundAsLiteral(b_)})) { + if (!integer_trail_->UnsafeEnqueue( + p_.GreaterOrEqual(new_min), {}, + {integer_trail_->LowerBoundAsLiteral(a_), + integer_trail_->LowerBoundAsLiteral(b_)})) { return false; } } @@ -655,16 +691,18 @@ bool PositiveProductPropagator::Propagate() { const IntegerValue max_p = integer_trail_->UpperBound(p_); const IntegerValue prod(CapProd(max_a.value(), min_b.value())); if (prod > max_p) { - if (!integer_trail_->Enqueue(a.LowerOrEqual(FloorRatio(max_p, min_b)), {}, - {integer_trail_->LowerBoundAsLiteral(b), - integer_trail_->UpperBoundAsLiteral(p_)})) { + if (!integer_trail_->UnsafeEnqueue( + a.LowerOrEqual(FloorRatio(max_p, min_b)), {}, + {integer_trail_->LowerBoundAsLiteral(b), + integer_trail_->UpperBoundAsLiteral(p_), + p_.GreaterOrEqual(0)})) { return false; } } else if (prod < min_p) { - if (!integer_trail_->Enqueue(b.GreaterOrEqual(CeilRatio(min_p, max_a)), - {}, - {integer_trail_->UpperBoundAsLiteral(a), - integer_trail_->LowerBoundAsLiteral(p_)})) { + if (!integer_trail_->UnsafeEnqueue( + b.GreaterOrEqual(CeilRatio(min_p, max_a)), {}, + {integer_trail_->UpperBoundAsLiteral(a), + integer_trail_->LowerBoundAsLiteral(p_), a.GreaterOrEqual(0)})) { return false; } } @@ -673,7 +711,199 @@ bool PositiveProductPropagator::Propagate() { return true; } -void PositiveProductPropagator::RegisterWith(GenericLiteralWatcher* watcher) { +// This assumes p > 0, p = a * X, and X can take any value. +// We can propagate max of a by computing a bound on the min b when positive. +// The expression b is just used to detect when there is no solution given the +// upper bound of b. +bool ProductPropagator::PropagateMaxOnPositiveProduct(AffineExpression a, + AffineExpression b, + IntegerValue min_p, + IntegerValue max_p) { + const IntegerValue max_a = integer_trail_->UpperBound(a); + DCHECK_GT(max_a, 0); + DCHECK_GT(min_p, 0); + + if (max_a >= min_p) { + if (max_p < max_a) { + if (!integer_trail_->UnsafeEnqueue( + a.LowerOrEqual(max_p), {}, + {p_.LowerOrEqual(max_p), p_.GreaterOrEqual(1)})) { + return false; + } + } + return true; + } + + const IntegerValue min_pos_b = CeilRatio(min_p, max_a); + if (min_pos_b > integer_trail_->UpperBound(b)) { + if (!integer_trail_->UnsafeEnqueue( + b.LowerOrEqual(0), {}, + {integer_trail_->LowerBoundAsLiteral(p_), + integer_trail_->UpperBoundAsLiteral(a), + integer_trail_->UpperBoundAsLiteral(b)})) { + return false; + } + return true; + } + + const IntegerValue new_max_a = FloorRatio(max_p, min_pos_b); + if (new_max_a < integer_trail_->UpperBound(a)) { + if (!integer_trail_->UnsafeEnqueue( + a.LowerOrEqual(new_max_a), {}, + {integer_trail_->LowerBoundAsLiteral(p_), + integer_trail_->UpperBoundAsLiteral(a), + integer_trail_->UpperBoundAsLiteral(p_)})) { + return false; + } + } + return true; +} + +bool ProductPropagator::Propagate() { + if (!CanonicalizeCases()) return false; + + // In the most common case, we use better reasons even though the code + // below would propagate the same. + const int64_t min_a = integer_trail_->LowerBound(a_).value(); + const int64_t min_b = integer_trail_->LowerBound(b_).value(); + if (min_a >= 0 && min_b >= 0) { + // This was done by CanonicalizeCases(). + DCHECK_GE(integer_trail_->LowerBound(p_), 0); + return PropagateWhenAllNonNegative(); + } + + // Lets propagate on p_ first, the max/min is given by one of: max_a * max_b, + // max_a * min_b, min_a * max_b, min_a * min_b. This is true, because any + // product x * y, depending on the sign, is dominated by one of these. + // + // TODO(user): In the reasons, including all 4 bounds is always correct, but + // we might be able to relax some of them. + const int64_t max_a = integer_trail_->UpperBound(a_).value(); + const int64_t max_b = integer_trail_->UpperBound(b_).value(); + const IntegerValue p1(CapProd(max_a, max_b)); + const IntegerValue p2(CapProd(max_a, min_b)); + const IntegerValue p3(CapProd(min_a, max_b)); + const IntegerValue p4(CapProd(min_a, min_b)); + const IntegerValue new_max_p = std::max({p1, p2, p3, p4}); + if (new_max_p < integer_trail_->UpperBound(p_)) { + if (!integer_trail_->UnsafeEnqueue( + p_.LowerOrEqual(new_max_p), {}, + {integer_trail_->LowerBoundAsLiteral(a_), + integer_trail_->LowerBoundAsLiteral(b_), + integer_trail_->UpperBoundAsLiteral(a_), + integer_trail_->UpperBoundAsLiteral(b_)})) { + return false; + } + } + const IntegerValue new_min_p = std::min({p1, p2, p3, p4}); + if (new_min_p > integer_trail_->LowerBound(p_)) { + if (!integer_trail_->UnsafeEnqueue( + p_.GreaterOrEqual(new_min_p), {}, + {integer_trail_->LowerBoundAsLiteral(a_), + integer_trail_->LowerBoundAsLiteral(b_), + integer_trail_->UpperBoundAsLiteral(a_), + integer_trail_->UpperBoundAsLiteral(b_)})) { + return false; + } + } + + // Lets propagate on a and b. + const IntegerValue min_p = integer_trail_->LowerBound(p_); + const IntegerValue max_p = integer_trail_->UpperBound(p_); + + // We need a bit more propagation to avoid bad cases below. + const bool zero_is_possible = min_p <= 0; + if (!zero_is_possible) { + if (integer_trail_->LowerBound(a_) == 0) { + if (!integer_trail_->UnsafeEnqueue( + a_.GreaterOrEqual(1), {}, + {p_.GreaterOrEqual(1), a_.GreaterOrEqual(0)})) { + return false; + } + } + if (integer_trail_->LowerBound(b_) == 0) { + if (!integer_trail_->UnsafeEnqueue( + b_.GreaterOrEqual(1), {}, + {p_.GreaterOrEqual(1), b_.GreaterOrEqual(0)})) { + return false; + } + } + if (integer_trail_->LowerBound(a_) >= 0 && + integer_trail_->LowerBound(b_) <= 0) { + return integer_trail_->UnsafeEnqueue( + b_.GreaterOrEqual(1), {}, + {a_.GreaterOrEqual(0), p_.GreaterOrEqual(1)}); + } + if (integer_trail_->LowerBound(b_) >= 0 && + integer_trail_->LowerBound(a_) <= 0) { + return integer_trail_->UnsafeEnqueue( + a_.GreaterOrEqual(1), {}, + {b_.GreaterOrEqual(0), p_.GreaterOrEqual(1)}); + } + } + + for (int i = 0; i < 2; ++i) { + // p = a * b, what is the min/max of a? + const AffineExpression a = i == 0 ? a_ : b_; + const AffineExpression b = i == 0 ? b_ : a_; + const IntegerValue max_b = integer_trail_->UpperBound(b); + const IntegerValue min_b = integer_trail_->LowerBound(b); + const IntegerValue max_a = integer_trail_->UpperBound(a); + const IntegerValue min_a = integer_trail_->LowerBound(a); + + // If the domain of b contain zero, we can't propagate anything on a. + // Because of CanonicalizeCases(), we just deal with min_b > 0 here. + if (zero_is_possible && min_b <= 0) continue; + + // Here both a and b are across zero, but zero is not possible. + if (min_b < 0 && max_b > 0) { + CHECK_GT(min_p, 0); // Because zero is not possible. + + // This should be done on the next Propagate() call. + if (min_a >= 0 || max_a <= 0) continue; + + PropagateMaxOnPositiveProduct(a, b, min_p, max_p); + PropagateMaxOnPositiveProduct(a.Negated(), b.Negated(), min_p, max_p); + continue; + } + + // This shouldn't happen here. + // If it does, we should reach the fixed point on the next iteration. + if (min_b <= 0) continue; + if (min_p >= 0) { + return integer_trail_->UnsafeEnqueue( + a.GreaterOrEqual(0), {}, {p_.GreaterOrEqual(0), b.GreaterOrEqual(1)}); + } + if (max_p <= 0) { + return integer_trail_->UnsafeEnqueue( + a.LowerOrEqual(0), {}, {p_.LowerOrEqual(0), b.GreaterOrEqual(1)}); + } + + // So min_b > 0 and p is across zero: min_p < 0 and max_p > 0. + const IntegerValue new_max_a = FloorRatio(max_p, min_b); + if (new_max_a < integer_trail_->UpperBound(a)) { + if (!integer_trail_->UnsafeEnqueue( + a.LowerOrEqual(new_max_a), {}, + {integer_trail_->UpperBoundAsLiteral(p_), + integer_trail_->LowerBoundAsLiteral(b)})) { + return false; + } + } + const IntegerValue new_min_a = CeilRatio(min_p, min_b); + if (new_min_a > integer_trail_->LowerBound(a)) { + if (!integer_trail_->UnsafeEnqueue( + a.GreaterOrEqual(new_min_a), {}, + {integer_trail_->LowerBoundAsLiteral(p_), + integer_trail_->LowerBoundAsLiteral(b)})) { + return false; + } + } + } + + return true; +} + +void ProductPropagator::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); watcher->WatchAffineExpression(a_, id); watcher->WatchAffineExpression(b_, id); @@ -956,7 +1186,9 @@ FixedDivisionPropagator::FixedDivisionPropagator(AffineExpression a, IntegerValue b, AffineExpression c, IntegerTrail* integer_trail) - : a_(a), b_(b), c_(c), integer_trail_(integer_trail) {} + : a_(a), b_(b), c_(c), integer_trail_(integer_trail) { + CHECK_GT(b_, 0); +} bool FixedDivisionPropagator::Propagate() { const IntegerValue min_a = integer_trail_->LowerBound(a_); @@ -964,8 +1196,6 @@ bool FixedDivisionPropagator::Propagate() { IntegerValue min_c = integer_trail_->LowerBound(c_); IntegerValue max_c = integer_trail_->UpperBound(c_); - CHECK_GT(b_, 0); - if (max_a / b_ < max_c) { max_c = max_a / b_; if (!integer_trail_->UnsafeEnqueue( @@ -1013,6 +1243,187 @@ void FixedDivisionPropagator::RegisterWith(GenericLiteralWatcher* watcher) { watcher->WatchAffineExpression(c_, id); } +FixedModuloPropagator::FixedModuloPropagator(AffineExpression expr, + IntegerValue mod, + AffineExpression target, + IntegerTrail* integer_trail) + : expr_(expr), + mod_(mod), + target_(target), + negated_expr_(expr.Negated()), + negated_target_(target.Negated()), + integer_trail_(integer_trail) { + CHECK_GT(mod_, 0); +} + +bool FixedModuloPropagator::Propagate() { + if (!PropagateSignsAndTargetRange()) return false; + if (!PropagateOuterBounds()) return false; + + if (integer_trail_->LowerBound(expr_) >= 0) { + if (!PropagatePositiveDomains(expr_, target_)) return false; + } else if (integer_trail_->UpperBound(expr_) <= 0) { + if (!PropagatePositiveDomains(negated_expr_, negated_target_)) return false; + } + + return true; +} + +bool FixedModuloPropagator::PropagateSignsAndTargetRange() { + // Initial domain reduction on the target. + if (integer_trail_->UpperBound(target_) >= mod_) { + if (!integer_trail_->UnsafeEnqueue(target_.LowerOrEqual(mod_ - 1), {}, + {})) { + return false; + } + } + + if (integer_trail_->LowerBound(target_) <= -mod_) { + if (!integer_trail_->UnsafeEnqueue(target_.GreaterOrEqual(1 - mod_), {}, + {})) { + return false; + } + } + + // The sign of target_ is fixed by the sign of expr_. + if (integer_trail_->LowerBound(expr_) >= 0 && + integer_trail_->LowerBound(target_) < 0) { + if (!integer_trail_->UnsafeEnqueue(target_.GreaterOrEqual(0), {}, + {expr_.GreaterOrEqual(0)})) { + return false; + } + } + + if (integer_trail_->UpperBound(expr_) <= 0 && + integer_trail_->UpperBound(target_) > 0) { + if (!integer_trail_->UnsafeEnqueue(target_.LowerOrEqual(0), {}, + {expr_.LowerOrEqual(0)})) { + return false; + } + } + + return true; +} + +bool FixedModuloPropagator::PropagateOuterBounds() { + const IntegerValue min_expr = integer_trail_->LowerBound(expr_); + const IntegerValue max_expr = integer_trail_->UpperBound(expr_); + const IntegerValue min_target = integer_trail_->LowerBound(target_); + const IntegerValue max_target = integer_trail_->UpperBound(target_); + const IntegerValue min_expr_round = (min_expr / mod_) * mod_; + const IntegerValue max_expr_round = (max_expr / mod_) * mod_; + + if (max_expr > max_expr_round + max_target) { + if (!integer_trail_->UnsafeEnqueue( + expr_.LowerOrEqual(max_expr_round + max_target), {}, + {integer_trail_->UpperBoundAsLiteral(target_), + integer_trail_->UpperBoundAsLiteral(expr_)})) { + return false; + } + } + + if (min_expr < min_expr_round + min_target) { + if (!integer_trail_->UnsafeEnqueue( + expr_.GreaterOrEqual(min_expr_round + min_target), {}, + {integer_trail_->LowerBoundAsLiteral(expr_), + integer_trail_->LowerBoundAsLiteral(target_)})) { + return false; + } + } + + if (min_expr_round == max_expr_round) { + if (min_expr_round + min_target < min_expr) { + if (!integer_trail_->UnsafeEnqueue( + target_.GreaterOrEqual(min_expr - min_expr_round), {}, + {integer_trail_->LowerBoundAsLiteral(target_), + integer_trail_->UpperBoundAsLiteral(target_), + integer_trail_->LowerBoundAsLiteral(expr_), + integer_trail_->UpperBoundAsLiteral(expr_)})) { + return false; + } + } + + if (max_expr_round + max_target > max_expr) { + if (!integer_trail_->UnsafeEnqueue( + target_.LowerOrEqual(max_expr - max_expr_round), {}, + {integer_trail_->LowerBoundAsLiteral(target_), + integer_trail_->UpperBoundAsLiteral(target_), + integer_trail_->LowerBoundAsLiteral(expr_), + integer_trail_->UpperBoundAsLiteral(expr_)})) { + return false; + } + } + } else if (min_expr_round == 0 && min_target < 0) { + // expr == target when expr <= 0. + if (min_target < min_expr) { + if (!integer_trail_->UnsafeEnqueue( + target_.GreaterOrEqual(min_expr), {}, + {integer_trail_->LowerBoundAsLiteral(target_), + integer_trail_->LowerBoundAsLiteral(expr_)})) { + return false; + } + } + } else if (max_expr_round == 0 && max_target > 0) { + // expr == target when expr >= 0. + if (max_target > max_expr) { + if (!integer_trail_->UnsafeEnqueue( + target_.LowerOrEqual(max_expr), {}, + {integer_trail_->UpperBoundAsLiteral(target_), + integer_trail_->UpperBoundAsLiteral(expr_)})) { + return false; + } + } + } + + return true; +} + +bool FixedModuloPropagator::PropagatePositiveDomains(AffineExpression expr, + AffineExpression target) { + const IntegerValue min_target = integer_trail_->LowerBound(target); + DCHECK_GE(min_target, 0); + const IntegerValue max_target = integer_trail_->UpperBound(target); + + // The propagation rules below will not be triggered if the domain of target + // covers [0..mod_ - 1]. + if (min_target == 0 && max_target == mod_ - 1) return true; + + const IntegerValue min_expr = integer_trail_->LowerBound(expr); + DCHECK_GE(min_expr, 0); + const IntegerValue max_expr = integer_trail_->UpperBound(expr); + + const IntegerValue min_expr_round = (min_expr / mod_) * mod_; + const IntegerValue max_expr_round = (max_expr / mod_) * mod_; + + if (max_expr < max_expr_round + min_target) { + if (!integer_trail_->UnsafeEnqueue( + expr.LowerOrEqual(max_expr_round - mod_ + max_target), {}, + {expr.GreaterOrEqual(0), integer_trail_->UpperBoundAsLiteral(expr), + integer_trail_->LowerBoundAsLiteral(target), + integer_trail_->UpperBoundAsLiteral(target)})) { + return false; + } + } + + if (min_expr > min_expr_round + max_target) { + if (!integer_trail_->UnsafeEnqueue( + expr.GreaterOrEqual(min_expr_round + mod_ + min_target), {}, + {integer_trail_->LowerBoundAsLiteral(target), + integer_trail_->UpperBoundAsLiteral(target), + integer_trail_->LowerBoundAsLiteral(expr)})) { + return false; + } + } + + return true; +} + +void FixedModuloPropagator::RegisterWith(GenericLiteralWatcher* watcher) { + const int id = watcher->Register(this); + watcher->WatchAffineExpression(expr_, id); + watcher->WatchAffineExpression(target_, id); +} + std::function IsOneOf(IntegerVariable var, const std::vector& selectors, const std::vector& values) { diff --git a/ortools/sat/integer_expr.h b/ortools/sat/integer_expr.h index cbd924095b..d0758a5918 100644 --- a/ortools/sat/integer_expr.h +++ b/ortools/sat/integer_expr.h @@ -208,26 +208,39 @@ class LinMinPropagator : public PropagatorInterface { int rev_unique_candidate_ = 0; }; -// Propagates a * b = c. Basic version, we don't extract any special cases, and -// we only propagates the bounds. +// Propagates a * b = p. // -// TODO(user): For now this only works on variables that are non-negative. -// TODO(user): Deal with overflow. -class PositiveProductPropagator : public PropagatorInterface { +// The bounds [min, max] of a and b will be propagated perfectly, but not +// the bounds on p as this require more complex arithmetics. +class ProductPropagator : public PropagatorInterface { public: - PositiveProductPropagator(AffineExpression a, AffineExpression b, - AffineExpression p, IntegerTrail* integer_trail); + ProductPropagator(AffineExpression a, AffineExpression b, AffineExpression p, + IntegerTrail* integer_trail); bool Propagate() final; void RegisterWith(GenericLiteralWatcher* watcher); private: - const AffineExpression a_; - const AffineExpression b_; - const AffineExpression p_; + // Maybe replace a_, b_ or c_ by their negation to simplify the cases. + bool CanonicalizeCases(); + + // Special case when all are >= 0. + // We use faster code and better reasons than the generic code. + bool PropagateWhenAllNonNegative(); + + // Internal helper, see code for more details. + bool PropagateMaxOnPositiveProduct(AffineExpression a, AffineExpression b, + IntegerValue min_p, IntegerValue max_p); + + // Note that we might negate any two terms in CanonicalizeCases() during + // each propagation. This is fine. + AffineExpression a_; + AffineExpression b_; + AffineExpression p_; + IntegerTrail* integer_trail_; - DISALLOW_COPY_AND_ASSIGN(PositiveProductPropagator); + DISALLOW_COPY_AND_ASSIGN(ProductPropagator); }; // Propagates num / denom = div. Basic version, we don't extract any special @@ -282,11 +295,42 @@ class FixedDivisionPropagator : public PropagatorInterface { const AffineExpression a_; const IntegerValue b_; const AffineExpression c_; + IntegerTrail* integer_trail_; DISALLOW_COPY_AND_ASSIGN(FixedDivisionPropagator); }; +// Propagates var_a % cst_b = var_c. Basic version, we don't extract any special +// cases, and we only propagates the bounds. cst_b must be > 0. +class FixedModuloPropagator : public PropagatorInterface { + public: + FixedModuloPropagator(AffineExpression expr, IntegerValue mod, + AffineExpression target, IntegerTrail* integer_trail); + + bool Propagate() final; + void RegisterWith(GenericLiteralWatcher* watcher); + + private: + // Propagates sign and basic bounds. + bool PropagateSignsAndTargetRange(); + + // Propagates on the positive domains. + bool PropagatePositiveDomains(AffineExpression expr, AffineExpression target); + + // Propagates outer bounds. + bool PropagateOuterBounds(); + + const AffineExpression expr_; + const IntegerValue mod_; + const AffineExpression target_; + const AffineExpression negated_expr_; + const AffineExpression negated_target_; + IntegerTrail* integer_trail_; + + DISALLOW_COPY_AND_ASSIGN(FixedModuloPropagator); +}; + // Propagates x * x = s. // TODO(user): Only works for x nonnegative. class SquarePropagator : public PropagatorInterface { @@ -793,34 +837,16 @@ inline std::function ProductConstraint(AffineExpression a, if (integer_trail->LowerBound(a) >= 0) { RegisterAndTransferOwnership(model, new SquarePropagator(a, p, integer_trail)); - } else if (integer_trail->UpperBound(a) <= 0) { + return; + } + if (integer_trail->UpperBound(a) <= 0) { RegisterAndTransferOwnership( model, new SquarePropagator(a.Negated(), p, integer_trail)); - } else { - LOG(FATAL) << "Not supported"; + return; } - } else if (integer_trail->LowerBound(a) >= 0 && - integer_trail->LowerBound(b) >= 0) { - RegisterAndTransferOwnership( - model, new PositiveProductPropagator(a, b, p, integer_trail)); - } else if (integer_trail->LowerBound(a) >= 0 && - integer_trail->UpperBound(b) <= 0) { - RegisterAndTransferOwnership( - model, new PositiveProductPropagator(a, b.Negated(), p.Negated(), - integer_trail)); - } else if (integer_trail->UpperBound(a) <= 0 && - integer_trail->LowerBound(b) >= 0) { - RegisterAndTransferOwnership( - model, new PositiveProductPropagator(a.Negated(), b, p.Negated(), - integer_trail)); - } else if (integer_trail->UpperBound(a) <= 0 && - integer_trail->UpperBound(b) <= 0) { - RegisterAndTransferOwnership( - model, new PositiveProductPropagator(a.Negated(), b.Negated(), p, - integer_trail)); - } else { - LOG(FATAL) << "Not supported"; } + RegisterAndTransferOwnership(model, + new ProductPropagator(a, b, p, integer_trail)); }; } @@ -857,6 +883,19 @@ inline std::function FixedDivisionConstraint(AffineExpression a, }; } +// Adds the constraint: a % b = c where b is a constant. +inline std::function FixedModuloConstraint(AffineExpression a, + IntegerValue b, + AffineExpression c) { + return [=](Model* model) { + IntegerTrail* integer_trail = model->GetOrCreate(); + FixedModuloPropagator* constraint = + new FixedModuloPropagator(a, b, c, integer_trail); + constraint->RegisterWith(model->GetOrCreate()); + model->TakeOwnership(constraint); + }; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/lb_tree_search.cc b/ortools/sat/lb_tree_search.cc index 5c3068ac6c..36b8c9b235 100644 --- a/ortools/sat/lb_tree_search.cc +++ b/ortools/sat/lb_tree_search.cc @@ -231,7 +231,6 @@ SatSolver::Status LbTreeSearch::Search( if (node.objective_lb > current_objective_lb_) { break; } - CHECK_EQ(node.objective_lb, current_objective_lb_); // This will be set to the next node index. NodeIndex n;