From a0e25debc5b8acc799e2616470147eda2a515359 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Sat, 26 Jul 2025 21:29:40 +0200 Subject: [PATCH] [CP-SAT] extend support for enforcement literals in constraints --- ortools/sat/cp_constraints.cc | 77 +-- ortools/sat/cp_constraints.h | 36 +- ortools/sat/cp_constraints_test.cc | 3 +- ortools/sat/cp_model.proto | 4 +- ortools/sat/cp_model_checker.cc | 2 + ortools/sat/cp_model_copy.cc | 2 + ortools/sat/cp_model_loader.cc | 14 +- ortools/sat/cp_model_presolve.cc | 4 + ortools/sat/cp_model_solver_test.cc | 46 +- ortools/sat/integer_expr.cc | 843 +++++++++++++++++----------- ortools/sat/integer_expr.h | 148 +++-- ortools/sat/integer_expr_test.cc | 173 +++++- 12 files changed, 879 insertions(+), 473 deletions(-) diff --git a/ortools/sat/cp_constraints.cc b/ortools/sat/cp_constraints.cc index f7a7903f49..3dfa1dc8d8 100644 --- a/ortools/sat/cp_constraints.cc +++ b/ortools/sat/cp_constraints.cc @@ -126,7 +126,6 @@ EnforcementId EnforcementPropagator::Register( std::function callback) { int num_true = 0; int num_false = 0; - bool is_always_false = false; temp_literals_.clear(); const int level = trail_.CurrentDecisionLevel(); for (const Literal l : enforcement) { @@ -139,22 +138,13 @@ EnforcementId EnforcementPropagator::Register( if (level == 0 || trail_.Info(l.Variable()).level == 0) continue; ++num_true; } else if (assignment_.LiteralIsFalse(l)) { - if (level == 0 || trail_.Info(l.Variable()).level == 0) { - is_always_false = true; - break; - } ++num_false; } temp_literals_.push_back(l); } gtl::STLSortAndRemoveDuplicates(&temp_literals_); - // Return special indices if never/always enforced. - if (is_always_false) { - if (callback != nullptr) - callback(EnforcementId(-1), EnforcementStatus::IS_FALSE); - return EnforcementId(-1); - } + // Return special index if always enforced. if (temp_literals_.empty()) { if (callback != nullptr) callback(EnforcementId(-1), EnforcementStatus::IS_ENFORCED); @@ -228,6 +218,18 @@ EnforcementId EnforcementPropagator::Register( return id; } +EnforcementId EnforcementPropagator::Register( + absl::Span enforcement_literals, + GenericLiteralWatcher* watcher, int literal_watcher_id) { + return Register(enforcement_literals, + [=](EnforcementId, EnforcementStatus status) { + if (status == EnforcementStatus::CAN_PROPAGATE || + status == EnforcementStatus::IS_ENFORCED) { + watcher->CallOnNextPropagate(literal_watcher_id); + } + }); +} + // Add the enforcement reason to the given vector. void EnforcementPropagator::AddEnforcementReason( EnforcementId id, std::vector* reason) const { @@ -266,6 +268,21 @@ bool EnforcementPropagator::PropagateWhenFalse( return true; } +bool EnforcementPropagator::SafeEnqueue( + EnforcementId id, IntegerLiteral i_lit, + absl::Span integer_reason) { + temp_reason_.clear(); + AddEnforcementReason(id, &temp_reason_); + return integer_trail_->SafeEnqueue(i_lit, temp_reason_, integer_reason); +} + +bool EnforcementPropagator::ReportConflict( + EnforcementId id, absl::Span integer_reason) { + temp_reason_.clear(); + AddEnforcementReason(id, &temp_reason_); + return integer_trail_->ReportConflict(temp_reason_, integer_reason); +} + absl::Span EnforcementPropagator::GetSpan(EnforcementId id) { if (id < 0) return {}; DCHECK_LE(id + 1, starts_.size()); @@ -358,31 +375,20 @@ EnforcementStatus EnforcementPropagator::DebugStatus(EnforcementId id) { BooleanXorPropagator::BooleanXorPropagator( const std::vector& enforcement_literals, - const std::vector& literals, bool value, Trail* trail, - IntegerTrail* integer_trail, EnforcementPropagator* enforcement_propagator) + const std::vector& literals, bool value, Model* model) : literals_(literals), value_(value), - trail_(trail), - integer_trail_(integer_trail), - enforcement_propagator_(enforcement_propagator) { - enforcement_id_ = enforcement_propagator->Register( - enforcement_literals, [this](EnforcementId id, EnforcementStatus status) { - // We cannot call Propagate() because enforcement_id_ is not - // set yet, and because Register() can call this callback - // before returning. - Propagate(id, status); - }); + trail_(model->GetOrCreate()), + integer_trail_(model->GetOrCreate()), + enforcement_propagator_(model->GetOrCreate()) { + GenericLiteralWatcher* watcher = model->GetOrCreate(); + enforcement_id_ = enforcement_propagator_->Register( + enforcement_literals, watcher, RegisterWith(watcher)); } bool BooleanXorPropagator::Propagate() { const EnforcementStatus status = - enforcement_id_ < 0 ? EnforcementStatus::IS_ENFORCED - : enforcement_propagator_->Status(enforcement_id_); - return Propagate(enforcement_id_, status); -} - -bool BooleanXorPropagator::Propagate(EnforcementId id, - EnforcementStatus status) { + enforcement_propagator_->Status(enforcement_id_); if (status == EnforcementStatus::IS_FALSE || status == EnforcementStatus::CANNOT_PROPAGATE) { return true; @@ -406,7 +412,8 @@ bool BooleanXorPropagator::Propagate(EnforcementId id, // Propagates? if (status == EnforcementStatus::IS_ENFORCED && unassigned_index != -1) { literal_reason_.clear(); - enforcement_propagator_->AddEnforcementReason(id, &literal_reason_); + enforcement_propagator_->AddEnforcementReason(enforcement_id_, + &literal_reason_); for (int i = 0; i < literals_.size(); ++i) { if (i == unassigned_index) continue; const Literal l = literals_[i]; @@ -420,7 +427,8 @@ bool BooleanXorPropagator::Propagate(EnforcementId id, } if (status == EnforcementStatus::CAN_PROPAGATE && unassigned_index == -1 && sum != value_) { - return enforcement_propagator_->PropagateWhenFalse(id, literals_, + return enforcement_propagator_->PropagateWhenFalse(enforcement_id_, + literals_, /*integer_reason=*/{}); } if (status != EnforcementStatus::IS_ENFORCED || unassigned_index != -1) { @@ -433,7 +441,7 @@ bool BooleanXorPropagator::Propagate(EnforcementId id, // Conflict. std::vector* conflict = trail_->MutableConflict(); conflict->clear(); - enforcement_propagator_->AddEnforcementReason(id, conflict); + enforcement_propagator_->AddEnforcementReason(enforcement_id_, conflict); for (const Literal& l : literals_) { conflict->push_back(trail_->Assignment().LiteralIsFalse(l) ? l : l.Negated()); @@ -441,12 +449,13 @@ bool BooleanXorPropagator::Propagate(EnforcementId id, return false; } -void BooleanXorPropagator::RegisterWith(GenericLiteralWatcher* watcher) { +int BooleanXorPropagator::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); for (const Literal& l : literals_) { watcher->WatchLiteral(l, id); watcher->WatchLiteral(l.Negated(), id); } + return id; } GreaterThanAtLeastOneOfPropagator::GreaterThanAtLeastOneOfPropagator( diff --git a/ortools/sat/cp_constraints.h b/ortools/sat/cp_constraints.h index 2f2f188b39..49d6394dfa 100644 --- a/ortools/sat/cp_constraints.h +++ b/ortools/sat/cp_constraints.h @@ -72,6 +72,13 @@ class EnforcementPropagator : public SatPropagator { absl::Span enforcement, std::function callback = nullptr); + // Calls `Register` with a callback calling + // `watcher->CallOnNextPropagate(literal_watcher_id)` if a propagation might + // be possible. + EnforcementId Register(absl::Span enforcement_literals, + GenericLiteralWatcher* watcher, + int literal_watcher_id); + // Add the enforcement reason to the given vector. void AddEnforcementReason(EnforcementId id, std::vector* reason) const; @@ -82,7 +89,17 @@ class EnforcementPropagator : public SatPropagator { EnforcementId id, absl::Span literal_reason, absl::Span integer_reason); - EnforcementStatus Status(EnforcementId id) const { return statuses_[id]; } + ABSL_MUST_USE_RESULT bool SafeEnqueue( + EnforcementId id, IntegerLiteral i_lit, + absl::Span integer_reason); + + bool ReportConflict(EnforcementId id, + absl::Span integer_reason); + + EnforcementStatus Status(EnforcementId id) const { + if (id < 0) return EnforcementStatus::IS_ENFORCED; + return statuses_[id]; + } // Recompute the status from the current assignment. // This should only used in DCHECK(). @@ -144,18 +161,16 @@ class BooleanXorPropagator : public PropagatorInterface { public: BooleanXorPropagator(const std::vector& enforcement_literals, const std::vector& literals, bool value, - Trail* trail, IntegerTrail* integer_trail, - EnforcementPropagator* enforcement_propagator); + Model* model); // This type is neither copyable nor movable. BooleanXorPropagator(const BooleanXorPropagator&) = delete; BooleanXorPropagator& operator=(const BooleanXorPropagator&) = delete; bool Propagate() final; - void RegisterWith(GenericLiteralWatcher* watcher); private: - bool Propagate(EnforcementId id, EnforcementStatus status); + int RegisterWith(GenericLiteralWatcher* watcher); const std::vector literals_; const bool value_; @@ -230,15 +245,8 @@ inline std::function LiteralXorIs( const std::vector& enforcement_literals, const std::vector& literals, bool value) { return [=](Model* model) { - Trail* trail = model->GetOrCreate(); - IntegerTrail* integer_trail = model->GetOrCreate(); - EnforcementPropagator* enforcement_propagator = - model->GetOrCreate(); - BooleanXorPropagator* constraint = - new BooleanXorPropagator(enforcement_literals, literals, value, trail, - integer_trail, enforcement_propagator); - constraint->RegisterWith(model->GetOrCreate()); - model->TakeOwnership(constraint); + model->TakeOwnership( + new BooleanXorPropagator(enforcement_literals, literals, value, model)); }; } diff --git a/ortools/sat/cp_constraints_test.cc b/ortools/sat/cp_constraints_test.cc index e633f6e20c..88efa3d167 100644 --- a/ortools/sat/cp_constraints_test.cc +++ b/ortools/sat/cp_constraints_test.cc @@ -163,8 +163,7 @@ TEST(LiteralXorIsTest, OneEnforcedVariable) { model.Add(LiteralXorIs({Literal(e, true)}, {}, true)); model.Add(LiteralXorIs({Literal(f, false)}, {}, true)); SatSolver* solver = model.GetOrCreate(); - // Propagation happens when the constraints are registered with the - // EnforcementPropagator. + EXPECT_TRUE(solver->Propagate()); EXPECT_TRUE(solver->Assignment().LiteralIsFalse(Literal(e, true))); EXPECT_TRUE(solver->Assignment().LiteralIsFalse(Literal(f, false))); } diff --git a/ortools/sat/cp_model.proto b/ortools/sat/cp_model.proto index c68b012eaa..92e0d2fa63 100644 --- a/ortools/sat/cp_model.proto +++ b/ortools/sat/cp_model.proto @@ -331,8 +331,8 @@ message ConstraintProto { // l). // // Important: as of July 2025, only a few constraint support enforcement: - // - bool_or, bool_and, at_most_one, exactly_one, bool_xor, int_prod, linear, - // table: fully supported. + // - bool_or, bool_and, at_most_one, exactly_one, bool_xor, int_div, int_mod, + // int_prod, linear, table: fully supported. // - interval: only support a single enforcement literal. // - other: no support (but can be added on a per-demand basis). repeated int32 enforcement_literal = 2; diff --git a/ortools/sat/cp_model_checker.cc b/ortools/sat/cp_model_checker.cc index ffdf285bc8..88ccd6485f 100644 --- a/ortools/sat/cp_model_checker.cc +++ b/ortools/sat/cp_model_checker.cc @@ -1171,9 +1171,11 @@ std::string ValidateCpModel(const CpModelProto& model, bool after_presolve) { RETURN_IF_NOT_EMPTY(ValidateIntProdConstraint(model, ct)); break; case ConstraintProto::ConstraintCase::kIntDiv: + support_enforcement = true; RETURN_IF_NOT_EMPTY(ValidateIntDivConstraint(model, ct)); break; case ConstraintProto::ConstraintCase::kIntMod: + support_enforcement = true; RETURN_IF_NOT_EMPTY(ValidateIntModConstraint(model, ct)); break; case ConstraintProto::ConstraintCase::kInverse: diff --git a/ortools/sat/cp_model_copy.cc b/ortools/sat/cp_model_copy.cc index cc6b31b704..590410ffcd 100644 --- a/ortools/sat/cp_model_copy.cc +++ b/ortools/sat/cp_model_copy.cc @@ -789,6 +789,7 @@ bool ModelCopy::CopyIntDiv(const ConstraintProto& ct, bool ignore_names) { if (!ignore_names) { new_ct->set_name(ct.name()); } + FinishEnforcementCopy(new_ct); for (const LinearExpressionProto& expr : ct.int_div().exprs()) { CopyLinearExpression(expr, new_ct->mutable_int_div()->add_exprs()); } @@ -802,6 +803,7 @@ bool ModelCopy::CopyIntMod(const ConstraintProto& ct, bool ignore_names) { if (!ignore_names) { new_ct->set_name(ct.name()); } + FinishEnforcementCopy(new_ct); for (const LinearExpressionProto& expr : ct.int_mod().exprs()) { CopyLinearExpression(expr, new_ct->mutable_int_mod()->add_exprs()); } diff --git a/ortools/sat/cp_model_loader.cc b/ortools/sat/cp_model_loader.cc index 3310aaa9fa..fcecb82670 100644 --- a/ortools/sat/cp_model_loader.cc +++ b/ortools/sat/cp_model_loader.cc @@ -1531,7 +1531,7 @@ void LoadAlwaysFalseConstraint(const ConstraintProto& ct, Model* m) { void LoadIntProdConstraint(const ConstraintProto& ct, Model* m) { auto* mapping = m->GetOrCreate(); - std::vector enforcement_literals = + const std::vector enforcement_literals = mapping->Literals(ct.enforcement_literal()); const AffineExpression prod = mapping->Affine(ct.int_prod().target()); std::vector terms; @@ -1574,11 +1574,14 @@ void LoadIntProdConstraint(const ConstraintProto& ct, Model* m) { void LoadIntDivConstraint(const ConstraintProto& ct, Model* m) { auto* integer_trail = m->GetOrCreate(); auto* mapping = m->GetOrCreate(); + const std::vector enforcement_literals = + mapping->Literals(ct.enforcement_literal()); const AffineExpression div = mapping->Affine(ct.int_div().target()); const AffineExpression num = mapping->Affine(ct.int_div().exprs(0)); const AffineExpression denom = mapping->Affine(ct.int_div().exprs(1)); if (integer_trail->IsFixed(denom)) { - m->Add(FixedDivisionConstraint(num, integer_trail->FixedValue(denom), div)); + m->Add(FixedDivisionConstraint(enforcement_literals, num, + integer_trail->FixedValue(denom), div)); } else { if (VLOG_IS_ON(1)) { LinearConstraintBuilder builder(m); @@ -1587,7 +1590,7 @@ void LoadIntDivConstraint(const ConstraintProto& ct, Model* m) { VLOG(1) << "Division " << ct << " can be linearized"; } } - m->Add(DivisionConstraint(num, denom, div)); + m->Add(DivisionConstraint(enforcement_literals, num, denom, div)); } } @@ -1595,12 +1598,15 @@ void LoadIntModConstraint(const ConstraintProto& ct, Model* m) { auto* mapping = m->GetOrCreate(); auto* integer_trail = m->GetOrCreate(); + const std::vector enforcement_literals = + mapping->Literals(ct.enforcement_literal()); const AffineExpression target = mapping->Affine(ct.int_mod().target()); const AffineExpression expr = mapping->Affine(ct.int_mod().exprs(0)); const AffineExpression mod = mapping->Affine(ct.int_mod().exprs(1)); CHECK(integer_trail->IsFixed(mod)); const IntegerValue fixed_modulo = integer_trail->FixedValue(mod); - m->Add(FixedModuloConstraint(expr, fixed_modulo, target)); + m->Add( + FixedModuloConstraint(enforcement_literals, expr, fixed_modulo, target)); } void LoadLinMaxConstraint(const ConstraintProto& ct, Model* m) { diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index 44f45f0eb1..8b3d417908 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -1980,6 +1980,8 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { bool CpModelPresolver::PresolveIntDiv(int c, ConstraintProto* ct) { if (context_->ModelIsUnsat()) return false; + // TODO(user): add support for this case. + if (HasEnforcementLiteral(*ct)) return false; const LinearExpressionProto target = ct->int_div().target(); const LinearExpressionProto expr = ct->int_div().exprs(0); @@ -2123,6 +2125,8 @@ bool CpModelPresolver::PresolveIntDiv(int c, ConstraintProto* ct) { bool CpModelPresolver::PresolveIntMod(int c, ConstraintProto* ct) { if (context_->ModelIsUnsat()) return false; + // TODO(user): add support for this case. + if (HasEnforcementLiteral(*ct)) return false; // TODO(user): Presolve f(X) = g(X) % fixed_mod. const LinearExpressionProto target = ct->int_mod().target(); diff --git a/ortools/sat/cp_model_solver_test.cc b/ortools/sat/cp_model_solver_test.cc index 9b095d0e33..a1ba69d2f1 100644 --- a/ortools/sat/cp_model_solver_test.cc +++ b/ortools/sat/cp_model_solver_test.cc @@ -804,8 +804,50 @@ TEST(SolveCpModelTest, BoolXorWithEnforcementLiteralPresolved) { EXPECT_THAT(response.solution(), ::testing::ElementsAre(1, 0)); } +TEST(SolveCpModelTest, IntDivWithEnforcementLiteral) { + // not(b) => 7x / 3y = 17, x in [0, 10], y in [1, 2] + CpModelProto model_proto = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 1, 2 ] } + constraints { + enforcement_literal: -1 + int_prod { + target { offset: 17 } + exprs { vars: 1 coeffs: 7 } + exprs { vars: 2 coeffs: 3 } + } + })pb"); + Model model; + model.Add(NewSatParameters("cp_model_presolve:false")); + const CpSolverResponse response = SolveCpModel(model_proto, &model); + EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL); + EXPECT_EQ(response.solution(0), 1); +} + +TEST(SolveCpModelTest, IntModWithEnforcementLiteral) { + // not(b) => x % 10 = y, x in [8, 11], y in [2, 7] + CpModelProto model_proto = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 2, 7 ] } + variables { domain: [ 8, 11 ] } + constraints { + enforcement_literal: -1 + int_prod { + target { vars: 1 coeffs: 1 } + exprs { vars: 2 coeffs: 1 } + exprs { offset: 10 } + } + })pb"); + Model model; + model.Add(NewSatParameters("cp_model_presolve:false")); + const CpSolverResponse response = SolveCpModel(model_proto, &model); + EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL); + EXPECT_EQ(response.solution(0), 1); +} + TEST(SolveCpModelTest, IntProdWithEnforcementLiteral) { - // b => x.y.z = 17 + // not(b) => x.y.z = 17 CpModelProto model_proto = ParseTestProto(R"pb( variables { domain: [ 0, 1 ] } variables { domain: [ 2, 20 ] } @@ -828,7 +870,7 @@ TEST(SolveCpModelTest, IntProdWithEnforcementLiteral) { } TEST(SolveCpModelTest, SquareIntProdWithEnforcementLiteral) { - // b => x.y.z = 17 + // not(b) => x.y.z = 17 CpModelProto model_proto = ParseTestProto(R"pb( variables { domain: [ 0, 1 ] } variables { domain: [ 2, 20 ] } diff --git a/ortools/sat/integer_expr.cc b/ortools/sat/integer_expr.cc index 2393325d84..e7341a53e8 100644 --- a/ortools/sat/integer_expr.cc +++ b/ortools/sat/integer_expr.cc @@ -851,49 +851,45 @@ void LinMinPropagator::RegisterWith(GenericLiteralWatcher* watcher) { ProductPropagator::ProductPropagator( absl::Span enforcement_literals, AffineExpression a, - AffineExpression b, AffineExpression p, IntegerTrail* integer_trail, - EnforcementPropagator* enforcement_propagator) + AffineExpression b, AffineExpression p, Model* model) : a_(a), b_(b), p_(p), - integer_trail_(integer_trail), - enforcement_propagator_(enforcement_propagator) { - enforcement_id_ = enforcement_propagator->Register( - enforcement_literals, [this](EnforcementId id, EnforcementStatus) { - // Register() can call this callback before returning, and Propagate() - // needs the enforcement id to be set. - enforcement_id_ = id; - Propagate(); - }); + integer_trail_(*model->GetOrCreate()), + enforcement_propagator_(*model->GetOrCreate()) { + GenericLiteralWatcher* watcher = model->GetOrCreate(); + enforcement_id_ = enforcement_propagator_.Register( + enforcement_literals, watcher, RegisterWith(watcher)); } // We want all affine expression to be either non-negative or across zero. bool ProductPropagator::CanonicalizeCases() { - if (integer_trail_->UpperBound(a_) <= 0) { + if (integer_trail_.UpperBound(a_) <= 0) { a_ = a_.Negated(); p_ = p_.Negated(); } - if (integer_trail_->UpperBound(b_) <= 0) { + if (integer_trail_.UpperBound(b_) <= 0) { b_ = b_.Negated(); p_ = p_.Negated(); } // If both a and b positive, p must be too. - if (integer_trail_->LowerBound(a_) >= 0 && - integer_trail_->LowerBound(b_) >= 0) { - return SafeEnqueue(p_.GreaterOrEqual(0), - {a_.GreaterOrEqual(0), b_.GreaterOrEqual(0)}); + if (integer_trail_.LowerBound(a_) >= 0 && + integer_trail_.LowerBound(b_) >= 0) { + return enforcement_propagator_.SafeEnqueue( + enforcement_id_, p_.GreaterOrEqual(0), + {a_.GreaterOrEqual(0), b_.GreaterOrEqual(0)}); } // Otherwise, make sure p is non-negative or across zero. - if (integer_trail_->UpperBound(p_) <= 0) { - if (integer_trail_->LowerBound(a_) < 0) { - DCHECK_GT(integer_trail_->UpperBound(a_), 0); + if (integer_trail_.UpperBound(p_) <= 0) { + if (integer_trail_.LowerBound(a_) < 0) { + DCHECK_GT(integer_trail_.UpperBound(a_), 0); a_ = a_.Negated(); p_ = p_.Negated(); } else { - DCHECK_LT(integer_trail_->LowerBound(b_), 0); - DCHECK_GT(integer_trail_->UpperBound(b_), 0); + DCHECK_LT(integer_trail_.LowerBound(b_), 0); + DCHECK_GT(integer_trail_.UpperBound(b_), 0); b_ = b_.Negated(); p_ = p_.Negated(); } @@ -910,36 +906,38 @@ bool ProductPropagator::CanonicalizeCases() { // smallest domain size between a or b). bool ProductPropagator::PropagateWhenAllNonNegative() { { - const IntegerValue max_a = integer_trail_->UpperBound(a_); - const IntegerValue max_b = integer_trail_->UpperBound(b_); + const IntegerValue max_a = integer_trail_.UpperBound(a_); + const IntegerValue max_b = integer_trail_.UpperBound(b_); const IntegerValue new_max = CapProdI(max_a, max_b); - if (new_max < integer_trail_->UpperBound(p_)) { - if (!SafeEnqueue(p_.LowerOrEqual(new_max), - {integer_trail_->UpperBoundAsLiteral(a_), - integer_trail_->UpperBoundAsLiteral(b_), - a_.GreaterOrEqual(0), b_.GreaterOrEqual(0)})) { + if (new_max < integer_trail_.UpperBound(p_)) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, p_.LowerOrEqual(new_max), + {integer_trail_.UpperBoundAsLiteral(a_), + integer_trail_.UpperBoundAsLiteral(b_), a_.GreaterOrEqual(0), + b_.GreaterOrEqual(0)})) { return false; } } } { - const IntegerValue min_a = integer_trail_->LowerBound(a_); - const IntegerValue min_b = integer_trail_->LowerBound(b_); + const IntegerValue min_a = integer_trail_.LowerBound(a_); + const IntegerValue min_b = integer_trail_.LowerBound(b_); const IntegerValue new_min = CapProdI(min_a, min_b); // The conflict test is needed because when new_min is large, we could // have an overflow in p_.GreaterOrEqual(new_min); - if (new_min > integer_trail_->UpperBound(p_)) { - return integer_trail_->ReportConflict( - {integer_trail_->UpperBoundAsLiteral(p_), - integer_trail_->LowerBoundAsLiteral(a_), - integer_trail_->LowerBoundAsLiteral(b_)}); + if (new_min > integer_trail_.UpperBound(p_)) { + return enforcement_propagator_.ReportConflict( + enforcement_id_, {integer_trail_.UpperBoundAsLiteral(p_), + integer_trail_.LowerBoundAsLiteral(a_), + integer_trail_.LowerBoundAsLiteral(b_)}); } - if (new_min > integer_trail_->LowerBound(p_)) { - if (!SafeEnqueue(p_.GreaterOrEqual(new_min), - {integer_trail_->LowerBoundAsLiteral(a_), - integer_trail_->LowerBoundAsLiteral(b_)})) { + if (new_min > integer_trail_.LowerBound(p_)) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, p_.GreaterOrEqual(new_min), + {integer_trail_.LowerBoundAsLiteral(a_), + integer_trail_.LowerBoundAsLiteral(b_)})) { return false; } } @@ -948,23 +946,23 @@ bool ProductPropagator::PropagateWhenAllNonNegative() { for (int i = 0; i < 2; ++i) { const AffineExpression a = i == 0 ? a_ : b_; const AffineExpression b = i == 0 ? b_ : a_; - const IntegerValue max_a = integer_trail_->UpperBound(a); - const IntegerValue min_b = integer_trail_->LowerBound(b); - const IntegerValue min_p = integer_trail_->LowerBound(p_); - const IntegerValue max_p = integer_trail_->UpperBound(p_); + const IntegerValue max_a = integer_trail_.UpperBound(a); + const IntegerValue min_b = integer_trail_.LowerBound(b); + const IntegerValue min_p = integer_trail_.LowerBound(p_); + const IntegerValue max_p = integer_trail_.UpperBound(p_); const IntegerValue prod = CapProdI(max_a, min_b); if (prod > max_p) { - if (!SafeEnqueue(a.LowerOrEqual(FloorRatio(max_p, min_b)), - {integer_trail_->LowerBoundAsLiteral(b), - integer_trail_->UpperBoundAsLiteral(p_), - p_.GreaterOrEqual(0)})) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, a.LowerOrEqual(FloorRatio(max_p, min_b)), + {integer_trail_.LowerBoundAsLiteral(b), + integer_trail_.UpperBoundAsLiteral(p_), p_.GreaterOrEqual(0)})) { return false; } } else if (prod < min_p && max_a != 0) { - if (!SafeEnqueue( - b.GreaterOrEqual(CeilRatio(min_p, max_a)), - {integer_trail_->UpperBoundAsLiteral(a), - integer_trail_->LowerBoundAsLiteral(p_), a.GreaterOrEqual(0)})) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, b.GreaterOrEqual(CeilRatio(min_p, max_a)), + {integer_trail_.UpperBoundAsLiteral(a), + integer_trail_.LowerBoundAsLiteral(p_), a.GreaterOrEqual(0)})) { return false; } } @@ -981,14 +979,15 @@ bool ProductPropagator::PropagateMaxOnPositiveProduct(AffineExpression a, AffineExpression b, IntegerValue min_p, IntegerValue max_p) { - const IntegerValue max_a = integer_trail_->UpperBound(a); + const IntegerValue max_a = integer_trail_.UpperBound(a); if (max_a <= 0) return true; DCHECK_GT(min_p, 0); if (max_a >= min_p) { if (max_p < max_a) { - if (!SafeEnqueue(a.LowerOrEqual(max_p), - {p_.LowerOrEqual(max_p), p_.GreaterOrEqual(1)})) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, a.LowerOrEqual(max_p), + {p_.LowerOrEqual(max_p), p_.GreaterOrEqual(1)})) { return false; } } @@ -996,22 +995,24 @@ bool ProductPropagator::PropagateMaxOnPositiveProduct(AffineExpression a, } const IntegerValue min_pos_b = CeilRatio(min_p, max_a); - if (min_pos_b > integer_trail_->UpperBound(b)) { - if (!SafeEnqueue(b.LowerOrEqual(0), - {integer_trail_->LowerBoundAsLiteral(p_), - integer_trail_->UpperBoundAsLiteral(a), - integer_trail_->UpperBoundAsLiteral(b)})) { + if (min_pos_b > integer_trail_.UpperBound(b)) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, b.LowerOrEqual(0), + {integer_trail_.LowerBoundAsLiteral(p_), + integer_trail_.UpperBoundAsLiteral(a), + integer_trail_.UpperBoundAsLiteral(b)})) { return false; } return true; } const IntegerValue new_max_a = FloorRatio(max_p, min_pos_b); - if (new_max_a < integer_trail_->UpperBound(a)) { - if (!SafeEnqueue(a.LowerOrEqual(new_max_a), - {integer_trail_->LowerBoundAsLiteral(p_), - integer_trail_->UpperBoundAsLiteral(a), - integer_trail_->UpperBoundAsLiteral(p_)})) { + if (new_max_a < integer_trail_.UpperBound(a)) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, a.LowerOrEqual(new_max_a), + {integer_trail_.LowerBoundAsLiteral(p_), + integer_trail_.UpperBoundAsLiteral(a), + integer_trail_.UpperBoundAsLiteral(p_)})) { return false; } } @@ -1020,19 +1021,14 @@ bool ProductPropagator::PropagateMaxOnPositiveProduct(AffineExpression a, bool ProductPropagator::Propagate() { const EnforcementStatus status = - enforcement_id_ < 0 ? EnforcementStatus::IS_ENFORCED - : enforcement_propagator_->Status(enforcement_id_); - if (status == EnforcementStatus::IS_FALSE || - status == EnforcementStatus::CANNOT_PROPAGATE) { - return true; - } + enforcement_propagator_.Status(enforcement_id_); if (status == EnforcementStatus::CAN_PROPAGATE) { - const int64_t min_a = integer_trail_->LowerBound(a_).value(); - const int64_t max_a = integer_trail_->UpperBound(a_).value(); - const int64_t min_b = integer_trail_->LowerBound(b_).value(); - const int64_t max_b = integer_trail_->UpperBound(b_).value(); - const int64_t min_p = integer_trail_->LowerBound(p_).value(); - const int64_t max_p = integer_trail_->UpperBound(p_).value(); + const int64_t min_a = integer_trail_.LowerBound(a_).value(); + const int64_t max_a = integer_trail_.UpperBound(a_).value(); + const int64_t min_b = integer_trail_.LowerBound(b_).value(); + const int64_t max_b = integer_trail_.UpperBound(b_).value(); + const int64_t min_p = integer_trail_.LowerBound(p_).value(); + const int64_t max_p = integer_trail_.UpperBound(p_).value(); const int64_t p1 = CapProdI(max_a, max_b).value(); const int64_t p2 = CapProdI(max_a, min_b).value(); const int64_t p3 = CapProdI(min_a, max_b).value(); @@ -1042,15 +1038,17 @@ bool ProductPropagator::Propagate() { // If the bounds of a * b and p are disjoint, the enforcement must be false. // TODO(user): relax the reason in a better way. if (min_ab > max_p) { - return enforcement_propagator_->PropagateWhenFalse( - enforcement_id_, /*literal_reason=*/{}, + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, + /*literal_reason=*/{}, {a_.GreaterOrEqual(min_a), a_.LowerOrEqual(max_a), b_.GreaterOrEqual(min_b), b_.LowerOrEqual(max_b), p_.LowerOrEqual(max_p)}); } if (min_p > max_ab) { - return enforcement_propagator_->PropagateWhenFalse( - enforcement_id_, /*literal_reason=*/{}, + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, + /*literal_reason=*/{}, {a_.GreaterOrEqual(min_a), a_.LowerOrEqual(max_a), b_.GreaterOrEqual(min_b), b_.LowerOrEqual(max_b), p_.GreaterOrEqual(min_p)}); @@ -1058,16 +1056,17 @@ bool ProductPropagator::Propagate() { // Otherwise we cannot propagate anything since the enforcement is unknown. return true; } - DCHECK_EQ(status, EnforcementStatus::IS_ENFORCED); + + if (status != EnforcementStatus::IS_ENFORCED) return true; if (!CanonicalizeCases()) return false; // In the most common case, we use better reasons even though the code // below would propagate the same. - const int64_t min_a = integer_trail_->LowerBound(a_).value(); - const int64_t min_b = integer_trail_->LowerBound(b_).value(); + const int64_t min_a = integer_trail_.LowerBound(a_).value(); + const int64_t min_b = integer_trail_.LowerBound(b_).value(); if (min_a >= 0 && min_b >= 0) { // This was done by CanonicalizeCases(). - DCHECK_GE(integer_trail_->LowerBound(p_), 0); + DCHECK_GE(integer_trail_.LowerBound(p_), 0); return PropagateWhenAllNonNegative(); } @@ -1077,61 +1076,67 @@ bool ProductPropagator::Propagate() { // // TODO(user): In the reasons, including all 4 bounds is always correct, but // we might be able to relax some of them. - const IntegerValue max_a = integer_trail_->UpperBound(a_); - const IntegerValue max_b = integer_trail_->UpperBound(b_); + const IntegerValue max_a = integer_trail_.UpperBound(a_); + const IntegerValue max_b = integer_trail_.UpperBound(b_); const IntegerValue p1 = CapProdI(max_a, max_b); const IntegerValue p2 = CapProdI(max_a, min_b); const IntegerValue p3 = CapProdI(min_a, max_b); const IntegerValue p4 = CapProdI(min_a, min_b); const IntegerValue new_max_p = std::max({p1, p2, p3, p4}); - if (new_max_p < integer_trail_->UpperBound(p_)) { - if (!SafeEnqueue(p_.LowerOrEqual(new_max_p), - {integer_trail_->LowerBoundAsLiteral(a_), - integer_trail_->LowerBoundAsLiteral(b_), - integer_trail_->UpperBoundAsLiteral(a_), - integer_trail_->UpperBoundAsLiteral(b_)})) { + if (new_max_p < integer_trail_.UpperBound(p_)) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, p_.LowerOrEqual(new_max_p), + {integer_trail_.LowerBoundAsLiteral(a_), + integer_trail_.LowerBoundAsLiteral(b_), + integer_trail_.UpperBoundAsLiteral(a_), + integer_trail_.UpperBoundAsLiteral(b_)})) { return false; } } const IntegerValue new_min_p = std::min({p1, p2, p3, p4}); - if (new_min_p > integer_trail_->LowerBound(p_)) { - if (!SafeEnqueue(p_.GreaterOrEqual(new_min_p), - {integer_trail_->LowerBoundAsLiteral(a_), - integer_trail_->LowerBoundAsLiteral(b_), - integer_trail_->UpperBoundAsLiteral(a_), - integer_trail_->UpperBoundAsLiteral(b_)})) { + if (new_min_p > integer_trail_.LowerBound(p_)) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, p_.GreaterOrEqual(new_min_p), + {integer_trail_.LowerBoundAsLiteral(a_), + integer_trail_.LowerBoundAsLiteral(b_), + integer_trail_.UpperBoundAsLiteral(a_), + integer_trail_.UpperBoundAsLiteral(b_)})) { return false; } } // Lets propagate on a and b. - const IntegerValue min_p = integer_trail_->LowerBound(p_); - const IntegerValue max_p = integer_trail_->UpperBound(p_); + const IntegerValue min_p = integer_trail_.LowerBound(p_); + const IntegerValue max_p = integer_trail_.UpperBound(p_); // We need a bit more propagation to avoid bad cases below. const bool zero_is_possible = min_p <= 0; if (!zero_is_possible) { - if (integer_trail_->LowerBound(a_) == 0) { - if (!SafeEnqueue(a_.GreaterOrEqual(1), - {p_.GreaterOrEqual(1), a_.GreaterOrEqual(0)})) { + if (integer_trail_.LowerBound(a_) == 0) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, a_.GreaterOrEqual(1), + {p_.GreaterOrEqual(1), a_.GreaterOrEqual(0)})) { return false; } } - if (integer_trail_->LowerBound(b_) == 0) { - if (!SafeEnqueue(b_.GreaterOrEqual(1), - {p_.GreaterOrEqual(1), b_.GreaterOrEqual(0)})) { + if (integer_trail_.LowerBound(b_) == 0) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, b_.GreaterOrEqual(1), + {p_.GreaterOrEqual(1), b_.GreaterOrEqual(0)})) { return false; } } - if (integer_trail_->LowerBound(a_) >= 0 && - integer_trail_->LowerBound(b_) <= 0) { - return SafeEnqueue(b_.GreaterOrEqual(1), - {a_.GreaterOrEqual(0), p_.GreaterOrEqual(1)}); + if (integer_trail_.LowerBound(a_) >= 0 && + integer_trail_.LowerBound(b_) <= 0) { + return enforcement_propagator_.SafeEnqueue( + enforcement_id_, b_.GreaterOrEqual(1), + {a_.GreaterOrEqual(0), p_.GreaterOrEqual(1)}); } - if (integer_trail_->LowerBound(b_) >= 0 && - integer_trail_->LowerBound(a_) <= 0) { - return SafeEnqueue(a_.GreaterOrEqual(1), - {b_.GreaterOrEqual(0), p_.GreaterOrEqual(1)}); + if (integer_trail_.LowerBound(b_) >= 0 && + integer_trail_.LowerBound(a_) <= 0) { + return enforcement_propagator_.SafeEnqueue( + enforcement_id_, a_.GreaterOrEqual(1), + {b_.GreaterOrEqual(0), p_.GreaterOrEqual(1)}); } } @@ -1139,8 +1144,8 @@ bool ProductPropagator::Propagate() { // p = a * b, what is the min/max of a? const AffineExpression a = i == 0 ? a_ : b_; const AffineExpression b = i == 0 ? b_ : a_; - const IntegerValue max_b = integer_trail_->UpperBound(b); - const IntegerValue min_b = integer_trail_->LowerBound(b); + const IntegerValue max_b = integer_trail_.UpperBound(b); + const IntegerValue min_b = integer_trail_.LowerBound(b); // If the domain of b contain zero, we can't propagate anything on a. // Because of CanonicalizeCases(), we just deal with min_b > 0 here. @@ -1166,28 +1171,32 @@ bool ProductPropagator::Propagate() { // If it does, we should reach the fixed point on the next iteration. if (min_b <= 0) continue; if (min_p >= 0) { - return SafeEnqueue(a.GreaterOrEqual(0), - {p_.GreaterOrEqual(0), b.GreaterOrEqual(1)}); + return enforcement_propagator_.SafeEnqueue( + enforcement_id_, a.GreaterOrEqual(0), + {p_.GreaterOrEqual(0), b.GreaterOrEqual(1)}); } if (max_p <= 0) { - return SafeEnqueue(a.LowerOrEqual(0), - {p_.LowerOrEqual(0), b.GreaterOrEqual(1)}); + return enforcement_propagator_.SafeEnqueue( + enforcement_id_, a.LowerOrEqual(0), + {p_.LowerOrEqual(0), b.GreaterOrEqual(1)}); } // So min_b > 0 and p is across zero: min_p < 0 and max_p > 0. const IntegerValue new_max_a = FloorRatio(max_p, min_b); - if (new_max_a < integer_trail_->UpperBound(a)) { - if (!SafeEnqueue(a.LowerOrEqual(new_max_a), - {integer_trail_->UpperBoundAsLiteral(p_), - integer_trail_->LowerBoundAsLiteral(b)})) { + if (new_max_a < integer_trail_.UpperBound(a)) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, a.LowerOrEqual(new_max_a), + {integer_trail_.UpperBoundAsLiteral(p_), + integer_trail_.LowerBoundAsLiteral(b)})) { return false; } } const IntegerValue new_min_a = CeilRatio(min_p, min_b); - if (new_min_a > integer_trail_->LowerBound(a)) { - if (!SafeEnqueue(a.GreaterOrEqual(new_min_a), - {integer_trail_->LowerBoundAsLiteral(p_), - integer_trail_->LowerBoundAsLiteral(b)})) { + if (new_min_a > integer_trail_.LowerBound(a)) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, a.GreaterOrEqual(new_min_a), + {integer_trail_.LowerBoundAsLiteral(p_), + integer_trail_.LowerBoundAsLiteral(b)})) { return false; } } @@ -1196,108 +1205,84 @@ bool ProductPropagator::Propagate() { return true; } -void ProductPropagator::RegisterWith(GenericLiteralWatcher* watcher) { +int ProductPropagator::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); watcher->WatchAffineExpression(a_, id); watcher->WatchAffineExpression(b_, id); watcher->WatchAffineExpression(p_, id); watcher->NotifyThatPropagatorMayNotReachFixedPointInOnePass(id); -} - -bool ProductPropagator::SafeEnqueue( - IntegerLiteral i_lit, absl::Span integer_reason) { - tmp_literal_reason_.clear(); - enforcement_propagator_->AddEnforcementReason(enforcement_id_, - &tmp_literal_reason_); - return integer_trail_->SafeEnqueue(i_lit, tmp_literal_reason_, - integer_reason); + return id; } SquarePropagator::SquarePropagator( absl::Span enforcement_literals, AffineExpression x, - AffineExpression s, IntegerTrail* integer_trail, - EnforcementPropagator* enforcement_propagator) + AffineExpression s, Model* model) : x_(x), s_(s), - integer_trail_(integer_trail), - enforcement_propagator_(enforcement_propagator) { - CHECK_GE(integer_trail->LevelZeroLowerBound(x), 0); - enforcement_id_ = enforcement_propagator->Register( - enforcement_literals, [this](EnforcementId id, EnforcementStatus status) { - // We cannot call Propagate() because enforcement_id_ is not - // set yet, and because Register() can call this callback - // before returning. - Propagate(id, status); - }); -} - -bool SquarePropagator::Propagate() { - const EnforcementStatus status = - enforcement_id_ < 0 ? EnforcementStatus::IS_ENFORCED - : enforcement_propagator_->Status(enforcement_id_); - return Propagate(enforcement_id_, status); + integer_trail_(*model->GetOrCreate()), + enforcement_propagator_(*model->GetOrCreate()) { + GenericLiteralWatcher* watcher = model->GetOrCreate(); + enforcement_id_ = enforcement_propagator_.Register( + enforcement_literals, watcher, RegisterWith(watcher)); + CHECK_GE(integer_trail_.LevelZeroLowerBound(x), 0); } // Propagation from x to s: s in [min_x * min_x, max_x * max_x]. // Propagation from s to x: x in [ceil(sqrt(min_s)), floor(sqrt(max_s))]. -bool SquarePropagator::Propagate(EnforcementId id, EnforcementStatus status) { - if (status == EnforcementStatus::IS_FALSE || - status == EnforcementStatus::CANNOT_PROPAGATE) { - return true; - } - const IntegerValue min_x = integer_trail_->LowerBound(x_); - const IntegerValue min_s = integer_trail_->LowerBound(s_); +bool SquarePropagator::Propagate() { + const IntegerValue min_x = integer_trail_.LowerBound(x_); + const IntegerValue min_s = integer_trail_.LowerBound(s_); const IntegerValue min_x_square = CapProdI(min_x, min_x); - const IntegerValue max_x = integer_trail_->UpperBound(x_); - const IntegerValue max_s = integer_trail_->UpperBound(s_); + const IntegerValue max_x = integer_trail_.UpperBound(x_); + const IntegerValue max_s = integer_trail_.UpperBound(s_); const IntegerValue max_x_square = CapProdI(max_x, max_x); + + const EnforcementStatus status = + enforcement_propagator_.Status(enforcement_id_); if (status == EnforcementStatus::CAN_PROPAGATE) { // If the bounds of x * x and s are disjoint, the enforcement must be false. // TODO(user): relax the reason in a better way. if (min_x_square > max_s) { - return enforcement_propagator_->PropagateWhenFalse( - id, /*literal_reason=*/{}, + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, + /*literal_reason=*/{}, {x_.GreaterOrEqual(min_x), s_.LowerOrEqual(min_x - 1)}); } if (min_s > max_x_square) { - return enforcement_propagator_->PropagateWhenFalse( - id, /*literal_reason=*/{}, + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, + /*literal_reason=*/{}, {s_.GreaterOrEqual(min_s), x_.LowerOrEqual(min_s - 1)}); } // Otherwise we cannot propagate anything since the enforcement is unknown. return true; } - auto safe_enqueue = [this, id]( - IntegerLiteral i_lit, - absl::Span integer_reason) { - tmp_literal_reason_.clear(); - enforcement_propagator_->AddEnforcementReason(id, &tmp_literal_reason_); - return integer_trail_->SafeEnqueue(i_lit, tmp_literal_reason_, - integer_reason); - }; - DCHECK_EQ(status, EnforcementStatus::IS_ENFORCED); + if (status != EnforcementStatus::IS_ENFORCED) return true; if (min_x_square > min_s) { - if (!safe_enqueue(s_.GreaterOrEqual(min_x_square), - {x_.GreaterOrEqual(min_x)})) { + if (!enforcement_propagator_.SafeEnqueue(enforcement_id_, + s_.GreaterOrEqual(min_x_square), + {x_.GreaterOrEqual(min_x)})) { return false; } } else if (min_x_square < min_s) { const IntegerValue new_min(CeilSquareRoot(min_s.value())); - if (!safe_enqueue(x_.GreaterOrEqual(new_min), - {s_.GreaterOrEqual((new_min - 1) * (new_min - 1) + 1)})) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, x_.GreaterOrEqual(new_min), + {s_.GreaterOrEqual((new_min - 1) * (new_min - 1) + 1)})) { return false; } } if (max_x_square < max_s) { - if (!safe_enqueue(s_.LowerOrEqual(max_x_square), - {x_.LowerOrEqual(max_x)})) { + if (!enforcement_propagator_.SafeEnqueue(enforcement_id_, + s_.LowerOrEqual(max_x_square), + {x_.LowerOrEqual(max_x)})) { return false; } } else if (max_x_square > max_s) { const IntegerValue new_max(FloorSquareRoot(max_s.value())); - if (!safe_enqueue( - x_.LowerOrEqual(new_max), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, x_.LowerOrEqual(new_max), {s_.LowerOrEqual(CapProdI(new_max + 1, new_max + 1) - 1)})) { return false; } @@ -1306,32 +1291,37 @@ bool SquarePropagator::Propagate(EnforcementId id, EnforcementStatus status) { return true; } -void SquarePropagator::RegisterWith(GenericLiteralWatcher* watcher) { +int SquarePropagator::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); watcher->WatchAffineExpression(x_, id); watcher->WatchAffineExpression(s_, id); watcher->NotifyThatPropagatorMayNotReachFixedPointInOnePass(id); + return id; } -DivisionPropagator::DivisionPropagator(AffineExpression num, - AffineExpression denom, - AffineExpression div, - IntegerTrail* integer_trail) +DivisionPropagator::DivisionPropagator( + absl::Span enforcement_literals, AffineExpression num, + AffineExpression denom, AffineExpression div, Model* model) : num_(num), denom_(denom), div_(div), negated_denom_(denom.Negated()), negated_num_(num.Negated()), negated_div_(div.Negated()), - integer_trail_(integer_trail) {} + integer_trail_(*model->GetOrCreate()), + enforcement_propagator_(*model->GetOrCreate()) { + GenericLiteralWatcher* watcher = model->GetOrCreate(); + enforcement_id_ = enforcement_propagator_.Register( + enforcement_literals, watcher, RegisterWith(watcher)); +} // TODO(user): We can propagate more, especially in the case where denom // spans across 0. // TODO(user): We can propagate a bit more if min_div = 0: // (min_num > -min_denom). bool DivisionPropagator::Propagate() { - if (integer_trail_->LowerBound(denom_) < 0 && - integer_trail_->UpperBound(denom_) > 0) { + if (integer_trail_.LowerBound(denom_) < 0 && + integer_trail_.UpperBound(denom_) > 0) { return true; } @@ -1340,32 +1330,62 @@ bool DivisionPropagator::Propagate() { AffineExpression denom = denom_; AffineExpression negated_denom = negated_denom_; - if (integer_trail_->UpperBound(denom) < 0) { + if (integer_trail_.UpperBound(denom) < 0) { std::swap(num, negated_num); std::swap(denom, negated_denom); } + const EnforcementStatus status = + enforcement_propagator_.Status(enforcement_id_); + if (status == EnforcementStatus::CAN_PROPAGATE) { + const IntegerValue min_num = integer_trail_.LowerBound(num); + const IntegerValue max_num = integer_trail_.UpperBound(num); + const IntegerValue min_denom = integer_trail_.LowerBound(denom); + const IntegerValue max_denom = integer_trail_.UpperBound(denom); + const IntegerValue min_div = integer_trail_.LowerBound(div_); + const IntegerValue max_div = integer_trail_.UpperBound(div_); + // If the bounds of num / denom and div are disjoint, the enforcement must + // be false. TODO(user): relax the reason in a better way. + if (min_num / max_denom > max_div) { + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, + /*literal_reason=*/{}, + {num_.GreaterOrEqual(min_num), denom_.LowerOrEqual(max_denom), + div_.LowerOrEqual(max_div)}); + } + if (max_num / min_denom < min_div) { + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, + /*literal_reason=*/{}, + {num_.LowerOrEqual(max_num), denom_.GreaterOrEqual(min_denom), + div_.GreaterOrEqual(min_div)}); + } + // Otherwise we cannot propagate anything since the enforcement is unknown. + return true; + } + + if (status != EnforcementStatus::IS_ENFORCED) return true; if (!PropagateSigns(num, denom, div_)) return false; - if (integer_trail_->UpperBound(num) >= 0 && - integer_trail_->UpperBound(div_) >= 0 && + if (integer_trail_.UpperBound(num) >= 0 && + integer_trail_.UpperBound(div_) >= 0 && !PropagateUpperBounds(num, denom, div_)) { return false; } - if (integer_trail_->UpperBound(negated_num) >= 0 && - integer_trail_->UpperBound(negated_div_) >= 0 && + if (integer_trail_.UpperBound(negated_num) >= 0 && + integer_trail_.UpperBound(negated_div_) >= 0 && !PropagateUpperBounds(negated_num, denom, negated_div_)) { return false; } - if (integer_trail_->LowerBound(num) >= 0 && - integer_trail_->LowerBound(div_) >= 0) { + if (integer_trail_.LowerBound(num) >= 0 && + integer_trail_.LowerBound(div_) >= 0) { return PropagatePositiveDomains(num, denom, div_); } - if (integer_trail_->LowerBound(negated_num) >= 0 && - integer_trail_->LowerBound(negated_div_) >= 0) { + if (integer_trail_.LowerBound(negated_num) >= 0 && + integer_trail_.LowerBound(negated_div_) >= 0) { return PropagatePositiveDomains(negated_num, denom, negated_div_); } @@ -1375,15 +1395,15 @@ bool DivisionPropagator::Propagate() { bool DivisionPropagator::PropagateSigns(AffineExpression num, AffineExpression denom, AffineExpression div) { - const IntegerValue min_num = integer_trail_->LowerBound(num); - const IntegerValue max_num = integer_trail_->UpperBound(num); - const IntegerValue min_div = integer_trail_->LowerBound(div); - const IntegerValue max_div = integer_trail_->UpperBound(div); + const IntegerValue min_num = integer_trail_.LowerBound(num); + const IntegerValue max_num = integer_trail_.UpperBound(num); + const IntegerValue min_div = integer_trail_.LowerBound(div); + const IntegerValue max_div = integer_trail_.UpperBound(div); // If num >= 0, as denom > 0, then div must be >= 0. if (min_num >= 0 && min_div < 0) { - if (!integer_trail_->SafeEnqueue( - div.GreaterOrEqual(0), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, div.GreaterOrEqual(0), {num.GreaterOrEqual(0), denom.GreaterOrEqual(1)})) { return false; } @@ -1391,8 +1411,8 @@ bool DivisionPropagator::PropagateSigns(AffineExpression num, // If div > 0, as denom > 0, then num must be > 0. if (min_num <= 0 && min_div > 0) { - if (!integer_trail_->SafeEnqueue( - num.GreaterOrEqual(1), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, num.GreaterOrEqual(1), {div.GreaterOrEqual(1), denom.GreaterOrEqual(1)})) { return false; } @@ -1400,8 +1420,8 @@ bool DivisionPropagator::PropagateSigns(AffineExpression num, // If num <= 0, as denom > 0, then div must be <= 0. if (max_num <= 0 && max_div > 0) { - if (!integer_trail_->SafeEnqueue( - div.LowerOrEqual(0), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, div.LowerOrEqual(0), {num.LowerOrEqual(0), denom.GreaterOrEqual(1)})) { return false; } @@ -1409,8 +1429,8 @@ bool DivisionPropagator::PropagateSigns(AffineExpression num, // If div < 0, as denom > 0, then num must be < 0. if (max_num >= 0 && max_div < 0) { - if (!integer_trail_->SafeEnqueue( - num.LowerOrEqual(-1), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, num.LowerOrEqual(-1), {div.LowerOrEqual(-1), denom.GreaterOrEqual(1)})) { return false; } @@ -1422,15 +1442,15 @@ bool DivisionPropagator::PropagateSigns(AffineExpression num, bool DivisionPropagator::PropagateUpperBounds(AffineExpression num, AffineExpression denom, AffineExpression div) { - const IntegerValue max_num = integer_trail_->UpperBound(num); - const IntegerValue min_denom = integer_trail_->LowerBound(denom); - const IntegerValue max_denom = integer_trail_->UpperBound(denom); - const IntegerValue max_div = integer_trail_->UpperBound(div); + const IntegerValue max_num = integer_trail_.UpperBound(num); + const IntegerValue min_denom = integer_trail_.LowerBound(denom); + const IntegerValue max_denom = integer_trail_.UpperBound(denom); + const IntegerValue max_div = integer_trail_.UpperBound(div); const IntegerValue new_max_div = max_num / min_denom; if (max_div > new_max_div) { - if (!integer_trail_->SafeEnqueue( - div.LowerOrEqual(new_max_div), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, div.LowerOrEqual(new_max_div), {num.LowerOrEqual(max_num), denom.GreaterOrEqual(min_denom)})) { return false; } @@ -1442,8 +1462,8 @@ bool DivisionPropagator::PropagateUpperBounds(AffineExpression num, const IntegerValue new_max_num = CapAddI(CapProdI(max_div + 1, max_denom), -1); if (max_num > new_max_num) { - if (!integer_trail_->SafeEnqueue( - num.LowerOrEqual(new_max_num), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, num.LowerOrEqual(new_max_num), {denom.LowerOrEqual(max_denom), denom.GreaterOrEqual(1), div.LowerOrEqual(max_div)})) { return false; @@ -1456,17 +1476,17 @@ bool DivisionPropagator::PropagateUpperBounds(AffineExpression num, bool DivisionPropagator::PropagatePositiveDomains(AffineExpression num, AffineExpression denom, AffineExpression div) { - const IntegerValue min_num = integer_trail_->LowerBound(num); - const IntegerValue max_num = integer_trail_->UpperBound(num); - const IntegerValue min_denom = integer_trail_->LowerBound(denom); - const IntegerValue max_denom = integer_trail_->UpperBound(denom); - const IntegerValue min_div = integer_trail_->LowerBound(div); - const IntegerValue max_div = integer_trail_->UpperBound(div); + const IntegerValue min_num = integer_trail_.LowerBound(num); + const IntegerValue max_num = integer_trail_.UpperBound(num); + const IntegerValue min_denom = integer_trail_.LowerBound(denom); + const IntegerValue max_denom = integer_trail_.UpperBound(denom); + const IntegerValue min_div = integer_trail_.LowerBound(div); + const IntegerValue max_div = integer_trail_.UpperBound(div); const IntegerValue new_min_div = min_num / max_denom; if (min_div < new_min_div) { - if (!integer_trail_->SafeEnqueue( - div.GreaterOrEqual(new_min_div), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, div.GreaterOrEqual(new_min_div), {num.GreaterOrEqual(min_num), denom.LowerOrEqual(max_denom), denom.GreaterOrEqual(1)})) { return false; @@ -1478,8 +1498,8 @@ bool DivisionPropagator::PropagatePositiveDomains(AffineExpression num, // num >= min_div * min_denom. const IntegerValue new_min_num = CapProdI(min_denom, min_div); if (min_num < new_min_num) { - if (!integer_trail_->SafeEnqueue( - num.GreaterOrEqual(new_min_num), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, num.GreaterOrEqual(new_min_num), {denom.GreaterOrEqual(min_denom), div.GreaterOrEqual(min_div)})) { return false; } @@ -1492,8 +1512,8 @@ bool DivisionPropagator::PropagatePositiveDomains(AffineExpression num, if (min_div > 0) { const IntegerValue new_max_denom = max_num / min_div; if (max_denom > new_max_denom) { - if (!integer_trail_->SafeEnqueue( - denom.LowerOrEqual(new_max_denom), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, denom.LowerOrEqual(new_max_denom), {num.LowerOrEqual(max_num), num.GreaterOrEqual(0), div.GreaterOrEqual(min_div), denom.GreaterOrEqual(1)})) { return false; @@ -1505,8 +1525,8 @@ bool DivisionPropagator::PropagatePositiveDomains(AffineExpression num, // >= CeilRatio(min_num + 1, max_div + 1). const IntegerValue new_min_denom = CeilRatio(min_num + 1, max_div + 1); if (min_denom < new_min_denom) { - if (!integer_trail_->SafeEnqueue( - denom.GreaterOrEqual(new_min_denom), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, denom.GreaterOrEqual(new_min_denom), {num.GreaterOrEqual(min_num), div.LowerOrEqual(max_div), div.GreaterOrEqual(0), denom.GreaterOrEqual(1)})) { return false; @@ -1516,60 +1536,89 @@ bool DivisionPropagator::PropagatePositiveDomains(AffineExpression num, return true; } -void DivisionPropagator::RegisterWith(GenericLiteralWatcher* watcher) { +int DivisionPropagator::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); watcher->WatchAffineExpression(num_, id); watcher->WatchAffineExpression(denom_, id); watcher->WatchAffineExpression(div_, id); watcher->NotifyThatPropagatorMayNotReachFixedPointInOnePass(id); + return id; } -FixedDivisionPropagator::FixedDivisionPropagator(AffineExpression a, - IntegerValue b, - AffineExpression c, - IntegerTrail* integer_trail) - : a_(a), b_(b), c_(c), integer_trail_(integer_trail) { +FixedDivisionPropagator::FixedDivisionPropagator( + absl::Span enforcement_literals, AffineExpression a, + IntegerValue b, AffineExpression c, Model* model) + : a_(a), + b_(b), + c_(c), + integer_trail_(*model->GetOrCreate()), + enforcement_propagator_(*model->GetOrCreate()) { + GenericLiteralWatcher* watcher = model->GetOrCreate(); + enforcement_id_ = enforcement_propagator_.Register( + enforcement_literals, watcher, RegisterWith(watcher)); CHECK_GT(b_, 0); } bool FixedDivisionPropagator::Propagate() { - const IntegerValue min_a = integer_trail_->LowerBound(a_); - const IntegerValue max_a = integer_trail_->UpperBound(a_); - IntegerValue min_c = integer_trail_->LowerBound(c_); - IntegerValue max_c = integer_trail_->UpperBound(c_); + const IntegerValue min_a = integer_trail_.LowerBound(a_); + const IntegerValue max_a = integer_trail_.UpperBound(a_); + IntegerValue min_c = integer_trail_.LowerBound(c_); + IntegerValue max_c = integer_trail_.UpperBound(c_); + const EnforcementStatus status = + enforcement_propagator_.Status(enforcement_id_); + if (status == EnforcementStatus::CAN_PROPAGATE) { + // If the bounds of a / b and c are disjoint, the enforcement must be false. + // TODO(user): relax the reason in a better way. + if (min_a / b_ > max_c) { + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, + /*literal_reason=*/{}, + {a_.GreaterOrEqual(max_c * b_ + 1), c_.LowerOrEqual(max_c)}); + } + if (max_a / b_ < min_c) { + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, + /*literal_reason=*/{}, + {a_.LowerOrEqual(min_c * b_ - 1), c_.GreaterOrEqual(min_c)}); + } + // Otherwise we cannot propagate anything since the enforcement is unknown. + return true; + } + + if (status != EnforcementStatus::IS_ENFORCED) return true; if (max_a / b_ < max_c) { max_c = max_a / b_; - if (!integer_trail_->SafeEnqueue( - c_.LowerOrEqual(max_c), - {integer_trail_->UpperBoundAsLiteral(a_)})) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, c_.LowerOrEqual(max_c), + {integer_trail_.UpperBoundAsLiteral(a_)})) { return false; } } else if (max_a / b_ > max_c) { const IntegerValue new_max_a = max_c >= 0 ? max_c * b_ + b_ - 1 : CapProdI(max_c, b_); CHECK_LT(new_max_a, max_a); - if (!integer_trail_->SafeEnqueue( - a_.LowerOrEqual(new_max_a), - {integer_trail_->UpperBoundAsLiteral(c_)})) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, a_.LowerOrEqual(new_max_a), + {integer_trail_.UpperBoundAsLiteral(c_)})) { return false; } } if (min_a / b_ > min_c) { min_c = min_a / b_; - if (!integer_trail_->SafeEnqueue( - c_.GreaterOrEqual(min_c), - {integer_trail_->LowerBoundAsLiteral(a_)})) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, c_.GreaterOrEqual(min_c), + {integer_trail_.LowerBoundAsLiteral(a_)})) { return false; } } else if (min_a / b_ < min_c) { const IntegerValue new_min_a = min_c > 0 ? CapProdI(min_c, b_) : min_c * b_ - b_ + 1; CHECK_GT(new_min_a, min_a); - if (!integer_trail_->SafeEnqueue( - a_.GreaterOrEqual(new_min_a), - {integer_trail_->LowerBoundAsLiteral(c_)})) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, a_.GreaterOrEqual(new_min_a), + {integer_trail_.LowerBoundAsLiteral(c_)})) { return false; } } @@ -1577,29 +1626,66 @@ bool FixedDivisionPropagator::Propagate() { return true; } -void FixedDivisionPropagator::RegisterWith(GenericLiteralWatcher* watcher) { +int FixedDivisionPropagator::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); watcher->WatchAffineExpression(a_, id); watcher->WatchAffineExpression(c_, id); + return id; } -FixedModuloPropagator::FixedModuloPropagator(AffineExpression expr, - IntegerValue mod, - AffineExpression target, - IntegerTrail* integer_trail) - : expr_(expr), mod_(mod), target_(target), integer_trail_(integer_trail) { +FixedModuloPropagator::FixedModuloPropagator( + absl::Span enforcement_literals, AffineExpression expr, + IntegerValue mod, AffineExpression target, Model* model) + : expr_(expr), + mod_(mod), + target_(target), + negated_expr_(expr.Negated()), + negated_target_(target.Negated()), + integer_trail_(*model->GetOrCreate()), + enforcement_propagator_(*model->GetOrCreate()) { CHECK_GT(mod_, 0); + GenericLiteralWatcher* watcher = model->GetOrCreate(); + enforcement_id_ = enforcement_propagator_.Register( + enforcement_literals, watcher, RegisterWith(watcher)); } bool FixedModuloPropagator::Propagate() { + const EnforcementStatus status = + enforcement_propagator_.Status(enforcement_id_); + if (status == EnforcementStatus::CAN_PROPAGATE) { + const IntegerValue min_target = integer_trail_.LowerBound(target_); + const IntegerValue max_target = integer_trail_.UpperBound(target_); + if (min_target >= mod_) { + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, /*literal_reason=*/{}, + {target_.GreaterOrEqual(mod_)}); + } else if (max_target <= -mod_) { + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, /*literal_reason=*/{}, + {target_.LowerOrEqual(-mod_)}); + } + if (min_target > 0) { + if (!PropagateWhenFalseAndTargetIsPositive(expr_, target_)) return false; + } else if (max_target < 0) { + if (!PropagateWhenFalseAndTargetIsPositive(negated_expr_, + negated_target_)) { + return false; + } + } else if (!PropagateWhenFalseAndTargetDomainContainsZero()) { + return false; + } + // Otherwise we cannot propagate anything since the enforcement is unknown. + return true; + } + + if (status != EnforcementStatus::IS_ENFORCED) return true; if (!PropagateSignsAndTargetRange()) return false; if (!PropagateOuterBounds()) return false; - if (integer_trail_->LowerBound(expr_) >= 0) { - if (!PropagateBoundsWhenExprIsPositive(expr_, target_)) return false; - } else if (integer_trail_->UpperBound(expr_) <= 0) { - if (!PropagateBoundsWhenExprIsPositive(expr_.Negated(), - target_.Negated())) { + if (integer_trail_.LowerBound(expr_) >= 0) { + if (!PropagateBoundsWhenExprIsNonNegative(expr_, target_)) return false; + } else if (integer_trail_.UpperBound(expr_) <= 0) { + if (!PropagateBoundsWhenExprIsNonNegative(negated_expr_, negated_target_)) { return false; } } @@ -1607,53 +1693,135 @@ bool FixedModuloPropagator::Propagate() { return true; } +bool FixedModuloPropagator::PropagateWhenFalseAndTargetIsPositive( + AffineExpression expr, AffineExpression target) { + const IntegerValue min_expr = integer_trail_.LowerBound(expr); + const IntegerValue max_expr = integer_trail_.UpperBound(expr); + // expr % mod_ must be in the target domain intersected with [0, mod_ - 1], + // noted [min_expr_mod, max_expr_mod]. This interval is non-empty. + const IntegerValue min_expr_mod = + std::max(IntegerValue(0), integer_trail_.LowerBound(target)); + const IntegerValue max_expr_mod = + std::min(mod_ - 1, integer_trail_.UpperBound(target)); + // expr must be in [min_expr_mod + k * mod_, max_expr_mod + k * mod_], for + // some k >= 0. If the expr domain is in one of the following intervals, the + // constraint is always false: + // - ]-infinity, min_expr_mod[ + // - ]max_expr_mod + k * mod_ , min_expr_mod + (k + 1) * mod_[ + if (max_expr < min_expr_mod) { + // TODO(user): relax the reason in a better way. + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, /*literal_reason=*/{}, + {expr.LowerOrEqual(min_expr_mod - 1), + target.GreaterOrEqual(min_expr_mod)}); + } + // Compute the smallest k such that max_expr < min_expr_mod + (k + 1) * mod_. + const IntegerValue k = MathUtil::FloorOfRatio(max_expr - min_expr_mod, mod_); + if (min_expr > max_expr_mod + k * mod_) { + // TODO(user): relax the reason in a better way. + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, /*literal_reason=*/{}, + {expr.GreaterOrEqual(max_expr_mod + k * mod_ + 1), + expr.LowerOrEqual(min_expr_mod + (k + 1) * mod_ - 1), + target.GreaterOrEqual(min_expr_mod), + target.LowerOrEqual(max_expr_mod)}); + } + return true; +} + +bool FixedModuloPropagator::PropagateWhenFalseAndTargetDomainContainsZero() { + const IntegerValue neg_max_expr_mod = + std::max(-mod_ + 1, integer_trail_.LowerBound(target_)); + const IntegerValue pos_max_expr_mod = + std::min(mod_ - 1, integer_trail_.UpperBound(target_)); + // expr must be in [k * mod_, pos_max_expr_mod + k * mod_] or in + // [neg_max_expr_mod - k * mod_, -k * mod_] for some k >= 0. If the expr + // domain is in one of the following intervals, the constraint is always + // false: + // - ]-(k + 1) * mod_, neg_max_expr_mod - k * mod_[ + // - ]pos_max_expr_mod + k * mod_ , (k + 1) * mod_[ + const IntegerValue min_expr = integer_trail_.LowerBound(expr_); + const IntegerValue max_expr = integer_trail_.UpperBound(expr_); + // Compute the smallest k such that max_expr < (k + 1) * mod_. + IntegerValue k = MathUtil::FloorOfRatio(max_expr, mod_); + if (k >= 0 && min_expr > pos_max_expr_mod + k * mod_) { + const IntegerValue min_target = integer_trail_.LowerBound(target_); + // TODO(user): relax the reason in a better way. + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, /*literal_reason=*/{}, + {expr_.GreaterOrEqual(pos_max_expr_mod + k * mod_ + 1), + expr_.LowerOrEqual((k + 1) * mod_ - 1), + target_.GreaterOrEqual(min_target), + target_.LowerOrEqual(pos_max_expr_mod)}); + } + // Compute the smallest k such that min_expr > -(k + 1) * mod_. + k = MathUtil::FloorOfRatio(-min_expr, mod_); + if (k >= 0 && max_expr < neg_max_expr_mod - k * mod_) { + const IntegerValue max_target = integer_trail_.UpperBound(target_); + // TODO(user): relax the reason in a better way. + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, /*literal_reason=*/{}, + {expr_.GreaterOrEqual(-(k + 1) * mod_ + 1), + expr_.LowerOrEqual(neg_max_expr_mod - k * mod_ - 1), + target_.GreaterOrEqual(neg_max_expr_mod), + target_.LowerOrEqual(max_target)}); + } + return true; +} + bool FixedModuloPropagator::PropagateSignsAndTargetRange() { // Initial domain reduction on the target. - if (integer_trail_->UpperBound(target_) >= mod_) { - if (!integer_trail_->SafeEnqueue(target_.LowerOrEqual(mod_ - 1), {})) { + if (integer_trail_.UpperBound(target_) >= mod_) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, target_.LowerOrEqual(mod_ - 1), {})) { return false; } } - if (integer_trail_->LowerBound(target_) <= -mod_) { - if (!integer_trail_->SafeEnqueue(target_.GreaterOrEqual(1 - mod_), {})) { + if (integer_trail_.LowerBound(target_) <= -mod_) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, target_.GreaterOrEqual(1 - mod_), {})) { return false; } } // The sign of target_ is fixed by the sign of expr_. - if (integer_trail_->LowerBound(expr_) >= 0 && - integer_trail_->LowerBound(target_) < 0) { + if (integer_trail_.LowerBound(expr_) >= 0 && + integer_trail_.LowerBound(target_) < 0) { // expr >= 0 => target >= 0. - if (!integer_trail_->SafeEnqueue(target_.GreaterOrEqual(0), - {expr_.GreaterOrEqual(0)})) { + if (!enforcement_propagator_.SafeEnqueue(enforcement_id_, + target_.GreaterOrEqual(0), + {expr_.GreaterOrEqual(0)})) { return false; } } - if (integer_trail_->UpperBound(expr_) <= 0 && - integer_trail_->UpperBound(target_) > 0) { + if (integer_trail_.UpperBound(expr_) <= 0 && + integer_trail_.UpperBound(target_) > 0) { // expr <= 0 => target <= 0. - if (!integer_trail_->SafeEnqueue(target_.LowerOrEqual(0), - {expr_.LowerOrEqual(0)})) { + if (!enforcement_propagator_.SafeEnqueue(enforcement_id_, + target_.LowerOrEqual(0), + {expr_.LowerOrEqual(0)})) { return false; } } - if (integer_trail_->LowerBound(target_) > 0 && - integer_trail_->LowerBound(expr_) <= 0) { + if (integer_trail_.LowerBound(target_) > 0 && + integer_trail_.LowerBound(expr_) <= 0) { // target > 0 => expr > 0. - if (!integer_trail_->SafeEnqueue(expr_.GreaterOrEqual(1), - {target_.GreaterOrEqual(1)})) { + if (!enforcement_propagator_.SafeEnqueue(enforcement_id_, + expr_.GreaterOrEqual(1), + {target_.GreaterOrEqual(1)})) { return false; } } - if (integer_trail_->UpperBound(target_) < 0 && - integer_trail_->UpperBound(expr_) >= 0) { + if (integer_trail_.UpperBound(target_) < 0 && + integer_trail_.UpperBound(expr_) >= 0) { // target < 0 => expr < 0. - if (!integer_trail_->SafeEnqueue(expr_.LowerOrEqual(-1), - {target_.LowerOrEqual(-1)})) { + if (!enforcement_propagator_.SafeEnqueue(enforcement_id_, + expr_.LowerOrEqual(-1), + {target_.LowerOrEqual(-1)})) { return false; } } @@ -1662,68 +1830,72 @@ bool FixedModuloPropagator::PropagateSignsAndTargetRange() { } bool FixedModuloPropagator::PropagateOuterBounds() { - const IntegerValue min_expr = integer_trail_->LowerBound(expr_); - const IntegerValue max_expr = integer_trail_->UpperBound(expr_); - const IntegerValue min_target = integer_trail_->LowerBound(target_); - const IntegerValue max_target = integer_trail_->UpperBound(target_); + const IntegerValue min_expr = integer_trail_.LowerBound(expr_); + const IntegerValue max_expr = integer_trail_.UpperBound(expr_); + const IntegerValue min_target = integer_trail_.LowerBound(target_); + const IntegerValue max_target = integer_trail_.UpperBound(target_); if (max_expr % mod_ > max_target) { - if (!integer_trail_->SafeEnqueue( + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, expr_.LowerOrEqual((max_expr / mod_) * mod_ + max_target), - {integer_trail_->UpperBoundAsLiteral(target_), - integer_trail_->UpperBoundAsLiteral(expr_)})) { + {integer_trail_.UpperBoundAsLiteral(target_), + integer_trail_.UpperBoundAsLiteral(expr_)})) { return false; } } if (min_expr % mod_ < min_target) { - if (!integer_trail_->SafeEnqueue( + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, expr_.GreaterOrEqual((min_expr / mod_) * mod_ + min_target), - {integer_trail_->LowerBoundAsLiteral(expr_), - integer_trail_->LowerBoundAsLiteral(target_)})) { + {integer_trail_.LowerBoundAsLiteral(expr_), + integer_trail_.LowerBoundAsLiteral(target_)})) { return false; } } if (min_expr / mod_ == max_expr / mod_) { if (min_target < min_expr % mod_) { - if (!integer_trail_->SafeEnqueue( + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, target_.GreaterOrEqual(min_expr - (min_expr / mod_) * mod_), - {integer_trail_->LowerBoundAsLiteral(target_), - integer_trail_->UpperBoundAsLiteral(target_), - integer_trail_->LowerBoundAsLiteral(expr_), - integer_trail_->UpperBoundAsLiteral(expr_)})) { + {integer_trail_.LowerBoundAsLiteral(target_), + integer_trail_.UpperBoundAsLiteral(target_), + integer_trail_.LowerBoundAsLiteral(expr_), + integer_trail_.UpperBoundAsLiteral(expr_)})) { return false; } } if (max_target > max_expr % mod_) { - if (!integer_trail_->SafeEnqueue( + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, target_.LowerOrEqual(max_expr - (max_expr / mod_) * mod_), - {integer_trail_->LowerBoundAsLiteral(target_), - integer_trail_->UpperBoundAsLiteral(target_), - integer_trail_->LowerBoundAsLiteral(expr_), - integer_trail_->UpperBoundAsLiteral(expr_)})) { + {integer_trail_.LowerBoundAsLiteral(target_), + integer_trail_.UpperBoundAsLiteral(target_), + integer_trail_.LowerBoundAsLiteral(expr_), + integer_trail_.UpperBoundAsLiteral(expr_)})) { return false; } } } else if (min_expr / mod_ == 0 && min_target < 0) { // expr == target when expr <= 0. if (min_target < min_expr) { - if (!integer_trail_->SafeEnqueue( - target_.GreaterOrEqual(min_expr), - {integer_trail_->LowerBoundAsLiteral(target_), - integer_trail_->LowerBoundAsLiteral(expr_)})) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, target_.GreaterOrEqual(min_expr), + {integer_trail_.LowerBoundAsLiteral(target_), + integer_trail_.LowerBoundAsLiteral(expr_)})) { return false; } } } else if (max_expr / mod_ == 0 && max_target > 0) { // expr == target when expr >= 0. if (max_target > max_expr) { - if (!integer_trail_->SafeEnqueue( - target_.LowerOrEqual(max_expr), - {integer_trail_->UpperBoundAsLiteral(target_), - integer_trail_->UpperBoundAsLiteral(expr_)})) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, target_.LowerOrEqual(max_expr), + {integer_trail_.UpperBoundAsLiteral(target_), + integer_trail_.UpperBoundAsLiteral(expr_)})) { return false; } } @@ -1732,37 +1904,39 @@ bool FixedModuloPropagator::PropagateOuterBounds() { return true; } -bool FixedModuloPropagator::PropagateBoundsWhenExprIsPositive( +bool FixedModuloPropagator::PropagateBoundsWhenExprIsNonNegative( AffineExpression expr, AffineExpression target) { - const IntegerValue min_target = integer_trail_->LowerBound(target); + const IntegerValue min_target = integer_trail_.LowerBound(target); DCHECK_GE(min_target, 0); - const IntegerValue max_target = integer_trail_->UpperBound(target); + const IntegerValue max_target = integer_trail_.UpperBound(target); // The propagation rules below will not be triggered if the domain of target // covers [0..mod_ - 1]. if (min_target == 0 && max_target == mod_ - 1) return true; - const IntegerValue min_expr = integer_trail_->LowerBound(expr); - const IntegerValue max_expr = integer_trail_->UpperBound(expr); + const IntegerValue min_expr = integer_trail_.LowerBound(expr); + const IntegerValue max_expr = integer_trail_.UpperBound(expr); if (max_expr % mod_ < min_target) { DCHECK_GE(max_expr, 0); - if (!integer_trail_->SafeEnqueue( + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, expr.LowerOrEqual((max_expr / mod_ - 1) * mod_ + max_target), - {integer_trail_->UpperBoundAsLiteral(expr), - integer_trail_->LowerBoundAsLiteral(target), - integer_trail_->UpperBoundAsLiteral(target)})) { + {integer_trail_.UpperBoundAsLiteral(expr), + integer_trail_.LowerBoundAsLiteral(target), + integer_trail_.UpperBoundAsLiteral(target)})) { return false; } } if (min_expr % mod_ > max_target) { DCHECK_GE(min_expr, 0); - if (!integer_trail_->SafeEnqueue( + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, expr.GreaterOrEqual((min_expr / mod_ + 1) * mod_ + min_target), - {integer_trail_->LowerBoundAsLiteral(target), - integer_trail_->UpperBoundAsLiteral(target), - integer_trail_->LowerBoundAsLiteral(expr)})) { + {integer_trail_.LowerBoundAsLiteral(target), + integer_trail_.UpperBoundAsLiteral(target), + integer_trail_.LowerBoundAsLiteral(expr)})) { return false; } } @@ -1770,11 +1944,12 @@ bool FixedModuloPropagator::PropagateBoundsWhenExprIsPositive( return true; } -void FixedModuloPropagator::RegisterWith(GenericLiteralWatcher* watcher) { +int FixedModuloPropagator::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); watcher->WatchAffineExpression(expr_, id); watcher->WatchAffineExpression(target_, id); watcher->NotifyThatPropagatorMayNotReachFixedPointInOnePass(id); + return id; } } // namespace sat diff --git a/ortools/sat/integer_expr.h b/ortools/sat/integer_expr.h index 9bfc1b6e1b..094921b1f6 100644 --- a/ortools/sat/integer_expr.h +++ b/ortools/sat/integer_expr.h @@ -287,17 +287,17 @@ class ProductPropagator : public PropagatorInterface { public: ProductPropagator(absl::Span enforcement_literals, AffineExpression a, AffineExpression b, AffineExpression p, - IntegerTrail* integer_trail, - EnforcementPropagator* enforcement_propagator); + Model* model); // This type is neither copyable nor movable. ProductPropagator(const ProductPropagator&) = delete; ProductPropagator& operator=(const ProductPropagator&) = delete; bool Propagate() final; - void RegisterWith(GenericLiteralWatcher* watcher); private: + int RegisterWith(GenericLiteralWatcher* watcher); + // Maybe replace a_, b_ or c_ by their negation to simplify the cases. bool CanonicalizeCases(); @@ -309,19 +309,14 @@ class ProductPropagator : public PropagatorInterface { bool PropagateMaxOnPositiveProduct(AffineExpression a, AffineExpression b, IntegerValue min_p, IntegerValue max_p); - ABSL_MUST_USE_RESULT bool SafeEnqueue( - IntegerLiteral i_lit, absl::Span integer_reason); - // Note that we might negate any two terms in CanonicalizeCases() during // each propagation. This is fine. AffineExpression a_; AffineExpression b_; AffineExpression p_; - - IntegerTrail* integer_trail_; - EnforcementPropagator* enforcement_propagator_; + const IntegerTrail& integer_trail_; + EnforcementPropagator& enforcement_propagator_; EnforcementId enforcement_id_; - std::vector tmp_literal_reason_; }; // Propagates num / denom = div. Basic version, we don't extract any special @@ -330,17 +325,19 @@ class ProductPropagator : public PropagatorInterface { // TODO(user): Deal with overflow. class DivisionPropagator : public PropagatorInterface { public: - DivisionPropagator(AffineExpression num, AffineExpression denom, - AffineExpression div, IntegerTrail* integer_trail); + DivisionPropagator(absl::Span enforcement_literals, + AffineExpression num, AffineExpression denom, + AffineExpression div, Model* model); // This type is neither copyable nor movable. DivisionPropagator(const DivisionPropagator&) = delete; DivisionPropagator& operator=(const DivisionPropagator&) = delete; bool Propagate() final; - void RegisterWith(GenericLiteralWatcher* watcher); private: + int RegisterWith(GenericLiteralWatcher* watcher); + // Propagates the fact that the signs of each domain, if fixed, are // compatible. bool PropagateSigns(AffineExpression num, AffineExpression denom, @@ -363,49 +360,59 @@ class DivisionPropagator : public PropagatorInterface { const AffineExpression negated_denom_; const AffineExpression negated_num_; const AffineExpression negated_div_; - IntegerTrail* integer_trail_; + const IntegerTrail& integer_trail_; + EnforcementPropagator& enforcement_propagator_; + EnforcementId enforcement_id_; }; // Propagates var_a / cst_b = var_c. Basic version, we don't extract any special // cases, and we only propagates the bounds. cst_b must be > 0. class FixedDivisionPropagator : public PropagatorInterface { public: - FixedDivisionPropagator(AffineExpression a, IntegerValue b, - AffineExpression c, IntegerTrail* integer_trail); + FixedDivisionPropagator(absl::Span enforcement_literals, + AffineExpression a, IntegerValue b, + AffineExpression c, Model* model); // This type is neither copyable nor movable. FixedDivisionPropagator(const FixedDivisionPropagator&) = delete; FixedDivisionPropagator& operator=(const FixedDivisionPropagator&) = delete; bool Propagate() final; - void RegisterWith(GenericLiteralWatcher* watcher); private: + int RegisterWith(GenericLiteralWatcher* watcher); + const AffineExpression a_; const IntegerValue b_; const AffineExpression c_; - - IntegerTrail* integer_trail_; + const IntegerTrail& integer_trail_; + EnforcementPropagator& enforcement_propagator_; + EnforcementId enforcement_id_; }; // Propagates target == expr % mod. Basic version, we don't extract any special // cases, and we only propagates the bounds. mod must be > 0. class FixedModuloPropagator : public PropagatorInterface { public: - FixedModuloPropagator(AffineExpression expr, IntegerValue mod, - AffineExpression target, IntegerTrail* integer_trail); + FixedModuloPropagator(absl::Span enforcement_literals, + AffineExpression expr, IntegerValue mod, + AffineExpression target, Model* model); // This type is neither copyable nor movable. FixedModuloPropagator(const FixedModuloPropagator&) = delete; FixedModuloPropagator& operator=(const FixedModuloPropagator&) = delete; bool Propagate() final; - void RegisterWith(GenericLiteralWatcher* watcher); private: + int RegisterWith(GenericLiteralWatcher* watcher); + + bool PropagateWhenFalseAndTargetIsPositive(AffineExpression expr, + AffineExpression target); + bool PropagateWhenFalseAndTargetDomainContainsZero(); bool PropagateSignsAndTargetRange(); - bool PropagateBoundsWhenExprIsPositive(AffineExpression expr, - AffineExpression target); + bool PropagateBoundsWhenExprIsNonNegative(AffineExpression expr, + AffineExpression target); bool PropagateOuterBounds(); const AffineExpression expr_; @@ -413,7 +420,9 @@ class FixedModuloPropagator : public PropagatorInterface { const AffineExpression target_; const AffineExpression negated_expr_; const AffineExpression negated_target_; - IntegerTrail* integer_trail_; + const IntegerTrail& integer_trail_; + EnforcementPropagator& enforcement_propagator_; + EnforcementId enforcement_id_; }; // Propagates x * x = s. @@ -421,26 +430,22 @@ class FixedModuloPropagator : public PropagatorInterface { class SquarePropagator : public PropagatorInterface { public: SquarePropagator(absl::Span enforcement_literals, - AffineExpression x, AffineExpression s, - IntegerTrail* integer_trail, - EnforcementPropagator* enforcement_propagator); + AffineExpression x, AffineExpression s, Model* model); // This type is neither copyable nor movable. SquarePropagator(const SquarePropagator&) = delete; SquarePropagator& operator=(const SquarePropagator&) = delete; bool Propagate() final; - void RegisterWith(GenericLiteralWatcher* watcher); private: - bool Propagate(EnforcementId enforcement_id, EnforcementStatus status); + int RegisterWith(GenericLiteralWatcher* watcher); const AffineExpression x_; const AffineExpression s_; - IntegerTrail* integer_trail_; - EnforcementPropagator* enforcement_propagator_; + const IntegerTrail& integer_trail_; + EnforcementPropagator& enforcement_propagator_; EnforcementId enforcement_id_; - std::vector tmp_literal_reason_; }; // ============================================================================= @@ -774,82 +779,71 @@ inline std::function IsEqualToMinOf( return [&](Model* model) { AddIsEqualToMinOf(min_expr, exprs, model); }; } -template -void RegisterAndTransferOwnership(Model* model, T* ct) { - ct->RegisterWith(model->GetOrCreate()); - model->TakeOwnership(ct); -} // Adds the constraint: a * b = p. inline std::function ProductConstraint( absl::Span enforcement_literals, AffineExpression a, AffineExpression b, AffineExpression p) { return [=](Model* model) { - IntegerTrail* integer_trail = model->GetOrCreate(); - EnforcementPropagator* enforcement_propagator = - model->GetOrCreate(); + const IntegerTrail& integer_trail = *model->GetOrCreate(); + // TODO(user): return early if constraint is never enforced. if (a == b) { - if (integer_trail->LowerBound(a) >= 0) { - RegisterAndTransferOwnership( - model, new SquarePropagator(enforcement_literals, a, p, - integer_trail, enforcement_propagator)); + if (integer_trail.LowerBound(a) >= 0) { + model->TakeOwnership( + new SquarePropagator(enforcement_literals, a, p, model)); return; } - if (integer_trail->UpperBound(a) <= 0) { - RegisterAndTransferOwnership( - model, new SquarePropagator(enforcement_literals, a.Negated(), p, - integer_trail, enforcement_propagator)); + if (integer_trail.UpperBound(a) <= 0) { + model->TakeOwnership( + new SquarePropagator(enforcement_literals, a.Negated(), p, model)); return; } } - RegisterAndTransferOwnership( - model, new ProductPropagator(enforcement_literals, a, b, p, - integer_trail, enforcement_propagator)); + model->TakeOwnership( + new ProductPropagator(enforcement_literals, a, b, p, model)); }; } // Adds the constraint: num / denom = div. (denom > 0). -inline std::function DivisionConstraint(AffineExpression num, - AffineExpression denom, - AffineExpression div) { +inline std::function DivisionConstraint( + absl::Span enforcement_literals, AffineExpression num, + AffineExpression denom, AffineExpression div) { return [=](Model* model) { - IntegerTrail* integer_trail = model->GetOrCreate(); + const IntegerTrail& integer_trail = *model->GetOrCreate(); + // TODO(user): return early if constraint is never enforced. DivisionPropagator* constraint; - if (integer_trail->UpperBound(denom) < 0) { - constraint = new DivisionPropagator(num.Negated(), denom.Negated(), div, - integer_trail); - + if (integer_trail.UpperBound(denom) < 0) { + constraint = new DivisionPropagator(enforcement_literals, num.Negated(), + denom.Negated(), div, model); } else { - constraint = new DivisionPropagator(num, denom, div, integer_trail); + constraint = + new DivisionPropagator(enforcement_literals, num, denom, div, model); } - constraint->RegisterWith(model->GetOrCreate()); model->TakeOwnership(constraint); }; } // Adds the constraint: a / b = c where b is a constant. -inline std::function FixedDivisionConstraint(AffineExpression a, - IntegerValue b, - AffineExpression c) { +inline std::function FixedDivisionConstraint( + absl::Span enforcement_literals, AffineExpression a, + IntegerValue b, AffineExpression c) { return [=](Model* model) { - IntegerTrail* integer_trail = model->GetOrCreate(); + // TODO(user): return early if constraint is never enforced. FixedDivisionPropagator* constraint = - b > 0 ? new FixedDivisionPropagator(a, b, c, integer_trail) - : new FixedDivisionPropagator(a.Negated(), -b, c, integer_trail); - constraint->RegisterWith(model->GetOrCreate()); + b > 0 + ? new FixedDivisionPropagator(enforcement_literals, a, b, c, model) + : new FixedDivisionPropagator(enforcement_literals, a.Negated(), -b, + c, model); model->TakeOwnership(constraint); }; } // Adds the constraint: a % b = c where b is a constant. -inline std::function FixedModuloConstraint(AffineExpression a, - IntegerValue b, - AffineExpression c) { +inline std::function FixedModuloConstraint( + absl::Span enforcement_literals, AffineExpression a, + IntegerValue b, AffineExpression c) { return [=](Model* model) { - IntegerTrail* integer_trail = model->GetOrCreate(); - FixedModuloPropagator* constraint = - new FixedModuloPropagator(a, b, c, integer_trail); - constraint->RegisterWith(model->GetOrCreate()); - model->TakeOwnership(constraint); + model->TakeOwnership( + new FixedModuloPropagator(enforcement_literals, a, b, c, model)); }; } diff --git a/ortools/sat/integer_expr_test.cc b/ortools/sat/integer_expr_test.cc index fa1d8105c6..2f2a04cfff 100644 --- a/ortools/sat/integer_expr_test.cc +++ b/ortools/sat/integer_expr_test.cc @@ -1297,7 +1297,7 @@ TEST(DivisionConstraintTest, CheckAllPropagationsRandomProblem) { const IntegerVariable var_x = model.Add(NewIntegerVariable(x_min, x_max)); const IntegerVariable var_y = model.Add(NewIntegerVariable(y_min, y_max)); const IntegerVariable var_z = model.Add(NewIntegerVariable(z_min, z_max)); - model.Add(DivisionConstraint(var_x, var_y, var_z)); + model.Add(DivisionConstraint({}, var_x, var_y, var_z)); const bool result = model.GetOrCreate()->Propagate(); if (result) { EXPECT_BOUNDS_EQ(var_x, expected_x_min, expected_x_max); @@ -1309,6 +1309,54 @@ TEST(DivisionConstraintTest, CheckAllPropagationsRandomProblem) { } } +TEST(DivisionConstraintTest, AlwaysFalseWithUnassignedEnforcementLiteral) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable num = model.Add(NewIntegerVariable(3, 5)); + const IntegerVariable denom = model.Add(NewIntegerVariable(2, 3)); + const IntegerVariable div = model.Add(NewIntegerVariable(3, 5)); + // Always false if enforced (num / denom always less than div). + model.Add(DivisionConstraint({b}, num, denom, div)); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_TRUE(model.GetOrCreate()->Assignment().LiteralIsFalse(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(num, 3, 5); + EXPECT_BOUNDS_EQ(denom, 2, 3); + EXPECT_BOUNDS_EQ(div, 3, 5); +} + +TEST(DivisionConstraintTest, AlwaysFalseWithUnassignedEnforcementLiteral2) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable num = model.Add(NewIntegerVariable(3, 5)); + const IntegerVariable denom = model.Add(NewIntegerVariable(2, 3)); + const IntegerVariable div = model.Add(NewIntegerVariable(-5, -3)); + // Always false if enforced (num / denom always greater than div). + model.Add(DivisionConstraint({b}, num, denom, div)); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_TRUE(model.GetOrCreate()->Assignment().LiteralIsFalse(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(num, 3, 5); + EXPECT_BOUNDS_EQ(denom, 2, 3); + EXPECT_BOUNDS_EQ(div, -5, -3); +} + +TEST(DivisionConstraintTest, NotAlwaysFalseWithUnassignedEnforcementLiteral) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable num = model.Add(NewIntegerVariable(3, 5)); + const IntegerVariable denom = model.Add(NewIntegerVariable(2, 3)); + const IntegerVariable div = model.Add(NewIntegerVariable(1, 5)); + model.Add(DivisionConstraint({b}, num, denom, div)); + // Nothing should be propagated. + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_FALSE(model.GetOrCreate()->Assignment().LiteralIsAssigned(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(num, 3, 5); + EXPECT_BOUNDS_EQ(denom, 2, 3); + EXPECT_BOUNDS_EQ(div, 1, 5); +} + TEST(DivisionConstraintTest, CheckAllSolutionsOnExprs) { absl::BitGen random; const int kMaxValue = 30; @@ -1428,7 +1476,7 @@ void TestAllDivisionValues(int64_t min_a, int64_t max_a, int64_t b, min_c == max_c ? AffineExpression(IntegerValue(min_c)) : AffineExpression(model.Add(NewIntegerVariable(min_c, max_c))); - model.Add(FixedDivisionConstraint(var_a, IntegerValue(b), var_c)); + model.Add(FixedDivisionConstraint({}, var_a, IntegerValue(b), var_c)); const bool result = model.GetOrCreate()->Propagate(); IntegerTrail* integer_trail = model.GetOrCreate(); if (result) { @@ -1462,7 +1510,7 @@ bool PropagateFixedDivision(int64_t a, int64_t max_a, int64_t b, int64_t c, Model model; const IntegerVariable var_a = model.Add(NewIntegerVariable(a, max_a)); const IntegerVariable var_c = model.Add(NewIntegerVariable(c, max_c)); - model.Add(FixedDivisionConstraint(var_a, IntegerValue(b), var_c)); + model.Add(FixedDivisionConstraint({}, var_a, IntegerValue(b), var_c)); const bool result = model.GetOrCreate()->Propagate(); if (result) { EXPECT_BOUNDS_EQ(var_a, new_a, new_max_a); @@ -1505,6 +1553,50 @@ TEST(FixedDivisionConstraintTest, ExpectedPropagation) { /*new_c=*/3, std::numeric_limits::max() / 10)); } +TEST(FixedDivisionConstraintTest, AlwaysFalseWithUnassignedEnforcementLiteral) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable num = model.Add(NewIntegerVariable(3, 5)); + const IntegerVariable div = model.Add(NewIntegerVariable(3, 5)); + // Always false if enforced (num / denom always less than div). + model.Add(FixedDivisionConstraint({b}, num, 2, div)); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_TRUE(model.GetOrCreate()->Assignment().LiteralIsFalse(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(num, 3, 5); + EXPECT_BOUNDS_EQ(div, 3, 5); +} + +TEST(FixedDivisionConstraintTest, + AlwaysFalseWithUnassignedEnforcementLiteral2) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable num = model.Add(NewIntegerVariable(3, 5)); + const IntegerVariable div = model.Add(NewIntegerVariable(-5, -3)); + // Always false if enforced (num / denom always greater than div). + model.Add(FixedDivisionConstraint({b}, num, 2, div)); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_TRUE(model.GetOrCreate()->Assignment().LiteralIsFalse(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(num, 3, 5); + EXPECT_BOUNDS_EQ(div, -5, -3); +} + +TEST(FixedDivisionConstraintTest, + NotAlwaysFalseWithUnassignedEnforcementLiteral) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable num = model.Add(NewIntegerVariable(3, 5)); + const IntegerVariable div = model.Add(NewIntegerVariable(1, 5)); + model.Add(FixedDivisionConstraint({b}, num, 2, div)); + // Nothing should be propagated. + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_FALSE(model.GetOrCreate()->Assignment().LiteralIsAssigned(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(num, 3, 5); + EXPECT_BOUNDS_EQ(div, 1, 5); +} + TEST(ModuloConstraintTest, CheckAllSolutions) { absl::BitGen random; const int kMaxValue = 50; @@ -1599,7 +1691,7 @@ TEST(ModuloConstraintTest, CheckAllPropagationsRandomProblem) { const IntegerVariable var = model.Add(NewIntegerVariable(var_min, var_max)); const IntegerVariable target = model.Add(NewIntegerVariable(target_min, target_max)); - model.Add(FixedModuloConstraint(var, IntegerValue(mod), target)); + model.Add(FixedModuloConstraint({}, var, IntegerValue(mod), target)); const bool result = model.GetOrCreate()->Propagate(); if (result) { EXPECT_BOUNDS_EQ(var, expected_var_min, expected_var_max); @@ -1616,6 +1708,79 @@ TEST(ModuloConstraintTest, CheckAllPropagationsRandomProblem) { } } +bool TestModuloPropagationWhenFalse(int min_var, int max_var, int mod, + int min_target, int max_target) { + bool is_always_false = true; + for (int var = min_var; var <= max_var; ++var) { + for (int target = min_target; target <= max_target; ++target) { + if (var % mod == target) { + is_always_false = false; + break; + } + } + } + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var = model.Add(NewIntegerVariable(min_var, max_var)); + const IntegerVariable target = + model.Add(NewIntegerVariable(min_target, max_target)); + model.Add(FixedModuloConstraint({b}, var, IntegerValue(mod), target)); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_EQ(model.GetOrCreate()->Assignment().LiteralIsFalse(b), + is_always_false) + << "min_var = " << min_var << " max_var = " << max_var << " mod = " << mod + << " min_target = " << min_target << " max_target = " << max_target; + EXPECT_FALSE(model.GetOrCreate()->Assignment().LiteralIsTrue(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(var, min_var, max_var); + EXPECT_BOUNDS_EQ(target, min_target, max_target); + return is_always_false; +} + +TEST(ModuloConstraintTest, CheckPropagationWhenFalse) { + bool propagated_when_false = false; + for (int min_var = -15; min_var <= 15; ++min_var) { + for (int max_var = min_var; max_var <= min_var + 5; ++max_var) { + for (int min_target = -4; min_target <= 4; ++min_target) { + for (int max_target = min_target; max_target <= 4; ++max_target) { + propagated_when_false |= TestModuloPropagationWhenFalse( + min_var, max_var, 3, min_target, max_target); + } + } + } + } + EXPECT_TRUE(propagated_when_false); +} + +TEST(ModuloConstraintTest, + CheckEnumerateAllSolutionsWithoutEnforcementLiteral) { + CpModelProto initial_model = ParseTestProto(R"pb( + variables { name: 'b' domain: 0 domain: 1 } + variables { name: 'x' domain: -10 domain: 10 } + variables { name: 'y' domain: -3 domain: 3 } + constraints { + enforcement_literal: 0 + int_mod { + target { vars: 2 coeffs: 1 } + exprs { vars: 1 coeffs: 1 } + exprs { offset: 10 } + } + } + )pb"); + absl::btree_set> solutions; + const CpSolverResponse response = + SolveAndCheck(initial_model, "", &solutions); + EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL); + + CpModelProto reference_model = initial_model; + reference_model.mutable_constraints(0)->clear_enforcement_literal(); + absl::btree_set> reference_solutions; + const CpSolverResponse reference_response = + SolveAndCheck(initial_model, "", &reference_solutions); + EXPECT_EQ(reference_response.status(), CpSolverStatus::OPTIMAL); + EXPECT_EQ(solutions, reference_solutions); +} + bool TestSquarePropagation(std::pair initial_domain_x, std::pair initial_domain_s, std::pair expected_domain_x,