diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index 94478c4b9c..d5278112e3 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -224,6 +224,32 @@ cc_test( ], ) +cc_library( + name = "old_precedences_propagator", + srcs = ["old_precedences_propagator.cc"], + hdrs = ["old_precedences_propagator.h"], + deps = [ + ":integer", + ":integer_base", + ":model", + ":precedences", + ":sat_base", + ":sat_solver", + ":synchronization", + "//ortools/base", + "//ortools/base:stl_util", + "//ortools/base:strong_vector", + "//ortools/util:bitset", + "//ortools/util:strong_integers", + "@abseil-cpp//absl/cleanup", + "@abseil-cpp//absl/container:inlined_vector", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/log:vlog_is_on", + "@abseil-cpp//absl/types:span", + ], +) + cc_proto_library( name = "cp_model_cc_proto", deps = [":cp_model_proto"], @@ -1776,6 +1802,7 @@ cc_test( deps = [ ":integer_base", "//ortools/base:gmock_main", + "@abseil-cpp//absl/log:check", ], ) @@ -2073,16 +2100,12 @@ cc_library( ":util", "//ortools/base", "//ortools/base:mathutil", - "//ortools/base:stl_util", "//ortools/base:strong_vector", - "//ortools/graph", "//ortools/graph:topologicalsorter", - "//ortools/util:bitset", "//ortools/util:logging", "//ortools/util:rev", "//ortools/util:strong_integers", "//ortools/util:time_limit", - "@abseil-cpp//absl/cleanup", "@abseil-cpp//absl/container:btree", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:flat_hash_set", @@ -2144,12 +2167,13 @@ cc_library( srcs = ["integer_expr.cc"], hdrs = ["integer_expr.h"], deps = [ + ":cp_constraints", ":integer", ":integer_base", ":linear_constraint", ":linear_propagation", ":model", - ":precedences", + ":old_precedences_propagator", ":sat_base", ":sat_parameters_cc_proto", ":sat_solver", @@ -2158,6 +2182,7 @@ cc_library( "//ortools/base:mathutil", "//ortools/util:strong_integers", "//ortools/util:time_limit", + "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/numeric:int128", @@ -2203,15 +2228,14 @@ cc_library( srcs = ["linear_propagation.cc"], hdrs = ["linear_propagation.h"], deps = [ + ":cp_constraints", ":integer", ":integer_base", ":model", ":precedences", ":sat_base", - ":sat_solver", ":synchronization", ":util", - "//ortools/base:stl_util", "//ortools/base:strong_vector", "//ortools/util:bitset", "//ortools/util:rev", @@ -2220,7 +2244,6 @@ cc_library( "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/base:log_severity", "@abseil-cpp//absl/cleanup", - "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:flat_hash_set", "@abseil-cpp//absl/container:inlined_vector", "@abseil-cpp//absl/log", @@ -2255,11 +2278,13 @@ cc_library( srcs = ["all_different.cc"], hdrs = ["all_different.h"], deps = [ + ":cp_constraints", ":integer", ":integer_base", ":model", ":sat_base", ":sat_solver", + ":util", "//ortools/base", "//ortools/graph:strongly_connected_components", "//ortools/util:bitset", @@ -2311,6 +2336,7 @@ cc_library( "//ortools/util:sort", "//ortools/util:strong_integers", "@abseil-cpp//absl/algorithm:container", + "@abseil-cpp//absl/cleanup", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/log:vlog_is_on", @@ -3291,8 +3317,13 @@ cc_library( ":integer_base", ":model", ":sat_base", + ":sat_solver", "//ortools/base", + "//ortools/base:stl_util", + "//ortools/base:strong_vector", "//ortools/util:strong_integers", + "@abseil-cpp//absl/base:core_headers", + "@abseil-cpp//absl/container:inlined_vector", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/types:span", ], @@ -3585,6 +3616,7 @@ cc_library( ":synchronization", ":timetable", ":util", + "//ortools/base:stl_util", "//ortools/util:bitset", "//ortools/util:saturated_arithmetic", "//ortools/util:strong_integers", @@ -4286,7 +4318,6 @@ cc_library( "//ortools/util:running_stat", "//ortools/util:strong_integers", "//ortools/util:time_limit", - "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:flat_hash_set", diff --git a/ortools/sat/constraint_violation.cc b/ortools/sat/constraint_violation.cc index f59bff555f..ec971d59fd 100644 --- a/ortools/sat/constraint_violation.cc +++ b/ortools/sat/constraint_violation.cc @@ -942,6 +942,42 @@ CompiledConstraintWithProto::CompiledConstraintWithProto( const ConstraintProto& ct_proto) : ct_proto_(ct_proto) {} +int64_t CompiledConstraintWithProto::ComputeViolation( + absl::Span solution) { + for (const int lit : ct_proto_.enforcement_literal()) { + if (!LiteralValue(lit, solution)) return 0; + } + return ComputeViolationWhenEnforced(solution); +} + +int64_t CompiledConstraintWithProto::ViolationDelta( + int var, int64_t old_value, + absl::Span solution_with_new_value) { + bool becomes_enforced = false; + bool becomes_unenforced = false; + for (const int lit : ct_proto().enforcement_literal()) { + if (var == PositiveRef(lit)) { + if (LiteralValue(lit, solution_with_new_value) == 1) { + becomes_enforced = true; + } else { + becomes_unenforced = true; + } + } else if (!LiteralValue(lit, solution_with_new_value)) { + // If an enforcement literal stays false, the violation stays 0. + return 0; + } + } + if (becomes_enforced) { + // New violation (ComputeViolationWhenEnforced()) minus old violation (0). + return ComputeViolationWhenEnforced(solution_with_new_value); + } + if (becomes_unenforced) { + // New violation (0) minus old violation (violation()). + return -violation(); + } + return ViolationDeltaWhenEnforced(var, old_value, solution_with_new_value); +} + std::vector CompiledConstraintWithProto::UsedVariables( const CpModelProto& model_proto) const { std::vector result = sat::UsedVariables(ct_proto_); @@ -956,13 +992,19 @@ std::vector CompiledConstraintWithProto::UsedVariables( return result; } +int64_t CompiledConstraintWithProto::ViolationDeltaWhenEnforced( + int /*var*/, int64_t /*old_value*/, + absl::Span solution_with_new_value) { + return ComputeViolationWhenEnforced(solution_with_new_value) - violation(); +} + // ----- CompiledBoolXorConstraint ----- CompiledBoolXorConstraint::CompiledBoolXorConstraint( const ConstraintProto& ct_proto) : CompiledConstraintWithProto(ct_proto) {} -int64_t CompiledBoolXorConstraint::ComputeViolation( +int64_t CompiledBoolXorConstraint::ComputeViolationWhenEnforced( absl::Span solution) { int64_t sum_of_literals = 0; for (const int lit : ct_proto().bool_xor().literals()) { @@ -971,7 +1013,7 @@ int64_t CompiledBoolXorConstraint::ComputeViolation( return 1 - (sum_of_literals % 2); } -int64_t CompiledBoolXorConstraint::ViolationDelta( +int64_t CompiledBoolXorConstraint::ViolationDeltaWhenEnforced( int /*var*/, int64_t /*old_value*/, absl::Span /*solution_with_new_value*/) { return violation() == 0 ? 1 : -1; @@ -983,7 +1025,7 @@ CompiledLinMaxConstraint::CompiledLinMaxConstraint( const ConstraintProto& ct_proto) : CompiledConstraintWithProto(ct_proto) {} -int64_t CompiledLinMaxConstraint::ComputeViolation( +int64_t CompiledLinMaxConstraint::ComputeViolationWhenEnforced( absl::Span solution) { const int64_t target_value = ExprValue(ct_proto().lin_max().target(), solution); @@ -1001,7 +1043,7 @@ CompiledIntProdConstraint::CompiledIntProdConstraint( const ConstraintProto& ct_proto) : CompiledConstraintWithProto(ct_proto) {} -int64_t CompiledIntProdConstraint::ComputeViolation( +int64_t CompiledIntProdConstraint::ComputeViolationWhenEnforced( absl::Span solution) { const int64_t target_value = ExprValue(ct_proto().int_prod().target(), solution); @@ -1018,7 +1060,7 @@ CompiledIntDivConstraint::CompiledIntDivConstraint( const ConstraintProto& ct_proto) : CompiledConstraintWithProto(ct_proto) {} -int64_t CompiledIntDivConstraint::ComputeViolation( +int64_t CompiledIntDivConstraint::ComputeViolationWhenEnforced( absl::Span solution) { const int64_t target_value = ExprValue(ct_proto().int_div().target(), solution); @@ -1034,7 +1076,7 @@ CompiledIntModConstraint::CompiledIntModConstraint( const ConstraintProto& ct_proto) : CompiledConstraintWithProto(ct_proto) {} -int64_t CompiledIntModConstraint::ComputeViolation( +int64_t CompiledIntModConstraint::ComputeViolationWhenEnforced( absl::Span solution) { const int64_t target_value = ExprValue(ct_proto().int_mod().target(), solution); @@ -1063,7 +1105,7 @@ CompiledAllDiffConstraint::CompiledAllDiffConstraint( const ConstraintProto& ct_proto) : CompiledConstraintWithProto(ct_proto) {} -int64_t CompiledAllDiffConstraint::ComputeViolation( +int64_t CompiledAllDiffConstraint::ComputeViolationWhenEnforced( absl::Span solution) { values_.clear(); for (const LinearExpressionProto& expr : ct_proto().all_diff().exprs()) { @@ -1175,7 +1217,7 @@ CompiledNoOverlap2dConstraint::CompiledNoOverlap2dConstraint( const ConstraintProto& ct_proto, const CpModelProto& cp_model) : CompiledConstraintWithProto(ct_proto), cp_model_(cp_model) {} -int64_t CompiledNoOverlap2dConstraint::ComputeViolation( +int64_t CompiledNoOverlap2dConstraint::ComputeViolationWhenEnforced( absl::Span solution) { DCHECK_GE(ct_proto().no_overlap_2d().x_intervals_size(), 2); const int size = ct_proto().no_overlap_2d().x_intervals_size(); @@ -1277,10 +1319,11 @@ class CompiledCircuitConstraint : public CompiledConstraintWithProto { explicit CompiledCircuitConstraint(const ConstraintProto& ct_proto); ~CompiledCircuitConstraint() override = default; - int64_t ComputeViolation(absl::Span solution) override; + int64_t ComputeViolationWhenEnforced( + absl::Span solution) override; void PerformMove(int var, int64_t old_value, absl::Span new_solution) override; - int64_t ViolationDelta( + int64_t ViolationDeltaWhenEnforced( int var, int64_t old_value, absl::Span solution_with_new_value) override; @@ -1385,7 +1428,7 @@ void CompiledCircuitConstraint::PerformMove( std::swap(committed_sccs_, sccs_); } -int64_t CompiledCircuitConstraint::ComputeViolation( +int64_t CompiledCircuitConstraint::ComputeViolationWhenEnforced( absl::Span solution) { InitGraph(solution); int64_t result = ViolationForCurrentGraph(); @@ -1393,7 +1436,7 @@ int64_t CompiledCircuitConstraint::ComputeViolation( return result; } -int64_t CompiledCircuitConstraint::ViolationDelta( +int64_t CompiledCircuitConstraint::ViolationDeltaWhenEnforced( int var, int64_t old_value, absl::Span solution_with_new_value) { int64_t result = 0; diff --git a/ortools/sat/constraint_violation.h b/ortools/sat/constraint_violation.h index cc09718d24..6284cf924a 100644 --- a/ortools/sat/constraint_violation.h +++ b/ortools/sat/constraint_violation.h @@ -291,9 +291,27 @@ class CompiledConstraintWithProto : public CompiledConstraint { const ConstraintProto& ct_proto() const { return ct_proto_; } + int64_t ComputeViolation(absl::Span solution) final; + + // Returns the delta if var changes from old_value to solution[var]. + int64_t ViolationDelta( + int var, int64_t old_value, + absl::Span solution_with_new_value) final; + // This just returns the variables used by the stored ct_proto_. std::vector UsedVariables(const CpModelProto& model_proto) const final; + protected: + // Computes the violation of a constraint when it is enforced. + virtual int64_t ComputeViolationWhenEnforced( + absl::Span solution) = 0; + + // Returns the delta if var changes from old_value to solution[var], assuming + // that the constraint was and stays enforced after the change. + virtual int64_t ViolationDeltaWhenEnforced( + int var, int64_t old_value, + absl::Span solution_with_new_value); + private: const ConstraintProto& ct_proto_; }; @@ -470,9 +488,10 @@ class CompiledBoolXorConstraint : public CompiledConstraintWithProto { explicit CompiledBoolXorConstraint(const ConstraintProto& ct_proto); ~CompiledBoolXorConstraint() override = default; - int64_t ComputeViolation(absl::Span solution) override; - int64_t ViolationDelta( - int /*var*/, int64_t /*old_value*/, + int64_t ComputeViolationWhenEnforced( + absl::Span solution) override; + int64_t ViolationDeltaWhenEnforced( + int var, int64_t old_value, absl::Span solution_with_new_value) override; }; @@ -485,7 +504,8 @@ class CompiledLinMaxConstraint : public CompiledConstraintWithProto { explicit CompiledLinMaxConstraint(const ConstraintProto& ct_proto); ~CompiledLinMaxConstraint() override = default; - int64_t ComputeViolation(absl::Span solution) override; + int64_t ComputeViolationWhenEnforced( + absl::Span solution) override; }; // The violation of an int_prod constraint is @@ -495,7 +515,8 @@ class CompiledIntProdConstraint : public CompiledConstraintWithProto { explicit CompiledIntProdConstraint(const ConstraintProto& ct_proto); ~CompiledIntProdConstraint() override = default; - int64_t ComputeViolation(absl::Span solution) override; + int64_t ComputeViolationWhenEnforced( + absl::Span solution) override; }; // The violation of an int_div constraint is @@ -505,7 +526,8 @@ class CompiledIntDivConstraint : public CompiledConstraintWithProto { explicit CompiledIntDivConstraint(const ConstraintProto& ct_proto); ~CompiledIntDivConstraint() override = default; - int64_t ComputeViolation(absl::Span solution) override; + int64_t ComputeViolationWhenEnforced( + absl::Span solution) override; }; // The violation of an int_mod constraint is defined as follow: @@ -525,7 +547,8 @@ class CompiledIntModConstraint : public CompiledConstraintWithProto { explicit CompiledIntModConstraint(const ConstraintProto& ct_proto); ~CompiledIntModConstraint() override = default; - int64_t ComputeViolation(absl::Span solution) override; + int64_t ComputeViolationWhenEnforced( + absl::Span solution) override; }; // The violation of a all_diff is the number of unordered pairs of expressions @@ -535,7 +558,8 @@ class CompiledAllDiffConstraint : public CompiledConstraintWithProto { explicit CompiledAllDiffConstraint(const ConstraintProto& ct_proto); ~CompiledAllDiffConstraint() override = default; - int64_t ComputeViolation(absl::Span solution) override; + int64_t ComputeViolationWhenEnforced( + absl::Span solution) override; private: std::vector values_; @@ -618,7 +642,8 @@ class CompiledNoOverlap2dConstraint : public CompiledConstraintWithProto { const CpModelProto& cp_model); ~CompiledNoOverlap2dConstraint() override = default; - int64_t ComputeViolation(absl::Span solution) override; + int64_t ComputeViolationWhenEnforced( + absl::Span solution) override; private: const CpModelProto& cp_model_; diff --git a/ortools/sat/constraint_violation_test.cc b/ortools/sat/constraint_violation_test.cc index eb1199c043..35209a5bce 100644 --- a/ortools/sat/constraint_violation_test.cc +++ b/ortools/sat/constraint_violation_test.cc @@ -231,6 +231,51 @@ TEST(ConstraintViolationTest, BasicBoolXorExample) { EXPECT_EQ(1, ct.ComputeViolation({1, 0, 0})); } +TEST(ConstraintViolationTest, ComputeViolationWithEnforcementLiteral) { + const ConstraintProto ct_proto = + ParseTestProto(R"pb(enforcement_literal: 0 + bool_xor { literals: [ 1, 2 ] })pb"); + CompiledBoolXorConstraint ct(ct_proto); + EXPECT_EQ(0, ct.ComputeViolation({0, 1, 1})); // Not enforced. + EXPECT_EQ(1, ct.ComputeViolation({1, 1, 1})); // Enforced. +} + +TEST(ConstraintViolationTest, ViolationDeltaWithEnforcementLiteral) { + const ConstraintProto ct_proto = + ParseTestProto(R"pb(enforcement_literal: 0 + enforcement_literal: 1 + bool_xor { literals: [ 2, 3 ] })pb"); + CompiledBoolXorConstraint ct(ct_proto); + ct.InitializeViolation({0, 0, 0, 0}); + EXPECT_EQ(0, ct.violation()); + + // Was not enforced and stays unenforced: no change. + EXPECT_EQ(0, ct.ViolationDelta(0, 0, {1, 0, 1, 1})); + ct.PerformMove(0, 0, {1, 0, 1, 1}); + EXPECT_EQ(0, ct.violation()); + + // Was not enforced and becomes enforced and violated. + EXPECT_EQ(1, ct.ViolationDelta(1, 0, {1, 1, 1, 1})); + ct.PerformMove(1, 0, {1, 1, 1, 1}); + EXPECT_EQ(1, ct.violation()); + + // Was enforced and violated, becomes unenforced. + EXPECT_EQ(-1, ct.ViolationDelta(0, 1, {0, 1, 1, 1})); + ct.PerformMove(0, 1, {0, 1, 1, 1}); + EXPECT_EQ(0, ct.violation()); +} + +TEST(ConstraintViolationTest, ViolationDeltaWhenEnforced) { + const ConstraintProto ct_proto = + ParseTestProto(R"pb(enforcement_literal: 0 + enforcement_literal: 1 + bool_xor { literals: [ 2, 3 ] })pb"); + CompiledBoolXorConstraint ct(ct_proto); + ct.InitializeViolation({1, 1, 0, 1}); + EXPECT_EQ(0, ct.violation()); + EXPECT_EQ(1, ct.ViolationDelta(2, 0, {1, 1, 1, 1})); +} + TEST(ConstraintViolationTest, BasicLinMaxExampleNoViolation) { const CpModelProto model = ParseTestProto(R"pb( variables { domain: [ 0, 1 ] } diff --git a/ortools/sat/cp_constraints.cc b/ortools/sat/cp_constraints.cc index 80c29cf3f3..341ca496bd 100644 --- a/ortools/sat/cp_constraints.cc +++ b/ortools/sat/cp_constraints.cc @@ -14,19 +14,386 @@ #include "ortools/sat/cp_constraints.h" #include +#include +#include +#include #include +#include "absl/log/check.h" #include "absl/types/span.h" +#include "ortools/base/stl_util.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" #include "ortools/util/strong_integers.h" namespace operations_research { namespace sat { +std::ostream& operator<<(std::ostream& os, const EnforcementStatus& e) { + switch (e) { + case EnforcementStatus::IS_FALSE: + os << "IS_FALSE"; + break; + case EnforcementStatus::CANNOT_PROPAGATE: + os << "CANNOT_PROPAGATE"; + break; + case EnforcementStatus::CAN_PROPAGATE: + os << "CAN_PROPAGATE"; + break; + case EnforcementStatus::IS_ENFORCED: + os << "IS_ENFORCED"; + break; + } + return os; +} + +EnforcementPropagator::EnforcementPropagator(Model* model) + : SatPropagator("EnforcementPropagator"), + trail_(*model->GetOrCreate()), + assignment_(trail_.Assignment()), + integer_trail_(model->GetOrCreate()), + rev_int_repository_(model->GetOrCreate()) { + // Note that this will be after the integer trail since rev_int_repository_ + // depends on IntegerTrail. + model->GetOrCreate()->AddPropagator(this); + + // Sentinel - also start of next Register(). + starts_.push_back(0); +} + +bool EnforcementPropagator::Propagate(Trail* /*trail*/) { + rev_int_repository_->SaveStateWithStamp(&rev_stack_size_, &rev_stamp_); + while (propagation_trail_index_ < trail_.Index()) { + const Literal literal = trail_[propagation_trail_index_++]; + if (literal.Index() >= static_cast(watcher_.size())) continue; + + int new_size = 0; + auto& watch_list = watcher_[literal.Index()]; + for (const EnforcementId id : watch_list) { + const LiteralIndex index = ProcessIdOnTrue(literal, id); + if (index == kNoLiteralIndex) { + // We keep the same watcher. + watch_list[new_size++] = id; + } else { + // Change the watcher. + CHECK_NE(index, literal.Index()); + watcher_[index].push_back(id); + } + } + watch_list.resize(new_size); + + // We also mark some constraint false. + for (const EnforcementId id : watcher_[literal.NegatedIndex()]) { + ChangeStatus(id, EnforcementStatus::IS_FALSE); + } + } + rev_stack_size_ = static_cast(untrail_stack_.size()); + + // Compute the enforcement status of any constraint added at a positive level. + // This is only needed until we are back to level zero. + for (const EnforcementId id : ids_to_fix_until_next_root_level_) { + ChangeStatus(id, DebugStatus(id)); + } + if (trail_.CurrentDecisionLevel() == 0) { + ids_to_fix_until_next_root_level_.clear(); + } + + return true; +} + +void EnforcementPropagator::Untrail(const Trail& /*trail*/, int trail_index) { + // Simply revert the status change. + const int size = static_cast(untrail_stack_.size()); + for (int i = size - 1; i >= rev_stack_size_; --i) { + const auto [id, status] = untrail_stack_[i]; + statuses_[id] = status; + if (callbacks_[id] != nullptr) callbacks_[id](id, status); + } + untrail_stack_.resize(rev_stack_size_); + propagation_trail_index_ = trail_index; +} + +// Adds a new constraint to the class and returns the constraint id. +// +// Note that we accept empty enforcement list so that client code can be used +// regardless of the presence of enforcement or not. A negative id means the +// constraint is never enforced, and should be ignored. +EnforcementId EnforcementPropagator::Register( + absl::Span enforcement, + std::function callback) { + int num_true = 0; + int num_false = 0; + temp_literals_.clear(); + const int level = trail_.CurrentDecisionLevel(); + for (const Literal l : enforcement) { + // Make sure we always have enough room for the literal and its negation. + const int size = std::max(l.Index().value(), l.NegatedIndex().value()) + 1; + if (size > static_cast(watcher_.size())) { + watcher_.resize(size); + } + if (assignment_.LiteralIsTrue(l)) { + if (level == 0 || trail_.Info(l.Variable()).level == 0) continue; + ++num_true; + } else if (assignment_.LiteralIsFalse(l)) { + ++num_false; + } + temp_literals_.push_back(l); + } + gtl::STLSortAndRemoveDuplicates(&temp_literals_); + + // Return special index if always enforced. + if (temp_literals_.empty()) { + if (callback != nullptr) + callback(EnforcementId(-1), EnforcementStatus::IS_ENFORCED); + return EnforcementId(-1); + } + + const EnforcementId id(static_cast(callbacks_.size())); + callbacks_.push_back(std::move(callback)); + + CHECK(!temp_literals_.empty()); + buffer_.insert(buffer_.end(), temp_literals_.begin(), temp_literals_.end()); + starts_.push_back(buffer_.size()); // Sentinel/next-start. + + // The default status at level zero. + statuses_.push_back(temp_literals_.size() == 1 + ? EnforcementStatus::CAN_PROPAGATE + : EnforcementStatus::CANNOT_PROPAGATE); + + if (temp_literals_.size() == 1) { + watcher_[temp_literals_[0].Index()].push_back(id); + } else { + // Make sure we watch correct literals. + const auto span = GetSpan(id); + int num_not_true = 0; + for (int i = 0; i < span.size(); ++i) { + if (assignment_.LiteralIsTrue(span[i])) continue; + std::swap(span[num_not_true], span[i]); + ++num_not_true; + if (num_not_true == 2) break; + } + + // We need to watch one of the literals at highest level. + if (num_not_true == 1) { + int max_level = trail_.Info(span[1].Variable()).level; + for (int i = 2; i < span.size(); ++i) { + const int level = trail_.Info(span[i].Variable()).level; + if (level > max_level) { + max_level = level; + std::swap(span[1], span[i]); + } + } + } + + watcher_[span[0].Index()].push_back(id); + watcher_[span[1].Index()].push_back(id); + } + + // Change status, call callback and set up untrail if the status is different + // from EnforcementStatus::CANNOT_PROPAGATE. + if (num_false > 0) { + ChangeStatus(id, EnforcementStatus::IS_FALSE); + } else if (num_true == temp_literals_.size()) { + ChangeStatus(id, EnforcementStatus::IS_ENFORCED); + } else if (num_true + 1 == temp_literals_.size()) { + ChangeStatus(id, EnforcementStatus::CAN_PROPAGATE); + // Because this is the default status, we still need to call the callback. + if (temp_literals_.size() == 1) { + if (callbacks_[id] != nullptr) { + callbacks_[id](id, EnforcementStatus::CAN_PROPAGATE); + } + } + } + + // Tricky: if we added something at a positive level, and its status is + // not CANNOT_PROPAGATE, then we might need to fix it on backtrack. + if (trail_.CurrentDecisionLevel() > 0 && + statuses_[id] != EnforcementStatus::CANNOT_PROPAGATE) { + ids_to_fix_until_next_root_level_.push_back(id); + } + + return id; +} + +EnforcementId EnforcementPropagator::Register( + absl::Span enforcement_literals, + GenericLiteralWatcher* watcher, int literal_watcher_id) { + return Register(enforcement_literals, + [=](EnforcementId, EnforcementStatus status) { + if (status == EnforcementStatus::CAN_PROPAGATE || + status == EnforcementStatus::IS_ENFORCED) { + watcher->CallOnNextPropagate(literal_watcher_id); + } + }); +} + +// Add the enforcement reason to the given vector. +void EnforcementPropagator::AddEnforcementReason( + EnforcementId id, std::vector* reason) const { + for (const Literal l : GetSpan(id)) { + reason->push_back(l.Negated()); + } +} + +// Try to propagate when the enforced constraint is not satisfiable. +// This is currently in O(enforcement_size); +bool EnforcementPropagator::PropagateWhenFalse( + EnforcementId id, absl::Span literal_reason, + absl::Span integer_reason) { + temp_reason_.clear(); + LiteralIndex unique_unassigned = kNoLiteralIndex; + for (const Literal l : GetSpan(id)) { + if (assignment_.LiteralIsFalse(l)) return true; + if (assignment_.LiteralIsTrue(l)) { + temp_reason_.push_back(l.Negated()); + continue; + } + if (unique_unassigned != kNoLiteralIndex) return true; + unique_unassigned = l.Index(); + } + + temp_reason_.insert(temp_reason_.end(), literal_reason.begin(), + literal_reason.end()); + if (unique_unassigned == kNoLiteralIndex) { + return integer_trail_->ReportConflict(temp_reason_, integer_reason); + } + + // We also change the status right away. + ChangeStatus(id, EnforcementStatus::IS_FALSE); + integer_trail_->EnqueueLiteral(Literal(unique_unassigned).Negated(), + temp_reason_, integer_reason); + return true; +} + +bool EnforcementPropagator::SafeEnqueue( + EnforcementId id, IntegerLiteral i_lit, + absl::Span integer_reason) { + temp_reason_.clear(); + AddEnforcementReason(id, &temp_reason_); + return integer_trail_->SafeEnqueue(i_lit, temp_reason_, integer_reason); +} + +bool EnforcementPropagator::ReportConflict( + EnforcementId id, absl::Span integer_reason) { + temp_reason_.clear(); + AddEnforcementReason(id, &temp_reason_); + return integer_trail_->ReportConflict(temp_reason_, integer_reason); +} + +absl::Span EnforcementPropagator::GetSpan(EnforcementId id) { + if (id < 0) return {}; + DCHECK_LE(id + 1, starts_.size()); + const int size = starts_[id + 1] - starts_[id]; + DCHECK_NE(size, 0); + return absl::MakeSpan(&buffer_[starts_[id]], size); +} + +absl::Span EnforcementPropagator::GetSpan( + EnforcementId id) const { + if (id < 0) return {}; + DCHECK_LE(id + 1, starts_.size()); + const int size = starts_[id + 1] - starts_[id]; + DCHECK_NE(size, 0); + return absl::MakeSpan(&buffer_[starts_[id]], size); +} + +LiteralIndex EnforcementPropagator::ProcessIdOnTrue(Literal watched, + EnforcementId id) { + const EnforcementStatus status = statuses_[id]; + if (status == EnforcementStatus::IS_FALSE) return kNoLiteralIndex; + + const auto span = GetSpan(id); + if (span.size() == 1) { + CHECK_EQ(status, EnforcementStatus::CAN_PROPAGATE); + ChangeStatus(id, EnforcementStatus::IS_ENFORCED); + return kNoLiteralIndex; + } + + const int watched_pos = (span[0] == watched) ? 0 : 1; + CHECK_EQ(span[watched_pos], watched); + if (assignment_.LiteralIsFalse(span[watched_pos ^ 1])) { + ChangeStatus(id, EnforcementStatus::IS_FALSE); + return kNoLiteralIndex; + } + + for (int i = 2; i < span.size(); ++i) { + const Literal l = span[i]; + if (assignment_.LiteralIsFalse(l)) { + ChangeStatus(id, EnforcementStatus::IS_FALSE); + return kNoLiteralIndex; + } + if (!assignment_.LiteralIsAssigned(l)) { + // Replace the watched literal. Note that if the other watched literal is + // true, it should be processed afterwards. We do not change the status + std::swap(span[watched_pos], span[i]); + return span[watched_pos].Index(); + } + } + + // All literal with index > 1 are true. Two case. + if (assignment_.LiteralIsTrue(span[watched_pos ^ 1])) { + // All literals are true. + ChangeStatus(id, EnforcementStatus::IS_ENFORCED); + return kNoLiteralIndex; + } else { + // The other watched literal is the last unassigned + CHECK_EQ(status, EnforcementStatus::CANNOT_PROPAGATE); + ChangeStatus(id, EnforcementStatus::CAN_PROPAGATE); + return kNoLiteralIndex; + } +} + +void EnforcementPropagator::ChangeStatus(EnforcementId id, + EnforcementStatus new_status) { + const EnforcementStatus old_status = statuses_[id]; + if (old_status == new_status) return; + if (trail_.CurrentDecisionLevel() != 0) { + untrail_stack_.push_back({id, old_status}); + } + statuses_[id] = new_status; + if (callbacks_[id] != nullptr) callbacks_[id](id, new_status); +} + +EnforcementStatus EnforcementPropagator::DebugStatus(EnforcementId id) { + if (id < 0) return EnforcementStatus::IS_ENFORCED; + + int num_true = 0; + for (const Literal l : GetSpan(id)) { + if (assignment_.LiteralIsFalse(l)) { + return EnforcementStatus::IS_FALSE; + } + if (assignment_.LiteralIsTrue(l)) ++num_true; + } + const int size = GetSpan(id).size(); + if (num_true == size) return EnforcementStatus::IS_ENFORCED; + if (num_true + 1 == size) return EnforcementStatus::CAN_PROPAGATE; + return EnforcementStatus::CANNOT_PROPAGATE; +} + +BooleanXorPropagator::BooleanXorPropagator( + absl::Span enforcement_literals, + const std::vector& literals, bool value, Model* model) + : literals_(literals), + value_(value), + trail_(model->GetOrCreate()), + integer_trail_(model->GetOrCreate()), + enforcement_propagator_(model->GetOrCreate()) { + GenericLiteralWatcher* watcher = model->GetOrCreate(); + enforcement_id_ = enforcement_propagator_->Register( + enforcement_literals, watcher, RegisterWith(watcher)); +} + bool BooleanXorPropagator::Propagate() { + const EnforcementStatus status = + enforcement_propagator_->Status(enforcement_id_); + if (status == EnforcementStatus::IS_FALSE || + status == EnforcementStatus::CANNOT_PROPAGATE) { + return true; + } + bool sum = false; int unassigned_index = -1; for (int i = 0; i < literals_.size(); ++i) { @@ -43,8 +410,10 @@ bool BooleanXorPropagator::Propagate() { } // Propagates? - if (unassigned_index != -1) { + if (status == EnforcementStatus::IS_ENFORCED && unassigned_index != -1) { literal_reason_.clear(); + enforcement_propagator_->AddEnforcementReason(enforcement_id_, + &literal_reason_); for (int i = 0; i < literals_.size(); ++i) { if (i == unassigned_index) continue; const Literal l = literals_[i]; @@ -56,6 +425,15 @@ bool BooleanXorPropagator::Propagate() { literal_reason_, {}); return true; } + if (status == EnforcementStatus::CAN_PROPAGATE && unassigned_index == -1 && + sum != value_) { + return enforcement_propagator_->PropagateWhenFalse(enforcement_id_, + literals_, + /*integer_reason=*/{}); + } + if (status != EnforcementStatus::IS_ENFORCED || unassigned_index != -1) { + return true; + } // Ok. if (sum == value_) return true; @@ -63,20 +441,21 @@ bool BooleanXorPropagator::Propagate() { // Conflict. std::vector* conflict = trail_->MutableConflict(); conflict->clear(); - for (int i = 0; i < literals_.size(); ++i) { - const Literal l = literals_[i]; + enforcement_propagator_->AddEnforcementReason(enforcement_id_, conflict); + for (const Literal& l : literals_) { conflict->push_back(trail_->Assignment().LiteralIsFalse(l) ? l : l.Negated()); } return false; } -void BooleanXorPropagator::RegisterWith(GenericLiteralWatcher* watcher) { +int BooleanXorPropagator::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); for (const Literal& l : literals_) { watcher->WatchLiteral(l, id); watcher->WatchLiteral(l.Negated(), id); } + return id; } GreaterThanAtLeastOneOfPropagator::GreaterThanAtLeastOneOfPropagator( diff --git a/ortools/sat/cp_constraints.h b/ortools/sat/cp_constraints.h index a88b635d8c..4e8a296366 100644 --- a/ortools/sat/cp_constraints.h +++ b/ortools/sat/cp_constraints.h @@ -16,19 +16,142 @@ #include #include +#include +#include #include +#include "absl/base/attributes.h" +#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/types/span.h" #include "ortools/base/logging.h" +#include "ortools/base/strong_vector.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" +#include "ortools/util/strong_integers.h" namespace operations_research { namespace sat { +DEFINE_STRONG_INDEX_TYPE(EnforcementId); + +// An enforced constraint can be in one of these 4 states. +// Note that we rely on the integer encoding to take 2 bits for optimization. +enum class EnforcementStatus { + // One enforcement literal is false. + IS_FALSE = 0, + // More than two literals are unassigned. + CANNOT_PROPAGATE = 1, + // All enforcement literals are true but one. + CAN_PROPAGATE = 2, + // All enforcement literals are true. + IS_ENFORCED = 3, +}; + +std::ostream& operator<<(std::ostream& os, const EnforcementStatus& e); + +// This is meant as an helper to deal with enforcement for any constraint. +class EnforcementPropagator : public SatPropagator { + public: + explicit EnforcementPropagator(Model* model); + + // SatPropagator interface. + bool Propagate(Trail* trail) final; + void Untrail(const Trail& trail, int trail_index) final; + + // Adds a new constraint to the class and register a callback that will + // be called on status change. Note that we also call the callback with the + // initial status if different from CANNOT_PROPAGATE when added. + // + // It is better to not call this for empty enforcement list, but you can. A + // negative id means the level zero status will never change, and only the + // first call to callback() should be necessary, we don't save it. + EnforcementId Register( + absl::Span enforcement, + std::function callback = nullptr); + + // Calls `Register` with a callback calling + // `watcher->CallOnNextPropagate(literal_watcher_id)` if a propagation might + // be possible. + EnforcementId Register(absl::Span enforcement_literals, + GenericLiteralWatcher* watcher, + int literal_watcher_id); + + // Add the enforcement reason to the given vector. + void AddEnforcementReason(EnforcementId id, + std::vector* reason) const; + + // Try to propagate when the enforced constraint is not satisfiable. + // This is currently in O(enforcement_size). + ABSL_MUST_USE_RESULT bool PropagateWhenFalse( + EnforcementId id, absl::Span literal_reason, + absl::Span integer_reason); + + ABSL_MUST_USE_RESULT bool SafeEnqueue( + EnforcementId id, IntegerLiteral i_lit, + absl::Span integer_reason); + + bool ReportConflict(EnforcementId id, + absl::Span integer_reason); + + EnforcementStatus Status(EnforcementId id) const { + if (id < 0) return EnforcementStatus::IS_ENFORCED; + return statuses_[id]; + } + + // Recompute the status from the current assignment. + // This should only used in DCHECK(). + EnforcementStatus DebugStatus(EnforcementId id); + + // Returns the enforcement literals of the given id. + absl::Span GetEnforcementLiterals(EnforcementId id) const { + if (id < 0) return {}; + return GetSpan(id); + } + + private: + absl::Span GetSpan(EnforcementId id); + absl::Span GetSpan(EnforcementId id) const; + void ChangeStatus(EnforcementId id, EnforcementStatus new_status); + + // Returns kNoLiteralIndex if nothing need to change or a new literal to + // watch. This also calls the registered callback. + LiteralIndex ProcessIdOnTrue(Literal watched, EnforcementId id); + + // External classes. + const Trail& trail_; + const VariablesAssignment& assignment_; + IntegerTrail* integer_trail_; + RevIntRepository* rev_int_repository_; + + // All enforcement will be copied there, and we will create Span out of this. + // Note that we don't store the span so that we are not invalidated on buffer_ + // resizing. + util_intops::StrongVector starts_; + std::vector buffer_; + + util_intops::StrongVector statuses_; + util_intops::StrongVector< + EnforcementId, std::function> + callbacks_; + + // Used to restore status and call callback on untrail. + std::vector> untrail_stack_; + int rev_stack_size_ = 0; + int64_t rev_stamp_ = 0; + + // We use a two watcher scheme. + util_intops::StrongVector> + watcher_; + + std::vector temp_literals_; + std::vector temp_reason_; + + std::vector ids_to_fix_until_next_root_level_; +}; + // Propagate the fact that a XOR of literals is equal to the given value. // The complexity is in O(n). // @@ -36,26 +159,26 @@ namespace sat { // faster. class BooleanXorPropagator : public PropagatorInterface { public: - BooleanXorPropagator(const std::vector& literals, bool value, - Trail* trail, IntegerTrail* integer_trail) - : literals_(literals), - value_(value), - trail_(trail), - integer_trail_(integer_trail) {} + BooleanXorPropagator(absl::Span enforcement_literals, + const std::vector& literals, bool value, + Model* model); // This type is neither copyable nor movable. BooleanXorPropagator(const BooleanXorPropagator&) = delete; BooleanXorPropagator& operator=(const BooleanXorPropagator&) = delete; bool Propagate() final; - void RegisterWith(GenericLiteralWatcher* watcher); private: + int RegisterWith(GenericLiteralWatcher* watcher); + const std::vector literals_; const bool value_; std::vector literal_reason_; Trail* trail_; IntegerTrail* integer_trail_; + EnforcementPropagator* enforcement_propagator_; + EnforcementId enforcement_id_; }; // If we have: @@ -119,14 +242,11 @@ inline std::vector ToIntegerValueVector( // Enforces the XOR of a set of literals to be equal to the given value. inline std::function LiteralXorIs( + const std::vector& enforcement_literals, const std::vector& literals, bool value) { return [=](Model* model) { - Trail* trail = model->GetOrCreate(); - IntegerTrail* integer_trail = model->GetOrCreate(); - BooleanXorPropagator* constraint = - new BooleanXorPropagator(literals, value, trail, integer_trail); - constraint->RegisterWith(model->GetOrCreate()); - model->TakeOwnership(constraint); + model->TakeOwnership( + new BooleanXorPropagator(enforcement_literals, literals, value, model)); }; } diff --git a/ortools/sat/cp_constraints_test.cc b/ortools/sat/cp_constraints_test.cc index c71cabd4b6..88efa3d167 100644 --- a/ortools/sat/cp_constraints_test.cc +++ b/ortools/sat/cp_constraints_test.cc @@ -32,18 +32,142 @@ namespace operations_research { namespace sat { namespace { +TEST(EnforcementPropagatorTest, BasicTest) { + Model model; + auto* sat_solver = model.GetOrCreate(); + auto* trail = model.GetOrCreate(); + auto* propag = model.GetOrCreate(); + sat_solver->SetNumVariables(10); + + const EnforcementId id1 = propag->Register(Literals({+1})); + const EnforcementId id2 = propag->Register(Literals({+1, +2})); + const EnforcementId id3 = propag->Register(Literals({-2})); + + EXPECT_TRUE(propag->Propagate(trail)); + EXPECT_EQ(propag->Status(id1), EnforcementStatus::CAN_PROPAGATE); + EXPECT_EQ(propag->Status(id2), EnforcementStatus::CANNOT_PROPAGATE); + EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); + + sat_solver->EnqueueDecisionIfNotConflicting(Literal(+1)); + EXPECT_TRUE(propag->Propagate(trail)); + EXPECT_EQ(propag->Status(id1), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id2), EnforcementStatus::CAN_PROPAGATE); + EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); + + sat_solver->EnqueueDecisionIfNotConflicting(Literal(+2)); + EXPECT_EQ(propag->Status(id1), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id2), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id3), EnforcementStatus::IS_FALSE); + + CHECK(sat_solver->ResetToLevelZero()); + EXPECT_EQ(propag->Status(id1), EnforcementStatus::CAN_PROPAGATE); + EXPECT_EQ(propag->Status(id2), EnforcementStatus::CANNOT_PROPAGATE); + EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); +} + +TEST(EnforcementPropagatorTest, UntrailWork) { + Model model; + auto* sat_solver = model.GetOrCreate(); + auto* trail = model.GetOrCreate(); + auto* propag = model.GetOrCreate(); + sat_solver->SetNumVariables(10); + + const EnforcementId id1 = propag->Register(Literals({+1})); + const EnforcementId id2 = propag->Register(Literals({+2})); + const EnforcementId id3 = propag->Register(Literals({+3})); + + EXPECT_TRUE(propag->Propagate(trail)); + EXPECT_EQ(propag->Status(id1), EnforcementStatus::CAN_PROPAGATE); + EXPECT_EQ(propag->Status(id2), EnforcementStatus::CAN_PROPAGATE); + EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); + + sat_solver->EnqueueDecisionIfNotConflicting(Literal(+1)); + EXPECT_TRUE(propag->Propagate(trail)); + EXPECT_EQ(propag->Status(id1), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id2), EnforcementStatus::CAN_PROPAGATE); + EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); + + sat_solver->EnqueueDecisionIfNotConflicting(Literal(+2)); + EXPECT_TRUE(propag->Propagate(trail)); + EXPECT_EQ(propag->Status(id1), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id2), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); + const int level = sat_solver->CurrentDecisionLevel(); + + sat_solver->EnqueueDecisionIfNotConflicting(Literal(+3)); + EXPECT_TRUE(propag->Propagate(trail)); + EXPECT_EQ(propag->Status(id1), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id2), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id3), EnforcementStatus::IS_ENFORCED); + + sat_solver->Backtrack(level); + EXPECT_EQ(propag->Status(id1), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id2), EnforcementStatus::IS_ENFORCED); + EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); +} + +TEST(EnforcementPropagatorTest, AddingAtPositiveLevelTrue) { + Model model; + auto* sat_solver = model.GetOrCreate(); + auto* trail = model.GetOrCreate(); + auto* propag = model.GetOrCreate(); + sat_solver->SetNumVariables(10); + + EXPECT_TRUE(propag->Propagate(trail)); + sat_solver->EnqueueDecisionIfNotConflicting(Literal(+1)); + EXPECT_TRUE(propag->Propagate(trail)); + + const EnforcementId id = propag->Register(std::vector{+1}); + EXPECT_EQ(propag->Status(id), EnforcementStatus::IS_ENFORCED); + + sat_solver->Backtrack(0); + EXPECT_TRUE(propag->Propagate(trail)); + EXPECT_EQ(propag->Status(id), EnforcementStatus::CAN_PROPAGATE); +} + +TEST(EnforcementPropagatorTest, AddingAtPositiveLevelFalse) { + Model model; + auto* sat_solver = model.GetOrCreate(); + auto* trail = model.GetOrCreate(); + auto* propag = model.GetOrCreate(); + sat_solver->SetNumVariables(10); + + EXPECT_TRUE(propag->Propagate(trail)); + sat_solver->EnqueueDecisionIfNotConflicting(Literal(-1)); + EXPECT_TRUE(propag->Propagate(trail)); + + const EnforcementId id = propag->Register(std::vector{+1}); + EXPECT_EQ(propag->Status(id), EnforcementStatus::IS_FALSE); + + sat_solver->Backtrack(0); + EXPECT_TRUE(propag->Propagate(trail)); + EXPECT_EQ(propag->Status(id), EnforcementStatus::CAN_PROPAGATE); +} + TEST(LiteralXorIsTest, OneVariable) { Model model; const BooleanVariable a = model.Add(NewBooleanVariable()); const BooleanVariable b = model.Add(NewBooleanVariable()); - model.Add(LiteralXorIs({Literal(a, true)}, true)); - model.Add(LiteralXorIs({Literal(b, true)}, false)); + model.Add(LiteralXorIs({}, {Literal(a, true)}, true)); + model.Add(LiteralXorIs({}, {Literal(b, true)}, false)); SatSolver* solver = model.GetOrCreate(); EXPECT_TRUE(solver->Propagate()); EXPECT_TRUE(solver->Assignment().LiteralIsTrue(Literal(a, true))); EXPECT_TRUE(solver->Assignment().LiteralIsFalse(Literal(b, true))); } +TEST(LiteralXorIsTest, OneEnforcedVariable) { + Model model; + const BooleanVariable e = model.Add(NewBooleanVariable()); + const BooleanVariable f = model.Add(NewBooleanVariable()); + model.Add(LiteralXorIs({Literal(e, true)}, {}, true)); + model.Add(LiteralXorIs({Literal(f, false)}, {}, true)); + SatSolver* solver = model.GetOrCreate(); + EXPECT_TRUE(solver->Propagate()); + EXPECT_TRUE(solver->Assignment().LiteralIsFalse(Literal(e, true))); + EXPECT_TRUE(solver->Assignment().LiteralIsFalse(Literal(f, false))); +} + // A simple macro to make the code more readable. #define EXPECT_BOUNDS_EQ(var, lb, ub) \ EXPECT_EQ(model.Get(LowerBound(var)), lb); \ diff --git a/ortools/sat/cp_model.proto b/ortools/sat/cp_model.proto index d52fd8bfe3..92e0d2fa63 100644 --- a/ortools/sat/cp_model.proto +++ b/ortools/sat/cp_model.proto @@ -330,8 +330,9 @@ message ConstraintProto { // (controlled by a literal l) and its negation (controlled by the negation of // l). // - // Important: as of September 2018, only a few constraint support enforcement: - // - bool_or, bool_and, linear: fully supported. + // Important: as of July 2025, only a few constraint support enforcement: + // - bool_or, bool_and, at_most_one, exactly_one, bool_xor, int_div, int_mod, + // int_prod, linear, table: fully supported. // - interval: only support a single enforcement literal. // - other: no support (but can be added on a per-demand basis). repeated int32 enforcement_literal = 2; @@ -355,10 +356,6 @@ message ConstraintProto { // bool_and constraint with n-1 term on the right hand side. So in a sense, // this constraint contribute directly to the "implication-graph" or the // 2-SAT part of the model. - // - // This constraint does not support enforcement_literal. Just use a linear - // constraint if you need to enforce it. You also do not need to use it - // directly, we will extract it from the model in most situations. BoolArgumentProto at_most_one = 26; // The exactly_one constraint force exactly one literal to true and no more. @@ -369,10 +366,6 @@ message ConstraintProto { // So in this sense, this constraint is not really needed. it is just here // for a better description of the problem structure and to facilitate some // algorithm. - // - // This constraint does not support enforcement_literal. Just use a linear - // constraint if you need to enforce it. You also do not need to use it - // directly, we will extract it from the model in most situations. BoolArgumentProto exactly_one = 29; // The bool_xor constraint forces an odd number of the literals to be true. diff --git a/ortools/sat/cp_model_checker.cc b/ortools/sat/cp_model_checker.cc index c35afde2bc..88ccd6485f 100644 --- a/ortools/sat/cp_model_checker.cc +++ b/ortools/sat/cp_model_checker.cc @@ -1145,6 +1145,15 @@ std::string ValidateCpModel(const CpModelProto& model, bool after_presolve) { case ConstraintProto::ConstraintCase::kBoolAnd: support_enforcement = true; break; + case ConstraintProto::ConstraintCase::kAtMostOne: + support_enforcement = true; + break; + case ConstraintProto::ConstraintCase::kExactlyOne: + support_enforcement = true; + break; + case ConstraintProto::ConstraintCase::kBoolXor: + support_enforcement = true; + break; case ConstraintProto::ConstraintCase::kLinear: support_enforcement = true; RETURN_IF_NOT_EMPTY(ValidateLinearConstraint(model, ct)); @@ -1158,12 +1167,15 @@ std::string ValidateCpModel(const CpModelProto& model, bool after_presolve) { break; } case ConstraintProto::ConstraintCase::kIntProd: + support_enforcement = true; RETURN_IF_NOT_EMPTY(ValidateIntProdConstraint(model, ct)); break; case ConstraintProto::ConstraintCase::kIntDiv: + support_enforcement = true; RETURN_IF_NOT_EMPTY(ValidateIntDivConstraint(model, ct)); break; case ConstraintProto::ConstraintCase::kIntMod: + support_enforcement = true; RETURN_IF_NOT_EMPTY(ValidateIntModConstraint(model, ct)); break; case ConstraintProto::ConstraintCase::kInverse: diff --git a/ortools/sat/cp_model_copy.cc b/ortools/sat/cp_model_copy.cc index 2ca0c8ed9c..590410ffcd 100644 --- a/ortools/sat/cp_model_copy.cc +++ b/ortools/sat/cp_model_copy.cc @@ -676,7 +676,33 @@ bool ModelCopy::CopyLinMax(const ConstraintProto& ct) { return true; } +namespace { +void LiteralsToLinear(absl::Span literals, int64_t lb, int64_t ub, + LinearConstraintProto* linear) { + for (const int lit : literals) { + if (RefIsPositive(lit)) { + linear->add_vars(lit); + linear->add_coeffs(1); + } else { + linear->add_vars(NegatedRef(lit)); + linear->add_coeffs(-1); + lb -= 1; + ub -= 1; + } + } + linear->add_domain(lb); + linear->add_domain(ub); +} +} // namespace + bool ModelCopy::CopyAtMostOne(const ConstraintProto& ct) { + if (!ct.enforcement_literal().empty()) { + ConstraintProto new_ct; + FinishEnforcementCopy(&new_ct); + LiteralsToLinear(ct.at_most_one().literals(), /*lb=*/0, /*ub=*/1, + new_ct.mutable_linear()); + return CopyLinear(new_ct, true); + } int num_true = 0; temp_literals_.clear(); for (const int lit : ct.at_most_one().literals()) { @@ -690,13 +716,19 @@ bool ModelCopy::CopyAtMostOne(const ConstraintProto& ct) { // TODO(user): presolve if num_true == 1. ConstraintProto* new_ct = context_->working_model->add_constraints(); - FinishEnforcementCopy(new_ct); new_ct->mutable_at_most_one()->mutable_literals()->Add(temp_literals_.begin(), temp_literals_.end()); return true; } bool ModelCopy::CopyExactlyOne(const ConstraintProto& ct) { + if (!ct.enforcement_literal().empty()) { + ConstraintProto new_ct; + FinishEnforcementCopy(&new_ct); + LiteralsToLinear(ct.exactly_one().literals(), /*lb=*/1, /*ub=*/1, + new_ct.mutable_linear()); + return CopyLinear(new_ct, true); + } int num_true = 0; temp_literals_.clear(); for (const int lit : ct.exactly_one().literals()) { @@ -710,7 +742,6 @@ bool ModelCopy::CopyExactlyOne(const ConstraintProto& ct) { // TODO(user): presolve if num_true == 1 and not everything is false. ConstraintProto* new_ct = context_->working_model->add_constraints(); - FinishEnforcementCopy(new_ct); new_ct->mutable_exactly_one()->mutable_literals()->Add(temp_literals_.begin(), temp_literals_.end()); return true; @@ -744,6 +775,7 @@ bool ModelCopy::CopyIntProd(const ConstraintProto& ct, bool ignore_names) { if (!ignore_names) { new_ct->set_name(ct.name()); } + FinishEnforcementCopy(new_ct); for (const LinearExpressionProto& expr : ct.int_prod().exprs()) { CopyLinearExpression(expr, new_ct->mutable_int_prod()->add_exprs()); } @@ -757,6 +789,7 @@ bool ModelCopy::CopyIntDiv(const ConstraintProto& ct, bool ignore_names) { if (!ignore_names) { new_ct->set_name(ct.name()); } + FinishEnforcementCopy(new_ct); for (const LinearExpressionProto& expr : ct.int_div().exprs()) { CopyLinearExpression(expr, new_ct->mutable_int_div()->add_exprs()); } @@ -770,6 +803,7 @@ bool ModelCopy::CopyIntMod(const ConstraintProto& ct, bool ignore_names) { if (!ignore_names) { new_ct->set_name(ct.name()); } + FinishEnforcementCopy(new_ct); for (const LinearExpressionProto& expr : ct.int_mod().exprs()) { CopyLinearExpression(expr, new_ct->mutable_int_mod()->add_exprs()); } diff --git a/ortools/sat/cp_model_copy_test.cc b/ortools/sat/cp_model_copy_test.cc index daad1d55d3..4675f66f41 100644 --- a/ortools/sat/cp_model_copy_test.cc +++ b/ortools/sat/cp_model_copy_test.cc @@ -209,6 +209,52 @@ TEST(ModelCopyTest, RemoveDuplicateFromEnforcementLiterals) { EXPECT_THAT(new_cp_model, EqualsProto(expected_moded)); } +TEST(ModelCopyTest, ChangeEnforcedAtMostOrExactlyOneToLinear) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: [ 0, 1 ] + at_most_one { literals: [ 2, -4 ] } + } + constraints { + enforcement_literal: [ 0, 1 ] + exactly_one { literals: [ 2, 3 ] } + } + )pb"); + const CpModelProto expected_moded = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: [ 0, 1 ] + linear { + vars: [ 2, 3 ] + coeffs: [ 1, -1 ] + domain: [ -1, 0 ] + } + } + constraints { + enforcement_literal: [ 0, 1 ] + linear { + vars: [ 2, 3 ] + coeffs: [ 1, 1 ] + domain: [ 1, 1 ] + } + } + )pb"); + CpModelProto new_cp_model; + Model model; + model.GetOrCreate() + ->set_keep_all_feasible_solutions_in_presolve(true); + PresolveContext context(&model, &new_cp_model, nullptr); + ImportModelWithBasicPresolveIntoContext(initial_model, &context); + EXPECT_THAT(new_cp_model, EqualsProto(expected_moded)); +} + } // namespace } // namespace sat } // namespace operations_research diff --git a/ortools/sat/cp_model_expand.cc b/ortools/sat/cp_model_expand.cc index b6d4c2404f..618a51b5bf 100644 --- a/ortools/sat/cp_model_expand.cc +++ b/ortools/sat/cp_model_expand.cc @@ -507,8 +507,12 @@ void ExpandIntProd(ConstraintProto* ct, PresolveContext* context) { context->DomainSuperSetOf(right)); const int new_var = context->NewIntVar(new_domain); new_vars.push_back(new_var); - LinearArgumentProto* const int_prod = - context->working_model->add_constraints()->mutable_int_prod(); + ConstraintProto* new_ct = context->working_model->add_constraints(); + // TODO(user): since we copy the enforcement literals in the final int + // prod constraint below, this is not strictly necessary. Is it better with + // or without? + *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); + LinearArgumentProto* const int_prod = new_ct->mutable_int_prod(); *int_prod->add_exprs() = left; *int_prod->add_exprs() = right; int_prod->mutable_target()->add_vars(new_var); @@ -517,8 +521,9 @@ void ExpandIntProd(ConstraintProto* ct, PresolveContext* context) { terms.front() = int_prod->target(); } - LinearArgumentProto* const final_int_prod = - context->working_model->add_constraints()->mutable_int_prod(); + ConstraintProto* new_ct = context->working_model->add_constraints(); + *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); + LinearArgumentProto* const final_int_prod = new_ct->mutable_int_prod(); *final_int_prod->add_exprs() = terms[0]; *final_int_prod->add_exprs() = terms[1]; *final_int_prod->mutable_target() = ct->int_prod().target(); diff --git a/ortools/sat/cp_model_expand_test.cc b/ortools/sat/cp_model_expand_test.cc index 49d637b94e..8114f6892f 100644 --- a/ortools/sat/cp_model_expand_test.cc +++ b/ortools/sat/cp_model_expand_test.cc @@ -547,6 +547,52 @@ TEST(IntModExpansionTest, ExpandIntModPreservesSolutionHint) { EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL); } +TEST(IntProdExpandTest, EnforcementLiterals) { + CpModelProto initial_model = ParseTestProto(R"pb( + variables { name: 'b' domain: 0 domain: 1 } + variables { name: 'x' domain: -100 domain: 100 } + variables { name: 'y' domain: -100 domain: 100 } + variables { name: 'z' domain: -100 domain: 100 } + constraints { + enforcement_literal: 0 + int_prod { + target { offset: 27 } + exprs { vars: 1 coeffs: 1 } + exprs { vars: 2 coeffs: 1 } + exprs { vars: 3 coeffs: 1 } + } + } + )pb"); + Model model; + PresolveContext context(&model, &initial_model, nullptr); + ExpandCpModel(&context); + + const CpModelProto expected_model = ParseTestProto(R"pb( + variables { name: "b" domain: 0 domain: 1 } + variables { name: "x" domain: -100 domain: 100 } + variables { name: "y" domain: -100 domain: 100 } + variables { name: "z" domain: -100 domain: 100 } + variables { domain: -10000 domain: 10000 } + constraints {} + constraints { + enforcement_literal: 0 + int_prod { + target { vars: 4 coeffs: 1 } + exprs { vars: 1 coeffs: 1 } + exprs { vars: 2 coeffs: 1 } + } + } + constraints { + enforcement_literal: 0 + int_prod { + target { offset: 27 } + exprs { vars: 4 coeffs: 1 } + exprs { vars: 3 coeffs: 1 } + } + })pb"); + EXPECT_THAT(initial_model, testing::EqualsProto(expected_model)); +} + TEST(IntProdExpandTest, LeftCase) { const CpModelProto initial_model = ParseTestProto(R"pb( variables { name: 'x' domain: -50 domain: -40 domain: 10 domain: 20 } diff --git a/ortools/sat/cp_model_loader.cc b/ortools/sat/cp_model_loader.cc index 3a691a1ddf..fcecb82670 100644 --- a/ortools/sat/cp_model_loader.cc +++ b/ortools/sat/cp_model_loader.cc @@ -60,6 +60,7 @@ #include "ortools/sat/sat_solver.h" #include "ortools/sat/symmetry.h" #include "ortools/sat/timetable.h" +#include "ortools/sat/util.h" #include "ortools/util/logging.h" #include "ortools/util/sorted_interval_list.h" #include "ortools/util/strong_integers.h" @@ -1038,8 +1039,8 @@ void LoadExactlyOneConstraint(const ConstraintProto& ct, Model* m) { void LoadBoolXorConstraint(const ConstraintProto& ct, Model* m) { auto* mapping = m->GetOrCreate(); - CHECK(!HasEnforcementLiteral(ct)) << "Not supported."; - m->Add(LiteralXorIs(mapping->Literals(ct.bool_xor().literals()), true)); + m->Add(LiteralXorIs(mapping->Literals(ct.enforcement_literal()), + mapping->Literals(ct.bool_xor().literals()), true)); } namespace { @@ -1130,8 +1131,6 @@ bool IsPartOfProductEncoding(const ConstraintProto& ct) { } // namespace -// TODO(user): We could use a smarter way to determine buckets, like putting -// everyone with the same coeff together if possible and the split is ok. void SplitAndLoadIntermediateConstraints(bool lb_required, bool ub_required, std::vector* vars, std::vector* coeffs, @@ -1143,25 +1142,61 @@ void SplitAndLoadIntermediateConstraints(bool lb_required, bool ub_required, ub_required = true; } + // We sort by absolute value of coefficients. The separate - from +, and then + // by variable order, usually variable with the same "meaning" are defined + // together in a model. + const int num_terms = vars->size(); + std::vector> terms; + { + terms.reserve(num_terms); + for (int i = 0; i < num_terms; ++i) { + terms.push_back({(*vars)[i], (*coeffs)[i]}); + } + std::sort(terms.begin(), terms.end(), + [](const std::pair a, + const std::pair b) { + const int64_t abs_coeff_a = std::abs(a.second); + const int64_t abs_coeff_b = std::abs(b.second); + if (abs_coeff_a != abs_coeff_b) { + return abs_coeff_a < abs_coeff_b; + } + if (a.second != b.second) { + return a.second < b.second; + } + return a.first < b.first; + }); + } + std::vector sorted_coeffs; + sorted_coeffs.resize(num_terms); + for (int i = 0; i < num_terms; ++i) { + sorted_coeffs[i] = terms[i].second; + } + const std::vector> buckets = + HeuristicallySplitLongLinear(sorted_coeffs); + std::vector bucket_sum_vars; std::vector bucket_sum_coeffs; std::vector local_vars; std::vector local_coeffs; - int64_t i = 0; - const int64_t num_vars = vars->size(); - const int64_t num_buckets = static_cast(std::round(std::sqrt(num_vars))); auto* integer_trail = m->GetOrCreate(); - for (int64_t b = 0; b < num_buckets; ++b) { + for (const auto [start, size] : buckets) { + // Just keep the same variable if the size of that bucket is one. + if (size == 1) { + const auto [var, coeff] = terms[start]; + bucket_sum_vars.push_back(var); + bucket_sum_coeffs.push_back(coeff); + continue; + } + local_vars.clear(); local_coeffs.clear(); int64_t bucket_lb = 0; int64_t bucket_ub = 0; int64_t gcd = 0; - const int64_t limit = num_vars * (b + 1); - for (; i * num_buckets < limit; ++i) { - const IntegerVariable var = (*vars)[i]; - const int64_t coeff = (*coeffs)[i]; + + for (int i = 0; i < size; ++i) { + const auto [var, coeff] = terms[start + i]; gcd = std::gcd(gcd, std::abs(coeff)); local_vars.push_back(var); local_coeffs.push_back(coeff); @@ -1170,6 +1205,7 @@ void SplitAndLoadIntermediateConstraints(bool lb_required, bool ub_required, bucket_lb += std::min(term1, term2); bucket_ub += std::max(term1, term2); } + if (gcd == 0) continue; if (gcd > 1) { // Everything should be exactly divisible! @@ -1194,6 +1230,8 @@ void SplitAndLoadIntermediateConstraints(bool lb_required, bool ub_required, m->Add(WeightedSumLowerOrEqual(local_vars, local_coeffs, 0)); } } + + // Rewrite the constraint. *vars = bucket_sum_vars; *coeffs = bucket_sum_coeffs; } @@ -1479,8 +1517,22 @@ void LoadAllDiffConstraint(const ConstraintProto& ct, Model* m) { m->Add(AllDifferentOnBounds(expressions)); } +void LoadAlwaysFalseConstraint(const ConstraintProto& ct, Model* m) { + if (ct.enforcement_literal().empty()) { + m->GetOrCreate()->NotifyThatModelIsUnsat(); + } + ConstraintProto new_ct = ct; + BoolArgumentProto& bool_or = *new_ct.mutable_bool_or(); + for (const int literal : ct.enforcement_literal()) { + bool_or.add_literals(NegatedRef(literal)); + } + LoadBoolOrConstraint(new_ct, m); +} + void LoadIntProdConstraint(const ConstraintProto& ct, Model* m) { auto* mapping = m->GetOrCreate(); + const std::vector enforcement_literals = + mapping->Literals(ct.enforcement_literal()); const AffineExpression prod = mapping->Affine(ct.int_prod().target()); std::vector terms; for (const LinearExpressionProto& expr : ct.int_prod().exprs()) { @@ -1489,15 +1541,14 @@ void LoadIntProdConstraint(const ConstraintProto& ct, Model* m) { switch (terms.size()) { case 0: { auto* integer_trail = m->GetOrCreate(); - auto* sat_solver = m->GetOrCreate(); if (prod.IsConstant()) { if (prod.constant.value() != 1) { - sat_solver->NotifyThatModelIsUnsat(); + LoadAlwaysFalseConstraint(ct, m); } } else { if (!integer_trail->Enqueue(prod.LowerOrEqual(1)) || !integer_trail->Enqueue(prod.GreaterOrEqual(1))) { - sat_solver->NotifyThatModelIsUnsat(); + LoadAlwaysFalseConstraint(ct, m); } } break; @@ -1506,11 +1557,11 @@ void LoadIntProdConstraint(const ConstraintProto& ct, Model* m) { LinearConstraintBuilder builder(m, /*lb=*/0, /*ub=*/0); builder.AddTerm(prod, 1); builder.AddTerm(terms[0], -1); - LoadLinearConstraint(builder.Build(), m); + LoadConditionalLinearConstraint(enforcement_literals, builder.Build(), m); break; } case 2: { - m->Add(ProductConstraint(terms[0], terms[1], prod)); + m->Add(ProductConstraint(enforcement_literals, terms[0], terms[1], prod)); break; } default: { @@ -1523,11 +1574,14 @@ void LoadIntProdConstraint(const ConstraintProto& ct, Model* m) { void LoadIntDivConstraint(const ConstraintProto& ct, Model* m) { auto* integer_trail = m->GetOrCreate(); auto* mapping = m->GetOrCreate(); + const std::vector enforcement_literals = + mapping->Literals(ct.enforcement_literal()); const AffineExpression div = mapping->Affine(ct.int_div().target()); const AffineExpression num = mapping->Affine(ct.int_div().exprs(0)); const AffineExpression denom = mapping->Affine(ct.int_div().exprs(1)); if (integer_trail->IsFixed(denom)) { - m->Add(FixedDivisionConstraint(num, integer_trail->FixedValue(denom), div)); + m->Add(FixedDivisionConstraint(enforcement_literals, num, + integer_trail->FixedValue(denom), div)); } else { if (VLOG_IS_ON(1)) { LinearConstraintBuilder builder(m); @@ -1536,7 +1590,7 @@ void LoadIntDivConstraint(const ConstraintProto& ct, Model* m) { VLOG(1) << "Division " << ct << " can be linearized"; } } - m->Add(DivisionConstraint(num, denom, div)); + m->Add(DivisionConstraint(enforcement_literals, num, denom, div)); } } @@ -1544,12 +1598,15 @@ void LoadIntModConstraint(const ConstraintProto& ct, Model* m) { auto* mapping = m->GetOrCreate(); auto* integer_trail = m->GetOrCreate(); + const std::vector enforcement_literals = + mapping->Literals(ct.enforcement_literal()); const AffineExpression target = mapping->Affine(ct.int_mod().target()); const AffineExpression expr = mapping->Affine(ct.int_mod().exprs(0)); const AffineExpression mod = mapping->Affine(ct.int_mod().exprs(1)); CHECK(integer_trail->IsFixed(mod)); const IntegerValue fixed_modulo = integer_trail->FixedValue(mod); - m->Add(FixedModuloConstraint(expr, fixed_modulo, target)); + m->Add( + FixedModuloConstraint(enforcement_literals, expr, fixed_modulo, target)); } void LoadLinMaxConstraint(const ConstraintProto& ct, Model* m) { diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index c17f6da241..b5bd658cac 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -238,7 +238,6 @@ bool CpModelPresolver::PresolveEnforcementLiteral(ConstraintProto* ct) { bool CpModelPresolver::PresolveBoolXor(ConstraintProto* ct) { if (context_->ModelIsUnsat()) return false; - if (HasEnforcementLiteral(*ct)) return false; int new_size = 0; bool changed = false; @@ -270,12 +269,13 @@ bool CpModelPresolver::PresolveBoolXor(ConstraintProto* ct) { if (new_size == 0) { if (num_true_literals % 2 == 0) { - return context_->NotifyThatModelIsUnsat("bool_xor: always false"); + return MarkConstraintAsFalse(ct, "bool_xor: always false"); } else { context_->UpdateRuleStats("bool_xor: always true"); return RemoveConstraint(ct); } - } else if (new_size == 1) { // We can fix the only active literal. + } else if (new_size == 1 && !HasEnforcementLiteral(*ct)) { + // We can fix the only active literal. if (num_true_literals % 2 == 0) { if (!context_->SetLiteralToTrue(ct->bool_xor().literals(0))) { return context_->NotifyThatModelIsUnsat( @@ -294,7 +294,7 @@ bool CpModelPresolver::PresolveBoolXor(ConstraintProto* ct) { const int b = ct->bool_xor().literals(1); if (a == b) { if (num_true_literals % 2 == 0) { - return context_->NotifyThatModelIsUnsat("bool_xor: always false"); + return MarkConstraintAsFalse(ct, "bool_xor: always false"); } else { context_->UpdateRuleStats("bool_xor: always true"); return RemoveConstraint(ct); @@ -302,19 +302,21 @@ bool CpModelPresolver::PresolveBoolXor(ConstraintProto* ct) { } if (a == NegatedRef(b)) { if (num_true_literals % 2 == 1) { - return context_->NotifyThatModelIsUnsat("bool_xor: always false"); + return MarkConstraintAsFalse(ct, "bool_xor: always false"); } else { context_->UpdateRuleStats("bool_xor: always true"); return RemoveConstraint(ct); } } - if (num_true_literals % 2 == 0) { // a == not(b). - if (!context_->StoreBooleanEqualityRelation(a, NegatedRef(b))) { - return false; - } - } else { // a == b. - if (!context_->StoreBooleanEqualityRelation(a, b)) { - return false; + if (!HasEnforcementLiteral(*ct)) { + if (num_true_literals % 2 == 0) { // a == not(b). + if (!context_->StoreBooleanEqualityRelation(a, NegatedRef(b))) { + return false; + } + } else { // a == b. + if (!context_->StoreBooleanEqualityRelation(a, b)) { + return false; + } } } context_->UpdateNewConstraintsVariableUsage(); @@ -417,7 +419,7 @@ bool CpModelPresolver::PresolveBoolOr(ConstraintProto* ct) { // Note this function does not update the constraint graph. It assumes this is // done elsewhere. ABSL_MUST_USE_RESULT bool CpModelPresolver::MarkConstraintAsFalse( - ConstraintProto* ct) { + ConstraintProto* ct, const std::string& reason) { if (HasEnforcementLiteral(*ct)) { // Change the constraint to a bool_or. ct->mutable_bool_or()->clear_literals(); @@ -426,9 +428,10 @@ ABSL_MUST_USE_RESULT bool CpModelPresolver::MarkConstraintAsFalse( } ct->clear_enforcement_literal(); PresolveBoolOr(ct); + context_->UpdateRuleStats(reason); return true; } else { - return context_->NotifyThatModelIsUnsat(); + return context_->NotifyThatModelIsUnsat(reason); } } @@ -1578,13 +1581,15 @@ Domain EvaluateImpliedIntProdDomain(const LinearArgumentProto& expr, bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { if (context_->ModelIsUnsat()) return false; - if (HasEnforcementLiteral(*ct)) return false; // Start by restricting the domain of target. We will be more precise later. bool domain_modified = false; Domain implied_domain = EvaluateImpliedIntProdDomain(ct->int_prod(), *context_); - if (!context_->IntersectDomainWith(ct->int_prod().target(), implied_domain, + // TODO(user): if implied_domain and target domain are disjoint, mark the + // constraint as false. + if (!HasEnforcementLiteral(*ct) && + !context_->IntersectDomainWith(ct->int_prod().target(), implied_domain, &domain_modified)) { return false; } @@ -1594,6 +1599,8 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { // - The target is an affine linear with coefficient -1 or 1. // - The target does not appear in the rhs (no x = (a*x + b) * ...). // - The target domain covers all the possible range of the rhs. + // This can be done whether or not there are enforcement literals, even if + // they are used in the target or the rhs. if (ExpressionContainsSingleRef(ct->int_prod().target()) && context_->VariableIsUniqueAndRemovable(ct->int_prod().target().vars(0)) && std::abs(ct->int_prod().target().coeffs(0)) == 1) { @@ -1639,12 +1646,20 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { proto->mutable_exprs()->end()); if (ct->int_prod().exprs().empty() || constant_factor == 0) { - if (!context_->IntersectDomainWith(ct->int_prod().target(), - Domain(constant_factor))) { - return false; + if (!context_->DomainContains(ct->int_prod().target(), constant_factor)) { + return MarkConstraintAsFalse(ct, "int_prod: always false"); + } + if (!HasEnforcementLiteral(*ct)) { + if (!context_->IntersectDomainWith(ct->int_prod().target(), + Domain(constant_factor))) { + return false; + } + context_->UpdateRuleStats("int_prod: constant product"); + return RemoveConstraint(ct); + } else { + context_->UpdateRuleStats("TODO enforced int_prod: constant product"); + // Replace ct with an enforced linear "target == constant_factor". } - context_->UpdateRuleStats("int_prod: constant product"); - return RemoveConstraint(ct); } // If target is fixed to zero, we can forget the constant factor. @@ -1655,9 +1670,9 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { constant_factor = 1; } - // In this case, the only possible value that fit in the domains is zero. + // In this case, the only possible value that fits in the domains is zero. // We will check for UNSAT if zero is not achievable by the rhs below. - if (AtMinOrMaxInt64(constant_factor)) { + if (!HasEnforcementLiteral(*ct) && AtMinOrMaxInt64(constant_factor)) { context_->UpdateRuleStats("int_prod: overflow if non zero"); if (!context_->IntersectDomainWith(ct->int_prod().target(), Domain(0))) { return false; @@ -1665,18 +1680,19 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { constant_factor = 1; } - // Replace by linear if it cannot overflow. + // Replace with linear if it cannot overflow. if (ct->int_prod().exprs().size() == 1) { LinearExpressionProto* const target = ct->mutable_int_prod()->mutable_target(); - LinearConstraintProto* const lin = - context_->working_model->add_constraints()->mutable_linear(); + ConstraintProto* const new_ct = context_->working_model->add_constraints(); + *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); + LinearConstraintProto* const lin = new_ct->mutable_linear(); if (context_->IsFixed(*target)) { int64_t target_value = context_->FixedValue(*target); if (target_value % constant_factor != 0) { - return context_->NotifyThatModelIsUnsat( - "int_prod: product incompatible with fixed target"); + return MarkConstraintAsFalse( + ct, "int_prod: product incompatible with fixed target"); } // expression == target_value / constant_factor. lin->add_domain(target_value / constant_factor); @@ -1755,8 +1771,8 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { if (context_->IsFixed(old_target)) { const int64_t target_value = context_->FixedValue(old_target); if (target_value % constant_factor != 0) { - return context_->NotifyThatModelIsUnsat( - "int_prod: constant factor does not divide constant target"); + return MarkConstraintAsFalse( + ct, "int_prod: constant factor does not divide constant target"); } changed = true; proto->clear_target(); @@ -1788,8 +1804,8 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { new_coeff < absl::int128(std::numeric_limits::min()) || new_offset > absl::int128(std::numeric_limits::max()) || new_offset < absl::int128(std::numeric_limits::min())) { - return context_->NotifyThatModelIsUnsat( - "int_prod: overflow during simplification."); + return MarkConstraintAsFalse( + ct, "int_prod: overflow during simplification."); } // Rewrite the target. @@ -1806,7 +1822,8 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { const bool is_square = ct->int_prod().exprs_size() == 2 && LinearExpressionProtosAreEqual( ct->int_prod().exprs(0), ct->int_prod().exprs(1)); - if (!context_->IntersectDomainWith(ct->int_prod().target(), implied_domain, + if (!HasEnforcementLiteral(*ct) && + !context_->IntersectDomainWith(ct->int_prod().target(), implied_domain, &domain_modified)) { return false; } @@ -1821,7 +1838,8 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { DCHECK_GE(target_max, 0); const int64_t sqrt_max = FloorSquareRoot(target_max); bool expr_reduced = false; - if (!context_->IntersectDomainWith(ct->int_prod().exprs(0), + if (!HasEnforcementLiteral(*ct) && + !context_->IntersectDomainWith(ct->int_prod().exprs(0), {-sqrt_max, sqrt_max}, &expr_reduced)) { return false; } @@ -1837,11 +1855,17 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { if (LinearExpressionProtosAreEqual(a, b) && LinearExpressionProtosAreEqual( a, product)) { // x = x * x, only true for {0, 1}. - if (!context_->IntersectDomainWith(product, Domain(0, 1))) { - return false; + if (!HasEnforcementLiteral(*ct)) { + if (!context_->IntersectDomainWith(product, Domain(0, 1))) { + return false; + } + context_->UpdateRuleStats("int_square: fix variable to zero or one."); + return RemoveConstraint(ct); + } else { + context_->UpdateRuleStats( + "TODO enforced int_square: fix variable to zero or one."); + // Replace ct with an enforced linear "product in [0, 1]". } - context_->UpdateRuleStats("int_square: fix variable to zero or one."); - return RemoveConstraint(ct); } } @@ -1871,6 +1895,10 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { context_->working_model->add_constraints(); ConstraintProto* constraint_for_true = context_->working_model->add_constraints(); + *constraint_for_true->mutable_enforcement_literal() = + ct->enforcement_literal(); + *constraint_for_false->mutable_enforcement_literal() = + ct->enforcement_literal(); constraint_for_true->add_enforcement_literal(boolean_linear->vars(0)); constraint_for_false->add_enforcement_literal( NegatedRef(boolean_linear->vars(0))); @@ -1930,6 +1958,7 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { context_->UpdateRuleStats("int_prod: all Boolean."); { ConstraintProto* new_ct = context_->working_model->add_constraints(); + *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); new_ct->add_enforcement_literal(target); auto* arg = new_ct->mutable_bool_and(); for (const int lit : literals) { @@ -1938,6 +1967,7 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { } { ConstraintProto* new_ct = context_->working_model->add_constraints(); + *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); auto* arg = new_ct->mutable_bool_or(); arg->add_literals(target); for (const int lit : literals) { @@ -1950,6 +1980,8 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { bool CpModelPresolver::PresolveIntDiv(int c, ConstraintProto* ct) { if (context_->ModelIsUnsat()) return false; + // TODO(user): add support for this case. + if (HasEnforcementLiteral(*ct)) return false; const LinearExpressionProto target = ct->int_div().target(); const LinearExpressionProto expr = ct->int_div().exprs(0); @@ -2093,6 +2125,8 @@ bool CpModelPresolver::PresolveIntDiv(int c, ConstraintProto* ct) { bool CpModelPresolver::PresolveIntMod(int c, ConstraintProto* ct) { if (context_->ModelIsUnsat()) return false; + // TODO(user): add support for this case. + if (HasEnforcementLiteral(*ct)) return false; // TODO(user): Presolve f(X) = g(X) % fixed_mod. const LinearExpressionProto target = ct->int_mod().target(); @@ -8862,31 +8896,42 @@ void CpModelPresolver::MergeNoOverlapConstraints() { &cliques, SafeDoubleToInt64(context_->params().merge_no_overlap_work_limit())); - // Replace each no-overlap with an extended version, or remove if empty. + time_limit_->ResetHistory(); + int new_num_no_overlaps = 0; int new_num_intervals = 0; + for (int i = 0; i < cliques.size(); ++i) { + new_num_no_overlaps++; + new_num_intervals += cliques[i].size(); + } + + if (old_num_intervals == new_num_intervals && + old_num_no_overlaps == new_num_no_overlaps) { + return; + } + + // Remove previous no_overlap constraints and add the new recomputed ones. for (int i = 0; i < cliques.size(); ++i) { const int ct_index = disjunctive_index[i]; - ConstraintProto* ct = - context_->working_model->mutable_constraints(ct_index); - ct->Clear(); + if (RemoveConstraint( + context_->working_model->mutable_constraints(ct_index))) { + context_->UpdateConstraintVariableUsage(ct_index); + } + } + for (int i = 0; i < cliques.size(); ++i) { if (cliques[i].empty()) continue; + ConstraintProto* ct = context_->working_model->add_constraints(); for (const Literal l : cliques[i]) { CHECK(l.IsPositive()); ct->mutable_no_overlap()->add_intervals(l.Variable().value()); } - new_num_no_overlaps++; - new_num_intervals += cliques[i].size(); } - if (old_num_intervals != new_num_intervals || - old_num_no_overlaps != new_num_no_overlaps) { - VLOG(1) << absl::StrCat("Merged ", old_num_no_overlaps, " no-overlaps (", - old_num_intervals, " intervals) into ", - new_num_no_overlaps, " no-overlaps (", - new_num_intervals, " intervals)."); - context_->UpdateRuleStats("no_overlap: merged constraints"); - } - time_limit_->ResetHistory(); + VLOG(1) << absl::StrCat("Merged ", old_num_no_overlaps, " no-overlaps (", + old_num_intervals, " intervals) into ", + new_num_no_overlaps, " no-overlaps (", + new_num_intervals, " intervals)."); + context_->UpdateRuleStats("no_overlap: merged constraints"); + context_->UpdateNewConstraintsVariableUsage(); } // TODO(user): Should we take into account the exactly_one constraints? note diff --git a/ortools/sat/cp_model_presolve.h b/ortools/sat/cp_model_presolve.h index 9c7a706014..725f5f8a82 100644 --- a/ortools/sat/cp_model_presolve.h +++ b/ortools/sat/cp_model_presolve.h @@ -29,6 +29,7 @@ #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_mapping.h" #include "ortools/sat/diffn_util.h" +#include "ortools/sat/integer_base.h" #include "ortools/sat/presolve_context.h" #include "ortools/sat/presolve_util.h" #include "ortools/sat/sat_base.h" @@ -346,7 +347,8 @@ class CpModelPresolver { bool ExploitEquivalenceRelations(int c, ConstraintProto* ct); ABSL_MUST_USE_RESULT bool RemoveConstraint(ConstraintProto* ct); - ABSL_MUST_USE_RESULT bool MarkConstraintAsFalse(ConstraintProto* ct); + ABSL_MUST_USE_RESULT bool MarkConstraintAsFalse( + ConstraintProto* ct, const std::string& reason = ""); std::vector* postsolve_mapping_; PresolveContext* context_; diff --git a/ortools/sat/cp_model_presolve_test.cc b/ortools/sat/cp_model_presolve_test.cc index 0f04decacc..f73272ac0e 100644 --- a/ortools/sat/cp_model_presolve_test.cc +++ b/ortools/sat/cp_model_presolve_test.cc @@ -2554,6 +2554,42 @@ TEST(PresolveCpModelTest, IntProdWithLeftConstant) { EXPECT_THAT(presolved_model, testing::EqualsProto(expected_presolved_model)); } +TEST(PresolveCpModelTest, EnforcedIntProdWithLeftConstant) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ 10, 12 ] } + variables { domain: [ 2, 2 ] } + variables { domain: [ 0, 100 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: 3 + int_prod { + target { vars: 2 coeffs: 1 } + exprs { vars: 1 coeffs: 1 } + exprs { vars: 0 coeffs: 1 } + } + } + )pb"); + const CpModelProto expected_presolved_model = ParseTestProto(R"pb( + variables { domain: [ 10, 12 ] } + variables { domain: [ 2, 2 ] } + variables { domain: [ 0, 100 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: 3 + linear { + vars: 2 + vars: 0 + coeffs: 1 + coeffs: -2 + domain: [ 0, 0 ] + } + } + )pb"); + const CpModelProto presolved_model = + PresolveOneConstraint(initial_model, /*constraint_index=*/0); + EXPECT_THAT(presolved_model, testing::EqualsProto(expected_presolved_model)); +} + TEST(PresolveCpModelTest, IntProdWithRightConstant) { const CpModelProto initial_model = ParseTestProto(R"pb( variables { @@ -2655,6 +2691,42 @@ TEST(PresolveCpModelTest, IntProdWithConstantProduct) { EXPECT_THAT(expected_mapping_model, testing::EqualsProto(mapping_model)); } +TEST(PresolveCpModelTest, AlwaysFalseIntProd) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ 20, 30 ] } + variables { domain: [ 2, 2 ] } + variables { domain: [ 5, 5 ] } + constraints { + int_prod { + target { vars: 0 coeffs: 1 } + exprs { vars: 1 coeffs: 1 } + exprs { vars: 2 coeffs: 1 } + } + } + )pb"); + PresolveForTest(initial_model, SatParameters(), CpSolverStatus::INFEASIBLE); +} + +TEST(PresolveCpModelTest, EnforcedAlwaysFalseIntProd) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ 20, 30 ] } + variables { domain: [ 2, 2 ] } + variables { domain: [ 5, 5 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: 3 + int_prod { + target { vars: 0 coeffs: 1 } + exprs { vars: 1 coeffs: 1 } + exprs { vars: 2 coeffs: 1 } + } + } + )pb"); + const CpModelProto expected_presolved_model; + const CpModelProto presolved_model = PresolveForTest(initial_model); + EXPECT_THAT(presolved_model, testing::EqualsProto(expected_presolved_model)); +} + TEST(PresolveCpModelTest, IntProdWithOverflow) { const CpModelProto initial_model = ParseTestProto(R"pb( variables { domain: [ -100000000000, 100000000000 ] } @@ -2898,6 +2970,49 @@ TEST(PresolveCpModelTest, IntProdWithAffineRelation) { EXPECT_THAT(presolved_model, EqualsProto(expected_presolved_model)); } +TEST(PresolveCpModelTest, EnforcedIntProdWithAffineRelation) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ -10, 20 ] } + variables { domain: [ 0, 5 ] } + variables { domain: [ 0, 0, 3, 3, 6, 6, 9, 9 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: 3 + int_prod { + target { vars: 0 coeffs: 1 } + exprs { vars: 1 coeffs: 1 } + exprs { vars: 2 coeffs: 1 } + } + } + # Add this just to avoid triggering the rule of unused target variable. + objective { + vars: [ 0, 1, 3 ] + coeffs: [ 1, 1, -1 ] + } + )pb"); + + // The variable 2 is detected to be of the form 3 * new_var1. Subsequently, + // the product target is detected to be a multiple of 3, so its target is + // replaced by new_var2. The domain are computed accordingly. + CpModelProto presolved_model = PresolveForTest(initial_model); + presolved_model.clear_objective(); + const CpModelProto expected_presolved_model = ParseTestProto(R"pb( + variables { domain: [ 0, 9 ] } # This is old_var_0 / 3. + variables { domain: [ 0, 5 ] } + variables { domain: [ 0, 3 ] } # This is old_var_2 / 3. + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: 3 + int_prod { + target { vars: 0 coeffs: 1 offset: -3 } + exprs { vars: 1 coeffs: 1 } + exprs { vars: 2 coeffs: 1 } + } + } + )pb"); + EXPECT_THAT(presolved_model, EqualsProto(expected_presolved_model)); +} + TEST(PresolveCpModelTest, IntProdCoeffDividesTarget) { const CpModelProto initial_model = ParseTestProto(R"pb( variables { domain: [ 3, 9 ] } @@ -2928,6 +3043,40 @@ TEST(PresolveCpModelTest, IntProdCoeffDividesTarget) { EXPECT_THAT(presolved_model, testing::EqualsProto(expected_presolved_model)); } +TEST(PresolveCpModelTest, EnforcedIntProdCoeffDividesTarget) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ 3, 9 ] } + variables { domain: [ 1, 10 ] } + variables { domain: [ 0, 1000 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: 3 + int_prod { + target { vars: 2 coeffs: 10 offset: 20 } + exprs { vars: 0 coeffs: 1 offset: 3 } + exprs { vars: 1 coeffs: 5 } + } + } + )pb"); + const CpModelProto expected_presolved_model = ParseTestProto(R"pb( + variables { domain: [ 3, 9 ] } + variables { domain: [ 1, 10 ] } + variables { domain: [ 0, 1000 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: 3 + int_prod { + target { vars: 2 coeffs: 2 offset: 4 } + exprs { vars: 0 coeffs: 1 offset: 3 } + exprs { vars: 1 coeffs: 1 } + } + } + )pb"); + const CpModelProto presolved_model = + PresolveOneConstraint(initial_model, /*constraint_index=*/0); + EXPECT_THAT(presolved_model, testing::EqualsProto(expected_presolved_model)); +} + TEST(PresolveCpModelTest, IntProdGlobalGcd) { const CpModelProto initial_model = ParseTestProto(R"pb( variables { domain: [ 3, 9 ] } @@ -2986,6 +3135,32 @@ TEST(PresolveCpModelTest, NullProduct) { EXPECT_THAT(presolved_model, testing::EqualsProto(expected_presolved_model)); } +TEST(PresolveCpModelTest, EnforcedNullProduct) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ -10, 20 ] } + variables { domain: [ 0, 5 ] } + variables { domain: [ 0, 0 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: 3 + int_prod { + target { vars: 0 coeffs: 1 } + exprs { vars: 1 coeffs: 1 } + exprs { vars: 2 coeffs: 1 } + } + } + constraints { dummy_constraint { vars: [ 0, 1, 2 ] } } + )pb"); + + const CpModelProto expected_presolved_model = ParseTestProto(R"pb( + variables { domain: [ -10, 20 ] } + variables { domain: [ 0, 5 ] } # Many possible values here. + variables { domain: [ 0, 0 ] } + )pb"); + const CpModelProto presolved_model = PresolveForTest(initial_model); + EXPECT_THAT(presolved_model, testing::EqualsProto(expected_presolved_model)); +} + TEST(PresolveCpModelTest, BooleanProduct) { const CpModelProto initial_model = ParseTestProto(R"pb( variables { domain: [ 0, 1 ] } @@ -2993,7 +3168,9 @@ TEST(PresolveCpModelTest, BooleanProduct) { variables { domain: [ 0, 1 ] } variables { domain: [ 0, 1 ] } variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } constraints { + enforcement_literal: 5 int_prod { target { vars: 0 coeffs: 1 } exprs { vars: 1 coeffs: 1 } @@ -3009,17 +3186,13 @@ TEST(PresolveCpModelTest, BooleanProduct) { variables { domain: [ 0, 1 ] } variables { domain: [ 0, 1 ] } variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } constraints { - bool_or { literals: -4 literals: -2 literals: 0 literals: 2 literals: 4 } - } - constraints { - enforcement_literal: -2 - bool_and { literals: -1 } - } - constraints { + enforcement_literal: 5 enforcement_literal: 0 - bool_and { literals: -3 literals: 3 literals: -5 } + bool_and { literals: [ 1, -3, 3, -5 ] } } + constraints { bool_or { literals: [ -6, -4, -2, 0, 2, 4 ] } } )pb"); SatParameters params; params.set_keep_all_feasible_solutions_in_presolve(true); @@ -3063,6 +3236,43 @@ TEST(PresolveCpModelTest, AffineBooleanProduct) { EXPECT_THAT(presolved_model, testing::EqualsProto(expected_presolved_model)); } +TEST(PresolveCpModelTest, EnforcedAffineBooleanProduct) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ 0, 30 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: 3 + int_prod { + target { vars: 0 coeffs: 1 } + exprs { vars: 1 coeffs: 2 offset: 3 } + exprs { vars: 2 coeffs: 3 offset: 2 } + } + } + )pb"); + const CpModelProto expected_presolved_model = ParseTestProto(R"pb( + variables { domain: [ 0, 30 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: [ 3, -2 ] + linear { vars: 0 vars: 2 coeffs: 1 coeffs: -9 domain: 6 domain: 6 } + } + constraints { + enforcement_literal: [ 3, 1 ] + linear { vars: 0 vars: 2 coeffs: 1 coeffs: -15 domain: 10 domain: 10 } + } + )pb"); + SatParameters params; + params.set_keep_all_feasible_solutions_in_presolve(true); + params.set_permute_variable_randomly(false); + params.set_cp_model_probing_level(0); + const CpModelProto presolved_model = PresolveForTest(initial_model, params); + EXPECT_THAT(presolved_model, testing::EqualsProto(expected_presolved_model)); +} + TEST(PresolveCpModelTest, IntDivSimplification) { const CpModelProto initial_model = ParseTestProto(R"pb( variables { domain: [ 3, 20 ] } @@ -5683,6 +5893,82 @@ TEST(PresolveCpModelTest, OneActiveLiteralToFalseBoolXor) { EXPECT_THAT(expected_presolved_model, testing::EqualsProto(presolved_model)); } +TEST(PresolveCpModelTest, BoolXorNotPresolvedIfEnforcementUnknown) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 1, 1 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: 2 + bool_xor { literals: [ 0, 1 ] } + } + )pb"); + const CpModelProto presolved_model = PresolveOneConstraint(initial_model, 0); + EXPECT_THAT(presolved_model, testing::EqualsProto(initial_model)); +} + +TEST(PresolveCpModelTest, BoolXorChangedToBoolOrIfAlwaysFalseWhenEnforced) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: [ 0, 1, 2 ] + bool_xor {} + } + )pb"); + const CpModelProto presolved_model = PresolveOneConstraint(initial_model, 0); + const CpModelProto expected_presolved_model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + constraints { bool_or { literals: [ -1, -2, -3 ] } } + )pb"); + EXPECT_THAT(presolved_model, testing::EqualsProto(expected_presolved_model)); +} + +TEST(PresolveCpModelTest, BoolXorChangedToBoolOrIfAlwaysFalseWhenEnforced2) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: [ 0, 1, 2 ] + bool_xor { literals: [ 1, 1 ] } + } + )pb"); + const CpModelProto presolved_model = PresolveOneConstraint(initial_model, 0); + const CpModelProto expected_presolved_model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + constraints { bool_or { literals: [ -1, -2, -3 ] } } + )pb"); + EXPECT_THAT(presolved_model, testing::EqualsProto(expected_presolved_model)); +} + +TEST(PresolveCpModelTest, BoolXorChangedToBoolOrIfAlwaysFalseWhenEnforced3) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 1, 1 ] } + constraints { + enforcement_literal: [ 0, 1, 2 ] + bool_xor { literals: [ 1, -2, 3 ] } + } + )pb"); + const CpModelProto presolved_model = PresolveOneConstraint(initial_model, 0); + const CpModelProto expected_presolved_model = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + variables { domain: [ 1, 1 ] } + constraints { bool_or { literals: [ -1, -2, -3 ] } } + )pb"); + EXPECT_THAT(presolved_model, testing::EqualsProto(expected_presolved_model)); +} + TEST(PresolveCpModelTest, OneActiveLiteralToTrueBoolXor) { const CpModelProto initial_model = ParseTestProto(R"pb( variables { domain: [ 1, 1 ] } diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index e4c60ab3de..38db8b87a5 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -1076,7 +1076,9 @@ class FullProblemSolver : public SubSolver { previous_task_is_completed_ = false; } return [this]() { + auto* time_limit = local_model_.GetOrCreate(); if (solving_first_chunk_) { + const double init_dtime = time_limit->GetElapsedDeterministicTime(); LoadCpModel(shared_->model_proto, &local_model_); // Level zero variable bounds sharing. It is important to register @@ -1127,15 +1129,18 @@ class FullProblemSolver : public SubSolver { // No need for mutex since we only run one task at the time. solving_first_chunk_ = false; + // Make sure we count the loading/hint dtime. + absl::MutexLock mutex_lock(&mutex_); + dtime_since_last_sync_ += + time_limit->GetElapsedDeterministicTime() - init_dtime; + + // Abort first chunk and allow to schedule the next. if (split_in_chunks_) { - // Abort first chunk and allow to schedule the next. - absl::MutexLock mutex_lock(&mutex_); previous_task_is_completed_ = true; return; } } - auto* time_limit = local_model_.GetOrCreate(); if (split_in_chunks_) { // Configure time limit for chunk solving. Note that we do not want // to do that for the hint search for now. diff --git a/ortools/sat/cp_model_solver_helpers.cc b/ortools/sat/cp_model_solver_helpers.cc index e64e37ae16..9d6a72423f 100644 --- a/ortools/sat/cp_model_solver_helpers.cc +++ b/ortools/sat/cp_model_solver_helpers.cc @@ -1131,6 +1131,7 @@ void FillBinaryRelationRepository(const CpModelProto& model_proto, auto* mapping = model->GetOrCreate(); auto* repository = model->GetOrCreate(); auto* root_level_lin2_bounds = model->GetOrCreate(); + auto* reified_lin2_bounds = model->GetOrCreate(); for (const ConstraintProto& ct : model_proto.constraints()) { // Load conditional precedences and always true binary relations. @@ -1198,6 +1199,8 @@ void FillBinaryRelationRepository(const CpModelProto& model_proto, if (vars.size() == 2) { const LinearExpression2 expr(vars[0], vars[1], coeffs[0], coeffs[1]); root_level_lin2_bounds->Add(expr, rhs_min, rhs_max); + } else if (vars.size() == 3 && rhs_min == rhs_max) { + reified_lin2_bounds->AddLinear3(vars, coeffs, rhs_min); } } else { const Literal lit = mapping->Literal(ct.enforcement_literal(0)); @@ -1866,10 +1869,9 @@ void QuickSolveWithHint(const CpModelProto& model_proto, Model* model) { void MinimizeL1DistanceWithHint(const CpModelProto& model_proto, Model* model) { Model local_model; - // Forward some shared class. - local_model.Register( - model->GetOrCreate()); - local_model.Register(model->GetOrCreate()); + // Pass the time limit and stop boolean to local limit. + model->GetOrCreate()->UpdateLocalLimit( + local_model.GetOrCreate()); if (!model_proto.has_solution_hint()) return; @@ -1967,6 +1969,10 @@ void MinimizeL1DistanceWithHint(const CpModelProto& model_proto, Model* model) { shared_response_manager->NewSolution( solution, absl::StrCat(solution_info, " [repaired]"), &local_model); } + + // Make sure we update the higher model with the timing info. + model->GetOrCreate()->AdvanceDeterministicTime( + local_model.GetOrCreate()->GetElapsedDeterministicTime()); } // TODO(user): If this ever shows up in the profile, we could avoid copying diff --git a/ortools/sat/cp_model_solver_test.cc b/ortools/sat/cp_model_solver_test.cc index e3d719b400..a1ba69d2f1 100644 --- a/ortools/sat/cp_model_solver_test.cc +++ b/ortools/sat/cp_model_solver_test.cc @@ -741,6 +741,154 @@ TEST(SolveCpModelTest, ObjectiveDomainLowerBound) { } } +TEST(SolveCpModelTest, AtMostAndExactlyOneWithEnforcementLiteral) { + // a => at_most_one(a, a) + // not(b) => exactly_one(not(b), not(b)) + CpModelProto model_proto = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: 0 + at_most_one { literals: [ 0, 0 ] } + } + constraints { + enforcement_literal: -2 + exactly_one { literals: [ -2, -2 ] } + })pb"); + Model model; + model.Add(NewSatParameters("cp_model_presolve:false")); + const CpSolverResponse response = SolveCpModel(model_proto, &model); + EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL); + EXPECT_THAT(response.solution(), ::testing::ElementsAre(0, 1)); +} + +TEST(SolveCpModelTest, BoolXorWithEnforcementLiteral) { + // a => a xor b + // not(a) => a xor a + CpModelProto model_proto = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: 0 + bool_xor { literals: [ 0, 1 ] } + } + constraints { + enforcement_literal: -1 + bool_xor { literals: [ 0, 0 ] } + })pb"); + Model model; + model.Add(NewSatParameters("cp_model_presolve:false")); + const CpSolverResponse response = SolveCpModel(model_proto, &model); + EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL); + EXPECT_THAT(response.solution(), ::testing::ElementsAre(1, 0)); +} + +TEST(SolveCpModelTest, BoolXorWithEnforcementLiteralPresolved) { + // a => a xor b + // not(a) => a xor a + CpModelProto model_proto = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 1 ] } + constraints { + enforcement_literal: 0 + bool_xor { literals: [ 0, 1 ] } + } + constraints { + enforcement_literal: -1 + bool_xor { literals: [ 0, 0 ] } + })pb"); + Model model; + model.Add(NewSatParameters("cp_model_presolve:true")); + const CpSolverResponse response = SolveCpModel(model_proto, &model); + EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL); + EXPECT_THAT(response.solution(), ::testing::ElementsAre(1, 0)); +} + +TEST(SolveCpModelTest, IntDivWithEnforcementLiteral) { + // not(b) => 7x / 3y = 17, x in [0, 10], y in [1, 2] + CpModelProto model_proto = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 0, 10 ] } + variables { domain: [ 1, 2 ] } + constraints { + enforcement_literal: -1 + int_prod { + target { offset: 17 } + exprs { vars: 1 coeffs: 7 } + exprs { vars: 2 coeffs: 3 } + } + })pb"); + Model model; + model.Add(NewSatParameters("cp_model_presolve:false")); + const CpSolverResponse response = SolveCpModel(model_proto, &model); + EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL); + EXPECT_EQ(response.solution(0), 1); +} + +TEST(SolveCpModelTest, IntModWithEnforcementLiteral) { + // not(b) => x % 10 = y, x in [8, 11], y in [2, 7] + CpModelProto model_proto = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 2, 7 ] } + variables { domain: [ 8, 11 ] } + constraints { + enforcement_literal: -1 + int_prod { + target { vars: 1 coeffs: 1 } + exprs { vars: 2 coeffs: 1 } + exprs { offset: 10 } + } + })pb"); + Model model; + model.Add(NewSatParameters("cp_model_presolve:false")); + const CpSolverResponse response = SolveCpModel(model_proto, &model); + EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL); + EXPECT_EQ(response.solution(0), 1); +} + +TEST(SolveCpModelTest, IntProdWithEnforcementLiteral) { + // not(b) => x.y.z = 17 + CpModelProto model_proto = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 2, 20 ] } + variables { domain: [ 2, 20 ] } + variables { domain: [ 2, 20 ] } + constraints { + enforcement_literal: -1 + int_prod { + target { offset: 17 } + exprs { vars: 1 coeffs: 1 } + exprs { vars: 2 coeffs: 1 } + exprs { vars: 3 coeffs: 1 } + } + })pb"); + Model model; + model.Add(NewSatParameters("cp_model_presolve:false")); + const CpSolverResponse response = SolveCpModel(model_proto, &model); + EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL); + EXPECT_EQ(response.solution(0), 1); +} + +TEST(SolveCpModelTest, SquareIntProdWithEnforcementLiteral) { + // not(b) => x.y.z = 17 + CpModelProto model_proto = ParseTestProto(R"pb( + variables { domain: [ 0, 1 ] } + variables { domain: [ 2, 20 ] } + constraints { + enforcement_literal: -1 + int_prod { + target { offset: 17 } + exprs { vars: 1 coeffs: 1 } + exprs { vars: 1 coeffs: 1 } + } + })pb"); + Model model; + model.Add(NewSatParameters("cp_model_presolve:false")); + const CpSolverResponse response = SolveCpModel(model_proto, &model); + EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL); + EXPECT_EQ(response.solution(0), 1); +} + TEST(SolveCpModelTest, LinMaxObjectiveDomainLowerBoundInfeasible) { const CpModelProto model_proto = ParseTestProto(R"pb( variables { domain: [ 0, 5 ] } diff --git a/ortools/sat/cp_model_symmetries_test.cc b/ortools/sat/cp_model_symmetries_test.cc index 30ce5f293e..4eaed61aeb 100644 --- a/ortools/sat/cp_model_symmetries_test.cc +++ b/ortools/sat/cp_model_symmetries_test.cc @@ -496,7 +496,7 @@ TEST(FindCpModelSymmetries, FindsSymmetryInBoolXorWithEnforcementLiteral) { CpModelProto model = ParseTestProto(absl::StrCat(kBooleanModel, R"pb( constraints { enforcement_literal: 0 - bool_or { literals: [ 1, 2 ] } + bool_xor { literals: [ 1, 2 ] } } )pb")); diff --git a/ortools/sat/csharp/sat.i b/ortools/sat/csharp/sat.i index cfe50f0f8b..0239f350f8 100644 --- a/ortools/sat/csharp/sat.i +++ b/ortools/sat/csharp/sat.i @@ -112,6 +112,7 @@ JAGGED_MATRIX_AS_CSHARP_ARRAY(int64_t, int64_t, long, Int64VectorVector); %feature("director") operations_research::sat::SolutionCallback; %unignore operations_research::sat::SolutionCallback; +%unignore operations_research::sat::SolutionCallback::SolutionCallback; %unignore operations_research::sat::SolutionCallback::~SolutionCallback; %unignore operations_research::sat::SolutionCallback::BestObjectiveBound; %feature("nodirector") operations_research::sat::SolutionCallback::BestObjectiveBound; diff --git a/ortools/sat/disjunctive.cc b/ortools/sat/disjunctive.cc index c475f64cc4..869cf1b07e 100644 --- a/ortools/sat/disjunctive.cc +++ b/ortools/sat/disjunctive.cc @@ -14,10 +14,12 @@ #include "ortools/sat/disjunctive.h" #include +#include #include #include #include "absl/algorithm/container.h" +#include "absl/cleanup/cleanup.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/types/span.h" @@ -523,7 +525,7 @@ bool DisjunctiveOverloadChecker::Propagate() { // propagation would have been done by the linear propagator, but if we // didn't add such relations yet, it is beneficial to detect that here! // - // TODO(user): Actually, we just infered a "not last" so we could check + // TODO(user): Actually, we just inferred a "not last" so we could check // for relevant_size > 2 potential propagation? // // TODO(user): Can we detect and propagate all such relations easily and @@ -577,6 +579,12 @@ bool DisjunctiveOverloadChecker::Propagate() { return false; } + // Subwindow propagation might have propagated that the + // task_with_max_end_min must be absent. + if (helper_->IsAbsent(task_with_max_end_min.task_index)) { + task_with_max_end_min = {0, kMinIntegerValue}; + } + // Start of the next window. window_size = 0; window[window_size++] = {task, start_min}; @@ -1189,6 +1197,13 @@ bool DisjunctivePrecedences::Propagate() { } bool DisjunctivePrecedences::PropagateSubwindow() { + // This function can be slow, so count it in the dtime. + int64_t num_hash_lookup = 0; + auto cleanup = ::absl::MakeCleanup([&num_hash_lookup, this]() { + time_limit_->AdvanceDeterministicTime(static_cast(num_hash_lookup) * + 5e-8); + }); + // TODO(user): We shouldn't consider ends for fixed intervals here. But // then we should do a better job of computing the min-end of a subset of // intervals from this disjunctive (like using fixed intervals even if there @@ -1272,6 +1287,7 @@ bool DisjunctivePrecedences::PropagateSubwindow() { // TODO(user): The lookup here is a bit slow, so we avoid fetching // the offset as much as possible. + ++num_hash_lookup; const IntegerValue inner_offset = -linear2_bounds_->NonTrivialUpperBound(lin2_index); DCHECK_NE(inner_offset, kMinIntegerValue); diff --git a/ortools/sat/disjunctive.h b/ortools/sat/disjunctive.h index e50d022975..442c5ebfb6 100644 --- a/ortools/sat/disjunctive.h +++ b/ortools/sat/disjunctive.h @@ -355,6 +355,7 @@ class DisjunctivePrecedences : public PropagatorInterface { integer_trail_(model->GetOrCreate()), precedence_relations_(model->GetOrCreate()), linear2_bounds_(model->GetOrCreate()), + time_limit_(model->GetOrCreate()), stats_("DisjunctivePrecedences", model) { window_.ClearAndReserve(helper->NumTasks()); index_to_end_vars_.ClearAndReserve(helper->NumTasks()); @@ -372,6 +373,7 @@ class DisjunctivePrecedences : public PropagatorInterface { IntegerTrail* integer_trail_; EnforcedLinear2Bounds* precedence_relations_; Linear2Bounds* linear2_bounds_; + TimeLimit* time_limit_; FixedCapacityVector window_; FixedCapacityVector index_to_end_vars_; diff --git a/ortools/sat/docs/boolean_logic.md b/ortools/sat/docs/boolean_logic.md index c22ae86ee3..2f96caad1f 100644 --- a/ortools/sat/docs/boolean_logic.md +++ b/ortools/sat/docs/boolean_logic.md @@ -20,6 +20,7 @@ negation of `x`. ### Python code ```python +# Snippet from ortools/sat/samples/literal_sample_sat.py #!/usr/bin/env python3 """Code sample to demonstrate Boolean variable and literals.""" @@ -28,25 +29,27 @@ from ortools.sat.python import cp_model def literal_sample_sat(): - model = cp_model.CpModel() - x = model.new_bool_var("x") - not_x = ~x - print(x) - print(not_x) + model = cp_model.CpModel() + x = model.new_bool_var('x') + not_x = ~x + print(x) + print(not_x) literal_sample_sat() + ``` ### C++ code ```cpp +// Snippet from ortools/sat/samples/literal_sample_sat.cc #include -#include "absl/base/log_severity.h" -#include "absl/log/globals.h" #include "ortools/base/init_google.h" #include "ortools/base/logging.h" +#include "absl/base/log_severity.h" +#include "absl/log/globals.h" #include "ortools/sat/cp_model.h" namespace operations_research { @@ -69,17 +72,19 @@ int main(int argc, char* argv[]) { operations_research::sat::LiteralSampleSat(); return EXIT_SUCCESS; } + ``` ### Java code ```java +// Snippet from ortools/sat/samples/LiteralSampleSat.java package com.google.ortools.sat.samples; -import com.google.ortools.Loader; import com.google.ortools.sat.BoolVar; import com.google.ortools.sat.CpModel; import com.google.ortools.sat.Literal; +import com.google.ortools.Loader; /** Code sample to demonstrate Boolean variable and literals. */ public class LiteralSampleSat { @@ -91,11 +96,15 @@ public class LiteralSampleSat { System.out.println(notX); } } + ``` ### C\# code -```cs +```csharp +// Snippet from ortools/sat/samples/LiteralSampleSat.cs + + using System; using Google.OrTools.Sat; @@ -108,11 +117,14 @@ public class LiteralSampleSat ILiteral not_x = x.Not(); } } + ``` ### Go code -```cs +```go +// Snippet from ortools/sat/samples/literal_sample_sat.go + // The literal_sample_sat command is a simple example of literals. package main @@ -133,6 +145,7 @@ func literalSampleSat() { func main() { literalSampleSat() } + ``` ## Boolean constraints @@ -150,6 +163,7 @@ constraints. For instance, we can add a constraint Or(x, not(y)). ### Python code ```python +# Snippet from ortools/sat/samples/bool_or_sample_sat.py #!/usr/bin/env python3 """Code sample to demonstrates a simple Boolean constraint.""" @@ -158,29 +172,33 @@ from ortools.sat.python import cp_model def bool_or_sample_sat(): - model = cp_model.CpModel() + model = cp_model.CpModel() - x = model.new_bool_var("x") - y = model.new_bool_var("y") + x = model.new_bool_var('x') + y = model.new_bool_var('y') - model.add_bool_or([x, y.negated()]) - # The [] is not mandatory. - # ~y is equivalent to y.negated() - model.add_bool_or(x, ~y) + model.add_bool_or([x, y.negated()]) + # The [] is not mandatory. + # ~y is equivalent to y.negated() + model.add_bool_or(x, ~y) bool_or_sample_sat() + ``` ### C++ code ```cpp +// Snippet from ortools/sat/samples/bool_or_sample_sat.cc + + #include +#include "ortools/base/init_google.h" #include "absl/base/log_severity.h" #include "absl/log/globals.h" #include "absl/types/span.h" -#include "ortools/base/init_google.h" #include "ortools/sat/cp_model.h" namespace operations_research { @@ -205,11 +223,15 @@ int main(int argc, char* argv[]) { operations_research::sat::BoolOrSampleSat(); return EXIT_SUCCESS; } + ``` ### Java code ```java +// Snippet from ortools/sat/samples/BoolOrSampleSat.java + + package com.google.ortools.sat.samples; import com.google.ortools.Loader; @@ -227,11 +249,15 @@ public class BoolOrSampleSat { model.addBoolOr(new Literal[] {x, y.not()}); } } + ``` ### C\# code -```cs +```csharp +// Snippet from ortools/sat/samples/BoolOrSampleSat.cs + + using System; using Google.OrTools.Sat; @@ -247,11 +273,15 @@ public class BoolOrSampleSat model.AddBoolOr(new ILiteral[] { x, y.Not() }); } } + ``` ### Go code -```cs +```go +// Snippet from ortools/sat/samples/bool_or_sample_sat.go + + // The bool_or_sample_sat command is simple example of the BoolOr constraint. package main @@ -271,6 +301,7 @@ func boolOrSampleSat() { func main() { boolOrSampleSat() } + ``` ## Reified constraints @@ -278,12 +309,14 @@ func main() { The CP-SAT solver supports *half-reified* constraints, also called *implications*, which are of the form: - x implies constraint +``` +x implies constraint +``` where the constraint must hold if `x` is true. -Please note that this is not an equivalence relation. The constraint can still -be true if `x` is false. +Note that this is not an equivalence relation. The constraint can still be true +if `x` is false. So we can write b => And(x, not y). That is, if b is true, then x is true and y is false. Note that in this particular example, there are multiple ways to @@ -293,6 +326,7 @@ then is written as Or(not b, x) and Or(not b, not y). ### Python code ```python +# Snippet from ortools/sat/samples/reified_sample_sat.py #!/usr/bin/env python3 """Simple model with a reified constraint.""" @@ -300,37 +334,42 @@ from ortools.sat.python import cp_model def reified_sample_sat(): - """Showcase creating a reified constraint.""" - model = cp_model.CpModel() + """Showcase creating a reified constraint.""" + model = cp_model.CpModel() - x = model.new_bool_var("x") - y = model.new_bool_var("y") - b = model.new_bool_var("b") + x = model.new_bool_var('x') + y = model.new_bool_var('y') + b = model.new_bool_var('b') - # First version using a half-reified bool and. - model.add_bool_and(x, ~y).only_enforce_if(b) + # First version using a half-reified bool and. + model.add_bool_and(x, ~y).only_enforce_if(b) - # Second version using implications. - model.add_implication(b, x) - model.add_implication(b, ~y) + # Second version using implications. + model.add_implication(b, x) + model.add_implication(b, ~y) - # Third version using bool or. - model.add_bool_or(~b, x) - model.add_bool_or(~b, ~y) + # Third version using bool or. + model.add_bool_or(~b, x) + model.add_bool_or(~b, ~y) reified_sample_sat() + ``` ### C++ code ```cpp +// Snippet from ortools/sat/samples/reified_sample_sat.cc + + + #include +#include "ortools/base/init_google.h" #include "absl/base/log_severity.h" #include "absl/log/globals.h" #include "absl/types/span.h" -#include "ortools/base/init_google.h" #include "ortools/sat/cp_model.h" namespace operations_research { @@ -364,11 +403,15 @@ int main(int argc, char* argv[]) { operations_research::sat::ReifiedSampleSat(); return EXIT_SUCCESS; } + ``` ### Java code ```java +// Snippet from ortools/sat/samples/ReifiedSampleSat.java + + package com.google.ortools.sat.samples; import com.google.ortools.Loader; @@ -407,11 +450,15 @@ public class ReifiedSampleSat { model.addBoolOr(new Literal[] {b.not(), y.not()}); } } + ``` ### C\# code -```cs +```csharp +// Snippet from ortools/sat/samples/ReifiedSampleSat.cs + + using System; using Google.OrTools.Sat; @@ -437,11 +484,15 @@ public class ReifiedSampleSat model.AddBoolOr(new ILiteral[] { b.Not(), y.Not() }); } } + ``` ### Go code -```cs +```go +// Snippet from ortools/sat/samples/reified_sample_sat.go + + // The reified_sample_sat command is a simple example of implication constraints. package main @@ -471,29 +522,37 @@ func reifiedSampleSat() { func main() { reifiedSampleSat() } + ``` ## Product of two Boolean Variables A useful construct is the product `p` of two Boolean variables `x` and `y`. - p == x * y +``` +p == x * y +``` This is equivalent to the logical relation - p <=> x and y +``` +p <=> x and y +``` This is encoded using one bool_or constraint and two implications. The following code samples output this truth table: - x = 0 y = 0 p = 0 - x = 1 y = 0 p = 0 - x = 0 y = 1 p = 0 - x = 1 y = 1 p = 1 +``` +x = 0 y = 0 p = 0 +x = 1 y = 0 p = 0 +x = 0 y = 1 p = 0 +x = 1 y = 1 p = 1 +``` ### Python code ```python +# Snippet from ortools/sat/samples/boolean_product_sample_sat.py #!/usr/bin/env python3 """Code sample that encodes the product of two Boolean variables.""" @@ -502,35 +561,39 @@ from ortools.sat.python import cp_model def boolean_product_sample_sat(): - """Encoding of the product of two Boolean variables. + """Encoding of the product of two Boolean variables. - p == x * y, which is the same as p <=> x and y - """ - model = cp_model.CpModel() - x = model.new_bool_var("x") - y = model.new_bool_var("y") - p = model.new_bool_var("p") + p == x * y, which is the same as p <=> x and y + """ + model = cp_model.CpModel() + x = model.new_bool_var('x') + y = model.new_bool_var('y') + p = model.new_bool_var('p') - # x and y implies p, rewrite as not(x and y) or p. - model.add_bool_or(~x, ~y, p) + # x and y implies p, rewrite as not(x and y) or p. + model.add_bool_or(~x, ~y, p) - # p implies x and y, expanded into two implications. - model.add_implication(p, x) - model.add_implication(p, y) + # p implies x and y, expanded into two implications. + model.add_implication(p, x) + model.add_implication(p, y) - # Create a solver and solve. - solver = cp_model.CpSolver() - solution_printer = cp_model.VarArraySolutionPrinter([x, y, p]) - solver.parameters.enumerate_all_solutions = True - solver.solve(model, solution_printer) + # Create a solver and solve. + solver = cp_model.CpSolver() + solution_printer = cp_model.VarArraySolutionPrinter([x, y, p]) + solver.parameters.enumerate_all_solutions = True + solver.solve(model, solution_printer) boolean_product_sample_sat() + ``` ### Go code -```cs +```go +// Snippet from ortools/sat/samples/boolean_product_sample_sat.go + + // The boolean_product_sample_sat command is a simple example of the product of two literals. package main @@ -538,9 +601,10 @@ import ( "fmt" log "github.com/golang/glog" - "github.com/google/or-tools/ortools/sat/go/cpmodel" - sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" "google.golang.org/protobuf/proto" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" ) func booleanProductSample() error { @@ -564,7 +628,7 @@ func booleanProductSample() error { } // Set `fill_additional_solutions_in_response` and `enumerate_all_solutions` to true so // the solver returns all solutions found. - params := &sppb.SatParameters{ + params := &sppb.SatParameters{ FillAdditionalSolutionsInResponse: proto.Bool(true), EnumerateAllSolutions: proto.Bool(true), SolutionPoolSize: proto.Int32(4), diff --git a/ortools/sat/docs/integer_arithmetic.md b/ortools/sat/docs/integer_arithmetic.md index f714931b5f..fac52d3acd 100644 --- a/ortools/sat/docs/integer_arithmetic.md +++ b/ortools/sat/docs/integer_arithmetic.md @@ -120,6 +120,7 @@ rabbits and pheasants are there? ### Python code ```python +# Snippet from ortools/sat/samples/rabbits_and_pheasants_sat.py #!/usr/bin/env python3 """Rabbits and Pheasants quizz.""" @@ -127,26 +128,27 @@ from ortools.sat.python import cp_model def rabbits_and_pheasants_sat(): - """Solves the rabbits + pheasants problem.""" - model = cp_model.CpModel() + """Solves the rabbits + pheasants problem.""" + model = cp_model.CpModel() - r = model.new_int_var(0, 100, "r") - p = model.new_int_var(0, 100, "p") + r = model.new_int_var(0, 100, 'r') + p = model.new_int_var(0, 100, 'p') - # 20 heads. - model.add(r + p == 20) - # 56 legs. - model.add(4 * r + 2 * p == 56) + # 20 heads. + model.add(r + p == 20) + # 56 legs. + model.add(4 * r + 2 * p == 56) - # Solves and prints out the solution. - solver = cp_model.CpSolver() - status = solver.solve(model) + # Solves and prints out the solution. + solver = cp_model.CpSolver() + status = solver.solve(model) - if status == cp_model.OPTIMAL: - print(f"{solver.value(r)} rabbits and {solver.value(p)} pheasants") + if status == cp_model.OPTIMAL: + print(f'{solver.value(r)} rabbits and {solver.value(p)} pheasants') rabbits_and_pheasants_sat() + ``` ### C++ code @@ -154,10 +156,10 @@ rabbits_and_pheasants_sat() ```cpp #include -#include "absl/base/log_severity.h" -#include "absl/log/globals.h" #include "ortools/base/init_google.h" #include "ortools/base/logging.h" +#include "absl/base/log_severity.h" +#include "absl/log/globals.h" #include "ortools/sat/cp_model.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_solver.h" @@ -195,17 +197,19 @@ int main(int argc, char* argv[]) { operations_research::sat::RabbitsAndPheasantsSat(); return EXIT_SUCCESS; } + ``` ### Java code ```java +// Snippet from ortools/sat/samples/RabbitsAndPheasantsSat.java package com.google.ortools.sat.samples; import com.google.ortools.Loader; +import com.google.ortools.sat.CpSolverStatus; import com.google.ortools.sat.CpModel; import com.google.ortools.sat.CpSolver; -import com.google.ortools.sat.CpSolverStatus; import com.google.ortools.sat.IntVar; import com.google.ortools.sat.LinearExpr; @@ -235,11 +239,12 @@ public class RabbitsAndPheasantsSat { } } } + ``` ### C\# code -```cs +```csharp using System; using Google.OrTools.Sat; @@ -271,7 +276,7 @@ public class RabbitsAndPheasantsSat ### Go code -```cs +```go // The rabbits_and_pheasants_sat command is an example of a simple sat program that // solves the rabbits and pheasants problem. package main @@ -281,6 +286,7 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) @@ -322,6 +328,7 @@ func main() { log.Exitf("rabbitsAndPheasants returned with error: %v", err) } } + ``` ## Earliness-Tardiness cost function. @@ -336,31 +343,34 @@ the max of them to define the piecewise linear function. The following samples output: - x=0 expr=40 - x=1 expr=32 - x=2 expr=24 - x=3 expr=16 - x=4 expr=8 - x=5 expr=0 - x=6 expr=0 - x=7 expr=0 - x=8 expr=0 - x=9 expr=0 - x=10 expr=0 - x=11 expr=0 - x=12 expr=0 - x=13 expr=0 - x=14 expr=0 - x=15 expr=0 - x=16 expr=12 - x=17 expr=24 - x=18 expr=36 - x=19 expr=48 - x=20 expr=60 +``` +x=0 expr=40 +x=1 expr=32 +x=2 expr=24 +x=3 expr=16 +x=4 expr=8 +x=5 expr=0 +x=6 expr=0 +x=7 expr=0 +x=8 expr=0 +x=9 expr=0 +x=10 expr=0 +x=11 expr=0 +x=12 expr=0 +x=13 expr=0 +x=14 expr=0 +x=15 expr=0 +x=16 expr=12 +x=17 expr=24 +x=18 expr=36 +x=19 expr=48 +x=20 expr=60 +``` ### Python code ```python +# Snippet from ortools/sat/samples/earliness_tardiness_cost_sample_sat.py #!/usr/bin/env python3 """Encodes a convex piecewise linear function.""" @@ -369,86 +379,90 @@ from ortools.sat.python import cp_model class VarArraySolutionPrinter(cp_model.CpSolverSolutionCallback): - """Print intermediate solutions.""" + """Print intermediate solutions.""" - def __init__(self, variables: list[cp_model.IntVar]): - cp_model.CpSolverSolutionCallback.__init__(self) - self.__variables = variables + def __init__(self, variables: list[cp_model.IntVar]): + cp_model.CpSolverSolutionCallback.__init__(self) + self.__variables = variables - def on_solution_callback(self) -> None: - for v in self.__variables: - print(f"{v}={self.value(v)}", end=" ") - print() + def on_solution_callback(self) -> None: + for v in self.__variables: + print(f'{v}={self.value(v)}', end=' ') + print() def earliness_tardiness_cost_sample_sat(): - """Encode the piecewise linear expression.""" + """Encode the piecewise linear expression.""" - earliness_date = 5 # ed. - earliness_cost = 8 - lateness_date = 15 # ld. - lateness_cost = 12 + earliness_date = 5 # ed. + earliness_cost = 8 + lateness_date = 15 # ld. + lateness_cost = 12 - # Model. - model = cp_model.CpModel() + # Model. + model = cp_model.CpModel() - # Declare our primary variable. - x = model.new_int_var(0, 20, "x") + # Declare our primary variable. + x = model.new_int_var(0, 20, 'x') - # Create the expression variable and implement the piecewise linear function. - # - # \ / - # \______/ - # ed ld - # - large_constant = 1000 - expr = model.new_int_var(0, large_constant, "expr") + # Create the expression variable and implement the piecewise linear function. + # + # \ / + # \______/ + # ed ld + # + large_constant = 1000 + expr = model.new_int_var(0, large_constant, 'expr') - # First segment. - s1 = model.new_int_var(-large_constant, large_constant, "s1") - model.add(s1 == earliness_cost * (earliness_date - x)) + # First segment. + s1 = model.new_int_var(-large_constant, large_constant, 's1') + model.add(s1 == earliness_cost * (earliness_date - x)) - # Second segment. - s2 = 0 + # Second segment. + s2 = 0 - # Third segment. - s3 = model.new_int_var(-large_constant, large_constant, "s3") - model.add(s3 == lateness_cost * (x - lateness_date)) + # Third segment. + s3 = model.new_int_var(-large_constant, large_constant, 's3') + model.add(s3 == lateness_cost * (x - lateness_date)) - # Link together expr and x through s1, s2, and s3. - model.add_max_equality(expr, [s1, s2, s3]) + # Link together expr and x through s1, s2, and s3. + model.add_max_equality(expr, [s1, s2, s3]) - # Search for x values in increasing order. - model.add_decision_strategy([x], cp_model.CHOOSE_FIRST, cp_model.SELECT_MIN_VALUE) + # Search for x values in increasing order. + model.add_decision_strategy( + [x], cp_model.CHOOSE_FIRST, cp_model.SELECT_MIN_VALUE + ) - # Create a solver and solve with a fixed search. - solver = cp_model.CpSolver() + # Create a solver and solve with a fixed search. + solver = cp_model.CpSolver() - # Force the solver to follow the decision strategy exactly. - solver.parameters.search_branching = cp_model.FIXED_SEARCH - # Enumerate all solutions. - solver.parameters.enumerate_all_solutions = True + # Force the solver to follow the decision strategy exactly. + solver.parameters.search_branching = cp_model.FIXED_SEARCH + # Enumerate all solutions. + solver.parameters.enumerate_all_solutions = True - # Search and print out all solutions. - solution_printer = VarArraySolutionPrinter([x, expr]) - solver.solve(model, solution_printer) + # Search and print out all solutions. + solution_printer = VarArraySolutionPrinter([x, expr]) + solver.solve(model, solution_printer) earliness_tardiness_cost_sample_sat() + ``` ### C++ code ```cpp +// Snippet from ortools/sat/samples/earliness_tardiness_cost_sample_sat.cc #include #include +#include "ortools/base/init_google.h" +#include "ortools/base/logging.h" #include "absl/base/log_severity.h" #include "absl/log/globals.h" #include "absl/types/span.h" -#include "ortools/base/init_google.h" -#include "ortools/base/logging.h" #include "ortools/sat/cp_model.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_solver.h" @@ -509,22 +523,24 @@ int main(int argc, char* argv[]) { operations_research::sat::EarlinessTardinessCostSampleSat(); return EXIT_SUCCESS; } + ``` ### Java code ```java +// Snippet from ortools/sat/samples/EarlinessTardinessCostSampleSat.java package com.google.ortools.sat.samples; import com.google.ortools.Loader; +import com.google.ortools.sat.CpSolverStatus; +import com.google.ortools.sat.DecisionStrategyProto; +import com.google.ortools.sat.SatParameters; import com.google.ortools.sat.CpModel; import com.google.ortools.sat.CpSolver; import com.google.ortools.sat.CpSolverSolutionCallback; -import com.google.ortools.sat.CpSolverStatus; -import com.google.ortools.sat.DecisionStrategyProto; import com.google.ortools.sat.IntVar; import com.google.ortools.sat.LinearExpr; -import com.google.ortools.sat.SatParameters; /** Encode the piecewise linear expression. */ public class EarlinessTardinessCostSampleSat { @@ -554,19 +570,20 @@ public class EarlinessTardinessCostSampleSat { // First segment: y == earlinessCost * (earlinessDate - x). // Second segment: y = 0 // Third segment: y == latenessCost * (x - latenessDate). - model.addMaxEquality(expr, - new LinearExpr[] {LinearExpr.newBuilder() - .addTerm(x, -earlinessCost) - .add(earlinessCost * earlinessDate) - .build(), - LinearExpr.constant(0), - LinearExpr.newBuilder() - .addTerm(x, latenessCost) - .add(-latenessCost * latenessDate) - .build()}); + model.addMaxEquality( + expr, + new LinearExpr[] { + LinearExpr.newBuilder() + .addTerm(x, -earlinessCost) + .add(earlinessCost * earlinessDate) + .build(), + LinearExpr.constant(0), + LinearExpr.newBuilder().addTerm(x, latenessCost).add(-latenessCost * latenessDate).build() + }); // Search for x values in increasing order. - model.addDecisionStrategy(new IntVar[] {x}, + model.addDecisionStrategy( + new IntVar[] {x}, DecisionStrategyProto.VariableSelectionStrategy.CHOOSE_FIRST, DecisionStrategyProto.DomainReductionStrategy.SELECT_MIN_VALUE); @@ -579,29 +596,34 @@ public class EarlinessTardinessCostSampleSat { solver.getParameters().setEnumerateAllSolutions(true); // Solve the problem with the printer callback. - CpSolverStatus unusedStatus = solver.solve(model, new CpSolverSolutionCallback() { - public CpSolverSolutionCallback init(IntVar[] variables) { - variableArray = variables; - return this; - } + CpSolverStatus unusedStatus = + solver.solve( + model, + new CpSolverSolutionCallback() { + public CpSolverSolutionCallback init(IntVar[] variables) { + variableArray = variables; + return this; + } - @Override - public void onSolutionCallback() { - for (IntVar v : variableArray) { - System.out.printf("%s=%d ", v.getName(), value(v)); - } - System.out.println(); - } + @Override + public void onSolutionCallback() { + for (IntVar v : variableArray) { + System.out.printf("%s=%d ", v.getName(), value(v)); + } + System.out.println(); + } - private IntVar[] variableArray; - }.init(new IntVar[] {x, expr})); + private IntVar[] variableArray; + }.init(new IntVar[] {x, expr})); } } + ``` ### C\# code -```cs +```csharp +// Snippet from ortools/sat/samples/EarlinessTardinessCostSampleSat.cs using System; using Google.OrTools.Sat; using Google.OrTools.Util; @@ -671,11 +693,13 @@ public class EarlinessTardinessCostSampleSat solver.Solve(model, cb); } } + ``` ### Go code -```cs +```go +// Snippet from ortools/sat/samples/earliness_tardiness_cost_sample_sat.go // The earliness_tardiness_cost_sample_sat command is an example of an implementation of a convex // piecewise linear function. package main @@ -684,10 +708,11 @@ import ( "fmt" log "github.com/golang/glog" + "google.golang.org/protobuf/proto" "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" - "google.golang.org/protobuf/proto" ) const ( @@ -727,7 +752,7 @@ func earlinessTardinessCostSampleSat() error { if err != nil { return fmt.Errorf("failed to instantiate the CP model: %w", err) } - params := &sppb.SatParameters{ + params := &sppb.SatParameters{ FillAdditionalSolutionsInResponse: proto.Bool(true), EnumerateAllSolutions: proto.Bool(true), SolutionPoolSize: proto.Int32(21), @@ -753,6 +778,7 @@ func main() { log.Exitf("earlinessTardinessCostSampleSat returned with error: %v", err) } } + ``` ## Step function. @@ -762,30 +788,33 @@ and filter the admissible domain of the input variable with this variable. The following samples output: - x=0 expr=2 - x=1 expr=2 - x=3 expr=2 - x=4 expr=2 - x=5 expr=0 - x=6 expr=0 - x=7 expr=3 - x=8 expr=0 - x=9 expr=0 - x=10 expr=0 - x=11 expr=2 - x=12 expr=2 - x=13 expr=2 - x=14 expr=2 - x=15 expr=2 - x=16 expr=2 - x=17 expr=2 - x=18 expr=2 - x=19 expr=2 - x=20 expr=2 +``` +x=0 expr=2 +x=1 expr=2 +x=3 expr=2 +x=4 expr=2 +x=5 expr=0 +x=6 expr=0 +x=7 expr=3 +x=8 expr=0 +x=9 expr=0 +x=10 expr=0 +x=11 expr=2 +x=12 expr=2 +x=13 expr=2 +x=14 expr=2 +x=15 expr=2 +x=16 expr=2 +x=17 expr=2 +x=18 expr=2 +x=19 expr=2 +x=20 expr=2 +``` ### Python code ```python +# Snippet from ortools/sat/samples/step_function_sample_sat.py #!/usr/bin/env python3 """Implements a step function.""" @@ -793,89 +822,93 @@ from ortools.sat.python import cp_model class VarArraySolutionPrinter(cp_model.CpSolverSolutionCallback): - """Print intermediate solutions.""" + """Print intermediate solutions.""" - def __init__(self, variables: list[cp_model.IntVar]): - cp_model.CpSolverSolutionCallback.__init__(self) - self.__variables = variables + def __init__(self, variables: list[cp_model.IntVar]): + cp_model.CpSolverSolutionCallback.__init__(self) + self.__variables = variables - def on_solution_callback(self) -> None: - for v in self.__variables: - print(f"{v}={self.value(v)}", end=" ") - print() + def on_solution_callback(self) -> None: + for v in self.__variables: + print(f'{v}={self.value(v)}', end=' ') + print() def step_function_sample_sat(): - """Encode the step function.""" + """Encode the step function.""" - # Model. - model = cp_model.CpModel() + # Model. + model = cp_model.CpModel() - # Declare our primary variable. - x = model.new_int_var(0, 20, "x") + # Declare our primary variable. + x = model.new_int_var(0, 20, 'x') - # Create the expression variable and implement the step function - # Note it is not defined for x == 2. - # - # - 3 - # -- -- --------- 2 - # 1 - # -- --- 0 - # 0 ================ 20 - # - expr = model.new_int_var(0, 3, "expr") + # Create the expression variable and implement the step function + # Note it is not defined for x == 2. + # + # - 3 + # -- -- --------- 2 + # 1 + # -- --- 0 + # 0 ================ 20 + # + expr = model.new_int_var(0, 3, 'expr') - # expr == 0 on [5, 6] U [8, 10] - b0 = model.new_bool_var("b0") - model.add_linear_expression_in_domain( - x, cp_model.Domain.from_intervals([(5, 6), (8, 10)]) - ).only_enforce_if(b0) - model.add(expr == 0).only_enforce_if(b0) + # expr == 0 on [5, 6] U [8, 10] + b0 = model.new_bool_var('b0') + model.add_linear_expression_in_domain( + x, cp_model.Domain.from_intervals([(5, 6), (8, 10)]) + ).only_enforce_if(b0) + model.add(expr == 0).only_enforce_if(b0) - # expr == 2 on [0, 1] U [3, 4] U [11, 20] - b2 = model.new_bool_var("b2") - model.add_linear_expression_in_domain( - x, cp_model.Domain.from_intervals([(0, 1), (3, 4), (11, 20)]) - ).only_enforce_if(b2) - model.add(expr == 2).only_enforce_if(b2) + # expr == 2 on [0, 1] U [3, 4] U [11, 20] + b2 = model.new_bool_var('b2') + model.add_linear_expression_in_domain( + x, cp_model.Domain.from_intervals([(0, 1), (3, 4), (11, 20)]) + ).only_enforce_if(b2) + model.add(expr == 2).only_enforce_if(b2) - # expr == 3 when x == 7 - b3 = model.new_bool_var("b3") - model.add(x == 7).only_enforce_if(b3) - model.add(expr == 3).only_enforce_if(b3) + # expr == 3 when x == 7 + b3 = model.new_bool_var('b3') + model.add(x == 7).only_enforce_if(b3) + model.add(expr == 3).only_enforce_if(b3) - # At least one bi is true. (we could use an exactly one constraint). - model.add_bool_or(b0, b2, b3) + # At least one bi is true. (we could use an exactly one constraint). + model.add_bool_or(b0, b2, b3) - # Search for x values in increasing order. - model.add_decision_strategy([x], cp_model.CHOOSE_FIRST, cp_model.SELECT_MIN_VALUE) + # Search for x values in increasing order. + model.add_decision_strategy( + [x], cp_model.CHOOSE_FIRST, cp_model.SELECT_MIN_VALUE + ) - # Create a solver and solve with a fixed search. - solver = cp_model.CpSolver() + # Create a solver and solve with a fixed search. + solver = cp_model.CpSolver() - # Force the solver to follow the decision strategy exactly. - solver.parameters.search_branching = cp_model.FIXED_SEARCH - # Enumerate all solutions. - solver.parameters.enumerate_all_solutions = True + # Force the solver to follow the decision strategy exactly. + solver.parameters.search_branching = cp_model.FIXED_SEARCH + # Enumerate all solutions. + solver.parameters.enumerate_all_solutions = True - # Search and print out all solutions. - solution_printer = VarArraySolutionPrinter([x, expr]) - solver.solve(model, solution_printer) + # Search and print out all solutions. + solution_printer = VarArraySolutionPrinter([x, expr]) + solver.solve(model, solution_printer) step_function_sample_sat() + ``` ### C++ code ```cpp +// Snippet from ortools/sat/samples/step_function_sample_sat.cc #include +#include "ortools/base/init_google.h" +#include "ortools/base/logging.h" #include "absl/base/log_severity.h" #include "absl/log/globals.h" #include "absl/types/span.h" -#include "ortools/base/init_google.h" -#include "ortools/base/logging.h" #include "ortools/sat/cp_model.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_solver.h" @@ -951,22 +984,24 @@ int main(int argc, char* argv[]) { operations_research::sat::StepFunctionSampleSat(); return EXIT_SUCCESS; } + ``` ### Java code ```java +// Snippet from ortools/sat/samples/StepFunctionSampleSat.java package com.google.ortools.sat.samples; import com.google.ortools.Loader; +import com.google.ortools.sat.CpSolverStatus; +import com.google.ortools.sat.DecisionStrategyProto; +import com.google.ortools.sat.SatParameters; import com.google.ortools.sat.CpModel; import com.google.ortools.sat.CpSolver; import com.google.ortools.sat.CpSolverSolutionCallback; -import com.google.ortools.sat.CpSolverStatus; -import com.google.ortools.sat.DecisionStrategyProto; import com.google.ortools.sat.IntVar; import com.google.ortools.sat.Literal; -import com.google.ortools.sat.SatParameters; import com.google.ortools.util.Domain; /** Link integer constraints together. */ @@ -992,7 +1027,8 @@ public class StepFunctionSampleSat { // expr == 0 on [5, 6] U [8, 10] Literal b0 = model.newBoolVar("b0"); - model.addLinearExpressionInDomain(x, Domain.fromValues(new long[] {5, 6, 8, 9, 10})) + model + .addLinearExpressionInDomain(x, Domain.fromValues(new long[] {5, 6, 8, 9, 10})) .onlyEnforceIf(b0); model.addEquality(expr, 0).onlyEnforceIf(b0); @@ -1013,7 +1049,8 @@ public class StepFunctionSampleSat { model.addBoolOr(new Literal[] {b0, b2, b3}); // Search for x values in increasing order. - model.addDecisionStrategy(new IntVar[] {x}, + model.addDecisionStrategy( + new IntVar[] {x}, DecisionStrategyProto.VariableSelectionStrategy.CHOOSE_FIRST, DecisionStrategyProto.DomainReductionStrategy.SELECT_MIN_VALUE); @@ -1026,29 +1063,33 @@ public class StepFunctionSampleSat { solver.getParameters().setEnumerateAllSolutions(true); // Solve the problem with the printer callback. - CpSolverStatus unusedStatus = solver.solve(model, new CpSolverSolutionCallback() { - public CpSolverSolutionCallback init(IntVar[] variables) { - variableArray = variables; - return this; - } + CpSolverStatus unusedStatus = + solver.solve( + model, + new CpSolverSolutionCallback() { + public CpSolverSolutionCallback init(IntVar[] variables) { + variableArray = variables; + return this; + } - @Override - public void onSolutionCallback() { - for (IntVar v : variableArray) { - System.out.printf("%s=%d ", v.getName(), value(v)); - } - System.out.println(); - } + @Override + public void onSolutionCallback() { + for (IntVar v : variableArray) { + System.out.printf("%s=%d ", v.getName(), value(v)); + } + System.out.println(); + } - private IntVar[] variableArray; - }.init(new IntVar[] {x, expr})); + private IntVar[] variableArray; + }.init(new IntVar[] {x, expr})); } } ``` ### C\# code -```cs +```csharp +// Snippet from ortools/sat/samples/StepFunctionSampleSat.cs using System; using Google.OrTools.Sat; using Google.OrTools.Util; @@ -1132,11 +1173,14 @@ public class StepFunctionSampleSat solver.Solve(model, cb); } } + ``` ### Go code -```cs +```go +// Snippet from ortools/sat/samples/step_function_sample_sat.go + // The step_function_sample_sat command is an example of an implementation of a step function. package main @@ -1144,10 +1188,11 @@ import ( "fmt" log "github.com/golang/glog" + "google.golang.org/protobuf/proto" "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" - "google.golang.org/protobuf/proto" ) func stepFunctionSampleSat() error { @@ -1197,7 +1242,7 @@ func stepFunctionSampleSat() error { if err != nil { return fmt.Errorf("failed to instantiate the CP model: %w", err) } - params := &sppb.SatParameters{ + params := &sppb.SatParameters{ FillAdditionalSolutionsInResponse: proto.Bool(true), EnumerateAllSolutions: proto.Bool(true), SolutionPoolSize: proto.Int32(21), @@ -1223,6 +1268,7 @@ func main() { log.Exitf("stepFunctionSampleSat returned with error: %v", err) } } + ``` ## Product of a Boolean variable and an integer variable @@ -1255,6 +1301,7 @@ x=10 b=1 p=10 ### Python code ```python +# Snippet from ortools/sat/samples/bool_and_int_var_product_sample_sat.py #!/usr/bin/env python3 """Code sample that encodes the product of a Boolean and an integer variable.""" @@ -1262,59 +1309,60 @@ from ortools.sat.python import cp_model class VarArraySolutionPrinter(cp_model.CpSolverSolutionCallback): - """Print intermediate solutions.""" + """Print intermediate solutions.""" - def __init__(self, variables: list[cp_model.IntVar]): - cp_model.CpSolverSolutionCallback.__init__(self) - self.__variables = variables + def __init__(self, variables: list[cp_model.IntVar]): + cp_model.CpSolverSolutionCallback.__init__(self) + self.__variables = variables - def on_solution_callback(self) -> None: - for v in self.__variables: - print(f"{v}={self.value(v)}", end=" ") - print() + def on_solution_callback(self) -> None: + for v in self.__variables: + print(f'{v}={self.value(v)}', end=' ') + print() def build_product_var( model: cp_model.CpModel, b: cp_model.IntVar, x: cp_model.IntVar, name: str ) -> cp_model.IntVar: - """Builds the product of a Boolean variable and an integer variable.""" - p = model.new_int_var_from_domain( - cp_model.Domain.from_flat_intervals(x.proto.domain).union_with( - cp_model.Domain(0, 0) - ), - name, - ) - model.add(p == x).only_enforce_if(b) - model.add(p == 0).only_enforce_if(~b) - return p + """Builds the product of a Boolean variable and an integer variable.""" + p = model.new_int_var_from_domain( + cp_model.Domain.from_flat_intervals(x.proto.domain).union_with( + cp_model.Domain(0, 0) + ), + name, + ) + model.add(p == x).only_enforce_if(b) + model.add(p == 0).only_enforce_if(~b) + return p def bool_and_int_var_product_sample_sat(): - """Encoding of the product of two Boolean variables. + """Encoding of the product of two Boolean variables. - p == x * y, which is the same as p <=> x and y - """ - model = cp_model.CpModel() - b = model.new_bool_var("b") - x = model.new_int_var_from_domain( - cp_model.Domain.from_values([1, 2, 3, 5, 6, 7, 9, 10]), "x" - ) - p = build_product_var(model, b, x, "p") + p == x * y, which is the same as p <=> x and y + """ + model = cp_model.CpModel() + b = model.new_bool_var('b') + x = model.new_int_var_from_domain( + cp_model.Domain.from_values([1, 2, 3, 5, 6, 7, 9, 10]), 'x' + ) + p = build_product_var(model, b, x, 'p') - # Search for x and b values in increasing order. - model.add_decision_strategy( - [b, x], cp_model.CHOOSE_FIRST, cp_model.SELECT_MIN_VALUE - ) + # Search for x and b values in increasing order. + model.add_decision_strategy( + [b, x], cp_model.CHOOSE_FIRST, cp_model.SELECT_MIN_VALUE + ) - # Create a solver and solve. - solver = cp_model.CpSolver() - solution_printer = VarArraySolutionPrinter([x, b, p]) - solver.parameters.enumerate_all_solutions = True - solver.parameters.search_branching = cp_model.FIXED_SEARCH - solver.solve(model, solution_printer) + # Create a solver and solve. + solver = cp_model.CpSolver() + solution_printer = VarArraySolutionPrinter([x, b, p]) + solver.parameters.enumerate_all_solutions = True + solver.parameters.search_branching = cp_model.FIXED_SEARCH + solver.solve(model, solution_printer) bool_and_int_var_product_sample_sat() + ``` ## Scanning the domain of variables. @@ -1331,6 +1379,7 @@ reading back the values from the model. ### Python code ```python +# Snippet from ortools/sat/samples/all_different_except_zero_sample_sat.py #!/usr/bin/env python3 """Implements AllDifferentExcept0 using atomic constraints.""" @@ -1340,66 +1389,67 @@ from ortools.sat.python import cp_model def all_different_except_0(): - """Encode the AllDifferentExcept0 constraint.""" + """Encode the AllDifferentExcept0 constraint.""" - # Model. - model = cp_model.CpModel() + # Model. + model = cp_model.CpModel() - # Declare our primary variable. - x = [model.new_int_var(0, 10, f"x{i}") for i in range(5)] + # Declare our primary variable. + x = [model.new_int_var(0, 10, f'x{i}') for i in range(5)] - # Expand the AllDifferentExcept0 constraint. - variables_per_value = collections.defaultdict(list) - all_values = set() + # Expand the AllDifferentExcept0 constraint. + variables_per_value = collections.defaultdict(list) + all_values = set() - for var in x: - all_encoding_literals = [] - # Domains of variables are represented by flat intervals. - for i in range(0, len(var.proto.domain), 2): - start = var.proto.domain[i] - end = var.proto.domain[i + 1] - for value in range(start, end + 1): # Intervals are inclusive. - # Create the literal attached to var == value. - bool_var = model.new_bool_var(f"{var} == {value}") - model.add(var == value).only_enforce_if(bool_var) + for var in x: + all_encoding_literals = [] + # Domains of variables are represented by flat intervals. + for i in range(0, len(var.proto.domain), 2): + start = var.proto.domain[i] + end = var.proto.domain[i + 1] + for value in range(start, end + 1): # Intervals are inclusive. + # Create the literal attached to var == value. + bool_var = model.new_bool_var(f'{var} == {value}') + model.add(var == value).only_enforce_if(bool_var) - # Collect all encoding literals for a given variable. - all_encoding_literals.append(bool_var) + # Collect all encoding literals for a given variable. + all_encoding_literals.append(bool_var) - # Collect all encoding literals for a given value. - variables_per_value[value].append(bool_var) + # Collect all encoding literals for a given value. + variables_per_value[value].append(bool_var) - # Collect all different values. - all_values.add(value) + # Collect all different values. + all_values.add(value) - # One variable must have exactly one value. - model.add_exactly_one(all_encoding_literals) + # One variable must have exactly one value. + model.add_exactly_one(all_encoding_literals) - # Add the all_different constraints. - for value, literals in variables_per_value.items(): - if value == 0: - continue - model.add_at_most_one(literals) + # Add the all_different constraints. + for value, literals in variables_per_value.items(): + if value == 0: + continue + model.add_at_most_one(literals) - model.add(x[0] == 0) - model.add(x[1] == 0) + model.add(x[0] == 0) + model.add(x[1] == 0) - model.maximize(sum(x)) + model.maximize(sum(x)) - # Create a solver and solve. - solver = cp_model.CpSolver() - status = solver.solve(model) + # Create a solver and solve. + solver = cp_model.CpSolver() + status = solver.solve(model) - # Checks and prints the output. - if status == cp_model.OPTIMAL: - print(f"Optimal solution: {solver.objective_value}, expected: 27.0") - elif status == cp_model.FEASIBLE: - print(f"Feasible solution: {solver.objective_value}, optimal 27.0") - elif status == cp_model.INFEASIBLE: - print("The model is infeasible") - else: - print("Something went wrong. Please check the status and the log") + # Checks and prints the output. + if status == cp_model.OPTIMAL: + print(f'Optimal solution: {solver.objective_value}, expected: 27.0') + elif status == cp_model.FEASIBLE: + print(f'Feasible solution: {solver.objective_value}, optimal 27.0') + elif status == cp_model.INFEASIBLE: + print('The model is infeasible') + else: + print('Something went wrong. Please check the status and the log') all_different_except_0() + ``` diff --git a/ortools/sat/docs/model.md b/ortools/sat/docs/model.md index 10b5b6e67b..a6c63d929e 100644 --- a/ortools/sat/docs/model.md +++ b/ortools/sat/docs/model.md @@ -77,39 +77,37 @@ Some remarks: ### Python code ```python -#!/usr/bin/env python3 -"""Code sample that solves a model using solution hinting.""" - +# Snippet from ortools/sat/samples/solution_hinting_sample_sat.py from ortools.sat.python import cp_model def solution_hinting_sample_sat(): - """Showcases solution hinting.""" - # Creates the model. - model = cp_model.CpModel() + """Showcases solution hinting.""" + # Creates the model. + model = cp_model.CpModel() - # Creates the variables. - num_vals = 3 - x = model.new_int_var(0, num_vals - 1, "x") - y = model.new_int_var(0, num_vals - 1, "y") - z = model.new_int_var(0, num_vals - 1, "z") + # Creates the variables. + num_vals = 3 + x = model.new_int_var(0, num_vals - 1, 'x') + y = model.new_int_var(0, num_vals - 1, 'y') + z = model.new_int_var(0, num_vals - 1, 'z') - # Creates the constraints. - model.add(x != y) + # Creates the constraints. + model.add(x != y) - model.maximize(x + 2 * y + 3 * z) + model.maximize(x + 2 * y + 3 * z) - # Solution hinting: x <- 1, y <- 2 - model.add_hint(x, 1) - model.add_hint(y, 2) + # Solution hinting: x <- 1, y <- 2 + model.add_hint(x, 1) + model.add_hint(y, 2) - # Creates a solver and solves. - solver = cp_model.CpSolver() - solution_printer = cp_model.VarArrayAndObjectiveSolutionPrinter([x, y, z]) - status = solver.solve(model, solution_printer) + # Creates a solver and solves. + solver = cp_model.CpSolver() + solution_printer = cp_model.VarArrayAndObjectiveSolutionPrinter([x, y, z]) + status = solver.solve(model, solution_printer) - print(f"Status = {solver.status_name(status)}") - print(f"Number of solutions found: {solution_printer.solution_count}") + print(f'Status = {solver.status_name(status)}') + print(f'Number of solutions found: {solution_printer.solution_count}') solution_hinting_sample_sat() @@ -118,12 +116,13 @@ solution_hinting_sample_sat() ### C++ code ```cpp +// Snippet from ortools/sat/samples/solution_hinting_sample_sat.cc #include -#include "absl/base/log_severity.h" -#include "absl/log/globals.h" #include "ortools/base/init_google.h" #include "ortools/base/logging.h" +#include "absl/base/log_severity.h" +#include "absl/log/globals.h" #include "ortools/sat/cp_model.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_solver.h" @@ -179,13 +178,14 @@ int main(int argc, char* argv[]) { ### Java code ```java +// Snippet from ortools/sat/samples/SolutionHintingSampleSat.java package com.google.ortools.sat.samples; import com.google.ortools.Loader; +import com.google.ortools.sat.CpSolverStatus; import com.google.ortools.sat.CpModel; import com.google.ortools.sat.CpSolver; import com.google.ortools.sat.CpSolverSolutionCallback; -import com.google.ortools.sat.CpSolverStatus; import com.google.ortools.sat.IntVar; import com.google.ortools.sat.LinearExpr; @@ -242,7 +242,8 @@ public class SolutionHintingSampleSat { ### C\# code -```cs +```csharp +// Snippet from ortools/sat/samples/SolutionHintingSampleSat.cs using System; using Google.OrTools.Sat; @@ -307,7 +308,7 @@ public class SolutionHintingSampleSat ### Go code -```cs +```go // The solution_hinting_sample_sat command is an example of setting solution hints on the model. package main @@ -316,6 +317,7 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) @@ -360,6 +362,7 @@ func main() { log.Exitf("solutionHintingSampleSat returned with error: %v", err) } } + ``` ## Model copy @@ -375,58 +378,60 @@ The deep copy python mechanism relies on the [`copy` Python Standard Library](https://docs.python.org/3/library/copy.html). ```python -#!/usr/bin/env python3 -"""Showcases deep copying of a model.""" - +# Snippet from ortools/sat/samples/clone_model_sample_sat.py import copy from ortools.sat.python import cp_model def clone_model_sample_sat(): - """Showcases cloning a model.""" - # Creates the model. - model = cp_model.CpModel() + """Showcases cloning a model.""" + # Creates the model. + model = cp_model.CpModel() - # Creates the variables. - num_vals = 3 - x = model.new_int_var(0, num_vals - 1, "x") - y = model.new_int_var(0, num_vals - 1, "y") - z = model.new_int_var(0, num_vals - 1, "z") + # Creates the variables. + num_vals = 3 + x = model.new_int_var(0, num_vals - 1, 'x') + y = model.new_int_var(0, num_vals - 1, 'y') + z = model.new_int_var(0, num_vals - 1, 'z') - # Creates the constraints. - model.add(x != y) + # Creates the constraints. + model.add(x != y) - model.maximize(x + 2 * y + 3 * z) + model.maximize(x + 2 * y + 3 * z) - # Creates a solver and solves. - solver = cp_model.CpSolver() - status = solver.solve(model) + # Creates a solver and solves. + solver = cp_model.CpSolver() + status = solver.solve(model) - if status == cp_model.OPTIMAL: - print("Optimal value of the original model: {}".format(solver.objective_value)) + if status == cp_model.OPTIMAL: + print( + 'Optimal value of the original model: {}'.format(solver.objective_value) + ) - # Creates a dictionary holding the model and the variables you want to use. - to_clone = { - "model": model, - "x": x, - "y": y, - "z": z, - } + # Creates a dictionary holding the model and the variables you want to use. + to_clone = { + 'model': model, + 'x': x, + 'y': y, + 'z': z, + } - # Deep copy the dictionary. - clone = copy.deepcopy(to_clone) + # Deep copy the dictionary. + clone = copy.deepcopy(to_clone) - # Retrieve the cloned model and variables. - cloned_model: cp_model.CpModel = clone["model"] - cloned_x = clone["x"] - cloned_y = clone["y"] - cloned_model.add(cloned_x + cloned_y <= 1) + # Retrieve the cloned model and variables. + cloned_model: cp_model.CpModel = clone['model'] + cloned_x = clone['x'] + cloned_y = clone['y'] + cloned_model.add(cloned_x + cloned_y <= 1) - status = solver.solve(cloned_model) + status = solver.solve(cloned_model) - if status == cp_model.OPTIMAL: - print("Optimal value of the modified model: {}".format(solver.objective_value)) + if status == cp_model.OPTIMAL: + print( + 'Optimal value of the modified model: {}'.format(solver.objective_value) + ) clone_model_sample_sat() @@ -435,12 +440,13 @@ clone_model_sample_sat() ### C++ code ```cpp +// Snippet from ortools/sat/samples/clone_model_sample_sat.cc #include -#include "absl/base/log_severity.h" -#include "absl/log/globals.h" #include "ortools/base/init_google.h" #include "ortools/base/logging.h" +#include "absl/base/log_severity.h" +#include "absl/log/globals.h" #include "ortools/sat/cp_model.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_solver.h" @@ -492,12 +498,13 @@ int main(int argc, char* argv[]) { ### Java code ```java +// Snippet from ortools/sat/samples/CloneModelSampleSat.java package com.google.ortools.sat.samples; import com.google.ortools.Loader; +import com.google.ortools.sat.CpSolverStatus; import com.google.ortools.sat.CpModel; import com.google.ortools.sat.CpSolver; -import com.google.ortools.sat.CpSolverStatus; import com.google.ortools.sat.IntVar; import com.google.ortools.sat.LinearExpr; diff --git a/ortools/sat/docs/solver.md b/ortools/sat/docs/solver.md index d8264d903d..c83b886167 100644 --- a/ortools/sat/docs/solver.md +++ b/ortools/sat/docs/solver.md @@ -12,36 +12,36 @@ solver. The most useful one is the time limit. ### Specifying the time limit in Python ```python -#!/usr/bin/env python3 +# Snippet from ortools/sat/samples/solve_with_time_limit_sample_sat.py """Solves a problem with a time limit.""" from ortools.sat.python import cp_model def solve_with_time_limit_sample_sat(): - """Minimal CP-SAT example to showcase calling the solver.""" - # Creates the model. - model = cp_model.CpModel() - # Creates the variables. - num_vals = 3 - x = model.new_int_var(0, num_vals - 1, "x") - y = model.new_int_var(0, num_vals - 1, "y") - z = model.new_int_var(0, num_vals - 1, "z") - # Adds an all-different constraint. - model.add(x != y) + """Minimal CP-SAT example to showcase calling the solver.""" + # Creates the model. + model = cp_model.CpModel() + # Creates the variables. + num_vals = 3 + x = model.new_int_var(0, num_vals - 1, 'x') + y = model.new_int_var(0, num_vals - 1, 'y') + z = model.new_int_var(0, num_vals - 1, 'z') + # Adds an all-different constraint. + model.add(x != y) - # Creates a solver and solves the model. - solver = cp_model.CpSolver() + # Creates a solver and solves the model. + solver = cp_model.CpSolver() - # Sets a time limit of 10 seconds. - solver.parameters.max_time_in_seconds = 10.0 + # Sets a time limit of 10 seconds. + solver.parameters.max_time_in_seconds = 10.0 - status = solver.solve(model) + status = solver.solve(model) - if status == cp_model.OPTIMAL: - print(f"x = {solver.value(x)}") - print(f"y = {solver.value(y)}") - print(f"z = {solver.value(z)}") + if status == cp_model.OPTIMAL: + print(f'x = {solver.value(x)}') + print(f'y = {solver.value(y)}') + print(f'z = {solver.value(z)}') solve_with_time_limit_sample_sat() @@ -50,12 +50,13 @@ solve_with_time_limit_sample_sat() ### Specifying the time limit in C++ ```cpp +// Snippet from ortools/sat/samples/solve_with_time_limit_sample_sat.cc #include -#include "absl/base/log_severity.h" -#include "absl/log/globals.h" #include "ortools/base/init_google.h" #include "ortools/base/logging.h" +#include "absl/base/log_severity.h" +#include "absl/log/globals.h" #include "ortools/sat/cp_model.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_solver.h" @@ -109,12 +110,13 @@ int main(int argc, char* argv[]) { ### Specifying the time limit in Java ```java +// Snippet from ortools/sat/samples/SolveWithTimeLimitSampleSat.java package com.google.ortools.sat.samples; import com.google.ortools.Loader; +import com.google.ortools.sat.CpSolverStatus; import com.google.ortools.sat.CpModel; import com.google.ortools.sat.CpSolver; -import com.google.ortools.sat.CpSolverStatus; import com.google.ortools.sat.IntVar; /** Solves a problem with a time limit. */ @@ -152,7 +154,8 @@ public final class SolveWithTimeLimitSampleSat { Parameters must be passed as string to the solver. -```cs +```csharp +// Snippet from ortools/sat/samples/SolveWithTimeLimitSampleSat.cs using System; using Google.OrTools.Sat; @@ -191,7 +194,7 @@ public class SolveWithTimeLimitSampleSat ### Specifying the time limit in Go -```cs +```go // The solve_with_time_limit_sample_sat command is an example of setting a time limit on the model. package main @@ -199,10 +202,11 @@ import ( "fmt" log "github.com/golang/glog" + "google.golang.org/protobuf/proto" "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" - "google.golang.org/protobuf/proto" ) func solveWithTimeLimitSampleSat() error { @@ -221,7 +225,7 @@ func solveWithTimeLimitSampleSat() error { } // Sets a time limit of 10 seconds. - params := &sppb.SatParameters{ + params := &sppb.SatParameters{ MaxTimeInSeconds: proto.Float64(10.0), } @@ -247,6 +251,7 @@ func main() { log.Exitf("solveWithTimeLimitSampleSat returned with error: %v", err) } } + ``` ## Printing intermediate solutions @@ -261,57 +266,55 @@ The exact implementation depends on the target language. ### Python code ```python -#!/usr/bin/env python3 -"""Solves an optimization problem and displays all intermediate solutions.""" - +# Snippet from ortools/sat/samples/solve_and_print_intermediate_solutions_sample_sat.py from ortools.sat.python import cp_model # You need to subclass the cp_model.CpSolverSolutionCallback class. class VarArrayAndObjectiveSolutionPrinter(cp_model.CpSolverSolutionCallback): - """Print intermediate solutions.""" + """Print intermediate solutions.""" - def __init__(self, variables: list[cp_model.IntVar]): - cp_model.CpSolverSolutionCallback.__init__(self) - self.__variables = variables - self.__solution_count = 0 + def __init__(self, variables: list[cp_model.IntVar]): + cp_model.CpSolverSolutionCallback.__init__(self) + self.__variables = variables + self.__solution_count = 0 - def on_solution_callback(self) -> None: - print(f"Solution {self.__solution_count}") - print(f" objective value = {self.objective_value}") - for v in self.__variables: - print(f" {v}={self.value(v)}", end=" ") - print() - self.__solution_count += 1 + def on_solution_callback(self) -> None: + print(f'Solution {self.__solution_count}') + print(f' objective value = {self.objective_value}') + for v in self.__variables: + print(f' {v}={self.value(v)}', end=' ') + print() + self.__solution_count += 1 - @property - def solution_count(self) -> int: - return self.__solution_count + @property + def solution_count(self) -> int: + return self.__solution_count def solve_and_print_intermediate_solutions_sample_sat(): - """Showcases printing intermediate solutions found during search.""" - # Creates the model. - model = cp_model.CpModel() + """Showcases printing intermediate solutions found during search.""" + # Creates the model. + model = cp_model.CpModel() - # Creates the variables. - num_vals = 3 - x = model.new_int_var(0, num_vals - 1, "x") - y = model.new_int_var(0, num_vals - 1, "y") - z = model.new_int_var(0, num_vals - 1, "z") + # Creates the variables. + num_vals = 3 + x = model.new_int_var(0, num_vals - 1, 'x') + y = model.new_int_var(0, num_vals - 1, 'y') + z = model.new_int_var(0, num_vals - 1, 'z') - # Creates the constraints. - model.add(x != y) + # Creates the constraints. + model.add(x != y) - model.maximize(x + 2 * y + 3 * z) + model.maximize(x + 2 * y + 3 * z) - # Creates a solver and solves. - solver = cp_model.CpSolver() - solution_printer = VarArrayAndObjectiveSolutionPrinter([x, y, z]) - status = solver.solve(model, solution_printer) + # Creates a solver and solves. + solver = cp_model.CpSolver() + solution_printer = VarArrayAndObjectiveSolutionPrinter([x, y, z]) + status = solver.solve(model, solution_printer) - print(f"Status = {solver.status_name(status)}") - print(f"Number of solutions found: {solution_printer.solution_count}") + print(f'Status = {solver.status_name(status)}') + print(f'Number of solutions found: {solution_printer.solution_count}') solve_and_print_intermediate_solutions_sample_sat() @@ -320,12 +323,13 @@ solve_and_print_intermediate_solutions_sample_sat() ### C++ code ```cpp +// Snippet from ortools/sat/samples/solve_and_print_intermediate_solutions_sample_sat.cc #include -#include "absl/base/log_severity.h" -#include "absl/log/globals.h" #include "ortools/base/init_google.h" #include "ortools/base/logging.h" +#include "absl/base/log_severity.h" +#include "absl/log/globals.h" #include "ortools/sat/cp_model.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_solver.h" @@ -377,13 +381,14 @@ int main(int argc, char* argv[]) { ### Java code ```java +// Snippet from ortools/sat/samples/SolveAndPrintIntermediateSolutionsSampleSat.java package com.google.ortools.sat.samples; import com.google.ortools.Loader; +import com.google.ortools.sat.CpSolverStatus; import com.google.ortools.sat.CpModel; import com.google.ortools.sat.CpSolver; import com.google.ortools.sat.CpSolverSolutionCallback; -import com.google.ortools.sat.CpSolverStatus; import com.google.ortools.sat.IntVar; import com.google.ortools.sat.LinearExpr; import java.util.function.Consumer; @@ -473,7 +478,8 @@ public final class SolveAndPrintIntermediateSolutionsSampleSat { ### C\# code -```cs +```csharp +// Snippet from ortools/sat/samples/SolveAndPrintIntermediateSolutionsSampleSat.cs using System; using Google.OrTools.Sat; @@ -536,7 +542,7 @@ public class SolveAndPrintIntermediateSolutionsSampleSat ### Go code -```cs +```go // The solve_and_print_intermediate_solutions_sample_sat command package main @@ -544,9 +550,10 @@ import ( "fmt" log "github.com/golang/glog" - "github.com/google/or-tools/ortools/sat/go/cpmodel" - sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" "google.golang.org/protobuf/proto" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" ) func solveAndPrintIntermediateSolutionsSampleSat() error { @@ -570,7 +577,7 @@ func solveAndPrintIntermediateSolutionsSampleSat() error { // Currently, the CpModelBuilder does not allow for callbacks, so intermediate solutions // cannot be printed while solving. However, the CP-SAT solver does allow for returning // the intermediate solutions found while solving in the response. - params := &sppb.SatParameters{ + params := &sppb.SatParameters{ FillAdditionalSolutionsInResponse: proto.Bool(true), SolutionPoolSize: proto.Int32(10), } @@ -594,6 +601,7 @@ func main() { log.Exitf("solveAndPrintIntermediateSolutionsSampleSat returned with error: %v", err) } } + ``` ## Searching for all solutions in a satisfiability model @@ -603,16 +611,15 @@ languages except Go, you need to register a callback on the solver that will be called at each solution. For Go, callbacks are not implemented, but you can still get the intermediate solutions in the response. -Please note that it does not work in parallel -(i. e. parameter `num_search_workers` > 1). +Note that it does not work in parallel (i. e. parameter `num_workers` > 1). It also does not work if the model contains an objective. The method will return the following: - * *FEASIBLE* if some solutions have been found - * *INFEASIBLE* if the solver has proved there are no solution - * *OPTIMAL* if all solutions have been found +* *FEASIBLE* if some solutions have been found +* *INFEASIBLE* if the solver has proved there are no solution +* *OPTIMAL* if all solutions have been found The exact implementation depends on the target language. @@ -622,55 +629,53 @@ To search for all solutions, use the Solve() method after setting the correct parameter. ```python -#!/usr/bin/env python3 -"""Code sample that solves a model and displays all solutions.""" - +# Snippet from ortools/sat/samples/search_for_all_solutions_sample_sat.py from ortools.sat.python import cp_model class VarArraySolutionPrinter(cp_model.CpSolverSolutionCallback): - """Print intermediate solutions.""" + """Print intermediate solutions.""" - def __init__(self, variables: list[cp_model.IntVar]): - cp_model.CpSolverSolutionCallback.__init__(self) - self.__variables = variables - self.__solution_count = 0 + def __init__(self, variables: list[cp_model.IntVar]): + cp_model.CpSolverSolutionCallback.__init__(self) + self.__variables = variables + self.__solution_count = 0 - def on_solution_callback(self) -> None: - self.__solution_count += 1 - for v in self.__variables: - print(f"{v}={self.value(v)}", end=" ") - print() + def on_solution_callback(self) -> None: + self.__solution_count += 1 + for v in self.__variables: + print(f'{v}={self.value(v)}', end=' ') + print() - @property - def solution_count(self) -> int: - return self.__solution_count + @property + def solution_count(self) -> int: + return self.__solution_count def search_for_all_solutions_sample_sat(): - """Showcases calling the solver to search for all solutions.""" - # Creates the model. - model = cp_model.CpModel() + """Showcases calling the solver to search for all solutions.""" + # Creates the model. + model = cp_model.CpModel() - # Creates the variables. - num_vals = 3 - x = model.new_int_var(0, num_vals - 1, "x") - y = model.new_int_var(0, num_vals - 1, "y") - z = model.new_int_var(0, num_vals - 1, "z") + # Creates the variables. + num_vals = 3 + x = model.new_int_var(0, num_vals - 1, 'x') + y = model.new_int_var(0, num_vals - 1, 'y') + z = model.new_int_var(0, num_vals - 1, 'z') - # Create the constraints. - model.add(x != y) + # Create the constraints. + model.add(x != y) - # Create a solver and solve. - solver = cp_model.CpSolver() - solution_printer = VarArraySolutionPrinter([x, y, z]) - # Enumerate all solutions. - solver.parameters.enumerate_all_solutions = True - # Solve. - status = solver.solve(model, solution_printer) + # Create a solver and solve. + solver = cp_model.CpSolver() + solution_printer = VarArraySolutionPrinter([x, y, z]) + # Enumerate all solutions. + solver.parameters.enumerate_all_solutions = True + # Solve. + status = solver.solve(model, solution_printer) - print(f"Status = {solver.status_name(status)}") - print(f"Number of solutions found: {solution_printer.solution_count}") + print(f'Status = {solver.status_name(status)}') + print(f'Number of solutions found: {solution_printer.solution_count}') search_for_all_solutions_sample_sat() @@ -681,12 +686,13 @@ search_for_all_solutions_sample_sat() To search for all solutions, a parameter of the SAT solver must be changed. ```cpp +// Snippet from ortools/sat/samples/search_for_all_solutions_sample_sat.cc #include -#include "absl/base/log_severity.h" -#include "absl/log/globals.h" #include "ortools/base/init_google.h" #include "ortools/base/logging.h" +#include "absl/base/log_severity.h" +#include "absl/log/globals.h" #include "ortools/sat/cp_model.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_solver.h" @@ -744,13 +750,14 @@ As in Python, CpSolver.solve() must be called after setting the correct parameter. ```java +// Snippet from ortools/sat/samples/SearchForAllSolutionsSampleSat.java package com.google.ortools.sat.samples; import com.google.ortools.Loader; +import com.google.ortools.sat.CpSolverStatus; import com.google.ortools.sat.CpModel; import com.google.ortools.sat.CpSolver; import com.google.ortools.sat.CpSolverSolutionCallback; -import com.google.ortools.sat.CpSolverStatus; import com.google.ortools.sat.IntVar; /** Code sample that solves a model and displays all solutions. */ @@ -810,7 +817,8 @@ public class SearchForAllSolutionsSampleSat { As in Python, CpSolver.Solve() must be called after setting the correct string parameter. -```cs +```csharp +// Snippet from ortools/sat/samples/SearchForAllSolutionsSampleSat.cs using System; using Google.OrTools.Sat; @@ -876,7 +884,7 @@ public class SearchForAllSolutionsSampleSat To search for all solutions, a parameter of the SAT solver must be changed. -```cs +```go // The search_for_all_solutions_sample_sat command is an example for how to search for // all solutions. package main @@ -885,9 +893,10 @@ import ( "fmt" log "github.com/golang/glog" - "github.com/google/or-tools/ortools/sat/go/cpmodel" - sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" "google.golang.org/protobuf/proto" + "github.com/google/or-tools/ortools/sat/go/cpmodel" + + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" ) func searchForAllSolutionsSampleSat() error { @@ -907,7 +916,7 @@ func searchForAllSolutionsSampleSat() error { // Currently, the CpModelBuilder does not allow for callbacks, so each feasible solution cannot // be printed while solving. However, the CP Solver can return all of the enumerated solutions // in the response by setting the following parameters. - params := &sppb.SatParameters{ + params := &sppb.SatParameters{ EnumerateAllSolutions: proto.Bool(true), FillAdditionalSolutionsInResponse: proto.Bool(true), SolutionPoolSize: proto.Int32(27), @@ -932,6 +941,7 @@ func main() { log.Exitf("searchForAllSolutionsSampleSat returned with error: %v", err) } } + ``` ## Stopping search early @@ -946,55 +956,55 @@ You can stop the search by calling StopSearch() inside of CpSolverSolutionCallback.OnSolutionCallback(). ```python -#!/usr/bin/env python3 +# Snippet from ortools/sat/samples/stop_after_n_solutions_sample_sat.py """Code sample that solves a model and displays a small number of solutions.""" from ortools.sat.python import cp_model class VarArraySolutionPrinterWithLimit(cp_model.CpSolverSolutionCallback): - """Print intermediate solutions.""" + """Print intermediate solutions.""" - def __init__(self, variables: list[cp_model.IntVar], limit: int): - cp_model.CpSolverSolutionCallback.__init__(self) - self.__variables = variables - self.__solution_count = 0 - self.__solution_limit = limit + def __init__(self, variables: list[cp_model.IntVar], limit: int): + cp_model.CpSolverSolutionCallback.__init__(self) + self.__variables = variables + self.__solution_count = 0 + self.__solution_limit = limit - def on_solution_callback(self) -> None: - self.__solution_count += 1 - for v in self.__variables: - print(f"{v}={self.value(v)}", end=" ") - print() - if self.__solution_count >= self.__solution_limit: - print(f"Stop search after {self.__solution_limit} solutions") - self.stop_search() + def on_solution_callback(self) -> None: + self.__solution_count += 1 + for v in self.__variables: + print(f'{v}={self.value(v)}', end=' ') + print() + if self.__solution_count >= self.__solution_limit: + print(f'Stop search after {self.__solution_limit} solutions') + self.stop_search() - @property - def solution_count(self) -> int: - return self.__solution_count + @property + def solution_count(self) -> int: + return self.__solution_count def stop_after_n_solutions_sample_sat(): - """Showcases calling the solver to search for small number of solutions.""" - # Creates the model. - model = cp_model.CpModel() - # Creates the variables. - num_vals = 3 - x = model.new_int_var(0, num_vals - 1, "x") - y = model.new_int_var(0, num_vals - 1, "y") - z = model.new_int_var(0, num_vals - 1, "z") + """Showcases calling the solver to search for small number of solutions.""" + # Creates the model. + model = cp_model.CpModel() + # Creates the variables. + num_vals = 3 + x = model.new_int_var(0, num_vals - 1, 'x') + y = model.new_int_var(0, num_vals - 1, 'y') + z = model.new_int_var(0, num_vals - 1, 'z') - # Create a solver and solve. - solver = cp_model.CpSolver() - solution_printer = VarArraySolutionPrinterWithLimit([x, y, z], 5) - # Enumerate all solutions. - solver.parameters.enumerate_all_solutions = True - # Solve. - status = solver.solve(model, solution_printer) - print(f"Status = {solver.status_name(status)}") - print(f"Number of solutions found: {solution_printer.solution_count}") - assert solution_printer.solution_count == 5 + # Create a solver and solve. + solver = cp_model.CpSolver() + solution_printer = VarArraySolutionPrinterWithLimit([x, y, z], 5) + # Enumerate all solutions. + solver.parameters.enumerate_all_solutions = True + # Solve. + status = solver.solve(model, solution_printer) + print(f'Status = {solver.status_name(status)}') + print(f'Number of solutions found: {solution_printer.solution_count}') + assert solution_printer.solution_count == 5 stop_after_n_solutions_sample_sat() @@ -1006,14 +1016,15 @@ Stopping search is done by registering an atomic bool on the model-owned time limit, and setting that bool to true. ```cpp +// Snippet from ortools/sat/samples/stop_after_n_solutions_sample_sat.cc #include #include -#include "absl/base/log_severity.h" -#include "absl/log/globals.h" #include "ortools/base/init_google.h" #include "ortools/base/logging.h" +#include "absl/base/log_severity.h" +#include "absl/log/globals.h" #include "ortools/sat/cp_model.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_solver.h" @@ -1075,13 +1086,14 @@ Stopping search is performed by calling stopSearch() inside of CpSolverSolutionCallback.onSolutionCallback(). ```java +// Snippet from ortools/sat/samples/StopAfterNSolutionsSampleSat.java package com.google.ortools.sat.samples; import com.google.ortools.Loader; +import com.google.ortools.sat.CpSolverStatus; import com.google.ortools.sat.CpModel; import com.google.ortools.sat.CpSolver; import com.google.ortools.sat.CpSolverSolutionCallback; -import com.google.ortools.sat.CpSolverStatus; import com.google.ortools.sat.IntVar; /** Code sample that solves a model and displays a small number of solutions. */ @@ -1149,7 +1161,8 @@ public final class StopAfterNSolutionsSampleSat { Stopping search is performed by calling StopSearch() inside of CpSolverSolutionCallback.OnSolutionCallback(). -```cs +```csharp +// Snippet from ortools/sat/samples/StopAfterNSolutionsSampleSat.cs using System; using Google.OrTools.Sat; diff --git a/ortools/sat/docs/troubleshooting.md b/ortools/sat/docs/troubleshooting.md index 3c6d8f729d..c5c73edbc3 100644 --- a/ortools/sat/docs/troubleshooting.md +++ b/ortools/sat/docs/troubleshooting.md @@ -96,59 +96,59 @@ parallelism. Therefore, the number of workers must be set to 1. ### Python code sample ```python -#!/usr/bin/env python3 -"""Code sample that solves a model and gets the infeasibility assumptions.""" +# Snippet from ortools/sat/samples/assumptions_sample_sat.py from ortools.sat.python import cp_model def main() -> None: - """Showcases assumptions.""" - # Creates the model. - model = cp_model.CpModel() + """Showcases assumptions.""" + # Creates the model. + model = cp_model.CpModel() - # Creates the variables. - x = model.new_int_var(0, 10, "x") - y = model.new_int_var(0, 10, "y") - z = model.new_int_var(0, 10, "z") - a = model.new_bool_var("a") - b = model.new_bool_var("b") - c = model.new_bool_var("c") + # Creates the variables. + x = model.new_int_var(0, 10, 'x') + y = model.new_int_var(0, 10, 'y') + z = model.new_int_var(0, 10, 'z') + a = model.new_bool_var('a') + b = model.new_bool_var('b') + c = model.new_bool_var('c') - # Creates the constraints. - model.add(x > y).only_enforce_if(a) - model.add(y > z).only_enforce_if(b) - model.add(z > x).only_enforce_if(c) + # Creates the constraints. + model.add(x > y).only_enforce_if(a) + model.add(y > z).only_enforce_if(b) + model.add(z > x).only_enforce_if(c) - # Add assumptions - model.add_assumptions([a, b, c]) + # Add assumptions + model.add_assumptions([a, b, c]) - # Creates a solver and solves. - solver = cp_model.CpSolver() - status = solver.solve(model) + # Creates a solver and solves. + solver = cp_model.CpSolver() + status = solver.solve(model) - # Print solution. - print(f"Status = {solver.status_name(status)}") - if status == cp_model.INFEASIBLE: - print( - "sufficient_assumptions_for_infeasibility = " - f"{solver.sufficient_assumptions_for_infeasibility()}" - ) + # Print solution. + print(f'Status = {solver.status_name(status)}') + if status == cp_model.INFEASIBLE: + print( + 'sufficient_assumptions_for_infeasibility = ' + f'{solver.sufficient_assumptions_for_infeasibility()}' + ) -if __name__ == "__main__": - main() +if __name__ == '__main__': + main() ``` ### C++ code samples ```cpp +// Snippet from ortools/sat/samples/assumptions_sample_sat.cc #include +#include "ortools/base/init_google.h" +#include "ortools/base/logging.h" #include "absl/base/log_severity.h" #include "absl/log/globals.h" #include "absl/types/span.h" -#include "ortools/base/init_google.h" -#include "ortools/base/logging.h" #include "ortools/sat/cp_model.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_solver.h" @@ -201,11 +201,12 @@ int main(int argc, char* argv[]) { ### Java code samples ```java +// Snippet from ortools/sat/samples/AssumptionsSampleSat.java package com.google.ortools.sat.samples; import com.google.ortools.Loader; +import com.google.ortools.sat.CpSolverStatus; import com.google.ortools.sat.CpModel; import com.google.ortools.sat.CpSolver; -import com.google.ortools.sat.CpSolverStatus; import com.google.ortools.sat.IntVar; import com.google.ortools.sat.Literal; @@ -249,7 +250,8 @@ public class AssumptionsSampleSat { ### C\# code samples -```cs +```csharp +// Snippet from ortools/sat/samples/AssumptionsSampleSat.cs using System; using Google.OrTools.Sat; diff --git a/ortools/sat/integer.cc b/ortools/sat/integer.cc index 5ecc0455a8..57f4d9db2f 100644 --- a/ortools/sat/integer.cc +++ b/ortools/sat/integer.cc @@ -1245,9 +1245,17 @@ bool IntegerTrail::RootLevelEnqueue(IntegerLiteral i_lit) { bool IntegerTrail::SafeEnqueue( IntegerLiteral i_lit, absl::Span integer_reason) { + return SafeEnqueue(i_lit, {}, integer_reason); +} + +bool IntegerTrail::SafeEnqueue( + IntegerLiteral i_lit, absl::Span literal_reason, + absl::Span integer_reason) { // Note that ReportConflict() deal correctly with constant literals. if (i_lit.IsAlwaysTrue()) return true; - if (i_lit.IsAlwaysFalse()) return ReportConflict({}, integer_reason); + if (i_lit.IsAlwaysFalse()) { + return ReportConflict(literal_reason, integer_reason); + } // Most of our propagation code do not use "constant" literal, so to not // have to test for them in Enqueue(), we clear them beforehand. @@ -1257,7 +1265,7 @@ bool IntegerTrail::SafeEnqueue( if (lit.IsAlwaysTrue()) continue; tmp_cleaned_reason_.push_back(lit); } - return Enqueue(i_lit, {}, tmp_cleaned_reason_); + return Enqueue(i_lit, literal_reason, tmp_cleaned_reason_); } bool IntegerTrail::ConditionalEnqueue( diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index 1b926092e7..3d667b9a3c 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -629,6 +629,9 @@ class IntegerTrail final : public SatPropagator { // ReportConflict() or Enqueue(). ABSL_MUST_USE_RESULT bool SafeEnqueue( IntegerLiteral i_lit, absl::Span integer_reason); + ABSL_MUST_USE_RESULT bool SafeEnqueue( + IntegerLiteral i_lit, absl::Span literal_reason, + absl::Span integer_reason); // Pushes the given integer literal assuming that the Boolean literal is true. // This can do a few things: diff --git a/ortools/sat/integer_base.cc b/ortools/sat/integer_base.cc index 29a7d8d186..6775f989a7 100644 --- a/ortools/sat/integer_base.cc +++ b/ortools/sat/integer_base.cc @@ -119,7 +119,12 @@ bool LinearExpression2::IsCanonicalized() const { return false; } } - if (vars[0] >= vars[1]) return false; + if (vars[0] >= vars[1]) { + if (vars[0] == kNoIntegerVariable && vars[1] == kNoIntegerVariable) { + return true; + } + return false; + } if (vars[0] == kNoIntegerVariable) return true; diff --git a/ortools/sat/integer_base.h b/ortools/sat/integer_base.h index 14aa492ced..f757b792aa 100644 --- a/ortools/sat/integer_base.h +++ b/ortools/sat/integer_base.h @@ -358,7 +358,9 @@ H AbslHashValue(H h, const AffineExpression& e) { // A linear expression with at most two variables (coeffs can be zero). // And some utility to canonicalize them. struct LinearExpression2 { + // Construct a zero expression. LinearExpression2() = default; + LinearExpression2(IntegerVariable v1, IntegerVariable v2, IntegerValue c1, IntegerValue c2) { vars[0] = v1; @@ -428,7 +430,7 @@ struct LinearExpression2 { } IntegerValue coeffs[2]; - IntegerVariable vars[2]; + IntegerVariable vars[2] = {kNoIntegerVariable, kNoIntegerVariable}; template friend void AbslStringify(Sink& sink, const LinearExpression2& expr) { diff --git a/ortools/sat/integer_base_test.cc b/ortools/sat/integer_base_test.cc index 10774a554a..358ed06aaf 100644 --- a/ortools/sat/integer_base_test.cc +++ b/ortools/sat/integer_base_test.cc @@ -15,6 +15,7 @@ #include +#include "absl/log/check.h" #include "gtest/gtest.h" namespace operations_research::sat { @@ -22,6 +23,7 @@ namespace { TEST(CanonicalizeAffinePrecedenceTest, Basic) { LinearExpression2 expr; + CHECK(expr.IsCanonicalized()) << expr; expr.vars[0] = IntegerVariable(0); expr.vars[1] = IntegerVariable(2); expr.coeffs[0] = IntegerValue(4); @@ -30,6 +32,7 @@ TEST(CanonicalizeAffinePrecedenceTest, Basic) { IntegerValue lb(0); IntegerValue ub(11); expr.CanonicalizeAndUpdateBounds(lb, ub); + CHECK(expr.IsCanonicalized()); EXPECT_EQ(expr.vars[0], IntegerVariable(0)); EXPECT_EQ(expr.vars[1], IntegerVariable(2)); @@ -47,6 +50,7 @@ TEST(CanonicalizeAffinePrecedenceTest, OneSingleVariable) { expr.coeffs[1] = IntegerValue(2); expr.SimpleCanonicalization(); + CHECK(expr.IsCanonicalized()); EXPECT_EQ(expr.vars[0], kNoIntegerVariable); EXPECT_EQ(expr.vars[1], IntegerVariable(0)); diff --git a/ortools/sat/integer_expr.cc b/ortools/sat/integer_expr.cc index c74909f5fd..e7341a53e8 100644 --- a/ortools/sat/integer_expr.cc +++ b/ortools/sat/integer_expr.cc @@ -26,6 +26,7 @@ #include "absl/types/span.h" #include "ortools/base/logging.h" #include "ortools/base/mathutil.h" +#include "ortools/sat/cp_constraints.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/linear_constraint.h" @@ -848,38 +849,47 @@ void LinMinPropagator::RegisterWith(GenericLiteralWatcher* watcher) { } } -ProductPropagator::ProductPropagator(AffineExpression a, AffineExpression b, - AffineExpression p, - IntegerTrail* integer_trail) - : a_(a), b_(b), p_(p), integer_trail_(integer_trail) {} +ProductPropagator::ProductPropagator( + absl::Span enforcement_literals, AffineExpression a, + AffineExpression b, AffineExpression p, Model* model) + : a_(a), + b_(b), + p_(p), + integer_trail_(*model->GetOrCreate()), + enforcement_propagator_(*model->GetOrCreate()) { + GenericLiteralWatcher* watcher = model->GetOrCreate(); + enforcement_id_ = enforcement_propagator_.Register( + enforcement_literals, watcher, RegisterWith(watcher)); +} // We want all affine expression to be either non-negative or across zero. bool ProductPropagator::CanonicalizeCases() { - if (integer_trail_->UpperBound(a_) <= 0) { + if (integer_trail_.UpperBound(a_) <= 0) { a_ = a_.Negated(); p_ = p_.Negated(); } - if (integer_trail_->UpperBound(b_) <= 0) { + if (integer_trail_.UpperBound(b_) <= 0) { b_ = b_.Negated(); p_ = p_.Negated(); } // If both a and b positive, p must be too. - if (integer_trail_->LowerBound(a_) >= 0 && - integer_trail_->LowerBound(b_) >= 0) { - return integer_trail_->SafeEnqueue( - p_.GreaterOrEqual(0), {a_.GreaterOrEqual(0), b_.GreaterOrEqual(0)}); + if (integer_trail_.LowerBound(a_) >= 0 && + integer_trail_.LowerBound(b_) >= 0) { + return enforcement_propagator_.SafeEnqueue( + enforcement_id_, p_.GreaterOrEqual(0), + {a_.GreaterOrEqual(0), b_.GreaterOrEqual(0)}); } // Otherwise, make sure p is non-negative or across zero. - if (integer_trail_->UpperBound(p_) <= 0) { - if (integer_trail_->LowerBound(a_) < 0) { - DCHECK_GT(integer_trail_->UpperBound(a_), 0); + if (integer_trail_.UpperBound(p_) <= 0) { + if (integer_trail_.LowerBound(a_) < 0) { + DCHECK_GT(integer_trail_.UpperBound(a_), 0); a_ = a_.Negated(); p_ = p_.Negated(); } else { - DCHECK_LT(integer_trail_->LowerBound(b_), 0); - DCHECK_GT(integer_trail_->UpperBound(b_), 0); + DCHECK_LT(integer_trail_.LowerBound(b_), 0); + DCHECK_GT(integer_trail_.UpperBound(b_), 0); b_ = b_.Negated(); p_ = p_.Negated(); } @@ -896,14 +906,14 @@ bool ProductPropagator::CanonicalizeCases() { // smallest domain size between a or b). bool ProductPropagator::PropagateWhenAllNonNegative() { { - const IntegerValue max_a = integer_trail_->UpperBound(a_); - const IntegerValue max_b = integer_trail_->UpperBound(b_); + const IntegerValue max_a = integer_trail_.UpperBound(a_); + const IntegerValue max_b = integer_trail_.UpperBound(b_); const IntegerValue new_max = CapProdI(max_a, max_b); - if (new_max < integer_trail_->UpperBound(p_)) { - if (!integer_trail_->SafeEnqueue( - p_.LowerOrEqual(new_max), - {integer_trail_->UpperBoundAsLiteral(a_), - integer_trail_->UpperBoundAsLiteral(b_), a_.GreaterOrEqual(0), + if (new_max < integer_trail_.UpperBound(p_)) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, p_.LowerOrEqual(new_max), + {integer_trail_.UpperBoundAsLiteral(a_), + integer_trail_.UpperBoundAsLiteral(b_), a_.GreaterOrEqual(0), b_.GreaterOrEqual(0)})) { return false; } @@ -911,23 +921,23 @@ bool ProductPropagator::PropagateWhenAllNonNegative() { } { - const IntegerValue min_a = integer_trail_->LowerBound(a_); - const IntegerValue min_b = integer_trail_->LowerBound(b_); + const IntegerValue min_a = integer_trail_.LowerBound(a_); + const IntegerValue min_b = integer_trail_.LowerBound(b_); const IntegerValue new_min = CapProdI(min_a, min_b); // The conflict test is needed because when new_min is large, we could // have an overflow in p_.GreaterOrEqual(new_min); - if (new_min > integer_trail_->UpperBound(p_)) { - return integer_trail_->ReportConflict( - {integer_trail_->UpperBoundAsLiteral(p_), - integer_trail_->LowerBoundAsLiteral(a_), - integer_trail_->LowerBoundAsLiteral(b_)}); + if (new_min > integer_trail_.UpperBound(p_)) { + return enforcement_propagator_.ReportConflict( + enforcement_id_, {integer_trail_.UpperBoundAsLiteral(p_), + integer_trail_.LowerBoundAsLiteral(a_), + integer_trail_.LowerBoundAsLiteral(b_)}); } - if (new_min > integer_trail_->LowerBound(p_)) { - if (!integer_trail_->SafeEnqueue( - p_.GreaterOrEqual(new_min), - {integer_trail_->LowerBoundAsLiteral(a_), - integer_trail_->LowerBoundAsLiteral(b_)})) { + if (new_min > integer_trail_.LowerBound(p_)) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, p_.GreaterOrEqual(new_min), + {integer_trail_.LowerBoundAsLiteral(a_), + integer_trail_.LowerBoundAsLiteral(b_)})) { return false; } } @@ -936,23 +946,23 @@ bool ProductPropagator::PropagateWhenAllNonNegative() { for (int i = 0; i < 2; ++i) { const AffineExpression a = i == 0 ? a_ : b_; const AffineExpression b = i == 0 ? b_ : a_; - const IntegerValue max_a = integer_trail_->UpperBound(a); - const IntegerValue min_b = integer_trail_->LowerBound(b); - const IntegerValue min_p = integer_trail_->LowerBound(p_); - const IntegerValue max_p = integer_trail_->UpperBound(p_); + const IntegerValue max_a = integer_trail_.UpperBound(a); + const IntegerValue min_b = integer_trail_.LowerBound(b); + const IntegerValue min_p = integer_trail_.LowerBound(p_); + const IntegerValue max_p = integer_trail_.UpperBound(p_); const IntegerValue prod = CapProdI(max_a, min_b); if (prod > max_p) { - if (!integer_trail_->SafeEnqueue(a.LowerOrEqual(FloorRatio(max_p, min_b)), - {integer_trail_->LowerBoundAsLiteral(b), - integer_trail_->UpperBoundAsLiteral(p_), - p_.GreaterOrEqual(0)})) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, a.LowerOrEqual(FloorRatio(max_p, min_b)), + {integer_trail_.LowerBoundAsLiteral(b), + integer_trail_.UpperBoundAsLiteral(p_), p_.GreaterOrEqual(0)})) { return false; } } else if (prod < min_p && max_a != 0) { - if (!integer_trail_->SafeEnqueue( - b.GreaterOrEqual(CeilRatio(min_p, max_a)), - {integer_trail_->UpperBoundAsLiteral(a), - integer_trail_->LowerBoundAsLiteral(p_), a.GreaterOrEqual(0)})) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, b.GreaterOrEqual(CeilRatio(min_p, max_a)), + {integer_trail_.UpperBoundAsLiteral(a), + integer_trail_.LowerBoundAsLiteral(p_), a.GreaterOrEqual(0)})) { return false; } } @@ -969,14 +979,14 @@ bool ProductPropagator::PropagateMaxOnPositiveProduct(AffineExpression a, AffineExpression b, IntegerValue min_p, IntegerValue max_p) { - const IntegerValue max_a = integer_trail_->UpperBound(a); + const IntegerValue max_a = integer_trail_.UpperBound(a); if (max_a <= 0) return true; DCHECK_GT(min_p, 0); if (max_a >= min_p) { if (max_p < max_a) { - if (!integer_trail_->SafeEnqueue( - a.LowerOrEqual(max_p), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, a.LowerOrEqual(max_p), {p_.LowerOrEqual(max_p), p_.GreaterOrEqual(1)})) { return false; } @@ -985,23 +995,24 @@ bool ProductPropagator::PropagateMaxOnPositiveProduct(AffineExpression a, } const IntegerValue min_pos_b = CeilRatio(min_p, max_a); - if (min_pos_b > integer_trail_->UpperBound(b)) { - if (!integer_trail_->SafeEnqueue( - b.LowerOrEqual(0), {integer_trail_->LowerBoundAsLiteral(p_), - integer_trail_->UpperBoundAsLiteral(a), - integer_trail_->UpperBoundAsLiteral(b)})) { + if (min_pos_b > integer_trail_.UpperBound(b)) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, b.LowerOrEqual(0), + {integer_trail_.LowerBoundAsLiteral(p_), + integer_trail_.UpperBoundAsLiteral(a), + integer_trail_.UpperBoundAsLiteral(b)})) { return false; } return true; } const IntegerValue new_max_a = FloorRatio(max_p, min_pos_b); - if (new_max_a < integer_trail_->UpperBound(a)) { - if (!integer_trail_->SafeEnqueue( - a.LowerOrEqual(new_max_a), - {integer_trail_->LowerBoundAsLiteral(p_), - integer_trail_->UpperBoundAsLiteral(a), - integer_trail_->UpperBoundAsLiteral(p_)})) { + if (new_max_a < integer_trail_.UpperBound(a)) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, a.LowerOrEqual(new_max_a), + {integer_trail_.LowerBoundAsLiteral(p_), + integer_trail_.UpperBoundAsLiteral(a), + integer_trail_.UpperBoundAsLiteral(p_)})) { return false; } } @@ -1009,15 +1020,53 @@ bool ProductPropagator::PropagateMaxOnPositiveProduct(AffineExpression a, } bool ProductPropagator::Propagate() { + const EnforcementStatus status = + enforcement_propagator_.Status(enforcement_id_); + if (status == EnforcementStatus::CAN_PROPAGATE) { + const int64_t min_a = integer_trail_.LowerBound(a_).value(); + const int64_t max_a = integer_trail_.UpperBound(a_).value(); + const int64_t min_b = integer_trail_.LowerBound(b_).value(); + const int64_t max_b = integer_trail_.UpperBound(b_).value(); + const int64_t min_p = integer_trail_.LowerBound(p_).value(); + const int64_t max_p = integer_trail_.UpperBound(p_).value(); + const int64_t p1 = CapProdI(max_a, max_b).value(); + const int64_t p2 = CapProdI(max_a, min_b).value(); + const int64_t p3 = CapProdI(min_a, max_b).value(); + const int64_t p4 = CapProdI(min_a, min_b).value(); + const int64_t min_ab = std::min({p1, p2, p3, p4}); + const int64_t max_ab = std::max({p1, p2, p3, p4}); + // If the bounds of a * b and p are disjoint, the enforcement must be false. + // TODO(user): relax the reason in a better way. + if (min_ab > max_p) { + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, + /*literal_reason=*/{}, + {a_.GreaterOrEqual(min_a), a_.LowerOrEqual(max_a), + b_.GreaterOrEqual(min_b), b_.LowerOrEqual(max_b), + p_.LowerOrEqual(max_p)}); + } + if (min_p > max_ab) { + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, + /*literal_reason=*/{}, + {a_.GreaterOrEqual(min_a), a_.LowerOrEqual(max_a), + b_.GreaterOrEqual(min_b), b_.LowerOrEqual(max_b), + p_.GreaterOrEqual(min_p)}); + } + // Otherwise we cannot propagate anything since the enforcement is unknown. + return true; + } + + if (status != EnforcementStatus::IS_ENFORCED) return true; if (!CanonicalizeCases()) return false; // In the most common case, we use better reasons even though the code // below would propagate the same. - const int64_t min_a = integer_trail_->LowerBound(a_).value(); - const int64_t min_b = integer_trail_->LowerBound(b_).value(); + const int64_t min_a = integer_trail_.LowerBound(a_).value(); + const int64_t min_b = integer_trail_.LowerBound(b_).value(); if (min_a >= 0 && min_b >= 0) { // This was done by CanonicalizeCases(). - DCHECK_GE(integer_trail_->LowerBound(p_), 0); + DCHECK_GE(integer_trail_.LowerBound(p_), 0); return PropagateWhenAllNonNegative(); } @@ -1027,65 +1076,67 @@ bool ProductPropagator::Propagate() { // // TODO(user): In the reasons, including all 4 bounds is always correct, but // we might be able to relax some of them. - const IntegerValue max_a = integer_trail_->UpperBound(a_); - const IntegerValue max_b = integer_trail_->UpperBound(b_); + const IntegerValue max_a = integer_trail_.UpperBound(a_); + const IntegerValue max_b = integer_trail_.UpperBound(b_); const IntegerValue p1 = CapProdI(max_a, max_b); const IntegerValue p2 = CapProdI(max_a, min_b); const IntegerValue p3 = CapProdI(min_a, max_b); const IntegerValue p4 = CapProdI(min_a, min_b); const IntegerValue new_max_p = std::max({p1, p2, p3, p4}); - if (new_max_p < integer_trail_->UpperBound(p_)) { - if (!integer_trail_->SafeEnqueue( - p_.LowerOrEqual(new_max_p), - {integer_trail_->LowerBoundAsLiteral(a_), - integer_trail_->LowerBoundAsLiteral(b_), - integer_trail_->UpperBoundAsLiteral(a_), - integer_trail_->UpperBoundAsLiteral(b_)})) { + if (new_max_p < integer_trail_.UpperBound(p_)) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, p_.LowerOrEqual(new_max_p), + {integer_trail_.LowerBoundAsLiteral(a_), + integer_trail_.LowerBoundAsLiteral(b_), + integer_trail_.UpperBoundAsLiteral(a_), + integer_trail_.UpperBoundAsLiteral(b_)})) { return false; } } const IntegerValue new_min_p = std::min({p1, p2, p3, p4}); - if (new_min_p > integer_trail_->LowerBound(p_)) { - if (!integer_trail_->SafeEnqueue( - p_.GreaterOrEqual(new_min_p), - {integer_trail_->LowerBoundAsLiteral(a_), - integer_trail_->LowerBoundAsLiteral(b_), - integer_trail_->UpperBoundAsLiteral(a_), - integer_trail_->UpperBoundAsLiteral(b_)})) { + if (new_min_p > integer_trail_.LowerBound(p_)) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, p_.GreaterOrEqual(new_min_p), + {integer_trail_.LowerBoundAsLiteral(a_), + integer_trail_.LowerBoundAsLiteral(b_), + integer_trail_.UpperBoundAsLiteral(a_), + integer_trail_.UpperBoundAsLiteral(b_)})) { return false; } } // Lets propagate on a and b. - const IntegerValue min_p = integer_trail_->LowerBound(p_); - const IntegerValue max_p = integer_trail_->UpperBound(p_); + const IntegerValue min_p = integer_trail_.LowerBound(p_); + const IntegerValue max_p = integer_trail_.UpperBound(p_); // We need a bit more propagation to avoid bad cases below. const bool zero_is_possible = min_p <= 0; if (!zero_is_possible) { - if (integer_trail_->LowerBound(a_) == 0) { - if (!integer_trail_->SafeEnqueue( - a_.GreaterOrEqual(1), + if (integer_trail_.LowerBound(a_) == 0) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, a_.GreaterOrEqual(1), {p_.GreaterOrEqual(1), a_.GreaterOrEqual(0)})) { return false; } } - if (integer_trail_->LowerBound(b_) == 0) { - if (!integer_trail_->SafeEnqueue( - b_.GreaterOrEqual(1), + if (integer_trail_.LowerBound(b_) == 0) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, b_.GreaterOrEqual(1), {p_.GreaterOrEqual(1), b_.GreaterOrEqual(0)})) { return false; } } - if (integer_trail_->LowerBound(a_) >= 0 && - integer_trail_->LowerBound(b_) <= 0) { - return integer_trail_->SafeEnqueue( - b_.GreaterOrEqual(1), {a_.GreaterOrEqual(0), p_.GreaterOrEqual(1)}); + if (integer_trail_.LowerBound(a_) >= 0 && + integer_trail_.LowerBound(b_) <= 0) { + return enforcement_propagator_.SafeEnqueue( + enforcement_id_, b_.GreaterOrEqual(1), + {a_.GreaterOrEqual(0), p_.GreaterOrEqual(1)}); } - if (integer_trail_->LowerBound(b_) >= 0 && - integer_trail_->LowerBound(a_) <= 0) { - return integer_trail_->SafeEnqueue( - a_.GreaterOrEqual(1), {b_.GreaterOrEqual(0), p_.GreaterOrEqual(1)}); + if (integer_trail_.LowerBound(b_) >= 0 && + integer_trail_.LowerBound(a_) <= 0) { + return enforcement_propagator_.SafeEnqueue( + enforcement_id_, a_.GreaterOrEqual(1), + {b_.GreaterOrEqual(0), p_.GreaterOrEqual(1)}); } } @@ -1093,8 +1144,8 @@ bool ProductPropagator::Propagate() { // p = a * b, what is the min/max of a? const AffineExpression a = i == 0 ? a_ : b_; const AffineExpression b = i == 0 ? b_ : a_; - const IntegerValue max_b = integer_trail_->UpperBound(b); - const IntegerValue min_b = integer_trail_->LowerBound(b); + const IntegerValue max_b = integer_trail_.UpperBound(b); + const IntegerValue min_b = integer_trail_.LowerBound(b); // If the domain of b contain zero, we can't propagate anything on a. // Because of CanonicalizeCases(), we just deal with min_b > 0 here. @@ -1120,30 +1171,32 @@ bool ProductPropagator::Propagate() { // If it does, we should reach the fixed point on the next iteration. if (min_b <= 0) continue; if (min_p >= 0) { - return integer_trail_->SafeEnqueue( - a.GreaterOrEqual(0), {p_.GreaterOrEqual(0), b.GreaterOrEqual(1)}); + return enforcement_propagator_.SafeEnqueue( + enforcement_id_, a.GreaterOrEqual(0), + {p_.GreaterOrEqual(0), b.GreaterOrEqual(1)}); } if (max_p <= 0) { - return integer_trail_->SafeEnqueue( - a.LowerOrEqual(0), {p_.LowerOrEqual(0), b.GreaterOrEqual(1)}); + return enforcement_propagator_.SafeEnqueue( + enforcement_id_, a.LowerOrEqual(0), + {p_.LowerOrEqual(0), b.GreaterOrEqual(1)}); } // So min_b > 0 and p is across zero: min_p < 0 and max_p > 0. const IntegerValue new_max_a = FloorRatio(max_p, min_b); - if (new_max_a < integer_trail_->UpperBound(a)) { - if (!integer_trail_->SafeEnqueue( - a.LowerOrEqual(new_max_a), - {integer_trail_->UpperBoundAsLiteral(p_), - integer_trail_->LowerBoundAsLiteral(b)})) { + if (new_max_a < integer_trail_.UpperBound(a)) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, a.LowerOrEqual(new_max_a), + {integer_trail_.UpperBoundAsLiteral(p_), + integer_trail_.LowerBoundAsLiteral(b)})) { return false; } } const IntegerValue new_min_a = CeilRatio(min_p, min_b); - if (new_min_a > integer_trail_->LowerBound(a)) { - if (!integer_trail_->SafeEnqueue( - a.GreaterOrEqual(new_min_a), - {integer_trail_->LowerBoundAsLiteral(p_), - integer_trail_->LowerBoundAsLiteral(b)})) { + if (new_min_a > integer_trail_.LowerBound(a)) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, a.GreaterOrEqual(new_min_a), + {integer_trail_.LowerBoundAsLiteral(p_), + integer_trail_.LowerBoundAsLiteral(b)})) { return false; } } @@ -1152,52 +1205,84 @@ bool ProductPropagator::Propagate() { return true; } -void ProductPropagator::RegisterWith(GenericLiteralWatcher* watcher) { +int ProductPropagator::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); watcher->WatchAffineExpression(a_, id); watcher->WatchAffineExpression(b_, id); watcher->WatchAffineExpression(p_, id); watcher->NotifyThatPropagatorMayNotReachFixedPointInOnePass(id); + return id; } -SquarePropagator::SquarePropagator(AffineExpression x, AffineExpression s, - IntegerTrail* integer_trail) - : x_(x), s_(s), integer_trail_(integer_trail) { - CHECK_GE(integer_trail->LevelZeroLowerBound(x), 0); +SquarePropagator::SquarePropagator( + absl::Span enforcement_literals, AffineExpression x, + AffineExpression s, Model* model) + : x_(x), + s_(s), + integer_trail_(*model->GetOrCreate()), + enforcement_propagator_(*model->GetOrCreate()) { + GenericLiteralWatcher* watcher = model->GetOrCreate(); + enforcement_id_ = enforcement_propagator_.Register( + enforcement_literals, watcher, RegisterWith(watcher)); + CHECK_GE(integer_trail_.LevelZeroLowerBound(x), 0); } // Propagation from x to s: s in [min_x * min_x, max_x * max_x]. // Propagation from s to x: x in [ceil(sqrt(min_s)), floor(sqrt(max_s))]. bool SquarePropagator::Propagate() { - const IntegerValue min_x = integer_trail_->LowerBound(x_); - const IntegerValue min_s = integer_trail_->LowerBound(s_); + const IntegerValue min_x = integer_trail_.LowerBound(x_); + const IntegerValue min_s = integer_trail_.LowerBound(s_); const IntegerValue min_x_square = CapProdI(min_x, min_x); + const IntegerValue max_x = integer_trail_.UpperBound(x_); + const IntegerValue max_s = integer_trail_.UpperBound(s_); + const IntegerValue max_x_square = CapProdI(max_x, max_x); + + const EnforcementStatus status = + enforcement_propagator_.Status(enforcement_id_); + if (status == EnforcementStatus::CAN_PROPAGATE) { + // If the bounds of x * x and s are disjoint, the enforcement must be false. + // TODO(user): relax the reason in a better way. + if (min_x_square > max_s) { + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, + /*literal_reason=*/{}, + {x_.GreaterOrEqual(min_x), s_.LowerOrEqual(min_x - 1)}); + } + if (min_s > max_x_square) { + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, + /*literal_reason=*/{}, + {s_.GreaterOrEqual(min_s), x_.LowerOrEqual(min_s - 1)}); + } + // Otherwise we cannot propagate anything since the enforcement is unknown. + return true; + } + + if (status != EnforcementStatus::IS_ENFORCED) return true; if (min_x_square > min_s) { - if (!integer_trail_->SafeEnqueue(s_.GreaterOrEqual(min_x_square), - {x_.GreaterOrEqual(min_x)})) { + if (!enforcement_propagator_.SafeEnqueue(enforcement_id_, + s_.GreaterOrEqual(min_x_square), + {x_.GreaterOrEqual(min_x)})) { return false; } } else if (min_x_square < min_s) { const IntegerValue new_min(CeilSquareRoot(min_s.value())); - if (!integer_trail_->SafeEnqueue( - x_.GreaterOrEqual(new_min), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, x_.GreaterOrEqual(new_min), {s_.GreaterOrEqual((new_min - 1) * (new_min - 1) + 1)})) { return false; } } - - const IntegerValue max_x = integer_trail_->UpperBound(x_); - const IntegerValue max_s = integer_trail_->UpperBound(s_); - const IntegerValue max_x_square = CapProdI(max_x, max_x); if (max_x_square < max_s) { - if (!integer_trail_->SafeEnqueue(s_.LowerOrEqual(max_x_square), - {x_.LowerOrEqual(max_x)})) { + if (!enforcement_propagator_.SafeEnqueue(enforcement_id_, + s_.LowerOrEqual(max_x_square), + {x_.LowerOrEqual(max_x)})) { return false; } } else if (max_x_square > max_s) { const IntegerValue new_max(FloorSquareRoot(max_s.value())); - if (!integer_trail_->SafeEnqueue( - x_.LowerOrEqual(new_max), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, x_.LowerOrEqual(new_max), {s_.LowerOrEqual(CapProdI(new_max + 1, new_max + 1) - 1)})) { return false; } @@ -1206,32 +1291,37 @@ bool SquarePropagator::Propagate() { return true; } -void SquarePropagator::RegisterWith(GenericLiteralWatcher* watcher) { +int SquarePropagator::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); watcher->WatchAffineExpression(x_, id); watcher->WatchAffineExpression(s_, id); watcher->NotifyThatPropagatorMayNotReachFixedPointInOnePass(id); + return id; } -DivisionPropagator::DivisionPropagator(AffineExpression num, - AffineExpression denom, - AffineExpression div, - IntegerTrail* integer_trail) +DivisionPropagator::DivisionPropagator( + absl::Span enforcement_literals, AffineExpression num, + AffineExpression denom, AffineExpression div, Model* model) : num_(num), denom_(denom), div_(div), negated_denom_(denom.Negated()), negated_num_(num.Negated()), negated_div_(div.Negated()), - integer_trail_(integer_trail) {} + integer_trail_(*model->GetOrCreate()), + enforcement_propagator_(*model->GetOrCreate()) { + GenericLiteralWatcher* watcher = model->GetOrCreate(); + enforcement_id_ = enforcement_propagator_.Register( + enforcement_literals, watcher, RegisterWith(watcher)); +} // TODO(user): We can propagate more, especially in the case where denom // spans across 0. // TODO(user): We can propagate a bit more if min_div = 0: // (min_num > -min_denom). bool DivisionPropagator::Propagate() { - if (integer_trail_->LowerBound(denom_) < 0 && - integer_trail_->UpperBound(denom_) > 0) { + if (integer_trail_.LowerBound(denom_) < 0 && + integer_trail_.UpperBound(denom_) > 0) { return true; } @@ -1240,32 +1330,62 @@ bool DivisionPropagator::Propagate() { AffineExpression denom = denom_; AffineExpression negated_denom = negated_denom_; - if (integer_trail_->UpperBound(denom) < 0) { + if (integer_trail_.UpperBound(denom) < 0) { std::swap(num, negated_num); std::swap(denom, negated_denom); } + const EnforcementStatus status = + enforcement_propagator_.Status(enforcement_id_); + if (status == EnforcementStatus::CAN_PROPAGATE) { + const IntegerValue min_num = integer_trail_.LowerBound(num); + const IntegerValue max_num = integer_trail_.UpperBound(num); + const IntegerValue min_denom = integer_trail_.LowerBound(denom); + const IntegerValue max_denom = integer_trail_.UpperBound(denom); + const IntegerValue min_div = integer_trail_.LowerBound(div_); + const IntegerValue max_div = integer_trail_.UpperBound(div_); + // If the bounds of num / denom and div are disjoint, the enforcement must + // be false. TODO(user): relax the reason in a better way. + if (min_num / max_denom > max_div) { + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, + /*literal_reason=*/{}, + {num_.GreaterOrEqual(min_num), denom_.LowerOrEqual(max_denom), + div_.LowerOrEqual(max_div)}); + } + if (max_num / min_denom < min_div) { + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, + /*literal_reason=*/{}, + {num_.LowerOrEqual(max_num), denom_.GreaterOrEqual(min_denom), + div_.GreaterOrEqual(min_div)}); + } + // Otherwise we cannot propagate anything since the enforcement is unknown. + return true; + } + + if (status != EnforcementStatus::IS_ENFORCED) return true; if (!PropagateSigns(num, denom, div_)) return false; - if (integer_trail_->UpperBound(num) >= 0 && - integer_trail_->UpperBound(div_) >= 0 && + if (integer_trail_.UpperBound(num) >= 0 && + integer_trail_.UpperBound(div_) >= 0 && !PropagateUpperBounds(num, denom, div_)) { return false; } - if (integer_trail_->UpperBound(negated_num) >= 0 && - integer_trail_->UpperBound(negated_div_) >= 0 && + if (integer_trail_.UpperBound(negated_num) >= 0 && + integer_trail_.UpperBound(negated_div_) >= 0 && !PropagateUpperBounds(negated_num, denom, negated_div_)) { return false; } - if (integer_trail_->LowerBound(num) >= 0 && - integer_trail_->LowerBound(div_) >= 0) { + if (integer_trail_.LowerBound(num) >= 0 && + integer_trail_.LowerBound(div_) >= 0) { return PropagatePositiveDomains(num, denom, div_); } - if (integer_trail_->LowerBound(negated_num) >= 0 && - integer_trail_->LowerBound(negated_div_) >= 0) { + if (integer_trail_.LowerBound(negated_num) >= 0 && + integer_trail_.LowerBound(negated_div_) >= 0) { return PropagatePositiveDomains(negated_num, denom, negated_div_); } @@ -1275,15 +1395,15 @@ bool DivisionPropagator::Propagate() { bool DivisionPropagator::PropagateSigns(AffineExpression num, AffineExpression denom, AffineExpression div) { - const IntegerValue min_num = integer_trail_->LowerBound(num); - const IntegerValue max_num = integer_trail_->UpperBound(num); - const IntegerValue min_div = integer_trail_->LowerBound(div); - const IntegerValue max_div = integer_trail_->UpperBound(div); + const IntegerValue min_num = integer_trail_.LowerBound(num); + const IntegerValue max_num = integer_trail_.UpperBound(num); + const IntegerValue min_div = integer_trail_.LowerBound(div); + const IntegerValue max_div = integer_trail_.UpperBound(div); // If num >= 0, as denom > 0, then div must be >= 0. if (min_num >= 0 && min_div < 0) { - if (!integer_trail_->SafeEnqueue( - div.GreaterOrEqual(0), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, div.GreaterOrEqual(0), {num.GreaterOrEqual(0), denom.GreaterOrEqual(1)})) { return false; } @@ -1291,8 +1411,8 @@ bool DivisionPropagator::PropagateSigns(AffineExpression num, // If div > 0, as denom > 0, then num must be > 0. if (min_num <= 0 && min_div > 0) { - if (!integer_trail_->SafeEnqueue( - num.GreaterOrEqual(1), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, num.GreaterOrEqual(1), {div.GreaterOrEqual(1), denom.GreaterOrEqual(1)})) { return false; } @@ -1300,8 +1420,8 @@ bool DivisionPropagator::PropagateSigns(AffineExpression num, // If num <= 0, as denom > 0, then div must be <= 0. if (max_num <= 0 && max_div > 0) { - if (!integer_trail_->SafeEnqueue( - div.LowerOrEqual(0), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, div.LowerOrEqual(0), {num.LowerOrEqual(0), denom.GreaterOrEqual(1)})) { return false; } @@ -1309,8 +1429,8 @@ bool DivisionPropagator::PropagateSigns(AffineExpression num, // If div < 0, as denom > 0, then num must be < 0. if (max_num >= 0 && max_div < 0) { - if (!integer_trail_->SafeEnqueue( - num.LowerOrEqual(-1), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, num.LowerOrEqual(-1), {div.LowerOrEqual(-1), denom.GreaterOrEqual(1)})) { return false; } @@ -1322,15 +1442,15 @@ bool DivisionPropagator::PropagateSigns(AffineExpression num, bool DivisionPropagator::PropagateUpperBounds(AffineExpression num, AffineExpression denom, AffineExpression div) { - const IntegerValue max_num = integer_trail_->UpperBound(num); - const IntegerValue min_denom = integer_trail_->LowerBound(denom); - const IntegerValue max_denom = integer_trail_->UpperBound(denom); - const IntegerValue max_div = integer_trail_->UpperBound(div); + const IntegerValue max_num = integer_trail_.UpperBound(num); + const IntegerValue min_denom = integer_trail_.LowerBound(denom); + const IntegerValue max_denom = integer_trail_.UpperBound(denom); + const IntegerValue max_div = integer_trail_.UpperBound(div); const IntegerValue new_max_div = max_num / min_denom; if (max_div > new_max_div) { - if (!integer_trail_->SafeEnqueue( - div.LowerOrEqual(new_max_div), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, div.LowerOrEqual(new_max_div), {num.LowerOrEqual(max_num), denom.GreaterOrEqual(min_denom)})) { return false; } @@ -1342,8 +1462,8 @@ bool DivisionPropagator::PropagateUpperBounds(AffineExpression num, const IntegerValue new_max_num = CapAddI(CapProdI(max_div + 1, max_denom), -1); if (max_num > new_max_num) { - if (!integer_trail_->SafeEnqueue( - num.LowerOrEqual(new_max_num), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, num.LowerOrEqual(new_max_num), {denom.LowerOrEqual(max_denom), denom.GreaterOrEqual(1), div.LowerOrEqual(max_div)})) { return false; @@ -1356,17 +1476,17 @@ bool DivisionPropagator::PropagateUpperBounds(AffineExpression num, bool DivisionPropagator::PropagatePositiveDomains(AffineExpression num, AffineExpression denom, AffineExpression div) { - const IntegerValue min_num = integer_trail_->LowerBound(num); - const IntegerValue max_num = integer_trail_->UpperBound(num); - const IntegerValue min_denom = integer_trail_->LowerBound(denom); - const IntegerValue max_denom = integer_trail_->UpperBound(denom); - const IntegerValue min_div = integer_trail_->LowerBound(div); - const IntegerValue max_div = integer_trail_->UpperBound(div); + const IntegerValue min_num = integer_trail_.LowerBound(num); + const IntegerValue max_num = integer_trail_.UpperBound(num); + const IntegerValue min_denom = integer_trail_.LowerBound(denom); + const IntegerValue max_denom = integer_trail_.UpperBound(denom); + const IntegerValue min_div = integer_trail_.LowerBound(div); + const IntegerValue max_div = integer_trail_.UpperBound(div); const IntegerValue new_min_div = min_num / max_denom; if (min_div < new_min_div) { - if (!integer_trail_->SafeEnqueue( - div.GreaterOrEqual(new_min_div), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, div.GreaterOrEqual(new_min_div), {num.GreaterOrEqual(min_num), denom.LowerOrEqual(max_denom), denom.GreaterOrEqual(1)})) { return false; @@ -1378,8 +1498,8 @@ bool DivisionPropagator::PropagatePositiveDomains(AffineExpression num, // num >= min_div * min_denom. const IntegerValue new_min_num = CapProdI(min_denom, min_div); if (min_num < new_min_num) { - if (!integer_trail_->SafeEnqueue( - num.GreaterOrEqual(new_min_num), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, num.GreaterOrEqual(new_min_num), {denom.GreaterOrEqual(min_denom), div.GreaterOrEqual(min_div)})) { return false; } @@ -1392,8 +1512,8 @@ bool DivisionPropagator::PropagatePositiveDomains(AffineExpression num, if (min_div > 0) { const IntegerValue new_max_denom = max_num / min_div; if (max_denom > new_max_denom) { - if (!integer_trail_->SafeEnqueue( - denom.LowerOrEqual(new_max_denom), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, denom.LowerOrEqual(new_max_denom), {num.LowerOrEqual(max_num), num.GreaterOrEqual(0), div.GreaterOrEqual(min_div), denom.GreaterOrEqual(1)})) { return false; @@ -1405,8 +1525,8 @@ bool DivisionPropagator::PropagatePositiveDomains(AffineExpression num, // >= CeilRatio(min_num + 1, max_div + 1). const IntegerValue new_min_denom = CeilRatio(min_num + 1, max_div + 1); if (min_denom < new_min_denom) { - if (!integer_trail_->SafeEnqueue( - denom.GreaterOrEqual(new_min_denom), + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, denom.GreaterOrEqual(new_min_denom), {num.GreaterOrEqual(min_num), div.LowerOrEqual(max_div), div.GreaterOrEqual(0), denom.GreaterOrEqual(1)})) { return false; @@ -1416,60 +1536,89 @@ bool DivisionPropagator::PropagatePositiveDomains(AffineExpression num, return true; } -void DivisionPropagator::RegisterWith(GenericLiteralWatcher* watcher) { +int DivisionPropagator::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); watcher->WatchAffineExpression(num_, id); watcher->WatchAffineExpression(denom_, id); watcher->WatchAffineExpression(div_, id); watcher->NotifyThatPropagatorMayNotReachFixedPointInOnePass(id); + return id; } -FixedDivisionPropagator::FixedDivisionPropagator(AffineExpression a, - IntegerValue b, - AffineExpression c, - IntegerTrail* integer_trail) - : a_(a), b_(b), c_(c), integer_trail_(integer_trail) { +FixedDivisionPropagator::FixedDivisionPropagator( + absl::Span enforcement_literals, AffineExpression a, + IntegerValue b, AffineExpression c, Model* model) + : a_(a), + b_(b), + c_(c), + integer_trail_(*model->GetOrCreate()), + enforcement_propagator_(*model->GetOrCreate()) { + GenericLiteralWatcher* watcher = model->GetOrCreate(); + enforcement_id_ = enforcement_propagator_.Register( + enforcement_literals, watcher, RegisterWith(watcher)); CHECK_GT(b_, 0); } bool FixedDivisionPropagator::Propagate() { - const IntegerValue min_a = integer_trail_->LowerBound(a_); - const IntegerValue max_a = integer_trail_->UpperBound(a_); - IntegerValue min_c = integer_trail_->LowerBound(c_); - IntegerValue max_c = integer_trail_->UpperBound(c_); + const IntegerValue min_a = integer_trail_.LowerBound(a_); + const IntegerValue max_a = integer_trail_.UpperBound(a_); + IntegerValue min_c = integer_trail_.LowerBound(c_); + IntegerValue max_c = integer_trail_.UpperBound(c_); + const EnforcementStatus status = + enforcement_propagator_.Status(enforcement_id_); + if (status == EnforcementStatus::CAN_PROPAGATE) { + // If the bounds of a / b and c are disjoint, the enforcement must be false. + // TODO(user): relax the reason in a better way. + if (min_a / b_ > max_c) { + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, + /*literal_reason=*/{}, + {a_.GreaterOrEqual(max_c * b_ + 1), c_.LowerOrEqual(max_c)}); + } + if (max_a / b_ < min_c) { + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, + /*literal_reason=*/{}, + {a_.LowerOrEqual(min_c * b_ - 1), c_.GreaterOrEqual(min_c)}); + } + // Otherwise we cannot propagate anything since the enforcement is unknown. + return true; + } + + if (status != EnforcementStatus::IS_ENFORCED) return true; if (max_a / b_ < max_c) { max_c = max_a / b_; - if (!integer_trail_->SafeEnqueue( - c_.LowerOrEqual(max_c), - {integer_trail_->UpperBoundAsLiteral(a_)})) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, c_.LowerOrEqual(max_c), + {integer_trail_.UpperBoundAsLiteral(a_)})) { return false; } } else if (max_a / b_ > max_c) { const IntegerValue new_max_a = max_c >= 0 ? max_c * b_ + b_ - 1 : CapProdI(max_c, b_); CHECK_LT(new_max_a, max_a); - if (!integer_trail_->SafeEnqueue( - a_.LowerOrEqual(new_max_a), - {integer_trail_->UpperBoundAsLiteral(c_)})) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, a_.LowerOrEqual(new_max_a), + {integer_trail_.UpperBoundAsLiteral(c_)})) { return false; } } if (min_a / b_ > min_c) { min_c = min_a / b_; - if (!integer_trail_->SafeEnqueue( - c_.GreaterOrEqual(min_c), - {integer_trail_->LowerBoundAsLiteral(a_)})) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, c_.GreaterOrEqual(min_c), + {integer_trail_.LowerBoundAsLiteral(a_)})) { return false; } } else if (min_a / b_ < min_c) { const IntegerValue new_min_a = min_c > 0 ? CapProdI(min_c, b_) : min_c * b_ - b_ + 1; CHECK_GT(new_min_a, min_a); - if (!integer_trail_->SafeEnqueue( - a_.GreaterOrEqual(new_min_a), - {integer_trail_->LowerBoundAsLiteral(c_)})) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, a_.GreaterOrEqual(new_min_a), + {integer_trail_.LowerBoundAsLiteral(c_)})) { return false; } } @@ -1477,29 +1626,66 @@ bool FixedDivisionPropagator::Propagate() { return true; } -void FixedDivisionPropagator::RegisterWith(GenericLiteralWatcher* watcher) { +int FixedDivisionPropagator::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); watcher->WatchAffineExpression(a_, id); watcher->WatchAffineExpression(c_, id); + return id; } -FixedModuloPropagator::FixedModuloPropagator(AffineExpression expr, - IntegerValue mod, - AffineExpression target, - IntegerTrail* integer_trail) - : expr_(expr), mod_(mod), target_(target), integer_trail_(integer_trail) { +FixedModuloPropagator::FixedModuloPropagator( + absl::Span enforcement_literals, AffineExpression expr, + IntegerValue mod, AffineExpression target, Model* model) + : expr_(expr), + mod_(mod), + target_(target), + negated_expr_(expr.Negated()), + negated_target_(target.Negated()), + integer_trail_(*model->GetOrCreate()), + enforcement_propagator_(*model->GetOrCreate()) { CHECK_GT(mod_, 0); + GenericLiteralWatcher* watcher = model->GetOrCreate(); + enforcement_id_ = enforcement_propagator_.Register( + enforcement_literals, watcher, RegisterWith(watcher)); } bool FixedModuloPropagator::Propagate() { + const EnforcementStatus status = + enforcement_propagator_.Status(enforcement_id_); + if (status == EnforcementStatus::CAN_PROPAGATE) { + const IntegerValue min_target = integer_trail_.LowerBound(target_); + const IntegerValue max_target = integer_trail_.UpperBound(target_); + if (min_target >= mod_) { + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, /*literal_reason=*/{}, + {target_.GreaterOrEqual(mod_)}); + } else if (max_target <= -mod_) { + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, /*literal_reason=*/{}, + {target_.LowerOrEqual(-mod_)}); + } + if (min_target > 0) { + if (!PropagateWhenFalseAndTargetIsPositive(expr_, target_)) return false; + } else if (max_target < 0) { + if (!PropagateWhenFalseAndTargetIsPositive(negated_expr_, + negated_target_)) { + return false; + } + } else if (!PropagateWhenFalseAndTargetDomainContainsZero()) { + return false; + } + // Otherwise we cannot propagate anything since the enforcement is unknown. + return true; + } + + if (status != EnforcementStatus::IS_ENFORCED) return true; if (!PropagateSignsAndTargetRange()) return false; if (!PropagateOuterBounds()) return false; - if (integer_trail_->LowerBound(expr_) >= 0) { - if (!PropagateBoundsWhenExprIsPositive(expr_, target_)) return false; - } else if (integer_trail_->UpperBound(expr_) <= 0) { - if (!PropagateBoundsWhenExprIsPositive(expr_.Negated(), - target_.Negated())) { + if (integer_trail_.LowerBound(expr_) >= 0) { + if (!PropagateBoundsWhenExprIsNonNegative(expr_, target_)) return false; + } else if (integer_trail_.UpperBound(expr_) <= 0) { + if (!PropagateBoundsWhenExprIsNonNegative(negated_expr_, negated_target_)) { return false; } } @@ -1507,53 +1693,135 @@ bool FixedModuloPropagator::Propagate() { return true; } +bool FixedModuloPropagator::PropagateWhenFalseAndTargetIsPositive( + AffineExpression expr, AffineExpression target) { + const IntegerValue min_expr = integer_trail_.LowerBound(expr); + const IntegerValue max_expr = integer_trail_.UpperBound(expr); + // expr % mod_ must be in the target domain intersected with [0, mod_ - 1], + // noted [min_expr_mod, max_expr_mod]. This interval is non-empty. + const IntegerValue min_expr_mod = + std::max(IntegerValue(0), integer_trail_.LowerBound(target)); + const IntegerValue max_expr_mod = + std::min(mod_ - 1, integer_trail_.UpperBound(target)); + // expr must be in [min_expr_mod + k * mod_, max_expr_mod + k * mod_], for + // some k >= 0. If the expr domain is in one of the following intervals, the + // constraint is always false: + // - ]-infinity, min_expr_mod[ + // - ]max_expr_mod + k * mod_ , min_expr_mod + (k + 1) * mod_[ + if (max_expr < min_expr_mod) { + // TODO(user): relax the reason in a better way. + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, /*literal_reason=*/{}, + {expr.LowerOrEqual(min_expr_mod - 1), + target.GreaterOrEqual(min_expr_mod)}); + } + // Compute the smallest k such that max_expr < min_expr_mod + (k + 1) * mod_. + const IntegerValue k = MathUtil::FloorOfRatio(max_expr - min_expr_mod, mod_); + if (min_expr > max_expr_mod + k * mod_) { + // TODO(user): relax the reason in a better way. + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, /*literal_reason=*/{}, + {expr.GreaterOrEqual(max_expr_mod + k * mod_ + 1), + expr.LowerOrEqual(min_expr_mod + (k + 1) * mod_ - 1), + target.GreaterOrEqual(min_expr_mod), + target.LowerOrEqual(max_expr_mod)}); + } + return true; +} + +bool FixedModuloPropagator::PropagateWhenFalseAndTargetDomainContainsZero() { + const IntegerValue neg_max_expr_mod = + std::max(-mod_ + 1, integer_trail_.LowerBound(target_)); + const IntegerValue pos_max_expr_mod = + std::min(mod_ - 1, integer_trail_.UpperBound(target_)); + // expr must be in [k * mod_, pos_max_expr_mod + k * mod_] or in + // [neg_max_expr_mod - k * mod_, -k * mod_] for some k >= 0. If the expr + // domain is in one of the following intervals, the constraint is always + // false: + // - ]-(k + 1) * mod_, neg_max_expr_mod - k * mod_[ + // - ]pos_max_expr_mod + k * mod_ , (k + 1) * mod_[ + const IntegerValue min_expr = integer_trail_.LowerBound(expr_); + const IntegerValue max_expr = integer_trail_.UpperBound(expr_); + // Compute the smallest k such that max_expr < (k + 1) * mod_. + IntegerValue k = MathUtil::FloorOfRatio(max_expr, mod_); + if (k >= 0 && min_expr > pos_max_expr_mod + k * mod_) { + const IntegerValue min_target = integer_trail_.LowerBound(target_); + // TODO(user): relax the reason in a better way. + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, /*literal_reason=*/{}, + {expr_.GreaterOrEqual(pos_max_expr_mod + k * mod_ + 1), + expr_.LowerOrEqual((k + 1) * mod_ - 1), + target_.GreaterOrEqual(min_target), + target_.LowerOrEqual(pos_max_expr_mod)}); + } + // Compute the smallest k such that min_expr > -(k + 1) * mod_. + k = MathUtil::FloorOfRatio(-min_expr, mod_); + if (k >= 0 && max_expr < neg_max_expr_mod - k * mod_) { + const IntegerValue max_target = integer_trail_.UpperBound(target_); + // TODO(user): relax the reason in a better way. + return enforcement_propagator_.PropagateWhenFalse( + enforcement_id_, /*literal_reason=*/{}, + {expr_.GreaterOrEqual(-(k + 1) * mod_ + 1), + expr_.LowerOrEqual(neg_max_expr_mod - k * mod_ - 1), + target_.GreaterOrEqual(neg_max_expr_mod), + target_.LowerOrEqual(max_target)}); + } + return true; +} + bool FixedModuloPropagator::PropagateSignsAndTargetRange() { // Initial domain reduction on the target. - if (integer_trail_->UpperBound(target_) >= mod_) { - if (!integer_trail_->SafeEnqueue(target_.LowerOrEqual(mod_ - 1), {})) { + if (integer_trail_.UpperBound(target_) >= mod_) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, target_.LowerOrEqual(mod_ - 1), {})) { return false; } } - if (integer_trail_->LowerBound(target_) <= -mod_) { - if (!integer_trail_->SafeEnqueue(target_.GreaterOrEqual(1 - mod_), {})) { + if (integer_trail_.LowerBound(target_) <= -mod_) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, target_.GreaterOrEqual(1 - mod_), {})) { return false; } } // The sign of target_ is fixed by the sign of expr_. - if (integer_trail_->LowerBound(expr_) >= 0 && - integer_trail_->LowerBound(target_) < 0) { + if (integer_trail_.LowerBound(expr_) >= 0 && + integer_trail_.LowerBound(target_) < 0) { // expr >= 0 => target >= 0. - if (!integer_trail_->SafeEnqueue(target_.GreaterOrEqual(0), - {expr_.GreaterOrEqual(0)})) { + if (!enforcement_propagator_.SafeEnqueue(enforcement_id_, + target_.GreaterOrEqual(0), + {expr_.GreaterOrEqual(0)})) { return false; } } - if (integer_trail_->UpperBound(expr_) <= 0 && - integer_trail_->UpperBound(target_) > 0) { + if (integer_trail_.UpperBound(expr_) <= 0 && + integer_trail_.UpperBound(target_) > 0) { // expr <= 0 => target <= 0. - if (!integer_trail_->SafeEnqueue(target_.LowerOrEqual(0), - {expr_.LowerOrEqual(0)})) { + if (!enforcement_propagator_.SafeEnqueue(enforcement_id_, + target_.LowerOrEqual(0), + {expr_.LowerOrEqual(0)})) { return false; } } - if (integer_trail_->LowerBound(target_) > 0 && - integer_trail_->LowerBound(expr_) <= 0) { + if (integer_trail_.LowerBound(target_) > 0 && + integer_trail_.LowerBound(expr_) <= 0) { // target > 0 => expr > 0. - if (!integer_trail_->SafeEnqueue(expr_.GreaterOrEqual(1), - {target_.GreaterOrEqual(1)})) { + if (!enforcement_propagator_.SafeEnqueue(enforcement_id_, + expr_.GreaterOrEqual(1), + {target_.GreaterOrEqual(1)})) { return false; } } - if (integer_trail_->UpperBound(target_) < 0 && - integer_trail_->UpperBound(expr_) >= 0) { + if (integer_trail_.UpperBound(target_) < 0 && + integer_trail_.UpperBound(expr_) >= 0) { // target < 0 => expr < 0. - if (!integer_trail_->SafeEnqueue(expr_.LowerOrEqual(-1), - {target_.LowerOrEqual(-1)})) { + if (!enforcement_propagator_.SafeEnqueue(enforcement_id_, + expr_.LowerOrEqual(-1), + {target_.LowerOrEqual(-1)})) { return false; } } @@ -1562,68 +1830,72 @@ bool FixedModuloPropagator::PropagateSignsAndTargetRange() { } bool FixedModuloPropagator::PropagateOuterBounds() { - const IntegerValue min_expr = integer_trail_->LowerBound(expr_); - const IntegerValue max_expr = integer_trail_->UpperBound(expr_); - const IntegerValue min_target = integer_trail_->LowerBound(target_); - const IntegerValue max_target = integer_trail_->UpperBound(target_); + const IntegerValue min_expr = integer_trail_.LowerBound(expr_); + const IntegerValue max_expr = integer_trail_.UpperBound(expr_); + const IntegerValue min_target = integer_trail_.LowerBound(target_); + const IntegerValue max_target = integer_trail_.UpperBound(target_); if (max_expr % mod_ > max_target) { - if (!integer_trail_->SafeEnqueue( + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, expr_.LowerOrEqual((max_expr / mod_) * mod_ + max_target), - {integer_trail_->UpperBoundAsLiteral(target_), - integer_trail_->UpperBoundAsLiteral(expr_)})) { + {integer_trail_.UpperBoundAsLiteral(target_), + integer_trail_.UpperBoundAsLiteral(expr_)})) { return false; } } if (min_expr % mod_ < min_target) { - if (!integer_trail_->SafeEnqueue( + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, expr_.GreaterOrEqual((min_expr / mod_) * mod_ + min_target), - {integer_trail_->LowerBoundAsLiteral(expr_), - integer_trail_->LowerBoundAsLiteral(target_)})) { + {integer_trail_.LowerBoundAsLiteral(expr_), + integer_trail_.LowerBoundAsLiteral(target_)})) { return false; } } if (min_expr / mod_ == max_expr / mod_) { if (min_target < min_expr % mod_) { - if (!integer_trail_->SafeEnqueue( + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, target_.GreaterOrEqual(min_expr - (min_expr / mod_) * mod_), - {integer_trail_->LowerBoundAsLiteral(target_), - integer_trail_->UpperBoundAsLiteral(target_), - integer_trail_->LowerBoundAsLiteral(expr_), - integer_trail_->UpperBoundAsLiteral(expr_)})) { + {integer_trail_.LowerBoundAsLiteral(target_), + integer_trail_.UpperBoundAsLiteral(target_), + integer_trail_.LowerBoundAsLiteral(expr_), + integer_trail_.UpperBoundAsLiteral(expr_)})) { return false; } } if (max_target > max_expr % mod_) { - if (!integer_trail_->SafeEnqueue( + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, target_.LowerOrEqual(max_expr - (max_expr / mod_) * mod_), - {integer_trail_->LowerBoundAsLiteral(target_), - integer_trail_->UpperBoundAsLiteral(target_), - integer_trail_->LowerBoundAsLiteral(expr_), - integer_trail_->UpperBoundAsLiteral(expr_)})) { + {integer_trail_.LowerBoundAsLiteral(target_), + integer_trail_.UpperBoundAsLiteral(target_), + integer_trail_.LowerBoundAsLiteral(expr_), + integer_trail_.UpperBoundAsLiteral(expr_)})) { return false; } } } else if (min_expr / mod_ == 0 && min_target < 0) { // expr == target when expr <= 0. if (min_target < min_expr) { - if (!integer_trail_->SafeEnqueue( - target_.GreaterOrEqual(min_expr), - {integer_trail_->LowerBoundAsLiteral(target_), - integer_trail_->LowerBoundAsLiteral(expr_)})) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, target_.GreaterOrEqual(min_expr), + {integer_trail_.LowerBoundAsLiteral(target_), + integer_trail_.LowerBoundAsLiteral(expr_)})) { return false; } } } else if (max_expr / mod_ == 0 && max_target > 0) { // expr == target when expr >= 0. if (max_target > max_expr) { - if (!integer_trail_->SafeEnqueue( - target_.LowerOrEqual(max_expr), - {integer_trail_->UpperBoundAsLiteral(target_), - integer_trail_->UpperBoundAsLiteral(expr_)})) { + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, target_.LowerOrEqual(max_expr), + {integer_trail_.UpperBoundAsLiteral(target_), + integer_trail_.UpperBoundAsLiteral(expr_)})) { return false; } } @@ -1632,37 +1904,39 @@ bool FixedModuloPropagator::PropagateOuterBounds() { return true; } -bool FixedModuloPropagator::PropagateBoundsWhenExprIsPositive( +bool FixedModuloPropagator::PropagateBoundsWhenExprIsNonNegative( AffineExpression expr, AffineExpression target) { - const IntegerValue min_target = integer_trail_->LowerBound(target); + const IntegerValue min_target = integer_trail_.LowerBound(target); DCHECK_GE(min_target, 0); - const IntegerValue max_target = integer_trail_->UpperBound(target); + const IntegerValue max_target = integer_trail_.UpperBound(target); // The propagation rules below will not be triggered if the domain of target // covers [0..mod_ - 1]. if (min_target == 0 && max_target == mod_ - 1) return true; - const IntegerValue min_expr = integer_trail_->LowerBound(expr); - const IntegerValue max_expr = integer_trail_->UpperBound(expr); + const IntegerValue min_expr = integer_trail_.LowerBound(expr); + const IntegerValue max_expr = integer_trail_.UpperBound(expr); if (max_expr % mod_ < min_target) { DCHECK_GE(max_expr, 0); - if (!integer_trail_->SafeEnqueue( + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, expr.LowerOrEqual((max_expr / mod_ - 1) * mod_ + max_target), - {integer_trail_->UpperBoundAsLiteral(expr), - integer_trail_->LowerBoundAsLiteral(target), - integer_trail_->UpperBoundAsLiteral(target)})) { + {integer_trail_.UpperBoundAsLiteral(expr), + integer_trail_.LowerBoundAsLiteral(target), + integer_trail_.UpperBoundAsLiteral(target)})) { return false; } } if (min_expr % mod_ > max_target) { DCHECK_GE(min_expr, 0); - if (!integer_trail_->SafeEnqueue( + if (!enforcement_propagator_.SafeEnqueue( + enforcement_id_, expr.GreaterOrEqual((min_expr / mod_ + 1) * mod_ + min_target), - {integer_trail_->LowerBoundAsLiteral(target), - integer_trail_->UpperBoundAsLiteral(target), - integer_trail_->LowerBoundAsLiteral(expr)})) { + {integer_trail_.LowerBoundAsLiteral(target), + integer_trail_.UpperBoundAsLiteral(target), + integer_trail_.LowerBoundAsLiteral(expr)})) { return false; } } @@ -1670,11 +1944,12 @@ bool FixedModuloPropagator::PropagateBoundsWhenExprIsPositive( return true; } -void FixedModuloPropagator::RegisterWith(GenericLiteralWatcher* watcher) { +int FixedModuloPropagator::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); watcher->WatchAffineExpression(expr_, id); watcher->WatchAffineExpression(target_, id); watcher->NotifyThatPropagatorMayNotReachFixedPointInOnePass(id); + return id; } } // namespace sat diff --git a/ortools/sat/integer_expr.h b/ortools/sat/integer_expr.h index c3871ea0b0..e95fa9d50a 100644 --- a/ortools/sat/integer_expr.h +++ b/ortools/sat/integer_expr.h @@ -22,15 +22,17 @@ #include #include +#include "absl/base/attributes.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/types/span.h" +#include "ortools/sat/cp_constraints.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/linear_constraint.h" #include "ortools/sat/linear_propagation.h" #include "ortools/sat/model.h" -#include "ortools/sat/precedences.h" +#include "ortools/sat/old_precedences_propagator.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/sat_solver.h" @@ -283,17 +285,19 @@ class LinMinPropagator : public PropagatorInterface, LazyReasonInterface { // the bounds on p as this require more complex arithmetics. class ProductPropagator : public PropagatorInterface { public: - ProductPropagator(AffineExpression a, AffineExpression b, AffineExpression p, - IntegerTrail* integer_trail); + ProductPropagator(absl::Span enforcement_literals, + AffineExpression a, AffineExpression b, AffineExpression p, + Model* model); // This type is neither copyable nor movable. ProductPropagator(const ProductPropagator&) = delete; ProductPropagator& operator=(const ProductPropagator&) = delete; bool Propagate() final; - void RegisterWith(GenericLiteralWatcher* watcher); private: + int RegisterWith(GenericLiteralWatcher* watcher); + // Maybe replace a_, b_ or c_ by their negation to simplify the cases. bool CanonicalizeCases(); @@ -310,8 +314,9 @@ class ProductPropagator : public PropagatorInterface { AffineExpression a_; AffineExpression b_; AffineExpression p_; - - IntegerTrail* integer_trail_; + const IntegerTrail& integer_trail_; + EnforcementPropagator& enforcement_propagator_; + EnforcementId enforcement_id_; }; // Propagates num / denom = div. Basic version, we don't extract any special @@ -320,17 +325,19 @@ class ProductPropagator : public PropagatorInterface { // TODO(user): Deal with overflow. class DivisionPropagator : public PropagatorInterface { public: - DivisionPropagator(AffineExpression num, AffineExpression denom, - AffineExpression div, IntegerTrail* integer_trail); + DivisionPropagator(absl::Span enforcement_literals, + AffineExpression num, AffineExpression denom, + AffineExpression div, Model* model); // This type is neither copyable nor movable. DivisionPropagator(const DivisionPropagator&) = delete; DivisionPropagator& operator=(const DivisionPropagator&) = delete; bool Propagate() final; - void RegisterWith(GenericLiteralWatcher* watcher); private: + int RegisterWith(GenericLiteralWatcher* watcher); + // Propagates the fact that the signs of each domain, if fixed, are // compatible. bool PropagateSigns(AffineExpression num, AffineExpression denom, @@ -353,49 +360,59 @@ class DivisionPropagator : public PropagatorInterface { const AffineExpression negated_denom_; const AffineExpression negated_num_; const AffineExpression negated_div_; - IntegerTrail* integer_trail_; + const IntegerTrail& integer_trail_; + EnforcementPropagator& enforcement_propagator_; + EnforcementId enforcement_id_; }; // Propagates var_a / cst_b = var_c. Basic version, we don't extract any special // cases, and we only propagates the bounds. cst_b must be > 0. class FixedDivisionPropagator : public PropagatorInterface { public: - FixedDivisionPropagator(AffineExpression a, IntegerValue b, - AffineExpression c, IntegerTrail* integer_trail); + FixedDivisionPropagator(absl::Span enforcement_literals, + AffineExpression a, IntegerValue b, + AffineExpression c, Model* model); // This type is neither copyable nor movable. FixedDivisionPropagator(const FixedDivisionPropagator&) = delete; FixedDivisionPropagator& operator=(const FixedDivisionPropagator&) = delete; bool Propagate() final; - void RegisterWith(GenericLiteralWatcher* watcher); private: + int RegisterWith(GenericLiteralWatcher* watcher); + const AffineExpression a_; const IntegerValue b_; const AffineExpression c_; - - IntegerTrail* integer_trail_; + const IntegerTrail& integer_trail_; + EnforcementPropagator& enforcement_propagator_; + EnforcementId enforcement_id_; }; // Propagates target == expr % mod. Basic version, we don't extract any special // cases, and we only propagates the bounds. mod must be > 0. class FixedModuloPropagator : public PropagatorInterface { public: - FixedModuloPropagator(AffineExpression expr, IntegerValue mod, - AffineExpression target, IntegerTrail* integer_trail); + FixedModuloPropagator(absl::Span enforcement_literals, + AffineExpression expr, IntegerValue mod, + AffineExpression target, Model* model); // This type is neither copyable nor movable. FixedModuloPropagator(const FixedModuloPropagator&) = delete; FixedModuloPropagator& operator=(const FixedModuloPropagator&) = delete; bool Propagate() final; - void RegisterWith(GenericLiteralWatcher* watcher); private: + int RegisterWith(GenericLiteralWatcher* watcher); + + bool PropagateWhenFalseAndTargetIsPositive(AffineExpression expr, + AffineExpression target); + bool PropagateWhenFalseAndTargetDomainContainsZero(); bool PropagateSignsAndTargetRange(); - bool PropagateBoundsWhenExprIsPositive(AffineExpression expr, - AffineExpression target); + bool PropagateBoundsWhenExprIsNonNegative(AffineExpression expr, + AffineExpression target); bool PropagateOuterBounds(); const AffineExpression expr_; @@ -403,27 +420,32 @@ class FixedModuloPropagator : public PropagatorInterface { const AffineExpression target_; const AffineExpression negated_expr_; const AffineExpression negated_target_; - IntegerTrail* integer_trail_; + const IntegerTrail& integer_trail_; + EnforcementPropagator& enforcement_propagator_; + EnforcementId enforcement_id_; }; // Propagates x * x = s. // TODO(user): Only works for x nonnegative. class SquarePropagator : public PropagatorInterface { public: - SquarePropagator(AffineExpression x, AffineExpression s, - IntegerTrail* integer_trail); + SquarePropagator(absl::Span enforcement_literals, + AffineExpression x, AffineExpression s, Model* model); // This type is neither copyable nor movable. SquarePropagator(const SquarePropagator&) = delete; SquarePropagator& operator=(const SquarePropagator&) = delete; bool Propagate() final; - void RegisterWith(GenericLiteralWatcher* watcher); private: + int RegisterWith(GenericLiteralWatcher* watcher); + const AffineExpression x_; const AffineExpression s_; - IntegerTrail* integer_trail_; + const IntegerTrail& integer_trail_; + EnforcementPropagator& enforcement_propagator_; + EnforcementId enforcement_id_; }; // ============================================================================= @@ -757,77 +779,71 @@ inline std::function IsEqualToMinOf( return [&](Model* model) { AddIsEqualToMinOf(min_expr, exprs, model); }; } -template -void RegisterAndTransferOwnership(Model* model, T* ct) { - ct->RegisterWith(model->GetOrCreate()); - model->TakeOwnership(ct); -} // Adds the constraint: a * b = p. -inline std::function ProductConstraint(AffineExpression a, - AffineExpression b, - AffineExpression p) { +inline std::function ProductConstraint( + absl::Span enforcement_literals, AffineExpression a, + AffineExpression b, AffineExpression p) { return [=](Model* model) { - IntegerTrail* integer_trail = model->GetOrCreate(); + const IntegerTrail& integer_trail = *model->GetOrCreate(); + // TODO(user): return early if constraint is never enforced. if (a == b) { - if (integer_trail->LowerBound(a) >= 0) { - RegisterAndTransferOwnership(model, - new SquarePropagator(a, p, integer_trail)); + if (integer_trail.LowerBound(a) >= 0) { + model->TakeOwnership( + new SquarePropagator(enforcement_literals, a, p, model)); return; } - if (integer_trail->UpperBound(a) <= 0) { - RegisterAndTransferOwnership( - model, new SquarePropagator(a.Negated(), p, integer_trail)); + if (integer_trail.UpperBound(a) <= 0) { + model->TakeOwnership( + new SquarePropagator(enforcement_literals, a.Negated(), p, model)); return; } } - RegisterAndTransferOwnership(model, - new ProductPropagator(a, b, p, integer_trail)); + model->TakeOwnership( + new ProductPropagator(enforcement_literals, a, b, p, model)); }; } // Adds the constraint: num / denom = div. (denom > 0). -inline std::function DivisionConstraint(AffineExpression num, - AffineExpression denom, - AffineExpression div) { +inline std::function DivisionConstraint( + absl::Span enforcement_literals, AffineExpression num, + AffineExpression denom, AffineExpression div) { return [=](Model* model) { - IntegerTrail* integer_trail = model->GetOrCreate(); + const IntegerTrail& integer_trail = *model->GetOrCreate(); + // TODO(user): return early if constraint is never enforced. DivisionPropagator* constraint; - if (integer_trail->UpperBound(denom) < 0) { - constraint = new DivisionPropagator(num.Negated(), denom.Negated(), div, - integer_trail); - + if (integer_trail.UpperBound(denom) < 0) { + constraint = new DivisionPropagator(enforcement_literals, num.Negated(), + denom.Negated(), div, model); } else { - constraint = new DivisionPropagator(num, denom, div, integer_trail); + constraint = + new DivisionPropagator(enforcement_literals, num, denom, div, model); } - constraint->RegisterWith(model->GetOrCreate()); model->TakeOwnership(constraint); }; } // Adds the constraint: a / b = c where b is a constant. -inline std::function FixedDivisionConstraint(AffineExpression a, - IntegerValue b, - AffineExpression c) { +inline std::function FixedDivisionConstraint( + absl::Span enforcement_literals, AffineExpression a, + IntegerValue b, AffineExpression c) { return [=](Model* model) { - IntegerTrail* integer_trail = model->GetOrCreate(); + // TODO(user): return early if constraint is never enforced. FixedDivisionPropagator* constraint = - b > 0 ? new FixedDivisionPropagator(a, b, c, integer_trail) - : new FixedDivisionPropagator(a.Negated(), -b, c, integer_trail); - constraint->RegisterWith(model->GetOrCreate()); + b > 0 + ? new FixedDivisionPropagator(enforcement_literals, a, b, c, model) + : new FixedDivisionPropagator(enforcement_literals, a.Negated(), -b, + c, model); model->TakeOwnership(constraint); }; } // Adds the constraint: a % b = c where b is a constant. -inline std::function FixedModuloConstraint(AffineExpression a, - IntegerValue b, - AffineExpression c) { +inline std::function FixedModuloConstraint( + absl::Span enforcement_literals, AffineExpression a, + IntegerValue b, AffineExpression c) { return [=](Model* model) { - IntegerTrail* integer_trail = model->GetOrCreate(); - FixedModuloPropagator* constraint = - new FixedModuloPropagator(a, b, c, integer_trail); - constraint->RegisterWith(model->GetOrCreate()); - model->TakeOwnership(constraint); + model->TakeOwnership( + new FixedModuloPropagator(enforcement_literals, a, b, c, model)); }; } diff --git a/ortools/sat/integer_expr_test.cc b/ortools/sat/integer_expr_test.cc index ecb7d255a4..b35f3fc4d5 100644 --- a/ortools/sat/integer_expr_test.cc +++ b/ortools/sat/integer_expr_test.cc @@ -625,7 +625,7 @@ TEST(ProductConstraintTest, RandomCases) { bool perfect_propagation = true; bool ok_propagation = true; - model.Add(ProductConstraint(vars[0], vars[1], vars[2])); + model.Add(ProductConstraint({}, vars[0], vars[1], vars[2])); const bool result = model.GetOrCreate()->Propagate(); if (expected_result != result) { if (expected_result) { @@ -1072,6 +1072,74 @@ TEST(ProductPropagationTest, LargeDomain) { EXPECT_EQ(solutions, expected); } +TEST(ProductPropagationTest, AlwaysFalseWithTwoEnforcementLiterals) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const Literal c = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable x = model.Add(NewIntegerVariable(0, 5)); + const IntegerVariable y = model.Add(NewIntegerVariable(0, 5)); + const IntegerVariable p = model.Add(NewIntegerVariable(50, 100)); + // Always false if enforced (x.y always less than p). + model.Add(ProductConstraint({b, c}, x, y, p)); + // Nothing should be propagated. + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_FALSE(model.GetOrCreate()->Assignment().LiteralIsAssigned(b)); + EXPECT_FALSE(model.GetOrCreate()->Assignment().LiteralIsAssigned(c)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(x, 0, 5); + EXPECT_BOUNDS_EQ(y, 0, 5); + EXPECT_BOUNDS_EQ(p, 50, 100); +} + +TEST(ProductPropagationTest, AlwaysFalseWithOneUnassignedEnforcementLiteral) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable x = model.Add(NewIntegerVariable(0, 5)); + const IntegerVariable y = model.Add(NewIntegerVariable(0, 5)); + const IntegerVariable p = model.Add(NewIntegerVariable(50, 100)); + // Always false if enforced (x.y always less than p). + model.Add(ProductConstraint({b}, x, x, p)); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_TRUE(model.GetOrCreate()->Assignment().LiteralIsFalse(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(x, 0, 5); + EXPECT_BOUNDS_EQ(y, 0, 5); + EXPECT_BOUNDS_EQ(p, 50, 100); +} + +TEST(ProductPropagationTest, AlwaysFalseWithOneUnassignedEnforcementLiteral2) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable x = model.Add(NewIntegerVariable(0, 5)); + const IntegerVariable y = model.Add(NewIntegerVariable(0, 5)); + const IntegerVariable p = model.Add(NewIntegerVariable(-100, -50)); + // Always false if enforced (x.y always greater than p). + model.Add(ProductConstraint({b}, x, x, p)); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_TRUE(model.GetOrCreate()->Assignment().LiteralIsFalse(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(x, 0, 5); + EXPECT_BOUNDS_EQ(y, 0, 5); + EXPECT_BOUNDS_EQ(p, -100, -50); +} + +TEST(ProductPropagationTest, + NotAlwaysFalseWithOneUnassignedEnforcementLiteral) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable x = model.Add(NewIntegerVariable(0, 5)); + const IntegerVariable y = model.Add(NewIntegerVariable(0, 5)); + const IntegerVariable p = model.Add(NewIntegerVariable(0, 100)); + model.Add(ProductConstraint({b}, x, y, p)); + // Nothing should be propagated. + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_FALSE(model.GetOrCreate()->Assignment().LiteralIsAssigned(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(x, 0, 5); + EXPECT_BOUNDS_EQ(y, 0, 5); + EXPECT_BOUNDS_EQ(p, 0, 100); +} + TEST(DivisionConstraintTest, CheckAllSolutions) { absl::BitGen random; const int kMaxValue = 100; @@ -1229,7 +1297,7 @@ TEST(DivisionConstraintTest, CheckAllPropagationsRandomProblem) { const IntegerVariable var_x = model.Add(NewIntegerVariable(x_min, x_max)); const IntegerVariable var_y = model.Add(NewIntegerVariable(y_min, y_max)); const IntegerVariable var_z = model.Add(NewIntegerVariable(z_min, z_max)); - model.Add(DivisionConstraint(var_x, var_y, var_z)); + model.Add(DivisionConstraint({}, var_x, var_y, var_z)); const bool result = model.GetOrCreate()->Propagate(); if (result) { EXPECT_BOUNDS_EQ(var_x, expected_x_min, expected_x_max); @@ -1241,6 +1309,54 @@ TEST(DivisionConstraintTest, CheckAllPropagationsRandomProblem) { } } +TEST(DivisionConstraintTest, AlwaysFalseWithUnassignedEnforcementLiteral) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable num = model.Add(NewIntegerVariable(3, 5)); + const IntegerVariable denom = model.Add(NewIntegerVariable(2, 3)); + const IntegerVariable div = model.Add(NewIntegerVariable(3, 5)); + // Always false if enforced (num / denom always less than div). + model.Add(DivisionConstraint({b}, num, denom, div)); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_TRUE(model.GetOrCreate()->Assignment().LiteralIsFalse(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(num, 3, 5); + EXPECT_BOUNDS_EQ(denom, 2, 3); + EXPECT_BOUNDS_EQ(div, 3, 5); +} + +TEST(DivisionConstraintTest, AlwaysFalseWithUnassignedEnforcementLiteral2) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable num = model.Add(NewIntegerVariable(3, 5)); + const IntegerVariable denom = model.Add(NewIntegerVariable(2, 3)); + const IntegerVariable div = model.Add(NewIntegerVariable(-5, -3)); + // Always false if enforced (num / denom always greater than div). + model.Add(DivisionConstraint({b}, num, denom, div)); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_TRUE(model.GetOrCreate()->Assignment().LiteralIsFalse(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(num, 3, 5); + EXPECT_BOUNDS_EQ(denom, 2, 3); + EXPECT_BOUNDS_EQ(div, -5, -3); +} + +TEST(DivisionConstraintTest, NotAlwaysFalseWithUnassignedEnforcementLiteral) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable num = model.Add(NewIntegerVariable(3, 5)); + const IntegerVariable denom = model.Add(NewIntegerVariable(2, 3)); + const IntegerVariable div = model.Add(NewIntegerVariable(1, 5)); + model.Add(DivisionConstraint({b}, num, denom, div)); + // Nothing should be propagated. + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_FALSE(model.GetOrCreate()->Assignment().LiteralIsAssigned(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(num, 3, 5); + EXPECT_BOUNDS_EQ(denom, 2, 3); + EXPECT_BOUNDS_EQ(div, 1, 5); +} + TEST(DivisionConstraintTest, CheckAllSolutionsOnExprs) { absl::BitGen random; const int kMaxValue = 30; @@ -1360,7 +1476,7 @@ void TestAllDivisionValues(int64_t min_a, int64_t max_a, int64_t b, min_c == max_c ? AffineExpression(IntegerValue(min_c)) : AffineExpression(model.Add(NewIntegerVariable(min_c, max_c))); - model.Add(FixedDivisionConstraint(var_a, IntegerValue(b), var_c)); + model.Add(FixedDivisionConstraint({}, var_a, IntegerValue(b), var_c)); const bool result = model.GetOrCreate()->Propagate(); IntegerTrail* integer_trail = model.GetOrCreate(); if (result) { @@ -1394,7 +1510,7 @@ bool PropagateFixedDivision(int64_t a, int64_t max_a, int64_t b, int64_t c, Model model; const IntegerVariable var_a = model.Add(NewIntegerVariable(a, max_a)); const IntegerVariable var_c = model.Add(NewIntegerVariable(c, max_c)); - model.Add(FixedDivisionConstraint(var_a, IntegerValue(b), var_c)); + model.Add(FixedDivisionConstraint({}, var_a, IntegerValue(b), var_c)); const bool result = model.GetOrCreate()->Propagate(); if (result) { EXPECT_BOUNDS_EQ(var_a, new_a, new_max_a); @@ -1437,6 +1553,50 @@ TEST(FixedDivisionConstraintTest, ExpectedPropagation) { /*new_c=*/3, std::numeric_limits::max() / 10)); } +TEST(FixedDivisionConstraintTest, AlwaysFalseWithUnassignedEnforcementLiteral) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable num = model.Add(NewIntegerVariable(3, 5)); + const IntegerVariable div = model.Add(NewIntegerVariable(3, 5)); + // Always false if enforced (num / denom always less than div). + model.Add(FixedDivisionConstraint({b}, num, 2, div)); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_TRUE(model.GetOrCreate()->Assignment().LiteralIsFalse(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(num, 3, 5); + EXPECT_BOUNDS_EQ(div, 3, 5); +} + +TEST(FixedDivisionConstraintTest, + AlwaysFalseWithUnassignedEnforcementLiteral2) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable num = model.Add(NewIntegerVariable(3, 5)); + const IntegerVariable div = model.Add(NewIntegerVariable(-5, -3)); + // Always false if enforced (num / denom always greater than div). + model.Add(FixedDivisionConstraint({b}, num, 2, div)); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_TRUE(model.GetOrCreate()->Assignment().LiteralIsFalse(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(num, 3, 5); + EXPECT_BOUNDS_EQ(div, -5, -3); +} + +TEST(FixedDivisionConstraintTest, + NotAlwaysFalseWithUnassignedEnforcementLiteral) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable num = model.Add(NewIntegerVariable(3, 5)); + const IntegerVariable div = model.Add(NewIntegerVariable(1, 5)); + model.Add(FixedDivisionConstraint({b}, num, 2, div)); + // Nothing should be propagated. + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_FALSE(model.GetOrCreate()->Assignment().LiteralIsAssigned(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(num, 3, 5); + EXPECT_BOUNDS_EQ(div, 1, 5); +} + TEST(ModuloConstraintTest, CheckAllSolutions) { absl::BitGen random; const int kMaxValue = 50; @@ -1531,7 +1691,7 @@ TEST(ModuloConstraintTest, CheckAllPropagationsRandomProblem) { const IntegerVariable var = model.Add(NewIntegerVariable(var_min, var_max)); const IntegerVariable target = model.Add(NewIntegerVariable(target_min, target_max)); - model.Add(FixedModuloConstraint(var, IntegerValue(mod), target)); + model.Add(FixedModuloConstraint({}, var, IntegerValue(mod), target)); const bool result = model.GetOrCreate()->Propagate(); if (result) { EXPECT_BOUNDS_EQ(var, expected_var_min, expected_var_max); @@ -1548,6 +1708,84 @@ TEST(ModuloConstraintTest, CheckAllPropagationsRandomProblem) { } } +bool TestModuloPropagationWhenFalse(int min_var, int max_var, int mod, + int min_target, int max_target) { + bool is_always_false = true; + for (int var = min_var; var <= max_var; ++var) { + for (int target = min_target; target <= max_target; ++target) { + if (var % mod == target) { + is_always_false = false; + break; + } + } + } + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable var = model.Add(NewIntegerVariable(min_var, max_var)); + const IntegerVariable target = + model.Add(NewIntegerVariable(min_target, max_target)); + model.Add(FixedModuloConstraint({b}, var, IntegerValue(mod), target)); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_EQ(model.GetOrCreate()->Assignment().LiteralIsFalse(b), + is_always_false) + << "min_var = " << min_var << " max_var = " << max_var << " mod = " << mod + << " min_target = " << min_target << " max_target = " << max_target; + EXPECT_FALSE(model.GetOrCreate()->Assignment().LiteralIsTrue(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(var, min_var, max_var); + EXPECT_BOUNDS_EQ(target, min_target, max_target); + return is_always_false; +} + +TEST(ModuloConstraintTest, CheckPropagationWhenFalse) { + bool propagated_when_false = false; + for (int min_var = -15; min_var <= 15; ++min_var) { + for (int max_var = min_var; max_var <= min_var + 5; ++max_var) { + for (int min_target = -4; min_target <= 4; ++min_target) { + for (int max_target = min_target; max_target <= 4; ++max_target) { + propagated_when_false |= TestModuloPropagationWhenFalse( + min_var, max_var, 3, min_target, max_target); + } + } + } + } + EXPECT_TRUE(propagated_when_false); +} + +TEST(ModuloConstraintTest, + CheckEnumerateAllSolutionsWithoutEnforcementLiteral) { + CpModelProto initial_model = ParseTestProto(R"pb( + variables { name: 'b' domain: 0 domain: 1 } + variables { name: 'x' domain: -10 domain: 10 } + variables { name: 'y' domain: -3 domain: 3 } + constraints { + enforcement_literal: 0 + int_mod { + target { vars: 2 coeffs: 1 } + exprs { vars: 1 coeffs: 1 } + exprs { offset: 10 } + } + } + )pb"); + absl::btree_set> solutions; + const CpSolverResponse response = + SolveAndCheck(initial_model, "", &solutions); + EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL); + + CpModelProto reference_model = initial_model; + reference_model.mutable_constraints(0)->clear_enforcement_literal(); + absl::btree_set> reference_solutions; + for (int x = -10; x <= 10; ++x) { + for (int y = -3; y <= 3; ++y) { + reference_solutions.insert({0, x, y}); + } + } + const CpSolverResponse reference_response = + SolveAndCheck(reference_model, "", &reference_solutions); + EXPECT_EQ(reference_response.status(), CpSolverStatus::OPTIMAL); + EXPECT_EQ(solutions, reference_solutions); +} + bool TestSquarePropagation(std::pair initial_domain_x, std::pair initial_domain_s, std::pair expected_domain_x, @@ -1557,7 +1795,7 @@ bool TestSquarePropagation(std::pair initial_domain_x, NewIntegerVariable(initial_domain_x.first, initial_domain_x.second)); IntegerVariable s = model.Add( NewIntegerVariable(initial_domain_s.first, initial_domain_s.second)); - model.Add(ProductConstraint(x, x, s)); + model.Add(ProductConstraint({}, x, x, s)); const bool result = model.GetOrCreate()->Propagate(); if (result) { EXPECT_BOUNDS_EQ(x, expected_domain_x.first, expected_domain_x.second); @@ -1598,6 +1836,65 @@ TEST(SquareConstraintTest, LargestSquare) { {0, square * square})); } +TEST(SquareConstraintTest, AlwaysFalseWithTwoEnforcementLiterals) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const Literal c = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable x = model.Add(NewIntegerVariable(0, 5)); + const IntegerVariable s = model.Add(NewIntegerVariable(50, 100)); + // Always false if enforced (x^2 always less than s). + model.Add(ProductConstraint({b, c}, x, x, s)); + // Nothing should be propagated. + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_FALSE(model.GetOrCreate()->Assignment().LiteralIsAssigned(b)); + EXPECT_FALSE(model.GetOrCreate()->Assignment().LiteralIsAssigned(c)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(x, 0, 5); + EXPECT_BOUNDS_EQ(s, 50, 100); +} + +TEST(SquareConstraintTest, AlwaysFalseWithOneUnassignedEnforcementLiteral) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable x = model.Add(NewIntegerVariable(0, 5)); + const IntegerVariable s = model.Add(NewIntegerVariable(50, 100)); + // Always false if enforced (x^2 always less than s). + model.Add(ProductConstraint({b}, x, x, s)); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_TRUE(model.GetOrCreate()->Assignment().LiteralIsFalse(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(x, 0, 5); + EXPECT_BOUNDS_EQ(s, 50, 100); +} + +TEST(SquareConstraintTest, AlwaysFalseWithOneUnassignedEnforcementLiteral2) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable x = model.Add(NewIntegerVariable(0, 5)); + const IntegerVariable s = model.Add(NewIntegerVariable(-100, -50)); + // Always false if enforced (x^2 always greater than s). + model.Add(ProductConstraint({b}, x, x, s)); + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_TRUE(model.GetOrCreate()->Assignment().LiteralIsFalse(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(x, 0, 5); + EXPECT_BOUNDS_EQ(s, -100, -50); +} + +TEST(SquareConstraintTest, NotAlwaysFalseWithOneUnassignedEnforcementLiteral) { + Model model; + const Literal b = Literal(model.Add(NewBooleanVariable()), true); + const IntegerVariable x = model.Add(NewIntegerVariable(0, 5)); + const IntegerVariable s = model.Add(NewIntegerVariable(0, 100)); + model.Add(ProductConstraint({b}, x, x, s)); + // Nothing should be propagated. + EXPECT_TRUE(model.GetOrCreate()->Propagate()); + EXPECT_FALSE(model.GetOrCreate()->Assignment().LiteralIsAssigned(b)); + EXPECT_EQ(model.GetOrCreate()->num_enqueues(), 0); + EXPECT_BOUNDS_EQ(x, 0, 5); + EXPECT_BOUNDS_EQ(s, 0, 100); +} + TEST(LevelZeroEqualityTest, BasicExample) { Model model; diff --git a/ortools/sat/intervals.cc b/ortools/sat/intervals.cc index 113ad4e5d9..1e9dcb2330 100644 --- a/ortools/sat/intervals.cc +++ b/ortools/sat/intervals.cc @@ -15,6 +15,7 @@ #include #include +#include #include #include "absl/container/flat_hash_map.h" @@ -43,7 +44,10 @@ IntervalsRepository::IntervalsRepository(Model* model) sat_solver_(model->GetOrCreate()), implications_(model->GetOrCreate()), integer_trail_(model->GetOrCreate()), - reified_precedences_(model->GetOrCreate()) {} + reified_precedences_(model->GetOrCreate()), + root_level_bounds_(model->GetOrCreate()), + linear2_bounds_(model->GetOrCreate()), + integer_encoder_(model->GetOrCreate()) {} IntervalVariable IntervalsRepository::CreateInterval(IntegerVariable start, IntegerVariable end, @@ -137,56 +141,90 @@ IntervalsRepository::GetOrCreateDisjunctivePrecedenceLiteralIfNonTrivial( // task_a is currently before task_b ? // Lets not create a literal that will be propagated right away. - if (integer_trail_->UpperBound(a.start) < integer_trail_->LowerBound(b.end)) { - if (sat_solver_->CurrentDecisionLevel() == 0) { - AddConditionalAffinePrecedence(enforcement_literals, a.end, b.start, - model_); - } + const auto [expr_b_before_a, ub_b_before_a] = + EncodeDifferenceLowerThan(b.end, a.start, 0); + const RelationStatus b_before_a_root_status = + root_level_bounds_->GetLevelZeroStatus(expr_b_before_a, kMinIntegerValue, + ub_b_before_a); + if (b_before_a_root_status == RelationStatus::IS_FALSE) { + AddConditionalAffinePrecedence(enforcement_literals, a.end, b.start, + model_); + return kNoLiteralIndex; + } + const RelationStatus b_before_a_status = linear2_bounds_->GetStatus( + expr_b_before_a, kMinIntegerValue, ub_b_before_a); + if (b_before_a_status != RelationStatus::IS_UNKNOWN) { + // Abort if the relation is already known. return kNoLiteralIndex; } // task_b is before task_a ? - if (integer_trail_->UpperBound(b.start) < integer_trail_->LowerBound(a.end)) { - if (sat_solver_->CurrentDecisionLevel() == 0) { - AddConditionalAffinePrecedence(enforcement_literals, b.end, a.start, - model_); - } + const auto [expr_a_before_b, ub_a_before_b] = + EncodeDifferenceLowerThan(a.end, b.start, 0); + const RelationStatus a_before_b_root_status = + root_level_bounds_->GetLevelZeroStatus(expr_a_before_b, kMinIntegerValue, + ub_a_before_b); + if (a_before_b_root_status == RelationStatus::IS_FALSE) { + AddConditionalAffinePrecedence(enforcement_literals, b.end, a.start, + model_); return kNoLiteralIndex; } - - // Abort if the relation is already known. - if (reified_precedences_->GetLevelZeroPrecedenceStatus(a.end, b.start) == - RelationStatus::IS_TRUE || - reified_precedences_->GetLevelZeroPrecedenceStatus(b.end, a.start) == - RelationStatus::IS_TRUE) { + const RelationStatus a_before_b_status = linear2_bounds_->GetStatus( + expr_a_before_b, kMinIntegerValue, ub_a_before_b); + if (a_before_b_status != RelationStatus::IS_UNKNOWN) { + // Abort if the relation is already known. return kNoLiteralIndex; } // Create a new literal. // - // TODO(user): If there are no enforcement and we already have at one of: + // TODO(user): An alternative solution when it is enforced is to get/create // - s <=> a.end <= b.start // - t <=> b.end <= a.start - // We could use (s, not(s)) or (not(t), t) and make sure s = not(t) if both - // exists. - // - // TODO(user): Otherwise, an alternative solution is to create s and t (can be - // one more Boolean though), and have enforcement => s + t == 1. The later - // might not even be needed though, since interval equation should already - // enforce it. - const BooleanVariable boolean_var = sat_solver_->NewBooleanVariable(); - const Literal a_before_b = Literal(boolean_var, true); + // and have enforcement => s + t == 1. The later might not even be needed + // though, since interval equation should already enforce it. + Literal a_before_b; + if (enforcement_literals.empty()) { + // We don't have any enforcement literal, so we should use the existing + // ReifiedLinear2Bounds class. + LiteralIndex a_before_b_index = GetPrecedenceLiteral(a.end, b.start); + const LiteralIndex b_before_a_index = GetPrecedenceLiteral(b.end, a.start); + if (a_before_b_index == kNoLiteralIndex && + b_before_a_index == kNoLiteralIndex) { + CreatePrecedenceLiteralIfNonTrivial(a.end, b.start); + a_before_b_index = GetPrecedenceLiteral(a.end, b.start); + DCHECK_NE(a_before_b_index, kNoLiteralIndex); // We tested not trivial. + // Now associate its negation with b.end <= a.start. + reified_precedences_->AddBoundEncodingIfNonTrivial( + Literal(a_before_b_index).Negated(), expr_b_before_a, ub_b_before_a); + } else if (a_before_b_index == kNoLiteralIndex && + b_before_a_index != kNoLiteralIndex) { + // We already have a literal for b.end <= a.start. + // We can just use the negation of that literal. + a_before_b_index = Literal(b_before_a_index).NegatedIndex(); + reified_precedences_->AddBoundEncodingIfNonTrivial( + Literal(a_before_b_index), expr_a_before_b, ub_a_before_b); + } else if (a_before_b_index != kNoLiteralIndex && + b_before_a_index == kNoLiteralIndex) { + reified_precedences_->AddBoundEncodingIfNonTrivial( + Literal(a_before_b_index).Negated(), expr_b_before_a, ub_b_before_a); + } else { + // We have both literals. One must be the negation of the other. + implications_->AddImplication(Literal(a_before_b_index), + Literal(b_before_a_index).Negated()); + implications_->AddImplication(Literal(a_before_b_index).Negated(), + Literal(b_before_a_index)); + } + DCHECK_NE(a_before_b_index, kNoLiteralIndex); + a_before_b = Literal(a_before_b_index); + } else { + const BooleanVariable boolean_var = sat_solver_->NewBooleanVariable(); + a_before_b = Literal(boolean_var, true); + } + disjunctive_precedences_.insert({{a, b}, a_before_b}); disjunctive_precedences_.insert({{b, a}, a_before_b.Negated()}); - // Also insert it in precedences. - if (enforcement_literals.empty()) { - reified_precedences_->AddReifiedPrecedenceIfNonTrivial(a_before_b, a.end, - b.start); - reified_precedences_->AddReifiedPrecedenceIfNonTrivial(a_before_b.Negated(), - b.end, a.start); - } - enforcement_literals.push_back(a_before_b); AddConditionalAffinePrecedence(enforcement_literals, a.end, b.start, model_); enforcement_literals.pop_back(); @@ -212,20 +250,41 @@ IntervalsRepository::GetOrCreateDisjunctivePrecedenceLiteralIfNonTrivial( bool IntervalsRepository::CreatePrecedenceLiteralIfNonTrivial( AffineExpression x, AffineExpression y) { - const LiteralIndex index = reified_precedences_->GetReifiedPrecedence(x, y); - if (index != kNoLiteralIndex) return false; + const auto [expr, ub] = EncodeDifferenceLowerThan(x, y, 0); + auto reified_bound = reified_precedences_->GetEncodedBound(expr, ub); + if (std::holds_alternative( + reified_bound)) { + const auto bound_type = + std::get(reified_bound); + if (bound_type == ReifiedLinear2Bounds::ReifiedBoundType::kAlwaysTrue || + bound_type == ReifiedLinear2Bounds::ReifiedBoundType::kAlwaysFalse) { + // Nothing to do, precedence is trivial at level zero. + return false; + } + } - // We want l => x <= y and not(l) => x > y <=> y + 1 <= x - // Do not create l if the relation is always true or false. - if (reified_precedences_->GetLevelZeroPrecedenceStatus(x, y) != - RelationStatus::IS_UNKNOWN) { + if (std::holds_alternative(reified_bound)) { + // Already created. return false; } + if (std::holds_alternative(reified_bound)) { + if (integer_encoder_->GetAssociatedLiteral( + std::get(reified_bound)) != kNoLiteralIndex) { + return false; + } + // Create a new literal from the IntegerLiteral. This makes sure + // GetPrecedenceLiteral() always returns something if this function was + // called on a non-trivial precedence. + integer_encoder_->GetOrCreateAssociatedLiteral( + std::get(reified_bound)); + return true; + } + // Create a new literal. const BooleanVariable boolean_var = sat_solver_->NewBooleanVariable(); const Literal x_before_y = Literal(boolean_var, true); - reified_precedences_->AddReifiedPrecedenceIfNonTrivial(x_before_y, x, y); + reified_precedences_->AddBoundEncodingIfNonTrivial(x_before_y, expr, ub); AffineExpression y_plus_one = y; y_plus_one.constant += 1; @@ -236,7 +295,28 @@ bool IntervalsRepository::CreatePrecedenceLiteralIfNonTrivial( LiteralIndex IntervalsRepository::GetPrecedenceLiteral( AffineExpression x, AffineExpression y) const { - return reified_precedences_->GetReifiedPrecedence(x, y); + const auto [expr, ub] = EncodeDifferenceLowerThan(x, y, 0); + auto reified_bound = reified_precedences_->GetEncodedBound(expr, ub); + if (std::holds_alternative(reified_bound)) { + return integer_encoder_->GetAssociatedLiteral( + std::get(reified_bound)); + } + if (std::holds_alternative(reified_bound)) { + return std::get(reified_bound).Index(); + } + if (std::holds_alternative( + reified_bound)) { + const auto bound_type = + std::get(reified_bound); + if (bound_type == ReifiedLinear2Bounds::ReifiedBoundType::kAlwaysTrue) { + return integer_encoder_->GetTrueLiteral().Index(); + } + if (bound_type == ReifiedLinear2Bounds::ReifiedBoundType::kAlwaysFalse) { + return integer_encoder_->GetTrueLiteral().NegatedIndex(); + } + } + + return kNoLiteralIndex; } Literal IntervalsRepository::GetOrCreatePrecedenceLiteral(AffineExpression x, @@ -247,7 +327,7 @@ Literal IntervalsRepository::GetOrCreatePrecedenceLiteral(AffineExpression x, } CHECK(CreatePrecedenceLiteralIfNonTrivial(x, y)); - const LiteralIndex index = reified_precedences_->GetReifiedPrecedence(x, y); + const LiteralIndex index = GetPrecedenceLiteral(x, y); CHECK_NE(index, kNoLiteralIndex); return Literal(index); } diff --git a/ortools/sat/intervals.h b/ortools/sat/intervals.h index fe4f0fde0b..03c2d8b4ab 100644 --- a/ortools/sat/intervals.h +++ b/ortools/sat/intervals.h @@ -191,6 +191,9 @@ class IntervalsRepository { BinaryImplicationGraph* implications_; IntegerTrail* integer_trail_; ReifiedLinear2Bounds* reified_precedences_; + RootLevelLinear2Bounds* root_level_bounds_; + Linear2Bounds* linear2_bounds_; + IntegerEncoder* integer_encoder_; // Literal indicating if the tasks is executed. Tasks that are always executed // will have a kNoLiteralIndex entry in this vector. diff --git a/ortools/sat/java/sat.swig b/ortools/sat/java/sat.swig index f2def937a9..d3ff1f798d 100644 --- a/ortools/sat/java/sat.swig +++ b/ortools/sat/java/sat.swig @@ -190,7 +190,8 @@ PROTO2_RETURN(operations_research::sat::CpSolverResponse, %feature("director") operations_research::sat::SolutionCallback; %unignore operations_research::sat::SolutionCallback; -%unignore operations_research::sat::SolutionCallback::~SolutionCallback; +%unignore operations_research::sat::SolutionCallback::SolutionCallback(); +%unignore operations_research::sat::SolutionCallback::~SolutionCallback(); %rename (bestObjectiveBound) operations_research::sat::SolutionCallback::BestObjectiveBound; %rename (numBinaryPropagations) operations_research::sat::SolutionCallback::NumBinaryPropagations; %rename (numBooleans) operations_research::sat::SolutionCallback::NumBooleans; diff --git a/ortools/sat/linear_propagation.cc b/ortools/sat/linear_propagation.cc index 330483c928..2446dd2ecd 100644 --- a/ortools/sat/linear_propagation.cc +++ b/ortools/sat/linear_propagation.cc @@ -18,30 +18,26 @@ #include #include #include -#include #include #include #include #include "absl/base/log_severity.h" #include "absl/cleanup/cleanup.h" -#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/log/vlog_is_on.h" #include "absl/numeric/int128.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" -#include "ortools/base/stl_util.h" #include "ortools/base/strong_vector.h" +#include "ortools/sat/cp_constraints.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/model.h" #include "ortools/sat/precedences.h" #include "ortools/sat/sat_base.h" -#include "ortools/sat/sat_solver.h" #include "ortools/sat/synchronization.h" #include "ortools/sat/util.h" #include "ortools/util/bitset.h" @@ -51,330 +47,6 @@ namespace operations_research { namespace sat { -std::ostream& operator<<(std::ostream& os, const EnforcementStatus& e) { - switch (e) { - case EnforcementStatus::IS_FALSE: - os << "IS_FALSE"; - break; - case EnforcementStatus::CANNOT_PROPAGATE: - os << "CANNOT_PROPAGATE"; - break; - case EnforcementStatus::CAN_PROPAGATE: - os << "CAN_PROPAGATE"; - break; - case EnforcementStatus::IS_ENFORCED: - os << "IS_ENFORCED"; - break; - } - return os; -} - -EnforcementPropagator::EnforcementPropagator(Model* model) - : SatPropagator("EnforcementPropagator"), - trail_(*model->GetOrCreate()), - assignment_(trail_.Assignment()), - integer_trail_(model->GetOrCreate()), - rev_int_repository_(model->GetOrCreate()) { - // Note that this will be after the integer trail since rev_int_repository_ - // depends on IntegerTrail. - model->GetOrCreate()->AddPropagator(this); - - // Sentinel - also start of next Register(). - starts_.push_back(0); -} - -bool EnforcementPropagator::Propagate(Trail* /*trail*/) { - rev_int_repository_->SaveStateWithStamp(&rev_stack_size_, &rev_stamp_); - while (propagation_trail_index_ < trail_.Index()) { - const Literal literal = trail_[propagation_trail_index_++]; - if (literal.Index() >= static_cast(watcher_.size())) continue; - - int new_size = 0; - auto& watch_list = watcher_[literal.Index()]; - for (const EnforcementId id : watch_list) { - const LiteralIndex index = ProcessIdOnTrue(literal, id); - if (index == kNoLiteralIndex) { - // We keep the same watcher. - watch_list[new_size++] = id; - } else { - // Change the watcher. - CHECK_NE(index, literal.Index()); - watcher_[index].push_back(id); - } - } - watch_list.resize(new_size); - - // We also mark some constraint false. - for (const EnforcementId id : watcher_[literal.NegatedIndex()]) { - ChangeStatus(id, EnforcementStatus::IS_FALSE); - } - } - rev_stack_size_ = static_cast(untrail_stack_.size()); - - // Compute the enforcement status of any constraint added at a positive level. - // This is only needed until we are back to level zero. - for (const EnforcementId id : ids_to_fix_until_next_root_level_) { - ChangeStatus(id, DebugStatus(id)); - } - if (trail_.CurrentDecisionLevel() == 0) { - ids_to_fix_until_next_root_level_.clear(); - } - - return true; -} - -void EnforcementPropagator::Untrail(const Trail& /*trail*/, int trail_index) { - // Simply revert the status change. - const int size = static_cast(untrail_stack_.size()); - for (int i = size - 1; i >= rev_stack_size_; --i) { - const auto [id, status] = untrail_stack_[i]; - statuses_[id] = status; - if (callbacks_[id] != nullptr) callbacks_[id](id, status); - } - untrail_stack_.resize(rev_stack_size_); - propagation_trail_index_ = trail_index; -} - -// Adds a new constraint to the class and returns the constraint id. -// -// Note that we accept empty enforcement list so that client code can be used -// regardless of the presence of enforcement or not. A negative id means the -// constraint is never enforced, and should be ignored. -EnforcementId EnforcementPropagator::Register( - absl::Span enforcement, - std::function callback) { - int num_true = 0; - int num_false = 0; - bool is_always_false = false; - temp_literals_.clear(); - const int level = trail_.CurrentDecisionLevel(); - for (const Literal l : enforcement) { - // Make sure we always have enough room for the literal and its negation. - const int size = std::max(l.Index().value(), l.NegatedIndex().value()) + 1; - if (size > static_cast(watcher_.size())) { - watcher_.resize(size); - } - if (assignment_.LiteralIsTrue(l)) { - if (level == 0 || trail_.Info(l.Variable()).level == 0) continue; - ++num_true; - } else if (assignment_.LiteralIsFalse(l)) { - if (level == 0 || trail_.Info(l.Variable()).level == 0) { - is_always_false = true; - break; - } - ++num_false; - } - temp_literals_.push_back(l); - } - gtl::STLSortAndRemoveDuplicates(&temp_literals_); - - // Return special indices if never/always enforced. - if (is_always_false) { - if (callback != nullptr) - callback(EnforcementId(-1), EnforcementStatus::IS_FALSE); - return EnforcementId(-1); - } - if (temp_literals_.empty()) { - if (callback != nullptr) - callback(EnforcementId(-1), EnforcementStatus::IS_ENFORCED); - return EnforcementId(-1); - } - - const EnforcementId id(static_cast(callbacks_.size())); - callbacks_.push_back(std::move(callback)); - - CHECK(!temp_literals_.empty()); - buffer_.insert(buffer_.end(), temp_literals_.begin(), temp_literals_.end()); - starts_.push_back(buffer_.size()); // Sentinel/next-start. - - // The default status at level zero. - statuses_.push_back(temp_literals_.size() == 1 - ? EnforcementStatus::CAN_PROPAGATE - : EnforcementStatus::CANNOT_PROPAGATE); - - if (temp_literals_.size() == 1) { - watcher_[temp_literals_[0].Index()].push_back(id); - } else { - // Make sure we watch correct literals. - const auto span = GetSpan(id); - int num_not_true = 0; - for (int i = 0; i < span.size(); ++i) { - if (assignment_.LiteralIsTrue(span[i])) continue; - std::swap(span[num_not_true], span[i]); - ++num_not_true; - if (num_not_true == 2) break; - } - - // We need to watch one of the literals at highest level. - if (num_not_true == 1) { - int max_level = trail_.Info(span[1].Variable()).level; - for (int i = 2; i < span.size(); ++i) { - const int level = trail_.Info(span[i].Variable()).level; - if (level > max_level) { - max_level = level; - std::swap(span[1], span[i]); - } - } - } - - watcher_[span[0].Index()].push_back(id); - watcher_[span[1].Index()].push_back(id); - } - - // Change status, call callback and set up untrail if the status is different - // from EnforcementStatus::CANNOT_PROPAGATE. - if (num_false > 0) { - ChangeStatus(id, EnforcementStatus::IS_FALSE); - } else if (num_true == temp_literals_.size()) { - ChangeStatus(id, EnforcementStatus::IS_ENFORCED); - } else if (num_true + 1 == temp_literals_.size()) { - ChangeStatus(id, EnforcementStatus::CAN_PROPAGATE); - // Because this is the default status, we still need to call the callback. - if (temp_literals_.size() == 1) { - if (callbacks_[id] != nullptr) { - callbacks_[id](id, EnforcementStatus::CAN_PROPAGATE); - } - } - } - - // Tricky: if we added something at a positive level, and its status is - // not CANNOT_PROPAGATE, then we might need to fix it on backtrack. - if (trail_.CurrentDecisionLevel() > 0 && - statuses_[id] != EnforcementStatus::CANNOT_PROPAGATE) { - ids_to_fix_until_next_root_level_.push_back(id); - } - - return id; -} - -// Add the enforcement reason to the given vector. -void EnforcementPropagator::AddEnforcementReason( - EnforcementId id, std::vector* reason) const { - for (const Literal l : GetSpan(id)) { - reason->push_back(l.Negated()); - } -} - -// Try to propagate when the enforced constraint is not satisfiable. -// This is currently in O(enforcement_size); -bool EnforcementPropagator::PropagateWhenFalse( - EnforcementId id, absl::Span literal_reason, - absl::Span integer_reason) { - temp_reason_.clear(); - LiteralIndex unique_unassigned = kNoLiteralIndex; - for (const Literal l : GetSpan(id)) { - if (assignment_.LiteralIsFalse(l)) return true; - if (assignment_.LiteralIsTrue(l)) { - temp_reason_.push_back(l.Negated()); - continue; - } - if (unique_unassigned != kNoLiteralIndex) return true; - unique_unassigned = l.Index(); - } - - temp_reason_.insert(temp_reason_.end(), literal_reason.begin(), - literal_reason.end()); - if (unique_unassigned == kNoLiteralIndex) { - return integer_trail_->ReportConflict(temp_reason_, integer_reason); - } - - // We also change the status right away. - ChangeStatus(id, EnforcementStatus::IS_FALSE); - integer_trail_->EnqueueLiteral(Literal(unique_unassigned).Negated(), - temp_reason_, integer_reason); - return true; -} - -absl::Span EnforcementPropagator::GetSpan(EnforcementId id) { - if (id < 0) return {}; - DCHECK_LE(id + 1, starts_.size()); - const int size = starts_[id + 1] - starts_[id]; - DCHECK_NE(size, 0); - return absl::MakeSpan(&buffer_[starts_[id]], size); -} - -absl::Span EnforcementPropagator::GetSpan( - EnforcementId id) const { - if (id < 0) return {}; - DCHECK_LE(id + 1, starts_.size()); - const int size = starts_[id + 1] - starts_[id]; - DCHECK_NE(size, 0); - return absl::MakeSpan(&buffer_[starts_[id]], size); -} - -LiteralIndex EnforcementPropagator::ProcessIdOnTrue(Literal watched, - EnforcementId id) { - const EnforcementStatus status = statuses_[id]; - if (status == EnforcementStatus::IS_FALSE) return kNoLiteralIndex; - - const auto span = GetSpan(id); - if (span.size() == 1) { - CHECK_EQ(status, EnforcementStatus::CAN_PROPAGATE); - ChangeStatus(id, EnforcementStatus::IS_ENFORCED); - return kNoLiteralIndex; - } - - const int watched_pos = (span[0] == watched) ? 0 : 1; - CHECK_EQ(span[watched_pos], watched); - if (assignment_.LiteralIsFalse(span[watched_pos ^ 1])) { - ChangeStatus(id, EnforcementStatus::IS_FALSE); - return kNoLiteralIndex; - } - - for (int i = 2; i < span.size(); ++i) { - const Literal l = span[i]; - if (assignment_.LiteralIsFalse(l)) { - ChangeStatus(id, EnforcementStatus::IS_FALSE); - return kNoLiteralIndex; - } - if (!assignment_.LiteralIsAssigned(l)) { - // Replace the watched literal. Note that if the other watched literal is - // true, it should be processed afterwards. We do not change the status - std::swap(span[watched_pos], span[i]); - return span[watched_pos].Index(); - } - } - - // All literal with index > 1 are true. Two case. - if (assignment_.LiteralIsTrue(span[watched_pos ^ 1])) { - // All literals are true. - ChangeStatus(id, EnforcementStatus::IS_ENFORCED); - return kNoLiteralIndex; - } else { - // The other watched literal is the last unassigned - CHECK_EQ(status, EnforcementStatus::CANNOT_PROPAGATE); - ChangeStatus(id, EnforcementStatus::CAN_PROPAGATE); - return kNoLiteralIndex; - } -} - -void EnforcementPropagator::ChangeStatus(EnforcementId id, - EnforcementStatus new_status) { - const EnforcementStatus old_status = statuses_[id]; - if (old_status == new_status) return; - if (trail_.CurrentDecisionLevel() != 0) { - untrail_stack_.push_back({id, old_status}); - } - statuses_[id] = new_status; - if (callbacks_[id] != nullptr) callbacks_[id](id, new_status); -} - -EnforcementStatus EnforcementPropagator::DebugStatus(EnforcementId id) { - if (id < 0) return EnforcementStatus::IS_ENFORCED; - - int num_true = 0; - for (const Literal l : GetSpan(id)) { - if (assignment_.LiteralIsFalse(l)) { - return EnforcementStatus::IS_FALSE; - } - if (assignment_.LiteralIsTrue(l)) ++num_true; - } - const int size = GetSpan(id).size(); - if (num_true == size) return EnforcementStatus::IS_ENFORCED; - if (num_true + 1 == size) return EnforcementStatus::CAN_PROPAGATE; - return EnforcementStatus::CANNOT_PROPAGATE; -} - LinearPropagator::LinearPropagator(Model* model) : trail_(model->GetOrCreate()), integer_trail_(model->GetOrCreate()), diff --git a/ortools/sat/linear_propagation.h b/ortools/sat/linear_propagation.h index ab4027b665..149aaf8313 100644 --- a/ortools/sat/linear_propagation.h +++ b/ortools/sat/linear_propagation.h @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -29,6 +28,7 @@ #include "absl/log/check.h" #include "absl/types/span.h" #include "ortools/base/strong_vector.h" +#include "ortools/sat/cp_constraints.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/model.h" @@ -44,106 +44,6 @@ namespace operations_research { namespace sat { -DEFINE_STRONG_INDEX_TYPE(EnforcementId); - -// An enforced constraint can be in one of these 4 states. -// Note that we rely on the integer encoding to take 2 bits for optimization. -enum class EnforcementStatus { - // One enforcement literal is false. - IS_FALSE = 0, - // More than two literals are unassigned. - CANNOT_PROPAGATE = 1, - // All enforcement literals are true but one. - CAN_PROPAGATE = 2, - // All enforcement literals are true. - IS_ENFORCED = 3, -}; - -std::ostream& operator<<(std::ostream& os, const EnforcementStatus& e); - -// This is meant as an helper to deal with enforcement for any constraint. -class EnforcementPropagator : public SatPropagator { - public: - explicit EnforcementPropagator(Model* model); - - // SatPropagator interface. - bool Propagate(Trail* trail) final; - void Untrail(const Trail& trail, int trail_index) final; - - // Adds a new constraint to the class and register a callback that will - // be called on status change. Note that we also call the callback with the - // initial status if different from CANNOT_PROPAGATE when added. - // - // It is better to not call this for empty enforcement list, but you can. A - // negative id means the level zero status will never change, and only the - // first call to callback() should be necessary, we don't save it. - EnforcementId Register( - absl::Span enforcement, - std::function callback = nullptr); - - // Add the enforcement reason to the given vector. - void AddEnforcementReason(EnforcementId id, - std::vector* reason) const; - - // Try to propagate when the enforced constraint is not satisfiable. - // This is currently in O(enforcement_size). - ABSL_MUST_USE_RESULT bool PropagateWhenFalse( - EnforcementId id, absl::Span literal_reason, - absl::Span integer_reason); - - EnforcementStatus Status(EnforcementId id) const { return statuses_[id]; } - - // Recompute the status from the current assignment. - // This should only used in DCHECK(). - EnforcementStatus DebugStatus(EnforcementId id); - - // Returns the enforcement literals of the given id. - absl::Span GetEnforcementLiterals(EnforcementId id) const { - if (id < 0) return {}; - return GetSpan(id); - } - - private: - absl::Span GetSpan(EnforcementId id); - absl::Span GetSpan(EnforcementId id) const; - void ChangeStatus(EnforcementId id, EnforcementStatus new_status); - - // Returns kNoLiteralIndex if nothing need to change or a new literal to - // watch. This also calls the registered callback. - LiteralIndex ProcessIdOnTrue(Literal watched, EnforcementId id); - - // External classes. - const Trail& trail_; - const VariablesAssignment& assignment_; - IntegerTrail* integer_trail_; - RevIntRepository* rev_int_repository_; - - // All enforcement will be copied there, and we will create Span out of this. - // Note that we don't store the span so that we are not invalidated on buffer_ - // resizing. - util_intops::StrongVector starts_; - std::vector buffer_; - - util_intops::StrongVector statuses_; - util_intops::StrongVector< - EnforcementId, std::function> - callbacks_; - - // Used to restore status and call callback on untrail. - std::vector> untrail_stack_; - int rev_stack_size_ = 0; - int64_t rev_stamp_ = 0; - - // We use a two watcher scheme. - util_intops::StrongVector> - watcher_; - - std::vector temp_literals_; - std::vector temp_reason_; - - std::vector ids_to_fix_until_next_root_level_; -}; - // Helper class to decide on the constraint propagation order. // // Each constraint might push some variables which might in turn make other diff --git a/ortools/sat/linear_propagation_test.cc b/ortools/sat/linear_propagation_test.cc index 356888282d..24b988d530 100644 --- a/ortools/sat/linear_propagation_test.cc +++ b/ortools/sat/linear_propagation_test.cc @@ -17,10 +17,8 @@ #include -#include "absl/log/check.h" #include "absl/types/span.h" #include "gtest/gtest.h" -#include "ortools/base/gmock.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/model.h" @@ -32,120 +30,6 @@ namespace operations_research { namespace sat { namespace { -using ::testing::ElementsAre; - -TEST(EnforcementPropagatorTest, BasicTest) { - Model model; - auto* sat_solver = model.GetOrCreate(); - auto* trail = model.GetOrCreate(); - auto* propag = model.GetOrCreate(); - sat_solver->SetNumVariables(10); - - const EnforcementId id1 = propag->Register(Literals({+1})); - const EnforcementId id2 = propag->Register(Literals({+1, +2})); - const EnforcementId id3 = propag->Register(Literals({-2})); - - EXPECT_TRUE(propag->Propagate(trail)); - EXPECT_EQ(propag->Status(id1), EnforcementStatus::CAN_PROPAGATE); - EXPECT_EQ(propag->Status(id2), EnforcementStatus::CANNOT_PROPAGATE); - EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); - - sat_solver->EnqueueDecisionIfNotConflicting(Literal(+1)); - EXPECT_TRUE(propag->Propagate(trail)); - EXPECT_EQ(propag->Status(id1), EnforcementStatus::IS_ENFORCED); - EXPECT_EQ(propag->Status(id2), EnforcementStatus::CAN_PROPAGATE); - EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); - - sat_solver->EnqueueDecisionIfNotConflicting(Literal(+2)); - EXPECT_EQ(propag->Status(id1), EnforcementStatus::IS_ENFORCED); - EXPECT_EQ(propag->Status(id2), EnforcementStatus::IS_ENFORCED); - EXPECT_EQ(propag->Status(id3), EnforcementStatus::IS_FALSE); - - CHECK(sat_solver->ResetToLevelZero()); - EXPECT_EQ(propag->Status(id1), EnforcementStatus::CAN_PROPAGATE); - EXPECT_EQ(propag->Status(id2), EnforcementStatus::CANNOT_PROPAGATE); - EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); -} - -TEST(EnforcementPropagatorTest, UntrailWork) { - Model model; - auto* sat_solver = model.GetOrCreate(); - auto* trail = model.GetOrCreate(); - auto* propag = model.GetOrCreate(); - sat_solver->SetNumVariables(10); - - const EnforcementId id1 = propag->Register(Literals({+1})); - const EnforcementId id2 = propag->Register(Literals({+2})); - const EnforcementId id3 = propag->Register(Literals({+3})); - - EXPECT_TRUE(propag->Propagate(trail)); - EXPECT_EQ(propag->Status(id1), EnforcementStatus::CAN_PROPAGATE); - EXPECT_EQ(propag->Status(id2), EnforcementStatus::CAN_PROPAGATE); - EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); - - sat_solver->EnqueueDecisionIfNotConflicting(Literal(+1)); - EXPECT_TRUE(propag->Propagate(trail)); - EXPECT_EQ(propag->Status(id1), EnforcementStatus::IS_ENFORCED); - EXPECT_EQ(propag->Status(id2), EnforcementStatus::CAN_PROPAGATE); - EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); - - sat_solver->EnqueueDecisionIfNotConflicting(Literal(+2)); - EXPECT_TRUE(propag->Propagate(trail)); - EXPECT_EQ(propag->Status(id1), EnforcementStatus::IS_ENFORCED); - EXPECT_EQ(propag->Status(id2), EnforcementStatus::IS_ENFORCED); - EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); - const int level = sat_solver->CurrentDecisionLevel(); - - sat_solver->EnqueueDecisionIfNotConflicting(Literal(+3)); - EXPECT_TRUE(propag->Propagate(trail)); - EXPECT_EQ(propag->Status(id1), EnforcementStatus::IS_ENFORCED); - EXPECT_EQ(propag->Status(id2), EnforcementStatus::IS_ENFORCED); - EXPECT_EQ(propag->Status(id3), EnforcementStatus::IS_ENFORCED); - - sat_solver->Backtrack(level); - EXPECT_EQ(propag->Status(id1), EnforcementStatus::IS_ENFORCED); - EXPECT_EQ(propag->Status(id2), EnforcementStatus::IS_ENFORCED); - EXPECT_EQ(propag->Status(id3), EnforcementStatus::CAN_PROPAGATE); -} - -TEST(EnforcementPropagatorTest, AddingAtPositiveLevelTrue) { - Model model; - auto* sat_solver = model.GetOrCreate(); - auto* trail = model.GetOrCreate(); - auto* propag = model.GetOrCreate(); - sat_solver->SetNumVariables(10); - - EXPECT_TRUE(propag->Propagate(trail)); - sat_solver->EnqueueDecisionIfNotConflicting(Literal(+1)); - EXPECT_TRUE(propag->Propagate(trail)); - - const EnforcementId id = propag->Register(std::vector{+1}); - EXPECT_EQ(propag->Status(id), EnforcementStatus::IS_ENFORCED); - - sat_solver->Backtrack(0); - EXPECT_TRUE(propag->Propagate(trail)); - EXPECT_EQ(propag->Status(id), EnforcementStatus::CAN_PROPAGATE); -} - -TEST(EnforcementPropagatorTest, AddingAtPositiveLevelFalse) { - Model model; - auto* sat_solver = model.GetOrCreate(); - auto* trail = model.GetOrCreate(); - auto* propag = model.GetOrCreate(); - sat_solver->SetNumVariables(10); - - EXPECT_TRUE(propag->Propagate(trail)); - sat_solver->EnqueueDecisionIfNotConflicting(Literal(-1)); - EXPECT_TRUE(propag->Propagate(trail)); - - const EnforcementId id = propag->Register(std::vector{+1}); - EXPECT_EQ(propag->Status(id), EnforcementStatus::IS_FALSE); - - sat_solver->Backtrack(0); - EXPECT_TRUE(propag->Propagate(trail)); - EXPECT_EQ(propag->Status(id), EnforcementStatus::CAN_PROPAGATE); -} - // TEST copied from integer_expr test with little modif to use the new propag. IntegerVariable AddWeightedSum(const absl::Span vars, const absl::Span coeffs, diff --git a/ortools/sat/linear_relaxation.cc b/ortools/sat/linear_relaxation.cc index 0bd15c832a..29e57fbf8e 100644 --- a/ortools/sat/linear_relaxation.cc +++ b/ortools/sat/linear_relaxation.cc @@ -1422,15 +1422,18 @@ void TryToLinearizeConstraint(const CpModelProto& /*model_proto*/, break; } case ConstraintProto::ConstraintCase::kIntProd: { - const LinearArgumentProto& int_prod = ct.int_prod(); - if (int_prod.exprs_size() == 2 && - LinearExpressionProtosAreEqual(int_prod.exprs(0), - int_prod.exprs(1))) { - AppendSquareRelaxation(ct, model, relaxation); - AddSquareCutGenerator(ct, linearization_level, model, relaxation); - } else { - // No relaxation, just a cut generator . - AddIntProdCutGenerator(ct, linearization_level, model, relaxation); + // TODO(user): add support for enforcement literals? + if (!HasEnforcementLiteral(ct)) { + const LinearArgumentProto& int_prod = ct.int_prod(); + if (int_prod.exprs_size() == 2 && + LinearExpressionProtosAreEqual(int_prod.exprs(0), + int_prod.exprs(1))) { + AppendSquareRelaxation(ct, model, relaxation); + AddSquareCutGenerator(ct, linearization_level, model, relaxation); + } else { + // No relaxation, just a cut generator. + AddIntProdCutGenerator(ct, linearization_level, model, relaxation); + } } break; } diff --git a/ortools/sat/lp_utils.cc b/ortools/sat/lp_utils.cc index caf9478faa..c1e337f218 100644 --- a/ortools/sat/lp_utils.cc +++ b/ortools/sat/lp_utils.cc @@ -22,6 +22,7 @@ #include #include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -725,37 +726,8 @@ std::vector DetectImpliedIntegers(MPModelProto* mp_model, return var_scaling; } -namespace { - -// We use a class to reuse the temporary memory. -struct ConstraintScaler { - // Scales an individual constraint. - ConstraintProto* AddConstraint(const MPModelProto& mp_model, - const MPConstraintProto& mp_constraint, - CpModelProto* cp_model); - - bool keep_names = false; - double max_relative_coeff_error = 0.0; - double max_absolute_rhs_error = 0.0; - double max_scaling_factor = 0.0; - double min_scaling_factor = std::numeric_limits::infinity(); - - double wanted_precision = 1e-6; - int64_t scaling_target = int64_t{1} << 50; - std::vector var_indices; - std::vector coefficients; - std::vector lower_bounds; - std::vector upper_bounds; -}; - -ConstraintProto* ConstraintScaler::AddConstraint( - const MPModelProto& mp_model, const MPConstraintProto& mp_constraint, - CpModelProto* cp_model) { - if (mp_constraint.lower_bound() == -kInfinity && - mp_constraint.upper_bound() == kInfinity) { - return nullptr; - } - +absl::Status ConstraintScaler::ScaleAndAddConstraint( + const MPConstraintProto& mp_constraint, CpModelProto* cp_model) { auto* constraint = cp_model->add_constraints(); if (keep_names) constraint->set_name(mp_constraint.name()); auto* arg = constraint->mutable_linear(); @@ -788,12 +760,9 @@ ConstraintProto* ConstraintScaler::AddConstraint( coefficients, lower_bounds, upper_bounds, scaling_target, wanted_precision, &relative_coeff_error, &scaled_sum_error); if (scaling_factor == 0.0) { - // TODO(user): Report error properly instead of ignoring constraint. Note - // however that this likely indicate a coefficient of inf in the constraint, - // so we should probably abort before reaching here. - LOG(DFATAL) << "Scaling factor of zero while scaling constraint: " - << ProtobufShortDebugString(mp_constraint); - return nullptr; + return absl::InvalidArgumentError( + absl::StrCat("Scaling factor of zero while scaling constraint: ", + ProtobufShortDebugString(mp_constraint))); } const int64_t gcd = ComputeGcdOfRoundedDoubles(coefficients, scaling_factor); @@ -853,7 +822,14 @@ ConstraintProto* ConstraintScaler::AddConstraint( .value()); } - return constraint; + return absl::OkStatus(); +} + +namespace { + +bool ConstraintIsAlwaysTrue(const MPConstraintProto& mp_constraint) { + return mp_constraint.lower_bound() == -kInfinity && + mp_constraint.upper_bound() == kInfinity; } // TODO(user): unit test this. @@ -1031,7 +1007,14 @@ bool ConvertMPModelProtoToCpModelProto(const SatParameters& params, // Add the constraints. We scale each of them individually. for (const MPConstraintProto& mp_constraint : mp_model.constraint()) { - scaler.AddConstraint(mp_model, mp_constraint, cp_model); + if (ConstraintIsAlwaysTrue(mp_constraint)) continue; + + const absl::Status status = + scaler.ScaleAndAddConstraint(mp_constraint, cp_model); + if (!status.ok()) { + SOLVER_LOG(logger, "Error while scaling constraint. ", status.message()); + return false; + } } for (const MPGeneralConstraintProto& general_constraint : mp_model.general_constraint()) { @@ -1041,14 +1024,22 @@ bool ConvertMPModelProtoToCpModelProto(const SatParameters& params, general_constraint.indicator_constraint(); const MPConstraintProto& mp_constraint = indicator_constraint.constraint(); - ConstraintProto* ct = - scaler.AddConstraint(mp_model, mp_constraint, cp_model); - if (ct == nullptr) continue; + if (ConstraintIsAlwaysTrue(mp_constraint)) continue; + + const int new_ct_index = cp_model->constraints().size(); + const absl::Status status = + scaler.ScaleAndAddConstraint(mp_constraint, cp_model); + if (!status.ok()) { + SOLVER_LOG(logger, "Error while scaling constraint. ", + status.message()); + return false; + } // Add the indicator. const int var = indicator_constraint.var_index(); const int value = indicator_constraint.var_value(); - ct->add_enforcement_literal(value == 1 ? var : NegatedRef(var)); + cp_model->mutable_constraints(new_ct_index) + ->add_enforcement_literal(value == 1 ? var : NegatedRef(var)); break; } case MPGeneralConstraintProto::kAndConstraint: { diff --git a/ortools/sat/lp_utils.h b/ortools/sat/lp_utils.h index 61ad09d75a..a029fbd516 100644 --- a/ortools/sat/lp_utils.h +++ b/ortools/sat/lp_utils.h @@ -71,6 +71,49 @@ double FindBestScalingAndComputeErrors( double wanted_absolute_activity_precision, double* relative_coeff_error, double* scaled_sum_error); +// Helper to scale MPConstraintProto to CpModelProto::ConstraintProto. +// We use a class to reuse the temporary memory when we scale many constraints. +// +// Note that this can be used to scale any constraint, one just has to fill a +// MPConstraintProto using variable indices that correspond to the given +// CpModelProto variables. +struct ConstraintScaler { + // Scales an individual constraint and add it to the given CpModelProto. + // + // We use the domain of the variables to derive error bounds and scale the + // constraint as best as we can within "wanted_precision" and + // "scaling_target". We usually scale with power of two scaling factor or + // a rational scaling factor if we detect a good one via FindRationalFactor(). + // + // Returns an error if the given constraint contained huge coefficient or + // infinity. Note that we do not consider it an error if the wanted precision + // is not reached (best effort). One can check the error statistics field + // below and decide when there are too high and report an error separately. + absl::Status ScaleAndAddConstraint(const MPConstraintProto& mp_constraint, + CpModelProto* cp_model); + + // Statistics over all scaled constraints. This can be inspected to know the + // final error produced by ScaleAndAddConstraint(). + double max_relative_coeff_error = 0.0; + double max_absolute_rhs_error = 0.0; + double max_scaling_factor = 0.0; + double min_scaling_factor = std::numeric_limits::infinity(); + + // Parameters. Whether we ignore or copy the mp_constraint.name() field. + bool keep_names = false; + + // Parameters passed to FindBestScalingAndComputeErrors(), see documentation + // there to understand their meaning. + double wanted_precision = 1e-6; + int64_t scaling_target = int64_t{1} << 50; + + // Private temporary field to reuse memory. + std::vector var_indices; + std::vector coefficients; + std::vector lower_bounds; + std::vector upper_bounds; +}; + // Multiplies all continuous variable by the given scaling parameters and change // the rest of the model accordingly. The returned vector contains the scaling // of each variable (will always be 1.0 for integers) and can be used to recover diff --git a/ortools/sat/old_precedences_propagator.cc b/ortools/sat/old_precedences_propagator.cc new file mode 100644 index 0000000000..c066367765 --- /dev/null +++ b/ortools/sat/old_precedences_propagator.cc @@ -0,0 +1,667 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/old_precedences_propagator.h" + +#include + +#include +#include +#include +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/log/vlog_is_on.h" +#include "absl/types/span.h" +#include "ortools/base/logging.h" +#include "ortools/base/stl_util.h" +#include "ortools/base/strong_vector.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_base.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/synchronization.h" +#include "ortools/util/bitset.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { + +namespace { + +void AppendLowerBoundReasonIfValid(IntegerVariable var, + const IntegerTrail& i_trail, + std::vector* reason) { + if (var != kNoIntegerVariable) { + reason->push_back(i_trail.LowerBoundAsLiteral(var)); + } +} + +} // namespace + +PrecedencesPropagator::~PrecedencesPropagator() { + if (!VLOG_IS_ON(1)) return; + if (shared_stats_ == nullptr) return; + std::vector> stats; + stats.push_back({"precedences/num_cycles", num_cycles_}); + stats.push_back({"precedences/num_pushes", num_pushes_}); + stats.push_back( + {"precedences/num_enforcement_pushes", num_enforcement_pushes_}); + shared_stats_->AddStats(stats); +} + +bool PrecedencesPropagator::Propagate(Trail* trail) { return Propagate(); } + +bool PrecedencesPropagator::Propagate() { + while (propagation_trail_index_ < trail_->Index()) { + const Literal literal = (*trail_)[propagation_trail_index_++]; + if (literal.Index() >= literal_to_new_impacted_arcs_.size()) continue; + + // IMPORTANT: Because of the way Untrail() work, we need to add all the + // potential arcs before we can abort. It is why we iterate twice here. + for (const ArcIndex arc_index : + literal_to_new_impacted_arcs_[literal.Index()]) { + if (--arc_counts_[arc_index] == 0) { + const ArcInfo& arc = arcs_[arc_index]; + PushConditionalRelations(arc); + impacted_arcs_[arc.tail_var].push_back(arc_index); + } + } + + // Iterate again to check for a propagation and indirectly update + // modified_vars_. + for (const ArcIndex arc_index : + literal_to_new_impacted_arcs_[literal.Index()]) { + if (arc_counts_[arc_index] > 0) continue; + const ArcInfo& arc = arcs_[arc_index]; + const IntegerValue new_head_lb = + integer_trail_->LowerBound(arc.tail_var) + ArcOffset(arc); + if (new_head_lb > integer_trail_->LowerBound(arc.head_var)) { + if (!EnqueueAndCheck(arc, new_head_lb, trail_)) return false; + } + } + } + + // Do the actual propagation of the IntegerVariable bounds. + InitializeBFQueueWithModifiedNodes(); + if (!BellmanFordTarjan(trail_)) return false; + + // We can only test that no propagation is left if we didn't enqueue new + // literal in the presence of optional variables. + // + // TODO(user): Because of our code to deal with InPropagationLoop(), this is + // not always true. Find a cleaner way to DCHECK() while not failing in this + // corner case. + if (/*DISABLES CODE*/ (false) && + propagation_trail_index_ == trail_->Index()) { + DCHECK(NoPropagationLeft(*trail_)); + } + + // Propagate the presence literals of the arcs that can't be added. + PropagateOptionalArcs(trail_); + + // Clean-up modified_vars_ to do as little as possible on the next call. + modified_vars_.ClearAndResize(integer_trail_->NumIntegerVariables()); + return true; +} + +bool PrecedencesPropagator::PropagateOutgoingArcs(IntegerVariable var) { + CHECK_NE(var, kNoIntegerVariable); + if (var >= impacted_arcs_.size()) return true; + for (const ArcIndex arc_index : impacted_arcs_[var]) { + const ArcInfo& arc = arcs_[arc_index]; + const IntegerValue new_head_lb = + integer_trail_->LowerBound(arc.tail_var) + ArcOffset(arc); + if (new_head_lb > integer_trail_->LowerBound(arc.head_var)) { + if (!EnqueueAndCheck(arc, new_head_lb, trail_)) return false; + } + } + return true; +} + +// TODO(user): Remove literal fixed at level zero from there. +void PrecedencesPropagator::PushConditionalRelations(const ArcInfo& arc) { + // We currently do not handle variable size in the reasons. + // TODO(user): we could easily take a level zero ArcOffset() instead, or + // add this to the reason though. + if (arc.offset_var != kNoIntegerVariable) return; + const IntegerValue offset = ArcOffset(arc); + relations_->PushConditionalRelation( + arc.presence_literals, + LinearExpression2::Difference(arc.tail_var, arc.head_var), -offset); +} + +void PrecedencesPropagator::Untrail(const Trail& trail, int trail_index) { + if (propagation_trail_index_ > trail_index) { + // This means that we already propagated all there is to propagate + // at the level trail_index, so we can safely clear modified_vars_ in case + // it wasn't already done. + modified_vars_.ClearAndResize(integer_trail_->NumIntegerVariables()); + } + while (propagation_trail_index_ > trail_index) { + const Literal literal = trail[--propagation_trail_index_]; + if (literal.Index() >= literal_to_new_impacted_arcs_.size()) continue; + for (const ArcIndex arc_index : + literal_to_new_impacted_arcs_[literal.Index()]) { + if (arc_counts_[arc_index]++ == 0) { + const ArcInfo& arc = arcs_[arc_index]; + impacted_arcs_[arc.tail_var].pop_back(); + } + } + } +} + +void PrecedencesPropagator::AdjustSizeFor(IntegerVariable i) { + const int index = std::max(i.value(), NegationOf(i).value()); + if (index >= impacted_arcs_.size()) { + // TODO(user): only watch lower bound of the relevant variable instead + // of watching everything in [0, max_index_of_variable_used_in_this_class). + for (IntegerVariable var(impacted_arcs_.size()); var <= index; ++var) { + watcher_->WatchLowerBound(var, watcher_id_); + } + impacted_arcs_.resize(index + 1); + impacted_potential_arcs_.resize(index + 1); + } +} + +void PrecedencesPropagator::AddArc( + IntegerVariable tail, IntegerVariable head, IntegerValue offset, + IntegerVariable offset_var, absl::Span presence_literals) { + AdjustSizeFor(tail); + AdjustSizeFor(head); + if (offset_var != kNoIntegerVariable) AdjustSizeFor(offset_var); + + // This arc is present iff all the literals here are true. + absl::InlinedVector enforcement_literals; + { + for (const Literal l : presence_literals) { + enforcement_literals.push_back(l); + } + gtl::STLSortAndRemoveDuplicates(&enforcement_literals); + + if (trail_->CurrentDecisionLevel() == 0) { + int new_size = 0; + for (const Literal l : enforcement_literals) { + if (trail_->Assignment().LiteralIsTrue(Literal(l))) { + continue; // At true, ignore this literal. + } else if (trail_->Assignment().LiteralIsFalse(Literal(l))) { + return; // At false, ignore completely this arc. + } + enforcement_literals[new_size++] = l; + } + enforcement_literals.resize(new_size); + } + } + + if (head == tail) { + // A self-arc is either plain SAT or plain UNSAT or it forces something on + // the given offset_var or presence_literal_index. In any case it could be + // presolved in something more efficient. + VLOG(1) << "Self arc! This could be presolved. " + << "var:" << tail << " offset:" << offset + << " offset_var:" << offset_var + << " conditioned_by:" << presence_literals; + } + + // Remove the offset_var if it is fixed. + // TODO(user): We should also handle the case where tail or head is fixed. + if (offset_var != kNoIntegerVariable) { + const IntegerValue lb = integer_trail_->LevelZeroLowerBound(offset_var); + if (lb == integer_trail_->LevelZeroUpperBound(offset_var)) { + offset += lb; + offset_var = kNoIntegerVariable; + } + } + + // Deal first with impacted_potential_arcs_/potential_arcs_. + if (!enforcement_literals.empty()) { + const OptionalArcIndex arc_index(potential_arcs_.size()); + potential_arcs_.push_back( + {tail, head, offset, offset_var, enforcement_literals}); + impacted_potential_arcs_[tail].push_back(arc_index); + impacted_potential_arcs_[NegationOf(head)].push_back(arc_index); + if (offset_var != kNoIntegerVariable) { + impacted_potential_arcs_[offset_var].push_back(arc_index); + } + } + + // Now deal with impacted_arcs_/arcs_. + struct InternalArc { + IntegerVariable tail_var; + IntegerVariable head_var; + IntegerVariable offset_var; + }; + std::vector to_add; + if (offset_var == kNoIntegerVariable) { + // a + offset <= b and -b + offset <= -a + to_add.push_back({tail, head, kNoIntegerVariable}); + to_add.push_back({NegationOf(head), NegationOf(tail), kNoIntegerVariable}); + } else { + // tail (a) and offset_var (b) are symmetric, so we add: + // - a + b + offset <= c + to_add.push_back({tail, head, offset_var}); + to_add.push_back({offset_var, head, tail}); + // - a - c + offset <= -b + to_add.push_back({tail, NegationOf(offset_var), NegationOf(head)}); + to_add.push_back({NegationOf(head), NegationOf(offset_var), tail}); + // - b - c + offset <= -a + to_add.push_back({offset_var, NegationOf(tail), NegationOf(head)}); + to_add.push_back({NegationOf(head), NegationOf(tail), offset_var}); + } + for (const InternalArc a : to_add) { + // Since we add a new arc, we will need to consider its tail during the next + // propagation. Note that the size of modified_vars_ will be automatically + // updated when new integer variables are created since we register it with + // IntegerTrail in this class constructor. + // + // TODO(user): Adding arcs and then calling Untrail() before Propagate() + // will cause this mecanism to break. Find a more robust implementation. + // + // TODO(user): In some rare corner case, rescanning the whole list of arc + // leaving tail_var can make AddVar() have a quadratic complexity where it + // shouldn't. A better solution would be to see if this new arc currently + // propagate something, and if it does, just update the lower bound of + // a.head_var and let the normal "is modified" mecanism handle any eventual + // follow up propagations. + modified_vars_.Set(a.tail_var); + + // If a.head_var is optional, we can potentially remove some literal from + // enforcement_literals. + const ArcIndex arc_index(arcs_.size()); + arcs_.push_back( + {a.tail_var, a.head_var, offset, a.offset_var, enforcement_literals}); + auto& presence_literals = arcs_.back().presence_literals; + + if (presence_literals.empty()) { + impacted_arcs_[a.tail_var].push_back(arc_index); + } else { + for (const Literal l : presence_literals) { + if (l.Index() >= literal_to_new_impacted_arcs_.size()) { + literal_to_new_impacted_arcs_.resize(l.Index().value() + 1); + } + literal_to_new_impacted_arcs_[l.Index()].push_back(arc_index); + } + } + + if (trail_->CurrentDecisionLevel() == 0) { + arc_counts_.push_back(presence_literals.size()); + } else { + arc_counts_.push_back(0); + for (const Literal l : presence_literals) { + if (!trail_->Assignment().LiteralIsTrue(l)) { + ++arc_counts_.back(); + } + } + CHECK(presence_literals.empty() || arc_counts_.back() > 0); + } + } +} + +bool PrecedencesPropagator::AddPrecedenceWithOffsetIfNew(IntegerVariable i1, + IntegerVariable i2, + IntegerValue offset) { + DCHECK_EQ(trail_->CurrentDecisionLevel(), 0); + if (i1 < impacted_arcs_.size() && i2 < impacted_arcs_.size()) { + for (const ArcIndex index : impacted_arcs_[i1]) { + const ArcInfo& arc = arcs_[index]; + if (arc.head_var == i2) { + const IntegerValue current = ArcOffset(arc); + if (offset <= current) { + return false; + } else { + // TODO(user): Modify arc in place! + } + break; + } + } + } + + AddPrecedenceWithOffset(i1, i2, offset); + return true; +} + +// TODO(user): On jobshop problems with a lot of tasks per machine (500), this +// takes up a big chunk of the running time even before we find a solution. +// This is because, for each lower bound changed, we inspect 500 arcs even +// though they will never be propagated because the other bound is still at the +// horizon. Find an even sparser algorithm? +void PrecedencesPropagator::PropagateOptionalArcs(Trail* trail) { + for (const IntegerVariable var : modified_vars_.PositionsSetAtLeastOnce()) { + // The variables are not in increasing order, so we need to continue. + if (var >= impacted_potential_arcs_.size()) continue; + + // Note that we can currently check the same ArcInfo up to 3 times, one for + // each of the arc variables: tail, NegationOf(head) and offset_var. + for (const OptionalArcIndex arc_index : impacted_potential_arcs_[var]) { + const ArcInfo& arc = potential_arcs_[arc_index]; + int num_not_true = 0; + Literal to_propagate; + for (const Literal l : arc.presence_literals) { + if (!trail->Assignment().LiteralIsTrue(l)) { + ++num_not_true; + to_propagate = l; + } + } + if (num_not_true != 1) continue; + if (trail->Assignment().LiteralIsFalse(to_propagate)) continue; + + // Test if this arc can be present or not. + // Important arc.tail_var can be different from var here. + const IntegerValue tail_lb = integer_trail_->LowerBound(arc.tail_var); + const IntegerValue head_ub = integer_trail_->UpperBound(arc.head_var); + if (tail_lb + ArcOffset(arc) > head_ub) { + integer_reason_.clear(); + integer_reason_.push_back( + integer_trail_->LowerBoundAsLiteral(arc.tail_var)); + integer_reason_.push_back( + integer_trail_->UpperBoundAsLiteral(arc.head_var)); + AppendLowerBoundReasonIfValid(arc.offset_var, *integer_trail_, + &integer_reason_); + literal_reason_.clear(); + for (const Literal l : arc.presence_literals) { + if (l != to_propagate) literal_reason_.push_back(l.Negated()); + } + ++num_enforcement_pushes_; + integer_trail_->EnqueueLiteral(to_propagate.Negated(), literal_reason_, + integer_reason_); + } + } + } +} + +IntegerValue PrecedencesPropagator::ArcOffset(const ArcInfo& arc) const { + return arc.offset + (arc.offset_var == kNoIntegerVariable + ? IntegerValue(0) + : integer_trail_->LowerBound(arc.offset_var)); +} + +bool PrecedencesPropagator::EnqueueAndCheck(const ArcInfo& arc, + IntegerValue new_head_lb, + Trail* trail) { + ++num_pushes_; + DCHECK_GT(new_head_lb, integer_trail_->LowerBound(arc.head_var)); + + // Compute the reason for new_head_lb. + // + // TODO(user): do like for clause and keep the negation of + // arc.presence_literals? I think we could change the integer.h API to accept + // true literal like for IntegerVariable, it is really confusing currently. + literal_reason_.clear(); + for (const Literal l : arc.presence_literals) { + literal_reason_.push_back(l.Negated()); + } + + integer_reason_.clear(); + integer_reason_.push_back(integer_trail_->LowerBoundAsLiteral(arc.tail_var)); + AppendLowerBoundReasonIfValid(arc.offset_var, *integer_trail_, + &integer_reason_); + + // The code works without this block since Enqueue() below can already take + // care of conflicts. However, it is better to deal with the conflict + // ourselves because we can be smarter about the reason this way. + // + // The reason for a "precedence" conflict is always a linear reason + // involving the tail lower_bound, the head upper bound and eventually the + // size lower bound. Because of that, we can use the RelaxLinearReason() + // code. + if (new_head_lb > integer_trail_->UpperBound(arc.head_var)) { + const IntegerValue slack = + new_head_lb - integer_trail_->UpperBound(arc.head_var) - 1; + integer_reason_.push_back( + integer_trail_->UpperBoundAsLiteral(arc.head_var)); + std::vector coeffs(integer_reason_.size(), IntegerValue(1)); + integer_trail_->RelaxLinearReason(slack, coeffs, &integer_reason_); + return integer_trail_->ReportConflict(literal_reason_, integer_reason_); + } + + return integer_trail_->Enqueue( + IntegerLiteral::GreaterOrEqual(arc.head_var, new_head_lb), + literal_reason_, integer_reason_); +} + +bool PrecedencesPropagator::NoPropagationLeft(const Trail& trail) const { + const int num_nodes = impacted_arcs_.size(); + for (IntegerVariable var(0); var < num_nodes; ++var) { + for (const ArcIndex arc_index : impacted_arcs_[var]) { + const ArcInfo& arc = arcs_[arc_index]; + if (integer_trail_->LowerBound(arc.tail_var) + ArcOffset(arc) > + integer_trail_->LowerBound(arc.head_var)) { + return false; + } + } + } + return true; +} + +void PrecedencesPropagator::InitializeBFQueueWithModifiedNodes() { + // Sparse clear of the queue. TODO(user): only use the sparse version if + // queue.size() is small or use SparseBitset. + const int num_nodes = impacted_arcs_.size(); + bf_in_queue_.resize(num_nodes, false); + for (const int node : bf_queue_) bf_in_queue_[node] = false; + bf_queue_.clear(); + DCHECK(std::none_of(bf_in_queue_.begin(), bf_in_queue_.end(), + [](bool v) { return v; })); + for (const IntegerVariable var : modified_vars_.PositionsSetAtLeastOnce()) { + if (var >= num_nodes) continue; + bf_queue_.push_back(var.value()); + bf_in_queue_[var.value()] = true; + } +} + +void PrecedencesPropagator::CleanUpMarkedArcsAndParents() { + // To be sparse, we use the fact that each node with a parent must be in + // modified_vars_. + const int num_nodes = impacted_arcs_.size(); + for (const IntegerVariable var : modified_vars_.PositionsSetAtLeastOnce()) { + if (var >= num_nodes) continue; + const ArcIndex parent_arc_index = bf_parent_arc_of_[var.value()]; + if (parent_arc_index != -1) { + arcs_[parent_arc_index].is_marked = false; + bf_parent_arc_of_[var.value()] = -1; + bf_can_be_skipped_[var.value()] = false; + } + } + DCHECK(std::none_of(bf_parent_arc_of_.begin(), bf_parent_arc_of_.end(), + [](ArcIndex v) { return v != -1; })); + DCHECK(std::none_of(bf_can_be_skipped_.begin(), bf_can_be_skipped_.end(), + [](bool v) { return v; })); +} + +bool PrecedencesPropagator::DisassembleSubtree( + int source, int target, std::vector* can_be_skipped) { + // Note that we explore a tree, so we can do it in any order, and the one + // below seems to be the fastest. + tmp_vector_.clear(); + tmp_vector_.push_back(source); + while (!tmp_vector_.empty()) { + const int tail = tmp_vector_.back(); + tmp_vector_.pop_back(); + for (const ArcIndex arc_index : impacted_arcs_[IntegerVariable(tail)]) { + const ArcInfo& arc = arcs_[arc_index]; + if (arc.is_marked) { + arc.is_marked = false; // mutable. + if (arc.head_var.value() == target) return true; + DCHECK(!(*can_be_skipped)[arc.head_var.value()]); + (*can_be_skipped)[arc.head_var.value()] = true; + tmp_vector_.push_back(arc.head_var.value()); + } + } + } + return false; +} + +void PrecedencesPropagator::AnalyzePositiveCycle( + ArcIndex first_arc, Trail* trail, std::vector* must_be_all_true, + std::vector* literal_reason, + std::vector* integer_reason) { + must_be_all_true->clear(); + literal_reason->clear(); + integer_reason->clear(); + + // Follow bf_parent_arc_of_[] to find the cycle containing first_arc. + const IntegerVariable first_arc_head = arcs_[first_arc].head_var; + ArcIndex arc_index = first_arc; + std::vector arc_on_cycle; + + // Just to be safe and avoid an infinite loop we use the fact that the maximum + // cycle size on a graph with n nodes is of size n. If we have more in the + // code below, it means first_arc is not part of a cycle according to + // bf_parent_arc_of_[], which should never happen. + const int num_nodes = impacted_arcs_.size(); + while (arc_on_cycle.size() <= num_nodes) { + arc_on_cycle.push_back(arc_index); + const ArcInfo& arc = arcs_[arc_index]; + if (arc.tail_var == first_arc_head) break; + arc_index = bf_parent_arc_of_[arc.tail_var.value()]; + CHECK_NE(arc_index, ArcIndex(-1)); + } + CHECK_NE(arc_on_cycle.size(), num_nodes + 1) << "Infinite loop."; + + // Compute the reason for this cycle. + IntegerValue sum(0); + for (const ArcIndex arc_index : arc_on_cycle) { + const ArcInfo& arc = arcs_[arc_index]; + sum += ArcOffset(arc); + AppendLowerBoundReasonIfValid(arc.offset_var, *integer_trail_, + integer_reason); + for (const Literal l : arc.presence_literals) { + literal_reason->push_back(l.Negated()); + } + } + + // TODO(user): what if the sum overflow? this is just a check so I guess + // we don't really care, but fix the issue. + CHECK_GT(sum, 0); +} + +// Note that in our settings it is important to use an algorithm that tries to +// minimize the number of integer_trail_->Enqueue() as much as possible. +// +// TODO(user): The current algorithm is quite efficient, but there is probably +// still room for improvements. +bool PrecedencesPropagator::BellmanFordTarjan(Trail* trail) { + const int num_nodes = impacted_arcs_.size(); + + // These vector are reset by CleanUpMarkedArcsAndParents() so resize is ok. + bf_can_be_skipped_.resize(num_nodes, false); + bf_parent_arc_of_.resize(num_nodes, ArcIndex(-1)); + const auto cleanup = + ::absl::MakeCleanup([this]() { CleanUpMarkedArcsAndParents(); }); + + // The queue initialization is done by InitializeBFQueueWithModifiedNodes(). + while (!bf_queue_.empty()) { + const int node = bf_queue_.front(); + bf_queue_.pop_front(); + bf_in_queue_[node] = false; + + // TODO(user): we don't need bf_can_be_skipped_ since we can detect this + // if this node has a parent arc which is not marked. Investigate if it is + // faster without the vector. + // + // TODO(user): An alternative algorithm is to remove all these nodes from + // the queue instead of simply marking them. This should also lead to a + // better "relaxation" order of the arcs. It is however a bit more work to + // remove them since we need to track their position. + if (bf_can_be_skipped_[node]) { + DCHECK_NE(bf_parent_arc_of_[node], -1); + DCHECK(!arcs_[bf_parent_arc_of_[node]].is_marked); + continue; + } + + const IntegerValue tail_lb = + integer_trail_->LowerBound(IntegerVariable(node)); + for (const ArcIndex arc_index : impacted_arcs_[IntegerVariable(node)]) { + const ArcInfo& arc = arcs_[arc_index]; + DCHECK_EQ(arc.tail_var, node); + const IntegerValue candidate = tail_lb + ArcOffset(arc); + if (candidate > integer_trail_->LowerBound(arc.head_var)) { + if (!EnqueueAndCheck(arc, candidate, trail)) return false; + + // This is the Tarjan contribution to Bellman-Ford. This code detect + // positive cycle, and because it disassemble the subtree while doing + // so, the cost is amortized during the algorithm execution. Another + // advantages is that it will mark the node explored here as skippable + // which will avoid to propagate them too early (knowing that they will + // need to be propagated again later). + if (DisassembleSubtree(arc.head_var.value(), arc.tail_var.value(), + &bf_can_be_skipped_)) { + std::vector must_be_all_true; + AnalyzePositiveCycle(arc_index, trail, &must_be_all_true, + &literal_reason_, &integer_reason_); + if (must_be_all_true.empty()) { + ++num_cycles_; + return integer_trail_->ReportConflict(literal_reason_, + integer_reason_); + } else { + gtl::STLSortAndRemoveDuplicates(&must_be_all_true); + for (const Literal l : must_be_all_true) { + if (trail_->Assignment().LiteralIsFalse(l)) { + literal_reason_.push_back(l); + return integer_trail_->ReportConflict(literal_reason_, + integer_reason_); + } + } + for (const Literal l : must_be_all_true) { + if (trail_->Assignment().LiteralIsTrue(l)) continue; + integer_trail_->EnqueueLiteral(l, literal_reason_, + integer_reason_); + } + + // We just marked some optional variable as ignored, no need + // to update bf_parent_arc_of_[]. + continue; + } + } + + // We need to enforce the invariant that only the arc_index in + // bf_parent_arc_of_[] are marked (but not necessarily all of them + // since we unmark some in DisassembleSubtree()). + if (bf_parent_arc_of_[arc.head_var.value()] != -1) { + arcs_[bf_parent_arc_of_[arc.head_var.value()]].is_marked = false; + } + + // Tricky: We just enqueued the fact that the lower-bound of head is + // candidate. However, because the domain of head may be discrete, it is + // possible that the lower-bound of head is now higher than candidate! + // If this is the case, we don't update bf_parent_arc_of_[] so that we + // don't wrongly detect a positive weight cycle because of this "extra + // push". + const IntegerValue new_bound = integer_trail_->LowerBound(arc.head_var); + if (new_bound == candidate) { + bf_parent_arc_of_[arc.head_var.value()] = arc_index; + arcs_[arc_index].is_marked = true; + } else { + // We still unmark any previous dependency, since we have pushed the + // value of arc.head_var further. + bf_parent_arc_of_[arc.head_var.value()] = -1; + } + + // We do not re-enqueue if we are in a propagation loop and new_bound + // was not pushed to candidate or higher. + bf_can_be_skipped_[arc.head_var.value()] = false; + if (!bf_in_queue_[arc.head_var.value()] && new_bound >= candidate) { + bf_queue_.push_back(arc.head_var.value()); + bf_in_queue_[arc.head_var.value()] = true; + } + } + } + } + return true; +} + +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/old_precedences_propagator.h b/ortools/sat/old_precedences_propagator.h new file mode 100644 index 0000000000..0ae3524c75 --- /dev/null +++ b/ortools/sat/old_precedences_propagator.h @@ -0,0 +1,356 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_OLD_PRECEDENCES_PROPAGATOR_H_ +#define OR_TOOLS_SAT_OLD_PRECEDENCES_PROPAGATOR_H_ + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "ortools/base/strong_vector.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_base.h" +#include "ortools/sat/model.h" +#include "ortools/sat/precedences.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/sat/synchronization.h" +#include "ortools/util/bitset.h" +#include "ortools/util/strong_integers.h" + +namespace operations_research { +namespace sat { + +// ============================================================================= +// Old precedences propagator. +// +// This is superseded by the new LinearPropagator and should only be used if the +// option 'new_linear_propagation' is false. We still keep it around to +// benchmark and test the new code vs this one. +// ============================================================================= + +// This class implement a propagator on simple inequalities between integer +// variables of the form (i1 + offset <= i2). The offset can be constant or +// given by the value of a third integer variable. Offsets can also be negative. +// +// The algorithm works by mapping the problem onto a graph where the edges carry +// the offset and the nodes correspond to one of the two bounds of an integer +// variable (lower_bound or -upper_bound). It then find the fixed point using an +// incremental variant of the Bellman-Ford(-Tarjan) algorithm. +// +// This is also known as an "integer difference logic theory" in the SMT world. +// Another word is "separation logic". +// +// TODO(user): We could easily generalize the code to support any relation of +// the form a*X + b*Y + c*Z >= rhs (or <=). Do that since this class should be +// a lot faster at propagating small linear inequality than the generic +// propagator and the overhead of supporting coefficient should not be too bad. +class PrecedencesPropagator : public SatPropagator, PropagatorInterface { + public: + explicit PrecedencesPropagator(Model* model) + : SatPropagator("PrecedencesPropagator"), + relations_(model->GetOrCreate()), + trail_(model->GetOrCreate()), + integer_trail_(model->GetOrCreate()), + shared_stats_(model->Mutable()), + watcher_(model->GetOrCreate()), + watcher_id_(watcher_->Register(this)) { + model->GetOrCreate()->AddPropagator(this); + integer_trail_->RegisterWatcher(&modified_vars_); + watcher_->SetPropagatorPriority(watcher_id_, 0); + } + + // This type is neither copyable nor movable. + PrecedencesPropagator(const PrecedencesPropagator&) = delete; + PrecedencesPropagator& operator=(const PrecedencesPropagator&) = delete; + ~PrecedencesPropagator() override; + + bool Propagate() final; + bool Propagate(Trail* trail) final; + void Untrail(const Trail& trail, int trail_index) final; + + // Propagates all the outgoing arcs of the given variable (and only those). It + // is more efficient to do all these propagation in one go by calling + // Propagate(), but for scheduling problem, we wants to propagate right away + // the end of an interval when its start moved. + bool PropagateOutgoingArcs(IntegerVariable var); + + // Add a precedence relation (i1 + offset <= i2) between integer variables. + // + // Important: The optionality of the variable should be marked BEFORE this + // is called. + void AddPrecedence(IntegerVariable i1, IntegerVariable i2); + void AddPrecedenceWithOffset(IntegerVariable i1, IntegerVariable i2, + IntegerValue offset); + void AddPrecedenceWithVariableOffset(IntegerVariable i1, IntegerVariable i2, + IntegerVariable offset_var); + + // Same as above, but the relation is only true when the given literal is. + void AddConditionalPrecedence(IntegerVariable i1, IntegerVariable i2, + Literal l); + void AddConditionalPrecedenceWithOffset(IntegerVariable i1, + IntegerVariable i2, + IntegerValue offset, Literal l); + + // Generic function that cover all of the above case and more. + void AddPrecedenceWithAllOptions(IntegerVariable i1, IntegerVariable i2, + IntegerValue offset, + IntegerVariable offset_var, + absl::Span presence_literals); + + // This version check current precedence. It is however "slow". + bool AddPrecedenceWithOffsetIfNew(IntegerVariable i1, IntegerVariable i2, + IntegerValue offset); + + private: + DEFINE_STRONG_INDEX_TYPE(ArcIndex); + DEFINE_STRONG_INDEX_TYPE(OptionalArcIndex); + + // Information about an individual arc. + struct ArcInfo { + IntegerVariable tail_var; + IntegerVariable head_var; + + IntegerValue offset; + IntegerVariable offset_var; // kNoIntegerVariable if none. + + // This arc is "present" iff all these literals are true. + absl::InlinedVector presence_literals; + + // Used temporarily by our implementation of the Bellman-Ford algorithm. It + // should be false at the beginning of BellmanFordTarjan(). + mutable bool is_marked; + }; + + // Internal functions to add new precedence relations. + // + // Note that internally, we only propagate lower bounds, so each time we add + // an arc, we actually create two of them: one on the given variables, and one + // on their negation. + void AdjustSizeFor(IntegerVariable i); + void AddArc(IntegerVariable tail, IntegerVariable head, IntegerValue offset, + IntegerVariable offset_var, + absl::Span presence_literals); + + // Enqueue a new lower bound for the variable arc.head_lb that was deduced + // from the current value of arc.tail_lb and the offset of this arc. + bool EnqueueAndCheck(const ArcInfo& arc, IntegerValue new_head_lb, + Trail* trail); + IntegerValue ArcOffset(const ArcInfo& arc) const; + + // Inspect all the optional arcs that needs inspection (to stay sparse) and + // check if their presence literal can be propagated to false. + void PropagateOptionalArcs(Trail* trail); + + // The core algorithm implementation is split in these functions. One must + // first call InitializeBFQueueWithModifiedNodes() that will push all the + // IntegerVariable whose lower bound has been modified since the last call. + // Then, BellmanFordTarjan() will take care of all the propagation and returns + // false in case of conflict. Internally, it uses DisassembleSubtree() which + // is the Tarjan variant to detect a possible positive cycle. Before exiting, + // it will call CleanUpMarkedArcsAndParents(). + // + // The Tarjan version of the Bellam-Ford algorithm is really nice in our + // context because it was really easy to make it incremental. Moreover, it + // supports batch increment! + // + // This implementation is kind of unique because of our context and the fact + // that it is incremental, but a good reference is "Negative-cycle detection + // algorithms", Boris V. Cherkassky, Andrew V. Goldberg, 1996, + // http://people.cs.nctu.edu.tw/~tjshen/doc/ne.pdf + void InitializeBFQueueWithModifiedNodes(); + bool BellmanFordTarjan(Trail* trail); + bool DisassembleSubtree(int source, int target, + std::vector* can_be_skipped); + void AnalyzePositiveCycle(ArcIndex first_arc, Trail* trail, + std::vector* must_be_all_true, + std::vector* literal_reason, + std::vector* integer_reason); + void CleanUpMarkedArcsAndParents(); + + // Loops over all the arcs and verify that there is no propagation left. + // This is only meant to be used in a DCHECK() and is not optimized. + bool NoPropagationLeft(const Trail& trail) const; + + // Update relations_. + void PushConditionalRelations(const ArcInfo& arc); + + // External class needed to get the IntegerVariable lower bounds and Enqueue + // new ones. + EnforcedLinear2Bounds* relations_; + Trail* trail_; + IntegerTrail* integer_trail_; + SharedStatistics* shared_stats_ = nullptr; + GenericLiteralWatcher* watcher_; + int watcher_id_; + + // The key to our incrementality. This will be cleared once the propagation + // is done, and automatically updated by the integer_trail_ with all the + // IntegerVariable that changed since the last clear. + SparseBitset modified_vars_; + + // An arc needs to be inspected for propagation (i.e. is impacted) if its + // tail_var changed. If an arc has 3 variables (tail, offset, head), it will + // appear as 6 different entries in the arcs_ vector, one for each variable + // and its negation, each time with a different tail. + // + // TODO(user): rearranging the index so that the arc of the same node are + // consecutive like in StaticGraph should have a big performance impact. + // + // TODO(user): We do not need to store ArcInfo.tail_var here. + util_intops::StrongVector> + impacted_arcs_; + util_intops::StrongVector arcs_; + + // This is similar to impacted_arcs_/arcs_ but it is only used to propagate + // one of the presence literals when the arc cannot be present. An arc needs + // to appear only once in potential_arcs_, but it will be referenced by + // all its variable in impacted_potential_arcs_. + util_intops::StrongVector> + impacted_potential_arcs_; + util_intops::StrongVector potential_arcs_; + + // Each time a literal becomes true, this list the set of arcs for which we + // need to decrement their count. When an arc count reach zero, it must be + // added to the set of impacted_arcs_. Note that counts never becomes + // negative. + // + // TODO(user): Try a one-watcher approach instead. Note that in most cases + // arc should be controlled by 1 or 2 literals, so not sure it is worth it. + util_intops::StrongVector> + literal_to_new_impacted_arcs_; + util_intops::StrongVector arc_counts_; + + // Temp vectors to hold the reason of an assignment. + std::vector literal_reason_; + std::vector integer_reason_; + + // Temp vectors for the Bellman-Ford algorithm. The graph in which this + // algorithm works is in one to one correspondence with the IntegerVariable in + // impacted_arcs_. + std::deque bf_queue_; + std::vector bf_in_queue_; + std::vector bf_can_be_skipped_; + std::vector bf_parent_arc_of_; + + // Temp vector used by the tree traversal in DisassembleSubtree(). + std::vector tmp_vector_; + + // Stats. + int64_t num_cycles_ = 0; + int64_t num_pushes_ = 0; + int64_t num_enforcement_pushes_ = 0; +}; + +// ============================================================================= +// Implementation of the small API functions below. +// ============================================================================= + +inline void PrecedencesPropagator::AddPrecedence(IntegerVariable i1, + IntegerVariable i2) { + AddArc(i1, i2, /*offset=*/IntegerValue(0), /*offset_var=*/kNoIntegerVariable, + {}); +} + +inline void PrecedencesPropagator::AddPrecedenceWithOffset( + IntegerVariable i1, IntegerVariable i2, IntegerValue offset) { + AddArc(i1, i2, offset, /*offset_var=*/kNoIntegerVariable, {}); +} + +inline void PrecedencesPropagator::AddConditionalPrecedence(IntegerVariable i1, + IntegerVariable i2, + Literal l) { + AddArc(i1, i2, /*offset=*/IntegerValue(0), /*offset_var=*/kNoIntegerVariable, + {l}); +} + +inline void PrecedencesPropagator::AddConditionalPrecedenceWithOffset( + IntegerVariable i1, IntegerVariable i2, IntegerValue offset, Literal l) { + AddArc(i1, i2, offset, /*offset_var=*/kNoIntegerVariable, {l}); +} + +inline void PrecedencesPropagator::AddPrecedenceWithVariableOffset( + IntegerVariable i1, IntegerVariable i2, IntegerVariable offset_var) { + AddArc(i1, i2, /*offset=*/IntegerValue(0), offset_var, {}); +} + +inline void PrecedencesPropagator::AddPrecedenceWithAllOptions( + IntegerVariable i1, IntegerVariable i2, IntegerValue offset, + IntegerVariable offset_var, absl::Span presence_literals) { + AddArc(i1, i2, offset, offset_var, presence_literals); +} + +// ============================================================================= +// Model based functions. +// ============================================================================= + +// l => (a + b <= ub). +inline void AddConditionalSum2LowerOrEqual( + absl::Span enforcement_literals, IntegerVariable a, + IntegerVariable b, int64_t ub, Model* model) { + // TODO(user): Refactor to be sure we do not miss any level zero relations. + if (enforcement_literals.empty()) { + LinearExpression2 expr(a, b, 1, 1); + model->GetOrCreate()->AddUpperBound( + expr, IntegerValue(ub)); + } + + PrecedencesPropagator* p = model->GetOrCreate(); + p->AddPrecedenceWithAllOptions(a, NegationOf(b), IntegerValue(-ub), + kNoIntegerVariable, enforcement_literals); +} + +// l => (a + b + c <= ub). +// +// TODO(user): Use level zero bounds to infer binary precedence relations? +inline void AddConditionalSum3LowerOrEqual( + absl::Span enforcement_literals, IntegerVariable a, + IntegerVariable b, IntegerVariable c, int64_t ub, Model* model) { + PrecedencesPropagator* p = model->GetOrCreate(); + p->AddPrecedenceWithAllOptions(a, NegationOf(c), IntegerValue(-ub), b, + enforcement_literals); +} + +// a == b. +// +// ABSL_DEPRECATED("Use linear constraint API instead") +inline std::function Equality(IntegerVariable a, + IntegerVariable b) { + return [=](Model* model) { + auto* precedences = model->GetOrCreate(); + precedences->AddPrecedence(a, b); + precedences->AddPrecedence(b, a); + }; +} + +// is_le => (a + offset <= b). +// +// ABSL_DEPRECATED("Use linear constraint API instead") +inline std::function ConditionalLowerOrEqualWithOffset( + IntegerVariable a, IntegerVariable b, int64_t offset, Literal is_le) { + return [=](Model* model) { + PrecedencesPropagator* p = model->GetOrCreate(); + p->AddConditionalPrecedenceWithOffset(a, b, IntegerValue(offset), is_le); + }; +} + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_OLD_PRECEDENCES_PROPAGATOR_H_ diff --git a/ortools/sat/precedences.cc b/ortools/sat/precedences.cc index 339c6ece5e..b1108dce7d 100644 --- a/ortools/sat/precedences.cc +++ b/ortools/sat/precedences.cc @@ -16,14 +16,13 @@ #include #include -#include #include #include #include #include +#include #include -#include "absl/cleanup/cleanup.h" #include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -34,9 +33,7 @@ #include "absl/types/span.h" #include "ortools/base/logging.h" #include "ortools/base/mathutil.h" -#include "ortools/base/stl_util.h" #include "ortools/base/strong_vector.h" -#include "ortools/graph/graph.h" #include "ortools/graph/topologicalsorter.h" #include "ortools/sat/clause.h" #include "ortools/sat/cp_constraints.h" @@ -47,7 +44,6 @@ #include "ortools/sat/sat_solver.h" #include "ortools/sat/synchronization.h" #include "ortools/sat/util.h" -#include "ortools/util/bitset.h" #include "ortools/util/logging.h" #include "ortools/util/strong_integers.h" #include "ortools/util/time_limit.h" @@ -210,13 +206,28 @@ RootLevelLinear2Bounds::~RootLevelLinear2Bounds() { RelationStatus RootLevelLinear2Bounds::GetLevelZeroStatus( LinearExpression2 expr, IntegerValue lb, IntegerValue ub) const { + IntegerValue known_ub = integer_trail_->LevelZeroUpperBound(expr); + IntegerValue known_lb = integer_trail_->LevelZeroLowerBound(expr); + + if (lb <= known_lb && ub >= known_ub) return RelationStatus::IS_TRUE; + if (lb > known_ub || ub < known_lb) return RelationStatus::IS_FALSE; + expr.SimpleCanonicalization(); - if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) { + if (expr.coeffs[0] == 0) { return RelationStatus::IS_UNKNOWN; } - const IntegerValue known_ub = LevelZeroUpperBound(expr); - expr.Negate(); - const IntegerValue known_lb = -LevelZeroUpperBound(expr); + DCHECK_NE(expr.coeffs[1], 0); + const IntegerValue gcd = expr.DivideByGcd(); + ub = FloorRatio(ub, gcd); + const LinearExpression2Index index = lin2_indices_->GetIndex(expr); + + if (index == kNoLinearExpression2Index) { + return RelationStatus::IS_UNKNOWN; + } + + known_ub = std::min(known_ub, GetUpperBoundNoTrail(index)); + known_lb = std::max(known_lb, -GetUpperBoundNoTrail(NegationOf(index))); + if (lb <= known_lb && ub >= known_ub) return RelationStatus::IS_TRUE; if (lb > known_ub || ub < known_lb) return RelationStatus::IS_FALSE; @@ -718,629 +729,6 @@ void EnforcedLinear2Bounds::CollectPrecedences( } } -namespace { - -void AppendLowerBoundReasonIfValid(IntegerVariable var, - const IntegerTrail& i_trail, - std::vector* reason) { - if (var != kNoIntegerVariable) { - reason->push_back(i_trail.LowerBoundAsLiteral(var)); - } -} - -} // namespace - -PrecedencesPropagator::~PrecedencesPropagator() { - if (!VLOG_IS_ON(1)) return; - if (shared_stats_ == nullptr) return; - std::vector> stats; - stats.push_back({"precedences/num_cycles", num_cycles_}); - stats.push_back({"precedences/num_pushes", num_pushes_}); - stats.push_back( - {"precedences/num_enforcement_pushes", num_enforcement_pushes_}); - shared_stats_->AddStats(stats); -} - -bool PrecedencesPropagator::Propagate(Trail* trail) { return Propagate(); } - -bool PrecedencesPropagator::Propagate() { - while (propagation_trail_index_ < trail_->Index()) { - const Literal literal = (*trail_)[propagation_trail_index_++]; - if (literal.Index() >= literal_to_new_impacted_arcs_.size()) continue; - - // IMPORTANT: Because of the way Untrail() work, we need to add all the - // potential arcs before we can abort. It is why we iterate twice here. - for (const ArcIndex arc_index : - literal_to_new_impacted_arcs_[literal.Index()]) { - if (--arc_counts_[arc_index] == 0) { - const ArcInfo& arc = arcs_[arc_index]; - PushConditionalRelations(arc); - impacted_arcs_[arc.tail_var].push_back(arc_index); - } - } - - // Iterate again to check for a propagation and indirectly update - // modified_vars_. - for (const ArcIndex arc_index : - literal_to_new_impacted_arcs_[literal.Index()]) { - if (arc_counts_[arc_index] > 0) continue; - const ArcInfo& arc = arcs_[arc_index]; - const IntegerValue new_head_lb = - integer_trail_->LowerBound(arc.tail_var) + ArcOffset(arc); - if (new_head_lb > integer_trail_->LowerBound(arc.head_var)) { - if (!EnqueueAndCheck(arc, new_head_lb, trail_)) return false; - } - } - } - - // Do the actual propagation of the IntegerVariable bounds. - InitializeBFQueueWithModifiedNodes(); - if (!BellmanFordTarjan(trail_)) return false; - - // We can only test that no propagation is left if we didn't enqueue new - // literal in the presence of optional variables. - // - // TODO(user): Because of our code to deal with InPropagationLoop(), this is - // not always true. Find a cleaner way to DCHECK() while not failing in this - // corner case. - if (/*DISABLES CODE*/ (false) && - propagation_trail_index_ == trail_->Index()) { - DCHECK(NoPropagationLeft(*trail_)); - } - - // Propagate the presence literals of the arcs that can't be added. - PropagateOptionalArcs(trail_); - - // Clean-up modified_vars_ to do as little as possible on the next call. - modified_vars_.ClearAndResize(integer_trail_->NumIntegerVariables()); - return true; -} - -bool PrecedencesPropagator::PropagateOutgoingArcs(IntegerVariable var) { - CHECK_NE(var, kNoIntegerVariable); - if (var >= impacted_arcs_.size()) return true; - for (const ArcIndex arc_index : impacted_arcs_[var]) { - const ArcInfo& arc = arcs_[arc_index]; - const IntegerValue new_head_lb = - integer_trail_->LowerBound(arc.tail_var) + ArcOffset(arc); - if (new_head_lb > integer_trail_->LowerBound(arc.head_var)) { - if (!EnqueueAndCheck(arc, new_head_lb, trail_)) return false; - } - } - return true; -} - -// TODO(user): Remove literal fixed at level zero from there. -void PrecedencesPropagator::PushConditionalRelations(const ArcInfo& arc) { - // We currently do not handle variable size in the reasons. - // TODO(user): we could easily take a level zero ArcOffset() instead, or - // add this to the reason though. - if (arc.offset_var != kNoIntegerVariable) return; - const IntegerValue offset = ArcOffset(arc); - relations_->PushConditionalRelation( - arc.presence_literals, - LinearExpression2::Difference(arc.tail_var, arc.head_var), -offset); -} - -void PrecedencesPropagator::Untrail(const Trail& trail, int trail_index) { - if (propagation_trail_index_ > trail_index) { - // This means that we already propagated all there is to propagate - // at the level trail_index, so we can safely clear modified_vars_ in case - // it wasn't already done. - modified_vars_.ClearAndResize(integer_trail_->NumIntegerVariables()); - } - while (propagation_trail_index_ > trail_index) { - const Literal literal = trail[--propagation_trail_index_]; - if (literal.Index() >= literal_to_new_impacted_arcs_.size()) continue; - for (const ArcIndex arc_index : - literal_to_new_impacted_arcs_[literal.Index()]) { - if (arc_counts_[arc_index]++ == 0) { - const ArcInfo& arc = arcs_[arc_index]; - impacted_arcs_[arc.tail_var].pop_back(); - } - } - } -} - -void PrecedencesPropagator::AdjustSizeFor(IntegerVariable i) { - const int index = std::max(i.value(), NegationOf(i).value()); - if (index >= impacted_arcs_.size()) { - // TODO(user): only watch lower bound of the relevant variable instead - // of watching everything in [0, max_index_of_variable_used_in_this_class). - for (IntegerVariable var(impacted_arcs_.size()); var <= index; ++var) { - watcher_->WatchLowerBound(var, watcher_id_); - } - impacted_arcs_.resize(index + 1); - impacted_potential_arcs_.resize(index + 1); - } -} - -void PrecedencesPropagator::AddArc( - IntegerVariable tail, IntegerVariable head, IntegerValue offset, - IntegerVariable offset_var, absl::Span presence_literals) { - AdjustSizeFor(tail); - AdjustSizeFor(head); - if (offset_var != kNoIntegerVariable) AdjustSizeFor(offset_var); - - // This arc is present iff all the literals here are true. - absl::InlinedVector enforcement_literals; - { - for (const Literal l : presence_literals) { - enforcement_literals.push_back(l); - } - gtl::STLSortAndRemoveDuplicates(&enforcement_literals); - - if (trail_->CurrentDecisionLevel() == 0) { - int new_size = 0; - for (const Literal l : enforcement_literals) { - if (trail_->Assignment().LiteralIsTrue(Literal(l))) { - continue; // At true, ignore this literal. - } else if (trail_->Assignment().LiteralIsFalse(Literal(l))) { - return; // At false, ignore completely this arc. - } - enforcement_literals[new_size++] = l; - } - enforcement_literals.resize(new_size); - } - } - - if (head == tail) { - // A self-arc is either plain SAT or plain UNSAT or it forces something on - // the given offset_var or presence_literal_index. In any case it could be - // presolved in something more efficient. - VLOG(1) << "Self arc! This could be presolved. " - << "var:" << tail << " offset:" << offset - << " offset_var:" << offset_var - << " conditioned_by:" << presence_literals; - } - - // Remove the offset_var if it is fixed. - // TODO(user): We should also handle the case where tail or head is fixed. - if (offset_var != kNoIntegerVariable) { - const IntegerValue lb = integer_trail_->LevelZeroLowerBound(offset_var); - if (lb == integer_trail_->LevelZeroUpperBound(offset_var)) { - offset += lb; - offset_var = kNoIntegerVariable; - } - } - - // Deal first with impacted_potential_arcs_/potential_arcs_. - if (!enforcement_literals.empty()) { - const OptionalArcIndex arc_index(potential_arcs_.size()); - potential_arcs_.push_back( - {tail, head, offset, offset_var, enforcement_literals}); - impacted_potential_arcs_[tail].push_back(arc_index); - impacted_potential_arcs_[NegationOf(head)].push_back(arc_index); - if (offset_var != kNoIntegerVariable) { - impacted_potential_arcs_[offset_var].push_back(arc_index); - } - } - - // Now deal with impacted_arcs_/arcs_. - struct InternalArc { - IntegerVariable tail_var; - IntegerVariable head_var; - IntegerVariable offset_var; - }; - std::vector to_add; - if (offset_var == kNoIntegerVariable) { - // a + offset <= b and -b + offset <= -a - to_add.push_back({tail, head, kNoIntegerVariable}); - to_add.push_back({NegationOf(head), NegationOf(tail), kNoIntegerVariable}); - } else { - // tail (a) and offset_var (b) are symmetric, so we add: - // - a + b + offset <= c - to_add.push_back({tail, head, offset_var}); - to_add.push_back({offset_var, head, tail}); - // - a - c + offset <= -b - to_add.push_back({tail, NegationOf(offset_var), NegationOf(head)}); - to_add.push_back({NegationOf(head), NegationOf(offset_var), tail}); - // - b - c + offset <= -a - to_add.push_back({offset_var, NegationOf(tail), NegationOf(head)}); - to_add.push_back({NegationOf(head), NegationOf(tail), offset_var}); - } - for (const InternalArc a : to_add) { - // Since we add a new arc, we will need to consider its tail during the next - // propagation. Note that the size of modified_vars_ will be automatically - // updated when new integer variables are created since we register it with - // IntegerTrail in this class constructor. - // - // TODO(user): Adding arcs and then calling Untrail() before Propagate() - // will cause this mecanism to break. Find a more robust implementation. - // - // TODO(user): In some rare corner case, rescanning the whole list of arc - // leaving tail_var can make AddVar() have a quadratic complexity where it - // shouldn't. A better solution would be to see if this new arc currently - // propagate something, and if it does, just update the lower bound of - // a.head_var and let the normal "is modified" mecanism handle any eventual - // follow up propagations. - modified_vars_.Set(a.tail_var); - - // If a.head_var is optional, we can potentially remove some literal from - // enforcement_literals. - const ArcIndex arc_index(arcs_.size()); - arcs_.push_back( - {a.tail_var, a.head_var, offset, a.offset_var, enforcement_literals}); - auto& presence_literals = arcs_.back().presence_literals; - - if (presence_literals.empty()) { - impacted_arcs_[a.tail_var].push_back(arc_index); - } else { - for (const Literal l : presence_literals) { - if (l.Index() >= literal_to_new_impacted_arcs_.size()) { - literal_to_new_impacted_arcs_.resize(l.Index().value() + 1); - } - literal_to_new_impacted_arcs_[l.Index()].push_back(arc_index); - } - } - - if (trail_->CurrentDecisionLevel() == 0) { - arc_counts_.push_back(presence_literals.size()); - } else { - arc_counts_.push_back(0); - for (const Literal l : presence_literals) { - if (!trail_->Assignment().LiteralIsTrue(l)) { - ++arc_counts_.back(); - } - } - CHECK(presence_literals.empty() || arc_counts_.back() > 0); - } - } -} - -bool PrecedencesPropagator::AddPrecedenceWithOffsetIfNew(IntegerVariable i1, - IntegerVariable i2, - IntegerValue offset) { - DCHECK_EQ(trail_->CurrentDecisionLevel(), 0); - if (i1 < impacted_arcs_.size() && i2 < impacted_arcs_.size()) { - for (const ArcIndex index : impacted_arcs_[i1]) { - const ArcInfo& arc = arcs_[index]; - if (arc.head_var == i2) { - const IntegerValue current = ArcOffset(arc); - if (offset <= current) { - return false; - } else { - // TODO(user): Modify arc in place! - } - break; - } - } - } - - AddPrecedenceWithOffset(i1, i2, offset); - return true; -} - -// TODO(user): On jobshop problems with a lot of tasks per machine (500), this -// takes up a big chunk of the running time even before we find a solution. -// This is because, for each lower bound changed, we inspect 500 arcs even -// though they will never be propagated because the other bound is still at the -// horizon. Find an even sparser algorithm? -void PrecedencesPropagator::PropagateOptionalArcs(Trail* trail) { - for (const IntegerVariable var : modified_vars_.PositionsSetAtLeastOnce()) { - // The variables are not in increasing order, so we need to continue. - if (var >= impacted_potential_arcs_.size()) continue; - - // Note that we can currently check the same ArcInfo up to 3 times, one for - // each of the arc variables: tail, NegationOf(head) and offset_var. - for (const OptionalArcIndex arc_index : impacted_potential_arcs_[var]) { - const ArcInfo& arc = potential_arcs_[arc_index]; - int num_not_true = 0; - Literal to_propagate; - for (const Literal l : arc.presence_literals) { - if (!trail->Assignment().LiteralIsTrue(l)) { - ++num_not_true; - to_propagate = l; - } - } - if (num_not_true != 1) continue; - if (trail->Assignment().LiteralIsFalse(to_propagate)) continue; - - // Test if this arc can be present or not. - // Important arc.tail_var can be different from var here. - const IntegerValue tail_lb = integer_trail_->LowerBound(arc.tail_var); - const IntegerValue head_ub = integer_trail_->UpperBound(arc.head_var); - if (tail_lb + ArcOffset(arc) > head_ub) { - integer_reason_.clear(); - integer_reason_.push_back( - integer_trail_->LowerBoundAsLiteral(arc.tail_var)); - integer_reason_.push_back( - integer_trail_->UpperBoundAsLiteral(arc.head_var)); - AppendLowerBoundReasonIfValid(arc.offset_var, *integer_trail_, - &integer_reason_); - literal_reason_.clear(); - for (const Literal l : arc.presence_literals) { - if (l != to_propagate) literal_reason_.push_back(l.Negated()); - } - ++num_enforcement_pushes_; - integer_trail_->EnqueueLiteral(to_propagate.Negated(), literal_reason_, - integer_reason_); - } - } - } -} - -IntegerValue PrecedencesPropagator::ArcOffset(const ArcInfo& arc) const { - return arc.offset + (arc.offset_var == kNoIntegerVariable - ? IntegerValue(0) - : integer_trail_->LowerBound(arc.offset_var)); -} - -bool PrecedencesPropagator::EnqueueAndCheck(const ArcInfo& arc, - IntegerValue new_head_lb, - Trail* trail) { - ++num_pushes_; - DCHECK_GT(new_head_lb, integer_trail_->LowerBound(arc.head_var)); - - // Compute the reason for new_head_lb. - // - // TODO(user): do like for clause and keep the negation of - // arc.presence_literals? I think we could change the integer.h API to accept - // true literal like for IntegerVariable, it is really confusing currently. - literal_reason_.clear(); - for (const Literal l : arc.presence_literals) { - literal_reason_.push_back(l.Negated()); - } - - integer_reason_.clear(); - integer_reason_.push_back(integer_trail_->LowerBoundAsLiteral(arc.tail_var)); - AppendLowerBoundReasonIfValid(arc.offset_var, *integer_trail_, - &integer_reason_); - - // The code works without this block since Enqueue() below can already take - // care of conflicts. However, it is better to deal with the conflict - // ourselves because we can be smarter about the reason this way. - // - // The reason for a "precedence" conflict is always a linear reason - // involving the tail lower_bound, the head upper bound and eventually the - // size lower bound. Because of that, we can use the RelaxLinearReason() - // code. - if (new_head_lb > integer_trail_->UpperBound(arc.head_var)) { - const IntegerValue slack = - new_head_lb - integer_trail_->UpperBound(arc.head_var) - 1; - integer_reason_.push_back( - integer_trail_->UpperBoundAsLiteral(arc.head_var)); - std::vector coeffs(integer_reason_.size(), IntegerValue(1)); - integer_trail_->RelaxLinearReason(slack, coeffs, &integer_reason_); - return integer_trail_->ReportConflict(literal_reason_, integer_reason_); - } - - return integer_trail_->Enqueue( - IntegerLiteral::GreaterOrEqual(arc.head_var, new_head_lb), - literal_reason_, integer_reason_); -} - -bool PrecedencesPropagator::NoPropagationLeft(const Trail& trail) const { - const int num_nodes = impacted_arcs_.size(); - for (IntegerVariable var(0); var < num_nodes; ++var) { - for (const ArcIndex arc_index : impacted_arcs_[var]) { - const ArcInfo& arc = arcs_[arc_index]; - if (integer_trail_->LowerBound(arc.tail_var) + ArcOffset(arc) > - integer_trail_->LowerBound(arc.head_var)) { - return false; - } - } - } - return true; -} - -void PrecedencesPropagator::InitializeBFQueueWithModifiedNodes() { - // Sparse clear of the queue. TODO(user): only use the sparse version if - // queue.size() is small or use SparseBitset. - const int num_nodes = impacted_arcs_.size(); - bf_in_queue_.resize(num_nodes, false); - for (const int node : bf_queue_) bf_in_queue_[node] = false; - bf_queue_.clear(); - DCHECK(std::none_of(bf_in_queue_.begin(), bf_in_queue_.end(), - [](bool v) { return v; })); - for (const IntegerVariable var : modified_vars_.PositionsSetAtLeastOnce()) { - if (var >= num_nodes) continue; - bf_queue_.push_back(var.value()); - bf_in_queue_[var.value()] = true; - } -} - -void PrecedencesPropagator::CleanUpMarkedArcsAndParents() { - // To be sparse, we use the fact that each node with a parent must be in - // modified_vars_. - const int num_nodes = impacted_arcs_.size(); - for (const IntegerVariable var : modified_vars_.PositionsSetAtLeastOnce()) { - if (var >= num_nodes) continue; - const ArcIndex parent_arc_index = bf_parent_arc_of_[var.value()]; - if (parent_arc_index != -1) { - arcs_[parent_arc_index].is_marked = false; - bf_parent_arc_of_[var.value()] = -1; - bf_can_be_skipped_[var.value()] = false; - } - } - DCHECK(std::none_of(bf_parent_arc_of_.begin(), bf_parent_arc_of_.end(), - [](ArcIndex v) { return v != -1; })); - DCHECK(std::none_of(bf_can_be_skipped_.begin(), bf_can_be_skipped_.end(), - [](bool v) { return v; })); -} - -bool PrecedencesPropagator::DisassembleSubtree( - int source, int target, std::vector* can_be_skipped) { - // Note that we explore a tree, so we can do it in any order, and the one - // below seems to be the fastest. - tmp_vector_.clear(); - tmp_vector_.push_back(source); - while (!tmp_vector_.empty()) { - const int tail = tmp_vector_.back(); - tmp_vector_.pop_back(); - for (const ArcIndex arc_index : impacted_arcs_[IntegerVariable(tail)]) { - const ArcInfo& arc = arcs_[arc_index]; - if (arc.is_marked) { - arc.is_marked = false; // mutable. - if (arc.head_var.value() == target) return true; - DCHECK(!(*can_be_skipped)[arc.head_var.value()]); - (*can_be_skipped)[arc.head_var.value()] = true; - tmp_vector_.push_back(arc.head_var.value()); - } - } - } - return false; -} - -void PrecedencesPropagator::AnalyzePositiveCycle( - ArcIndex first_arc, Trail* trail, std::vector* must_be_all_true, - std::vector* literal_reason, - std::vector* integer_reason) { - must_be_all_true->clear(); - literal_reason->clear(); - integer_reason->clear(); - - // Follow bf_parent_arc_of_[] to find the cycle containing first_arc. - const IntegerVariable first_arc_head = arcs_[first_arc].head_var; - ArcIndex arc_index = first_arc; - std::vector arc_on_cycle; - - // Just to be safe and avoid an infinite loop we use the fact that the maximum - // cycle size on a graph with n nodes is of size n. If we have more in the - // code below, it means first_arc is not part of a cycle according to - // bf_parent_arc_of_[], which should never happen. - const int num_nodes = impacted_arcs_.size(); - while (arc_on_cycle.size() <= num_nodes) { - arc_on_cycle.push_back(arc_index); - const ArcInfo& arc = arcs_[arc_index]; - if (arc.tail_var == first_arc_head) break; - arc_index = bf_parent_arc_of_[arc.tail_var.value()]; - CHECK_NE(arc_index, ArcIndex(-1)); - } - CHECK_NE(arc_on_cycle.size(), num_nodes + 1) << "Infinite loop."; - - // Compute the reason for this cycle. - IntegerValue sum(0); - for (const ArcIndex arc_index : arc_on_cycle) { - const ArcInfo& arc = arcs_[arc_index]; - sum += ArcOffset(arc); - AppendLowerBoundReasonIfValid(arc.offset_var, *integer_trail_, - integer_reason); - for (const Literal l : arc.presence_literals) { - literal_reason->push_back(l.Negated()); - } - } - - // TODO(user): what if the sum overflow? this is just a check so I guess - // we don't really care, but fix the issue. - CHECK_GT(sum, 0); -} - -// Note that in our settings it is important to use an algorithm that tries to -// minimize the number of integer_trail_->Enqueue() as much as possible. -// -// TODO(user): The current algorithm is quite efficient, but there is probably -// still room for improvements. -bool PrecedencesPropagator::BellmanFordTarjan(Trail* trail) { - const int num_nodes = impacted_arcs_.size(); - - // These vector are reset by CleanUpMarkedArcsAndParents() so resize is ok. - bf_can_be_skipped_.resize(num_nodes, false); - bf_parent_arc_of_.resize(num_nodes, ArcIndex(-1)); - const auto cleanup = - ::absl::MakeCleanup([this]() { CleanUpMarkedArcsAndParents(); }); - - // The queue initialization is done by InitializeBFQueueWithModifiedNodes(). - while (!bf_queue_.empty()) { - const int node = bf_queue_.front(); - bf_queue_.pop_front(); - bf_in_queue_[node] = false; - - // TODO(user): we don't need bf_can_be_skipped_ since we can detect this - // if this node has a parent arc which is not marked. Investigate if it is - // faster without the vector. - // - // TODO(user): An alternative algorithm is to remove all these nodes from - // the queue instead of simply marking them. This should also lead to a - // better "relaxation" order of the arcs. It is however a bit more work to - // remove them since we need to track their position. - if (bf_can_be_skipped_[node]) { - DCHECK_NE(bf_parent_arc_of_[node], -1); - DCHECK(!arcs_[bf_parent_arc_of_[node]].is_marked); - continue; - } - - const IntegerValue tail_lb = - integer_trail_->LowerBound(IntegerVariable(node)); - for (const ArcIndex arc_index : impacted_arcs_[IntegerVariable(node)]) { - const ArcInfo& arc = arcs_[arc_index]; - DCHECK_EQ(arc.tail_var, node); - const IntegerValue candidate = tail_lb + ArcOffset(arc); - if (candidate > integer_trail_->LowerBound(arc.head_var)) { - if (!EnqueueAndCheck(arc, candidate, trail)) return false; - - // This is the Tarjan contribution to Bellman-Ford. This code detect - // positive cycle, and because it disassemble the subtree while doing - // so, the cost is amortized during the algorithm execution. Another - // advantages is that it will mark the node explored here as skippable - // which will avoid to propagate them too early (knowing that they will - // need to be propagated again later). - if (DisassembleSubtree(arc.head_var.value(), arc.tail_var.value(), - &bf_can_be_skipped_)) { - std::vector must_be_all_true; - AnalyzePositiveCycle(arc_index, trail, &must_be_all_true, - &literal_reason_, &integer_reason_); - if (must_be_all_true.empty()) { - ++num_cycles_; - return integer_trail_->ReportConflict(literal_reason_, - integer_reason_); - } else { - gtl::STLSortAndRemoveDuplicates(&must_be_all_true); - for (const Literal l : must_be_all_true) { - if (trail_->Assignment().LiteralIsFalse(l)) { - literal_reason_.push_back(l); - return integer_trail_->ReportConflict(literal_reason_, - integer_reason_); - } - } - for (const Literal l : must_be_all_true) { - if (trail_->Assignment().LiteralIsTrue(l)) continue; - integer_trail_->EnqueueLiteral(l, literal_reason_, - integer_reason_); - } - - // We just marked some optional variable as ignored, no need - // to update bf_parent_arc_of_[]. - continue; - } - } - - // We need to enforce the invariant that only the arc_index in - // bf_parent_arc_of_[] are marked (but not necessarily all of them - // since we unmark some in DisassembleSubtree()). - if (bf_parent_arc_of_[arc.head_var.value()] != -1) { - arcs_[bf_parent_arc_of_[arc.head_var.value()]].is_marked = false; - } - - // Tricky: We just enqueued the fact that the lower-bound of head is - // candidate. However, because the domain of head may be discrete, it is - // possible that the lower-bound of head is now higher than candidate! - // If this is the case, we don't update bf_parent_arc_of_[] so that we - // don't wrongly detect a positive weight cycle because of this "extra - // push". - const IntegerValue new_bound = integer_trail_->LowerBound(arc.head_var); - if (new_bound == candidate) { - bf_parent_arc_of_[arc.head_var.value()] = arc_index; - arcs_[arc_index].is_marked = true; - } else { - // We still unmark any previous dependency, since we have pushed the - // value of arc.head_var further. - bf_parent_arc_of_[arc.head_var.value()] = -1; - } - - // We do not re-enqueue if we are in a propagation loop and new_bound - // was not pushed to candidate or higher. - bf_can_be_skipped_[arc.head_var.value()] = false; - if (!bf_in_queue_[arc.head_var.value()] && new_bound >= candidate) { - bf_queue_.push_back(arc.head_var.value()); - bf_in_queue_[arc.head_var.value()] = true; - } - } - } - } - return true; -} - void BinaryRelationRepository::Add(Literal lit, LinearExpression2 expr, IntegerValue lhs, IntegerValue rhs) { expr.MakeVariablesPositive(); @@ -1705,8 +1093,9 @@ int GreaterThanAtLeastOneOfDetector::AddGreaterThanAtLeastOneOfConstraints( } ReifiedLinear2Bounds::ReifiedLinear2Bounds(Model* model) - : integer_encoder_(model->GetOrCreate()), - best_root_level_bounds_(model->GetOrCreate()) { + : best_root_level_bounds_(model->GetOrCreate()), + lin2_indices_(model->GetOrCreate()), + shared_stats_(model->GetOrCreate()) { int index = 0; model->GetOrCreate()->callbacks.push_back( [index = index, trail = model->GetOrCreate(), this]() mutable { @@ -1721,14 +1110,18 @@ ReifiedLinear2Bounds::ReifiedLinear2Bounds(Model* model) if (relevant_true_literals.empty()) return true; // Linear scan. - for (const auto [l, expr, ub] : all_reified_relations_) { + for (const auto [l, expr_index, ub] : all_reified_relations_) { if (relevant_true_literals.contains(l)) { - best_root_level_bounds_->Add(expr, kMinIntegerValue, ub); - VLOG(2) << "New fixed precedence: " << expr << " <= " << ub + ++num_relations_fixed_at_root_level_; + best_root_level_bounds_->AddUpperBound(expr_index, ub); + VLOG(2) << "New fixed precedence: " + << lin2_indices_->GetExpression(expr_index) << " <= " << ub << " (was reified by " << l << ")"; } else if (relevant_true_literals.contains(l.Negated())) { - best_root_level_bounds_->Add(expr, ub + 1, kMaxIntegerValue); - VLOG(2) << "New fixed precedence: " << expr << " > " << ub + ++num_relations_fixed_at_root_level_; + best_root_level_bounds_->AddLowerBound(expr_index, ub + 1); + VLOG(2) << "New fixed precedence: " + << lin2_indices_->GetExpression(expr_index) << " > " << ub << " (was reified by not(" << l << "))"; } } @@ -1736,49 +1129,98 @@ ReifiedLinear2Bounds::ReifiedLinear2Bounds(Model* model) }); } -Linear2BoundsFromLinear3::~Linear2BoundsFromLinear3() { +ReifiedLinear2Bounds::~ReifiedLinear2Bounds() { if (!VLOG_IS_ON(1)) return; std::vector> stats; stats.push_back( - {"Linear2BoundsFromLinear3/num_affine_updates", num_affine_updates_}); + {"ReifiedLinear2Bounds/num_linear3_relations", num_linear3_relations_}); + stats.push_back( + {"ReifiedLinear2Bounds/num_literal_relations", relation_to_lit_.size()}); + stats.push_back({"ReifiedLinear2Bounds/num_relations_fixed_at_root_level", + num_relations_fixed_at_root_level_}); shared_stats_->AddStats(stats); } -RelationStatus ReifiedLinear2Bounds::GetLevelZeroPrecedenceStatus( - AffineExpression a, AffineExpression b) const { - const auto [expr, ub] = EncodeDifferenceLowerThan(a, b, 0); - return best_root_level_bounds_->GetLevelZeroStatus(expr, kMinIntegerValue, - ub); -} - -void ReifiedLinear2Bounds::AddReifiedPrecedenceIfNonTrivial( - Literal l, AffineExpression a, AffineExpression b) { - const auto [expr, ub] = EncodeDifferenceLowerThan(a, b, 0); +void ReifiedLinear2Bounds::AddBoundEncodingIfNonTrivial(Literal l, + LinearExpression2 expr, + IntegerValue ub) { + DCHECK(expr.IsCanonicalized()); + DCHECK_EQ(expr.DivideByGcd(), 1); const RelationStatus status = best_root_level_bounds_->GetLevelZeroStatus(expr, kMinIntegerValue, ub); if (status != RelationStatus::IS_UNKNOWN) return; - relation_to_lit_.insert({{expr, ub}, l}); - + if (expr.vars[0] == kNoIntegerVariable) { + // For a single Affine GetEncodedBound() will return the IntegerLiteral + // without needing any indexing. + return; + } + const LinearExpression2Index expr_index = lin2_indices_->AddOrGet(expr); + relation_to_lit_.insert({{expr_index, ub}, l}); variable_appearing_in_reified_relations_.insert(l.Variable()); - all_reified_relations_.push_back({l, expr, ub}); + all_reified_relations_.push_back({l, expr_index, ub}); } -LiteralIndex ReifiedLinear2Bounds::GetReifiedPrecedence(AffineExpression a, - AffineExpression b) { - const auto [expr, ub] = EncodeDifferenceLowerThan(a, b, 0); +std::variant +ReifiedLinear2Bounds::GetEncodedBound(LinearExpression2 expr, IntegerValue ub) { + DCHECK(expr.IsCanonicalized()); + DCHECK_EQ(expr.DivideByGcd(), 1); const RelationStatus status = best_root_level_bounds_->GetLevelZeroStatus(expr, kMinIntegerValue, ub); if (status == RelationStatus::IS_TRUE) { - return integer_encoder_->GetTrueLiteral().Index(); + return ReifiedBoundType::kAlwaysTrue; } if (status == RelationStatus::IS_FALSE) { - return integer_encoder_->GetFalseLiteral().Index(); + return ReifiedBoundType::kAlwaysFalse; + } + if (expr.vars[0] == kNoIntegerVariable) { + DCHECK_NE(expr.vars[1], kNoIntegerVariable); + DCHECK_EQ(expr.coeffs[1], 1); + return IntegerLiteral::LowerOrEqual(expr.vars[1], ub); } - const auto it = relation_to_lit_.find({expr, ub}); - if (it == relation_to_lit_.end()) return kNoLiteralIndex; - return it->second; + const LinearExpression2Index expr_index = lin2_indices_->GetIndex(expr); + if (expr_index == kNoLinearExpression2Index) { + return ReifiedBoundType::kNoLiteralStored; + } + const auto it = relation_to_lit_.find({expr_index, ub}); + if (it != relation_to_lit_.end()) return it->second; + if (linear3_bounds_.size() <= expr_index) { + return ReifiedBoundType::kNoLiteralStored; + } + const auto [affine_expr, divisor] = linear3_bounds_[expr_index]; + if (divisor == 0) { + return ReifiedBoundType::kNoLiteralStored; + } + const IntegerValue affine_bound = CapProdI(ub, divisor); + if (affine_bound == kMaxIntegerValue) { + return ReifiedBoundType::kNoLiteralStored; + } + return affine_expr.LowerOrEqual(affine_bound); +} + +void ReifiedLinear2Bounds::AddLinear3(absl::Span vars, + absl::Span coeffs, + int64_t activity) { + DCHECK_EQ(vars.size(), 3); + DCHECK_EQ(coeffs.size(), 3); + for (int i = 0; i < vars.size(); ++i) { + LinearExpression2 expr(vars[i], vars[(i + 1) % 3], coeffs[i], + coeffs[(i + 1) % 3]); + AffineExpression affine_expr(vars[(i + 2) % 3], -coeffs[(i + 2) % 3], + activity); + expr.SimpleCanonicalization(); + const IntegerValue gcd = expr.DivideByGcd(); + const LinearExpression2Index expr_index = lin2_indices_->AddOrGet(expr); + if (linear3_bounds_.size() <= expr_index) { + linear3_bounds_.resize(expr_index + 1, {AffineExpression(), 0}); + } + auto& [old_affine_expr, old_divisor] = linear3_bounds_[expr_index]; + if (old_divisor == 0 || old_divisor > gcd) { + linear3_bounds_[expr_index] = {affine_expr, gcd}; + if (old_divisor == 0) ++num_linear3_relations_; + } + } } Linear2BoundsFromLinear3::Linear2BoundsFromLinear3(Model* model) @@ -1790,6 +1232,14 @@ Linear2BoundsFromLinear3::Linear2BoundsFromLinear3(Model* model) root_level_bounds_(model->GetOrCreate()), lin2_indices_(model->GetOrCreate()) {} +Linear2BoundsFromLinear3::~Linear2BoundsFromLinear3() { + if (!VLOG_IS_ON(1)) return; + std::vector> stats; + stats.push_back( + {"Linear2BoundsFromLinear3/num_affine_updates", num_affine_updates_}); + shared_stats_->AddStats(stats); +} + // Note that for speed we do not compare to the trivial or root level bounds. // // It is okay to still store it in the hash-map, since at worst we will have no @@ -1929,5 +1379,26 @@ void Linear2Bounds::AddReasonForUpperBoundLowerThan( integer_reason); } +RelationStatus Linear2Bounds::GetStatus(LinearExpression2 expr, IntegerValue lb, + IntegerValue ub) const { + expr.SimpleCanonicalization(); + const IntegerValue gcd = expr.DivideByGcd(); + const LinearExpression2Index index = lin2_indices_->GetIndex(expr); + IntegerValue known_ub; + IntegerValue known_lb; + if (index == kNoLinearExpression2Index) { + known_ub = CapProdI(gcd, integer_trail_->UpperBound(expr)); + expr.Negate(); + known_lb = -CapProdI(gcd, integer_trail_->UpperBound(expr)); + } else { + known_ub = CapProdI(gcd, UpperBound(index)); + known_lb = -CapProdI(gcd, UpperBound(NegationOf(index))); + } + if (lb <= known_lb && ub >= known_ub) return RelationStatus::IS_TRUE; + if (lb > known_ub || ub < known_lb) return RelationStatus::IS_FALSE; + + return RelationStatus::IS_UNKNOWN; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/precedences.h b/ortools/sat/precedences.h index 59b0d0f064..2b120268e2 100644 --- a/ortools/sat/precedences.h +++ b/ortools/sat/precedences.h @@ -16,10 +16,9 @@ #include #include -#include -#include #include #include +#include #include #include "absl/container/btree_set.h" @@ -30,16 +29,13 @@ #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "ortools/base/strong_vector.h" -#include "ortools/graph/graph.h" #include "ortools/sat/cp_model_mapping.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" -#include "ortools/sat/sat_solver.h" #include "ortools/sat/synchronization.h" #include "ortools/sat/util.h" -#include "ortools/util/bitset.h" #include "ortools/util/rev.h" #include "ortools/util/strong_integers.h" @@ -188,6 +184,15 @@ class RootLevelLinear2Bounds { return AddUpperBound(lin2_indices_->AddOrGet(expr), ub); } + bool AddLowerBound(LinearExpression2 expr, IntegerValue lb) { + expr.Negate(); + return AddUpperBound(expr, -lb); + } + + bool AddLowerBound(LinearExpression2Index index, IntegerValue lb) { + return AddUpperBound(NegationOf(index), -lb); + } + // All modifications go through this function. bool AddUpperBound(LinearExpression2Index index, IntegerValue ub); @@ -606,32 +611,49 @@ class Linear2BoundsFromLinear3 { // TODO(user): Merge with BinaryRelationRepository. Note that this one provides // different indexing though, so it could be kept separate. -// TODO(user): Use LinearExpression2 instead of pairs of AffineExpression for -// consistency with other classes. class ReifiedLinear2Bounds { public: explicit ReifiedLinear2Bounds(Model* model); + ~ReifiedLinear2Bounds(); - // Return the status of a <= b; - RelationStatus GetLevelZeroPrecedenceStatus(AffineExpression a, - AffineExpression b) const; - - // Register the fact that l <=> ( a <= b ). + // Register the fact that l <=> ( expr <= ub ). + // `expr` must already be canonicalized and gcd-reduced. // These are considered equivalence relation. - void AddReifiedPrecedenceIfNonTrivial(Literal l, AffineExpression a, - AffineExpression b); + void AddBoundEncodingIfNonTrivial(Literal l, LinearExpression2 expr, + IntegerValue ub); - // Returns kNoLiteralIndex if we don't have a literal <=> ( a <= b ), or - // returns that literal if we have one. Note that we will return the - // true/false literal if the status is known at level zero. - LiteralIndex GetReifiedPrecedence(AffineExpression a, AffineExpression b); + // Add a linear3 of the form vars[i]*coeffs[i] = activity that is not + // enforced and valid at level zero. + void AddLinear3(absl::Span vars, + absl::Span coeffs, int64_t activity); + + // Returns ReifiedBoundType if we don't have a literal <=> ( expr <= ub ), or + // returns that literal if we have one. `expr` must be canonicalized and + // gcd-reduced. + enum class ReifiedBoundType { + kNoLiteralStored, + kAlwaysTrue, + kAlwaysFalse, + }; + std::variant GetEncodedBound( + LinearExpression2 expr, IntegerValue ub); private: - IntegerEncoder* integer_encoder_; RootLevelLinear2Bounds* best_root_level_bounds_; + Linear2Indices* lin2_indices_; + SharedStatistics* shared_stats_; + + // This stores divisor * linear2 = AffineExpression similarly to + // Linear2BoundsFromLinear3. The difference here is that we only store linear3 + // that are equality, but irrespective of whether it constraint any linear2 at + // the current level. If there is more than one expression for a given + // linear2, we will keep the one with the smallest divisor. + util_intops::StrongVector> + linear3_bounds_; // This stores relations l <=> (linear2 <= rhs). - absl::flat_hash_map, Literal> + absl::flat_hash_map, Literal> relation_to_lit_; // This is used to detect relations that become fixed at level zero and @@ -639,8 +661,11 @@ class ReifiedLinear2Bounds { // we fix variable, a linear scan shouldn't be too bad and is relatively // compact memory wise. absl::flat_hash_set variable_appearing_in_reified_relations_; - std::vector> + std::vector> all_reified_relations_; + + int64_t num_linear3_relations_ = 0; + int64_t num_relations_fixed_at_root_level_ = 0; }; // Simple wrapper around the different repositories for bounds of linear2. @@ -665,6 +690,9 @@ class Linear2Bounds { std::vector* literal_reason, std::vector* integer_reason) const; + RelationStatus GetStatus(LinearExpression2 expr, IntegerValue lb, + IntegerValue ub) const; + // Like UpperBound() but do not consider the bounds coming from // the individual variable bounds. This is faster. IntegerValue NonTrivialUpperBound(LinearExpression2Index lin2_index) const; @@ -720,228 +748,6 @@ class GreaterThanAtLeastOneOfDetector { BinaryRelationRepository& repository_; }; -// ============================================================================= -// Old precedences propagator. -// -// This is superseded by the new LinearPropagator and should only be used if the -// option 'new_linear_propagation' is false. We still keep it around to -// benchmark and test the new code vs this one. -// ============================================================================= - -// This class implement a propagator on simple inequalities between integer -// variables of the form (i1 + offset <= i2). The offset can be constant or -// given by the value of a third integer variable. Offsets can also be negative. -// -// The algorithm works by mapping the problem onto a graph where the edges carry -// the offset and the nodes correspond to one of the two bounds of an integer -// variable (lower_bound or -upper_bound). It then find the fixed point using an -// incremental variant of the Bellman-Ford(-Tarjan) algorithm. -// -// This is also known as an "integer difference logic theory" in the SMT world. -// Another word is "separation logic". -// -// TODO(user): We could easily generalize the code to support any relation of -// the form a*X + b*Y + c*Z >= rhs (or <=). Do that since this class should be -// a lot faster at propagating small linear inequality than the generic -// propagator and the overhead of supporting coefficient should not be too bad. -class PrecedencesPropagator : public SatPropagator, PropagatorInterface { - public: - explicit PrecedencesPropagator(Model* model) - : SatPropagator("PrecedencesPropagator"), - relations_(model->GetOrCreate()), - trail_(model->GetOrCreate()), - integer_trail_(model->GetOrCreate()), - shared_stats_(model->Mutable()), - watcher_(model->GetOrCreate()), - watcher_id_(watcher_->Register(this)) { - model->GetOrCreate()->AddPropagator(this); - integer_trail_->RegisterWatcher(&modified_vars_); - watcher_->SetPropagatorPriority(watcher_id_, 0); - } - - // This type is neither copyable nor movable. - PrecedencesPropagator(const PrecedencesPropagator&) = delete; - PrecedencesPropagator& operator=(const PrecedencesPropagator&) = delete; - ~PrecedencesPropagator() override; - - bool Propagate() final; - bool Propagate(Trail* trail) final; - void Untrail(const Trail& trail, int trail_index) final; - - // Propagates all the outgoing arcs of the given variable (and only those). It - // is more efficient to do all these propagation in one go by calling - // Propagate(), but for scheduling problem, we wants to propagate right away - // the end of an interval when its start moved. - bool PropagateOutgoingArcs(IntegerVariable var); - - // Add a precedence relation (i1 + offset <= i2) between integer variables. - // - // Important: The optionality of the variable should be marked BEFORE this - // is called. - void AddPrecedence(IntegerVariable i1, IntegerVariable i2); - void AddPrecedenceWithOffset(IntegerVariable i1, IntegerVariable i2, - IntegerValue offset); - void AddPrecedenceWithVariableOffset(IntegerVariable i1, IntegerVariable i2, - IntegerVariable offset_var); - - // Same as above, but the relation is only true when the given literal is. - void AddConditionalPrecedence(IntegerVariable i1, IntegerVariable i2, - Literal l); - void AddConditionalPrecedenceWithOffset(IntegerVariable i1, - IntegerVariable i2, - IntegerValue offset, Literal l); - - // Generic function that cover all of the above case and more. - void AddPrecedenceWithAllOptions(IntegerVariable i1, IntegerVariable i2, - IntegerValue offset, - IntegerVariable offset_var, - absl::Span presence_literals); - - // This version check current precedence. It is however "slow". - bool AddPrecedenceWithOffsetIfNew(IntegerVariable i1, IntegerVariable i2, - IntegerValue offset); - - private: - DEFINE_STRONG_INDEX_TYPE(ArcIndex); - DEFINE_STRONG_INDEX_TYPE(OptionalArcIndex); - - // Information about an individual arc. - struct ArcInfo { - IntegerVariable tail_var; - IntegerVariable head_var; - - IntegerValue offset; - IntegerVariable offset_var; // kNoIntegerVariable if none. - - // This arc is "present" iff all these literals are true. - absl::InlinedVector presence_literals; - - // Used temporarily by our implementation of the Bellman-Ford algorithm. It - // should be false at the beginning of BellmanFordTarjan(). - mutable bool is_marked; - }; - - // Internal functions to add new precedence relations. - // - // Note that internally, we only propagate lower bounds, so each time we add - // an arc, we actually create two of them: one on the given variables, and one - // on their negation. - void AdjustSizeFor(IntegerVariable i); - void AddArc(IntegerVariable tail, IntegerVariable head, IntegerValue offset, - IntegerVariable offset_var, - absl::Span presence_literals); - - // Enqueue a new lower bound for the variable arc.head_lb that was deduced - // from the current value of arc.tail_lb and the offset of this arc. - bool EnqueueAndCheck(const ArcInfo& arc, IntegerValue new_head_lb, - Trail* trail); - IntegerValue ArcOffset(const ArcInfo& arc) const; - - // Inspect all the optional arcs that needs inspection (to stay sparse) and - // check if their presence literal can be propagated to false. - void PropagateOptionalArcs(Trail* trail); - - // The core algorithm implementation is split in these functions. One must - // first call InitializeBFQueueWithModifiedNodes() that will push all the - // IntegerVariable whose lower bound has been modified since the last call. - // Then, BellmanFordTarjan() will take care of all the propagation and returns - // false in case of conflict. Internally, it uses DisassembleSubtree() which - // is the Tarjan variant to detect a possible positive cycle. Before exiting, - // it will call CleanUpMarkedArcsAndParents(). - // - // The Tarjan version of the Bellam-Ford algorithm is really nice in our - // context because it was really easy to make it incremental. Moreover, it - // supports batch increment! - // - // This implementation is kind of unique because of our context and the fact - // that it is incremental, but a good reference is "Negative-cycle detection - // algorithms", Boris V. Cherkassky, Andrew V. Goldberg, 1996, - // http://people.cs.nctu.edu.tw/~tjshen/doc/ne.pdf - void InitializeBFQueueWithModifiedNodes(); - bool BellmanFordTarjan(Trail* trail); - bool DisassembleSubtree(int source, int target, - std::vector* can_be_skipped); - void AnalyzePositiveCycle(ArcIndex first_arc, Trail* trail, - std::vector* must_be_all_true, - std::vector* literal_reason, - std::vector* integer_reason); - void CleanUpMarkedArcsAndParents(); - - // Loops over all the arcs and verify that there is no propagation left. - // This is only meant to be used in a DCHECK() and is not optimized. - bool NoPropagationLeft(const Trail& trail) const; - - // Update relations_. - void PushConditionalRelations(const ArcInfo& arc); - - // External class needed to get the IntegerVariable lower bounds and Enqueue - // new ones. - EnforcedLinear2Bounds* relations_; - Trail* trail_; - IntegerTrail* integer_trail_; - SharedStatistics* shared_stats_ = nullptr; - GenericLiteralWatcher* watcher_; - int watcher_id_; - - // The key to our incrementality. This will be cleared once the propagation - // is done, and automatically updated by the integer_trail_ with all the - // IntegerVariable that changed since the last clear. - SparseBitset modified_vars_; - - // An arc needs to be inspected for propagation (i.e. is impacted) if its - // tail_var changed. If an arc has 3 variables (tail, offset, head), it will - // appear as 6 different entries in the arcs_ vector, one for each variable - // and its negation, each time with a different tail. - // - // TODO(user): rearranging the index so that the arc of the same node are - // consecutive like in StaticGraph should have a big performance impact. - // - // TODO(user): We do not need to store ArcInfo.tail_var here. - util_intops::StrongVector> - impacted_arcs_; - util_intops::StrongVector arcs_; - - // This is similar to impacted_arcs_/arcs_ but it is only used to propagate - // one of the presence literals when the arc cannot be present. An arc needs - // to appear only once in potential_arcs_, but it will be referenced by - // all its variable in impacted_potential_arcs_. - util_intops::StrongVector> - impacted_potential_arcs_; - util_intops::StrongVector potential_arcs_; - - // Each time a literal becomes true, this list the set of arcs for which we - // need to decrement their count. When an arc count reach zero, it must be - // added to the set of impacted_arcs_. Note that counts never becomes - // negative. - // - // TODO(user): Try a one-watcher approach instead. Note that in most cases - // arc should be controlled by 1 or 2 literals, so not sure it is worth it. - util_intops::StrongVector> - literal_to_new_impacted_arcs_; - util_intops::StrongVector arc_counts_; - - // Temp vectors to hold the reason of an assignment. - std::vector literal_reason_; - std::vector integer_reason_; - - // Temp vectors for the Bellman-Ford algorithm. The graph in which this - // algorithm works is in one to one correspondence with the IntegerVariable in - // impacted_arcs_. - std::deque bf_queue_; - std::vector bf_in_queue_; - std::vector bf_can_be_skipped_; - std::vector bf_parent_arc_of_; - - // Temp vector used by the tree traversal in DisassembleSubtree(). - std::vector tmp_vector_; - - // Stats. - int64_t num_cycles_ = 0; - int64_t num_pushes_ = 0; - int64_t num_enforcement_pushes_ = 0; -}; - // This can be in a hot-loop, so we want to inline it if possible. inline IntegerValue Linear2Bounds::NonTrivialUpperBound( LinearExpression2Index lin2_index) const { @@ -952,99 +758,6 @@ inline IntegerValue Linear2Bounds::NonTrivialUpperBound( ub = std::min(ub, linear3_bounds_->GetUpperBoundFromLinear3(lin2_index)); return ub; } - -// ============================================================================= -// Implementation of the small API functions below. -// ============================================================================= - -inline void PrecedencesPropagator::AddPrecedence(IntegerVariable i1, - IntegerVariable i2) { - AddArc(i1, i2, /*offset=*/IntegerValue(0), /*offset_var=*/kNoIntegerVariable, - {}); -} - -inline void PrecedencesPropagator::AddPrecedenceWithOffset( - IntegerVariable i1, IntegerVariable i2, IntegerValue offset) { - AddArc(i1, i2, offset, /*offset_var=*/kNoIntegerVariable, {}); -} - -inline void PrecedencesPropagator::AddConditionalPrecedence(IntegerVariable i1, - IntegerVariable i2, - Literal l) { - AddArc(i1, i2, /*offset=*/IntegerValue(0), /*offset_var=*/kNoIntegerVariable, - {l}); -} - -inline void PrecedencesPropagator::AddConditionalPrecedenceWithOffset( - IntegerVariable i1, IntegerVariable i2, IntegerValue offset, Literal l) { - AddArc(i1, i2, offset, /*offset_var=*/kNoIntegerVariable, {l}); -} - -inline void PrecedencesPropagator::AddPrecedenceWithVariableOffset( - IntegerVariable i1, IntegerVariable i2, IntegerVariable offset_var) { - AddArc(i1, i2, /*offset=*/IntegerValue(0), offset_var, {}); -} - -inline void PrecedencesPropagator::AddPrecedenceWithAllOptions( - IntegerVariable i1, IntegerVariable i2, IntegerValue offset, - IntegerVariable offset_var, absl::Span presence_literals) { - AddArc(i1, i2, offset, offset_var, presence_literals); -} - -// ============================================================================= -// Model based functions. -// ============================================================================= - -// l => (a + b <= ub). -inline void AddConditionalSum2LowerOrEqual( - absl::Span enforcement_literals, IntegerVariable a, - IntegerVariable b, int64_t ub, Model* model) { - // TODO(user): Refactor to be sure we do not miss any level zero relations. - if (enforcement_literals.empty()) { - LinearExpression2 expr(a, b, 1, 1); - model->GetOrCreate()->AddUpperBound( - expr, IntegerValue(ub)); - } - - PrecedencesPropagator* p = model->GetOrCreate(); - p->AddPrecedenceWithAllOptions(a, NegationOf(b), IntegerValue(-ub), - kNoIntegerVariable, enforcement_literals); -} - -// l => (a + b + c <= ub). -// -// TODO(user): Use level zero bounds to infer binary precedence relations? -inline void AddConditionalSum3LowerOrEqual( - absl::Span enforcement_literals, IntegerVariable a, - IntegerVariable b, IntegerVariable c, int64_t ub, Model* model) { - PrecedencesPropagator* p = model->GetOrCreate(); - p->AddPrecedenceWithAllOptions(a, NegationOf(c), IntegerValue(-ub), b, - enforcement_literals); -} - -// a == b. -// -// ABSL_DEPRECATED("Use linear constraint API instead") -inline std::function Equality(IntegerVariable a, - IntegerVariable b) { - return [=](Model* model) { - auto* precedences = model->GetOrCreate(); - precedences->AddPrecedence(a, b); - precedences->AddPrecedence(b, a); - }; -} - -// is_le => (a + offset <= b). -// -// ABSL_DEPRECATED("Use linear constraint API instead") -inline std::function ConditionalLowerOrEqualWithOffset( - IntegerVariable a, IntegerVariable b, int64_t offset, Literal is_le) { - return [=](Model* model) { - PrecedencesPropagator* p = model->GetOrCreate(); - p->AddConditionalPrecedenceWithOffset(a, b, IntegerValue(offset), is_le); - }; -} - inline LinearExpression2Index Linear2Indices::GetIndex( LinearExpression2 expr) const { if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) { diff --git a/ortools/sat/primary_variables.cc b/ortools/sat/primary_variables.cc index c32b1b0654..f074bb6c16 100644 --- a/ortools/sat/primary_variables.cc +++ b/ortools/sat/primary_variables.cc @@ -56,14 +56,14 @@ void GetRelationshipForConstraint(const ConstraintProto& ct, deducible_vars->clear(); input_vars->clear(); *preferred_to_deduce = -1; + if (!ct.enforcement_literal().empty()) { + return; + } switch (ct.constraint_case()) { case ConstraintProto::kLinear: { if (ReadDomainFromProto(ct.linear()).Size() != 1) { return; } - if (!ct.enforcement_literal().empty()) { - return; - } for (int i = 0; i < ct.linear().vars_size(); ++i) { if (ct.linear().coeffs(i) == 0) continue; deducible_vars->insert(ct.linear().vars(i)); diff --git a/ortools/sat/python/BUILD.bazel b/ortools/sat/python/BUILD.bazel index 58ab55d178..ffd20a2ebd 100644 --- a/ortools/sat/python/BUILD.bazel +++ b/ortools/sat/python/BUILD.bazel @@ -16,7 +16,7 @@ load("@pip_deps//:requirements.bzl", "requirement") load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") load("@rules_python//python:py_library.bzl", "py_library") -load("@rules_python//python:py_test.bzl", "py_test") +load("@rules_python//python:py_test.bzl", "py_test") cc_library( name = "linear_expr_doc", @@ -28,6 +28,7 @@ cc_library( srcs = ["linear_expr.cc"], hdrs = ["linear_expr.h"], deps = [ + "//ortools/base:string_view_migration", "//ortools/sat:cp_model_cc_proto", "//ortools/sat:cp_model_utils", "//ortools/util:fp_roundtrip_conv", diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index cd97ae8aa1..3eefdfbdbd 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -45,18 +45,14 @@ Other methods and functions listed are primarily used for developing OR-Tools, rather than for solving specific optimization problems. """ +from collections.abc import Callable, Iterable, Sequence +import copy import threading import time from typing import ( Any, - Callable, - Dict, - Iterable, Optional, - Sequence, - Tuple, Union, - cast, overload, ) import warnings @@ -69,12 +65,14 @@ from ortools.util.python import sorted_interval_list # Import external types. BoundedLinearExpression = cmh.BoundedLinearExpression +Constraint = cmh.Constraint CpModelProto = cmh.CpModelProto CpSolverResponse = cmh.CpSolverResponse CpSolverStatus = cmh.CpSolverStatus Domain = sorted_interval_list.Domain FlatFloatExpr = cmh.FlatFloatExpr FlatIntExpr = cmh.FlatIntExpr +IntervalVar = cmh.IntervalVar IntVar = cmh.IntVar LinearExpr = cmh.LinearExpr NotBooleanVariable = cmh.NotBooleanVariable @@ -180,319 +178,19 @@ VariableT = Union["IntVar", IntegralT] LinearExprT = Union[LinearExpr, "IntVar", IntegralT] ObjLinearExprT = Union[LinearExpr, NumberT] -ArcT = Tuple[IntegralT, IntegralT, LiteralT] +ArcT = tuple[IntegralT, IntegralT, LiteralT] _IndexOrSeries = Union[pd.Index, pd.Series] -def short_name(model: cmh.CpModelProto, i: int) -> str: - """Returns a short name of an integer variable, or its negation.""" - if i >= 0: - return str(IntVar(model, i)) - else: - return f"not({IntVar(model, -i - 1)})" - - -def short_expr_name( - model: cmh.CpModelProto, - e: cmh.LinearExpressionProto, -) -> str: - """Pretty-print LinearExpressionProto instances.""" - if not e.vars: - return str(e.offset) - if len(e.vars) == 1: - var_name = short_name(model, e.vars[0]) - coeff = e.coeffs[0] - result = "" - if coeff == 1: - result = var_name - elif coeff == -1: - result = f"-{var_name}" - elif coeff != 0: - result = f"{coeff} * {var_name}" - if e.offset > 0: - result = f"{result} + {e.offset}" - elif e.offset < 0: - result = f"{result} - {-e.offset}" - return result - # TODO(user): Support more than affine expressions. - return str(e) - - -def arg_is_boolean(x: Any) -> bool: - """Checks if the x is a boolean.""" - if isinstance(x, bool): - return True - if isinstance(x, np.bool_): - return True - return False - - -def rebuild_from_linear_expression_proto( - proto: cmh.LinearExpressionProto, - model_proto: cmh.CpModelProto, -) -> LinearExprT: - """Recreate a LinearExpr from a LinearExpressionProto.""" - num_elements = len(proto.vars) - if num_elements == 0: - return proto.offset - elif num_elements == 1: - var = IntVar(model_proto, proto.vars[0]) - return LinearExpr.affine( - var, proto.coeffs[0], proto.offset - ) # pytype: disable=bad-return-type - else: - variables = [] - for var_index in range(len(proto.vars)): - var = IntVar(model_proto, var_index) - variables.append(var) - if proto.offset != 0: - coeffs = [] - coeffs.extend(proto.coeffs) - coeffs.append(1) - variables.append(proto.offset) - return LinearExpr.weighted_sum(variables, coeffs) - else: - return LinearExpr.weighted_sum(variables, proto.coeffs) - - -class Constraint: - """Base class for constraints. - - Constraints are built by the CpModel through the add methods. - Once created by the CpModel class, they are automatically added to the model. - The purpose of this class is to allow specification of enforcement literals - for this constraint. - - b = model.new_bool_var('b') - x = model.new_int_var(0, 10, 'x') - y = model.new_int_var(0, 10, 'y') - - model.add(x + 2 * y == 5).only_enforce_if(b.negated()) - """ - - def __init__(self, cp_model: "CpModel", index: Optional[int] = None) -> None: - self.__cp_model: "CpModel" = cp_model - if index is None: - self.__index: int = len(cp_model.proto.constraints) - cp_model.proto.constraints.add() - else: - self.__index: int = index - - @overload - def only_enforce_if(self, literals: Iterable[LiteralT]) -> "Constraint": ... - - @overload - def only_enforce_if(self, *literals: LiteralT) -> "Constraint": ... - - def only_enforce_if(self, *literals) -> "Constraint": - """Adds one or more enforcement literals to the constraint. - - This method adds one or more literals (that is, a boolean variable or its - negation) as enforcement literals. The conjunction of all these literals - determines whether the constraint is active or not. It acts as an - implication, so if the conjunction is true, it implies that the constraint - must be enforced. If it is false, then the constraint is ignored. - - BoolOr, BoolAnd, and linear constraints all support enforcement literals. - - Args: - *literals: One or more Boolean literals. - - Returns: - self. - """ - cmh.CpSatHelper.add_enforcement_literals( - self.__index, - self.__cp_model.expand_literals_to_index_list(literals), - self.__cp_model.proto, - ) - return self - - def with_name(self, name: str) -> "Constraint": - """Sets the name of the constraint.""" - if name: - cmh.CpSatHelper.set_ct_name(self.__index, name, self.__cp_model.proto) - else: - cmh.CpSatHelper.clear_ct_name(self.__index, self.__cp_model.proto) - return self - - @property - def name(self) -> str: - """Returns the name of the constraint.""" - return cmh.CpSatHelper.ct_name(self.__index, self.__cp_model.proto) - - @property - def index(self) -> int: - """Returns the index of the constraint in the model.""" - return self.__index - - @property - def proto(self) -> cmh.ConstraintProto: - """Returns the constraint protobuf.""" - return self.__cp_model.proto.constraints[self.__index] - - def __str__(self) -> str: - return ( - f"Constraint({self.__index}," - f" {self.__cp_model.proto.constraints[self.__index]})" - ) - - # Pre PEP8 compatibility. - # pylint: disable=invalid-name - OnlyEnforceIf = only_enforce_if - WithName = with_name - - def Name(self) -> str: - return self.name - - def Index(self) -> int: - return self.index - - def Proto(self) -> cmh.ConstraintProto: - return self.proto - - # pylint: enable=invalid-name - - -class IntervalVar: - """Represents an Interval variable. - - An interval variable is both a constraint and a variable. It is defined by - three integer variables: start, size, and end. - - It is a constraint because, internally, it enforces that start + size == end. - - It is also a variable as it can appear in specific scheduling constraints: - NoOverlap, NoOverlap2D, Cumulative. - - Optionally, an enforcement literal can be added to this constraint, in which - case these scheduling constraints will ignore interval variables with - enforcement literals assigned to false. Conversely, these constraints will - also set these enforcement literals to false if they cannot fit these - intervals into the schedule. - - Raises: - ValueError: if start, size, end are not defined, or have the wrong type. - """ - - def __init__( - self, - model: cmh.CpModelProto, - start: Union[cmh.LinearExpressionProto, int], - size: Optional[cmh.LinearExpressionProto], - end: Optional[cmh.LinearExpressionProto], - is_present_index: Optional[int], - name: Optional[str], - ) -> None: - self.__model: cmh.CpModelProto = model - self.__index: int - self.__ct: cmh.ConstraintProto - # As with the IntVar::__init__ method, we hack the __init__ method to - # support two use cases: - # case 1: called when creating a new interval variable. - # {start|size|end} are linear expressions, is_present_index is either - # None or the index of a Boolean literal. name is a string - # case 2: called when querying an existing interval variable. - # start_index is an int, all parameters after are None. - if isinstance(start, int): - if size is not None: - raise ValueError("size should be None") - if end is not None: - raise ValueError("end should be None") - if is_present_index is not None: - raise ValueError("is_present_index should be None") - self.__index = cast(int, start) - self.__ct = model.constraints[self.__index] - else: - self.__index = len(model.constraints) - self.__ct = self.__model.constraints.add() - if start is None: - raise TypeError("start is not defined") - self.__ct.interval.start.copy_from(start) - if size is None: - raise TypeError("size is not defined") - self.__ct.interval.size.copy_from(size) - if end is None: - raise TypeError("end is not defined") - self.__ct.interval.end.copy_from(end) - if is_present_index is not None: - self.__ct.enforcement_literal.append(is_present_index) - if name: - self.__ct.name = name - - @property - def index(self) -> int: - """Returns the index of the interval constraint in the model.""" - return self.__index - - @property - def proto(self) -> cmh.ConstraintProto: - """Returns the interval protobuf.""" - return self.__model.constraints[self.__index] - - @property - def model_proto(self) -> cmh.CpModelProto: - """Returns the model protobuf.""" - return self.__model - - def __str__(self): - return self.proto.name - - def __repr__(self): - interval = self.proto.interval - if self.proto.enforcement_literal: - return ( - f"{self.proto.name}(start =" - f" {short_expr_name(self.__model, interval.start)}, size =" - f" {short_expr_name(self.__model, interval.size)}, end =" - f" {short_expr_name(self.__model, interval.end)}, is_present =" - f" {short_name(self.__model, self.proto.enforcement_literal[0])})" - ) - else: - return ( - f"{self.proto.name}(start =" - f" {short_expr_name(self.__model, interval.start)}, size =" - f" {short_expr_name(self.__model, interval.size)}, end =" - f" {short_expr_name(self.__model, interval.end)})" - ) - - @property - def name(self) -> str: - if not self.proto or not self.proto.name: - return "" - return self.proto.name - - def start_expr(self) -> LinearExprT: - return rebuild_from_linear_expression_proto( - self.proto.interval.start, self.__model - ) - - def size_expr(self) -> LinearExprT: - return rebuild_from_linear_expression_proto( - self.proto.interval.size, self.__model - ) - - def end_expr(self) -> LinearExprT: - return rebuild_from_linear_expression_proto( - self.proto.interval.end, self.__model - ) - - # Pre PEP8 compatibility. - # pylint: disable=invalid-name - def Name(self) -> str: - return self.name - - def Index(self) -> int: - return self.index - - def Proto(self) -> cmh.ConstraintProto: - return self.proto - - StartExpr = start_expr - SizeExpr = size_expr - EndExpr = end_expr - - # pylint: enable=invalid-name +# Helper functions. +def snake_case_to_camel_case(name: str) -> str: + """Converts a snake_case name to CamelCase.""" + words = name.split("_") + return ( + "".join(word.capitalize() for word in words) + .replace("2d", "2D") + .replace("Xor", "XOr") + ) def object_is_a_true_literal(literal: LiteralT) -> bool: @@ -503,8 +201,11 @@ def object_is_a_true_literal(literal: LiteralT) -> bool: if isinstance(literal, cmh.NotBooleanVariable): proto = literal.negated().proto return len(proto.domain) == 2 and proto.domain[0] == 0 and proto.domain[1] == 0 + if isinstance(literal, (bool, np.bool_)): + return bool(literal) if isinstance(literal, IntegralTypes): - return int(literal) == 1 + literal_as_int = int(literal) + return literal_as_int == 1 or literal_as_int == ~False return False @@ -516,12 +217,50 @@ def object_is_a_false_literal(literal: LiteralT) -> bool: if isinstance(literal, cmh.NotBooleanVariable): proto = literal.negated().proto return len(proto.domain) == 2 and proto.domain[0] == 1 and proto.domain[1] == 1 + if isinstance(literal, (bool, np.bool_)): + return not bool(literal) if isinstance(literal, IntegralTypes): - return int(literal) == 0 + literal_as_int = int(literal) + return literal_as_int == 0 or literal_as_int == ~True return False -class CpModel: +def _get_index(obj: _IndexOrSeries) -> pd.Index: + """Returns the indices of `obj` as a `pd.Index`.""" + if isinstance(obj, pd.Series): + return obj.index + return obj + + +@overload +def _convert_to_series_and_validate_index( + value_or_series: Union[LinearExprT, pd.Series], index: pd.Index +) -> pd.Series: ... + + +@overload +def _convert_to_series_and_validate_index( + value_or_series: Union[LiteralT, pd.Series], index: pd.Index +) -> pd.Series: ... + + +@overload +def _convert_to_series_and_validate_index( + value_or_series: Union[IntegralT, pd.Series], index: pd.Index +) -> pd.Series: ... + + +def _convert_to_series_and_validate_index(value_or_series, index): + """Returns a pd.Series of the given index with the corresponding values.""" + if isinstance(value_or_series, pd.Series): + if value_or_series.index.equals(index): + return value_or_series + else: + raise ValueError("index does not match") + return pd.Series(data=value_or_series, index=index) + + +class CpModel(cmh.CpBaseModel): """Methods for building a CP model. Methods beginning with: @@ -530,22 +269,32 @@ class CpModel: * ```add_``` create new constraints and add them to the model. """ - def __init__(self) -> None: - self.__model: cmh.CpModelProto = cmh.CpModelProto() - self.__constant_map: Dict[IntegralT, int] = {} + def __init__(self, model_proto: Optional[cmh.CpModelProto] = None) -> None: + cmh.CpBaseModel.__init__(self, model_proto) + self._add_pre_pep8_methods() + + def _add_pre_pep8_methods(self) -> None: + for method_name in dir(self): + if callable(getattr(self, method_name)) and ( + method_name.startswith("add_") + or method_name.startswith("new_") + or method_name.startswith("clear_") + ): + pre_pep8_name = snake_case_to_camel_case(method_name) + setattr(self, pre_pep8_name, getattr(self, method_name)) # Naming. @property def name(self) -> str: """Returns the name of the model.""" - if not self.__model or not self.__model.name: + if not self.model_proto or not self.model_proto.name: return "" - return self.__model.name + return self.model_proto.name @name.setter def name(self, name: str): """Sets the name of the model.""" - self.__model.name = name + self.model_proto.name = name # Integer variable. def new_int_var(self, lb: IntegralT, ub: IntegralT, name: str) -> IntVar: @@ -564,7 +313,7 @@ class CpModel: a variable whose domain is [lb, ub]. """ return ( - IntVar(self.__model) + IntVar(self.model_proto) .with_name(name) .with_domain(sorted_interval_list.Domain(lb, ub)) ) @@ -585,20 +334,19 @@ class CpModel: Returns: a variable whose domain is the given domain. """ - return IntVar(self.__model).with_name(name).with_domain(domain) + return IntVar(self.model_proto).with_name(name).with_domain(domain) def new_bool_var(self, name: str) -> IntVar: """Creates a 0-1 variable with the given name.""" return ( - IntVar(self.__model) + IntVar(self.model_proto) .with_name(name) .with_domain(sorted_interval_list.Domain(0, 1)) ) def new_constant(self, value: IntegralT) -> IntVar: """Declares a constant integer.""" - index: int = self.get_or_make_index_from_constant(value) - return IntVar(self.__model, index) + return IntVar(self.model_proto, self.get_or_make_index_from_constant(value)) def new_int_var_series( self, @@ -643,17 +391,13 @@ class CpModel: f" upper_bound={upper_bounds} for variable set={name}" ) - lower_bounds = _convert_to_integral_series_and_validate_index( - lower_bounds, index - ) - upper_bounds = _convert_to_integral_series_and_validate_index( - upper_bounds, index - ) + lower_bounds = _convert_to_series_and_validate_index(lower_bounds, index) + upper_bounds = _convert_to_series_and_validate_index(upper_bounds, index) return pd.Series( index=index, data=[ # pylint: disable=g-complex-comprehension - IntVar(self.__model) + IntVar(self.model_proto) .with_name(f"{name}[{i}]") .with_domain( sorted_interval_list.Domain(lower_bounds[i], upper_bounds[i]) @@ -688,7 +432,7 @@ class CpModel: index=index, data=[ # pylint: disable=g-complex-comprehension - IntVar(self.__model) + IntVar(self.model_proto) .with_name(f"{name}[{i}]") .with_domain(sorted_interval_list.Domain(0, 1)) for i in index @@ -718,7 +462,7 @@ class CpModel: "Cannot add a linear expression containing floating point" f" coefficients or constants: {type(linear_expr).__name__!r}" ) - return self.add(ble) + return self._add_bounded_linear_expression(ble) if isinstance(linear_expr, IntegralTypes): if not domain.contains(int(linear_expr)): return self.add_bool_or([]) # Evaluate to false. @@ -742,15 +486,10 @@ class CpModel: TypeError: If the `ct` is not a `BoundedLinearExpression` or a Boolean. """ if isinstance(ct, BoundedLinearExpression): - return Constraint( - self, - cmh.CpSatHelper.add_bounded_linear_expression_to_model( - ct, self.__model - ), - ) - if ct and arg_is_boolean(ct): + return self._add_bounded_linear_expression(ct) + if ct and self.is_boolean_value(ct): return self.add_bool_or([True]) - if not ct and arg_is_boolean(ct): + if not ct and self.is_boolean_value(ct): return self.add_bool_or([]) # Evaluate to false. raise TypeError(f"not supported: CpModel.add({type(ct).__name__!r})") @@ -773,13 +512,7 @@ class CpModel: Returns: An instance of the `Constraint` class. """ - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - expanded = expand_exprs_generator_or_tuple(expressions) - model_ct.all_diff.exprs.extend( - self.parse_linear_expression(x) for x in expanded - ) - return ct + return self._add_all_different(*expressions) def add_element( self, @@ -807,10 +540,7 @@ class CpModel: expression: LinearExprT = list(expressions)[int(index)] return self.add(expression == target) - return Constraint( - self, - cmh.CpSatHelper.add_element(index, expressions, target, self.__model), - ) + return self._add_element(index, expressions, target) def add_circuit(self, arcs: Sequence[ArcT]) -> Constraint: """Adds Circuit(arcs). @@ -836,13 +566,7 @@ class CpModel: """ if not arcs: raise ValueError("add_circuit expects a non-empty array of arcs") - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - for arc in arcs: - model_ct.circuit.tails.append(arc[0]) - model_ct.circuit.heads.append(arc[1]) - model_ct.circuit.literals.append(self.get_or_make_boolean_index(arc[2])) - return ct + return self._add_circuit(arcs) def add_multiple_circuit(self, arcs: Sequence[ArcT]) -> Constraint: """Adds a multiple circuit constraint, aka the 'VRP' constraint. @@ -870,13 +594,7 @@ class CpModel: """ if not arcs: raise ValueError("add_multiple_circuit expects a non-empty array of arcs") - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - for arc in arcs: - model_ct.routes.tails.append(arc[0]) - model_ct.routes.heads.append(arc[1]) - model_ct.routes.literals.append(self.get_or_make_boolean_index(arc[2])) - return ct + return self._add_routes(arcs) def add_allowed_assignments( self, @@ -910,27 +628,7 @@ class CpModel: "add_allowed_assignments expects a non-empty expressions array" ) - ct: Constraint = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.table.exprs.extend( - [self.parse_linear_expression(e) for e in expressions] - ) - arity: int = len(expressions) - for one_tuple in tuples_list: - if len(one_tuple) != arity: - raise TypeError(f"Tuple {one_tuple!r} has the wrong arity") - - # duck-typing (no explicit type checks here) - try: - for one_tuple in tuples_list: - model_ct.table.values.extend(one_tuple) - except ValueError as ex: - raise TypeError( - "add_xxx_assignment: Not an integer or does not fit in an int64_t:" - f" {type(ex.args).__name__!r}" - ) from ex - - return ct + return self._add_table(expressions, tuples_list, False) def add_forbidden_assignments( self, @@ -963,17 +661,14 @@ class CpModel: "add_forbidden_assignments expects a non-empty expressions array" ) - index: int = len(self.__model.constraints) - ct: Constraint = self.add_allowed_assignments(expressions, tuples_list) - self.__model.constraints[index].table.negated = True - return ct + return self._add_table(expressions, tuples_list, True) def add_automaton( self, transition_expressions: Sequence[LinearExprT], starting_state: IntegralT, final_states: Sequence[IntegralT], - transition_triples: Sequence[Tuple[IntegralT, IntegralT, IntegralT]], + transition_triples: Sequence[tuple[IntegralT, IntegralT, IntegralT]], ) -> Constraint: """Adds an automaton constraint. @@ -1028,21 +723,12 @@ class CpModel: if not transition_triples: raise ValueError("add_automaton expects some transition triples") - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.automaton.exprs.extend( - [self.parse_linear_expression(e) for e in transition_expressions] + return self._add_automaton( + transition_expressions, + starting_state, + final_states, + transition_triples, ) - model_ct.automaton.starting_state = starting_state - for v in final_states: - model_ct.automaton.final_states.append(v) - for t in transition_triples: - if len(t) != 3: - raise TypeError(f"Tuple {t!r} has the wrong arity (!= 3)") - model_ct.automaton.transition_tail.append(t[0]) - model_ct.automaton.transition_label.append(t[1]) - model_ct.automaton.transition_head.append(t[2]) - return ct def add_inverse( self, @@ -1073,18 +759,12 @@ class CpModel: "In the inverse constraint, the two array variables and" " inverse_variables must have the same length." ) - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.inverse.f_direct.extend([self.get_or_make_index(x) for x in variables]) - model_ct.inverse.f_inverse.extend( - [self.get_or_make_index(x) for x in inverse_variables] - ) - return ct + return self._add_inverse(variables, inverse_variables) def add_reservoir_constraint( self, - times: Iterable[LinearExprT], - level_changes: Iterable[LinearExprT], + times: Sequence[LinearExprT], + level_changes: Sequence[LinearExprT], min_level: int, max_level: int, ) -> Constraint: @@ -1123,32 +803,19 @@ class CpModel: ValueError: if min_level > 0 """ - if max_level < min_level: - raise ValueError("Reservoir constraint must have a max_level >= min_level") - - if max_level < 0: - raise ValueError("Reservoir constraint must have a max_level >= 0") - - if min_level > 0: - raise ValueError("Reservoir constraint must have a min_level <= 0") - - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.reservoir.time_exprs.extend( - [self.parse_linear_expression(x) for x in times] + return self._add_reservoir( + times, + level_changes, + [], + min_level, + max_level, ) - model_ct.reservoir.level_changes.extend( - [self.parse_linear_expression(x) for x in level_changes] - ) - model_ct.reservoir.min_level = min_level - model_ct.reservoir.max_level = max_level - return ct def add_reservoir_constraint_with_active( self, - times: Iterable[LinearExprT], - level_changes: Iterable[LinearExprT], - actives: Iterable[LiteralT], + times: Sequence[LinearExprT], + level_changes: Sequence[LinearExprT], + actives: Sequence[LiteralT], min_level: int, max_level: int, ) -> Constraint: @@ -1205,52 +872,28 @@ class CpModel: if min_level > 0: raise ValueError("Reservoir constraint must have a min_level <= 0") - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.reservoir.time_exprs.extend( - [self.parse_linear_expression(x) for x in times] + if not times: + raise ValueError("Reservoir constraint must have a non-empty times array") + + return self._add_reservoir( + times, + level_changes, + actives, + min_level, + max_level, ) - model_ct.reservoir.level_changes.extend( - [self.parse_linear_expression(x) for x in level_changes] - ) - model_ct.reservoir.active_literals.extend( - [self.get_or_make_boolean_index(x) for x in actives] - ) - model_ct.reservoir.min_level = min_level - model_ct.reservoir.max_level = max_level - return ct def add_map_domain( self, var: IntVar, bool_var_array: Iterable[IntVar], offset: IntegralT = 0 ): """Adds `var == i + offset <=> bool_var_array[i] == true for all i`.""" - for i, bool_var in enumerate(bool_var_array): - b_index = bool_var.index - var_index = var.index - model_ct = self.__model.constraints.add() - model_ct.linear.vars.append(var_index) - model_ct.linear.coeffs.append(1) - offset_as_int = int(offset) - model_ct.linear.domain.extend([offset_as_int + i, offset_as_int + i]) - model_ct.enforcement_literal.append(b_index) - - model_ct = self.__model.constraints.add() - model_ct.linear.vars.append(var_index) - model_ct.linear.coeffs.append(1) - model_ct.enforcement_literal.append(-b_index - 1) - if offset + i - 1 >= INT_MIN: - model_ct.linear.domain.extend([INT_MIN, offset_as_int + i - 1]) - if offset + i + 1 <= INT_MAX: - model_ct.linear.domain.extend([offset_as_int + i + 1, INT_MAX]) + self.add(var == i + offset).only_enforce_if(bool_var) + self.add(var != i + offset).only_enforce_if(~bool_var) def add_implication(self, a: LiteralT, b: LiteralT) -> Constraint: """Adds `a => b` (`a` implies `b`).""" - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.bool_or.literals.append(self.get_or_make_boolean_index(b)) - model_ct.enforcement_literal.append(self.get_or_make_boolean_index(a)) - return ct + return self.add_bool_and(b).only_enforce_if(a) @overload def add_bool_or(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -1260,9 +903,9 @@ class CpModel: def add_bool_or(self, *literals): """Adds `Or(literals) == true`: sum(literals) >= 1.""" - lits = self.expand_literals_to_index_list(literals) - index: int = cmh.CpSatHelper.add_bool_or(lits, self.__model) - return Constraint(self, index) + return self._add_bool_argument_constraint( + cmh.BoolArgumentConstraint.bool_or, *literals + ) @overload def add_at_least_one(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -1272,7 +915,9 @@ class CpModel: def add_at_least_one(self, *literals): """Same as `add_bool_or`: `sum(literals) >= 1`.""" - return self.add_bool_or(*literals) + return self._add_bool_argument_constraint( + cmh.BoolArgumentConstraint.bool_or, *literals + ) @overload def add_at_most_one(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -1282,9 +927,9 @@ class CpModel: def add_at_most_one(self, *literals) -> Constraint: """Adds `AtMostOne(literals)`: `sum(literals) <= 1`.""" - lits = self.expand_literals_to_index_list(literals) - index: int = cmh.CpSatHelper.add_at_most_one(lits, self.__model) - return Constraint(self, index) + return self._add_bool_argument_constraint( + cmh.BoolArgumentConstraint.at_most_one, *literals + ) @overload def add_exactly_one(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -1294,9 +939,9 @@ class CpModel: def add_exactly_one(self, *literals): """Adds `ExactlyOne(literals)`: `sum(literals) == 1`.""" - lits = self.expand_literals_to_index_list(literals) - index: int = cmh.CpSatHelper.add_exactly_one(lits, self.__model) - return Constraint(self, index) + return self._add_bool_argument_constraint( + cmh.BoolArgumentConstraint.exactly_one, *literals + ) @overload def add_bool_and(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -1306,9 +951,9 @@ class CpModel: def add_bool_and(self, *literals): """Adds `And(literals) == true`.""" - lits = self.expand_literals_to_index_list(literals) - index: int = cmh.CpSatHelper.add_bool_and(lits, self.__model) - return Constraint(self, index) + return self._add_bool_argument_constraint( + cmh.BoolArgumentConstraint.bool_and, *literals + ) @overload def add_bool_xor(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -1328,9 +973,9 @@ class CpModel: Returns: An `Constraint` object. """ - lits = self.expand_literals_to_index_list(literals) - index: int = cmh.CpSatHelper.add_bool_xor(lits, self.__model) - return Constraint(self, index) + return self._add_bool_argument_constraint( + cmh.BoolArgumentConstraint.bool_xor, *literals + ) @overload def add_min_equality( @@ -1344,15 +989,8 @@ class CpModel: def add_min_equality(self, target, *expressions) -> Constraint: """Adds `target == Min(expressions)`.""" - exprs = [e for e in expand_exprs_generator_or_tuple(expressions)] - return Constraint( - self, - cmh.CpSatHelper.add_linear_argument_constraint( - "min", - target, - exprs, - self.__model, - ), + return self._add_linear_argument_constraint( + cmh.LinearArgumentConstraint.min, target, *expressions ) @overload @@ -1367,35 +1005,22 @@ class CpModel: def add_max_equality(self, target, *expressions) -> Constraint: """Adds `target == Max(expressions)`.""" - exprs = [e for e in expand_exprs_generator_or_tuple(expressions)] - return Constraint( - self, - cmh.CpSatHelper.add_linear_argument_constraint( - "max", - target, - exprs, - self.__model, - ), + return self._add_linear_argument_constraint( + cmh.LinearArgumentConstraint.max, target, *expressions ) def add_division_equality( self, target: LinearExprT, num: LinearExprT, denom: LinearExprT ) -> Constraint: """Adds `target == num // denom` (integer division rounded towards 0).""" - return Constraint( - self, - cmh.CpSatHelper.add_linear_argument_constraint( - "div", target, [num, denom], self.__model - ), + return self._add_linear_argument_constraint( + cmh.LinearArgumentConstraint.div, target, [num, denom] ) def add_abs_equality(self, target: LinearExprT, expr: LinearExprT) -> Constraint: """Adds `target == Abs(expr)`.""" - return Constraint( - self, - cmh.CpSatHelper.add_linear_argument_constraint( - "max", target, [expr, -expr], self.__model - ), + return self._add_linear_argument_constraint( + cmh.LinearArgumentConstraint.max, target, [expr, -expr] ) def add_modulo_equality( @@ -1420,11 +1045,8 @@ class CpModel: Returns: A `Constraint` object. """ - return Constraint( - self, - cmh.CpSatHelper.add_linear_argument_constraint( - "mod", target, [expr, mod], self.__model - ), + return self._add_linear_argument_constraint( + cmh.LinearArgumentConstraint.mod, target, [expr, mod] ) def add_multiplication_equality( @@ -1433,14 +1055,8 @@ class CpModel: *expressions: Union[Iterable[LinearExprT], LinearExprT], ) -> Constraint: """Adds `target == expressions[0] * .. * expressions[n]`.""" - return Constraint( - self, - cmh.CpSatHelper.add_linear_argument_constraint( - "prod", - target, - expand_exprs_generator_or_tuple(expressions), - self.__model, - ), + return self._add_linear_argument_constraint( + cmh.LinearArgumentConstraint.prod, target, *expressions ) # Scheduling support @@ -1464,30 +1080,7 @@ class CpModel: Returns: An `IntervalVar` object. """ - - start_expr = self.parse_linear_expression(start) - size_expr = self.parse_linear_expression(size) - end_expr = self.parse_linear_expression(end) - if len(start_expr.vars) > 1: - raise TypeError( - "cp_model.new_interval_var: start must be 1-var affine or constant." - ) - if len(size_expr.vars) > 1: - raise TypeError( - "cp_model.new_interval_var: size must be 1-var affine or constant." - ) - if len(end_expr.vars) > 1: - raise TypeError( - "cp_model.new_interval_var: end must be 1-var affine or constant." - ) - return IntervalVar( - self.__model, - start_expr, - size_expr, - end_expr, - None, - name, - ) + return self._new_interval_var(name, start, size, end, []) def new_interval_var_series( self, @@ -1526,9 +1119,9 @@ class CpModel: if not name.isidentifier(): raise ValueError(f"name={name!r} is not a valid identifier") - starts = _convert_to_linear_expr_series_and_validate_index(starts, index) - sizes = _convert_to_linear_expr_series_and_validate_index(sizes, index) - ends = _convert_to_linear_expr_series_and_validate_index(ends, index) + starts = _convert_to_series_and_validate_index(starts, index) + sizes = _convert_to_series_and_validate_index(sizes, index) + ends = _convert_to_series_and_validate_index(ends, index) interval_array = [] for i in index: interval_array.append( @@ -1557,21 +1150,7 @@ class CpModel: Returns: An `IntervalVar` object. """ - start_expr = self.parse_linear_expression(start) - size_expr = self.parse_linear_expression(size) - end_expr = self.parse_linear_expression(start + size) - if len(start_expr.vars) > 1: - raise TypeError( - "cp_model.new_interval_var: start must be affine or constant." - ) - return IntervalVar( - self.__model, - start_expr, - size_expr, - end_expr, - None, - name, - ) + return self._new_interval_var(name, start, size, start + size, []) def new_fixed_size_interval_var_series( self, @@ -1606,8 +1185,8 @@ class CpModel: if not name.isidentifier(): raise ValueError(f"name={name!r} is not a valid identifier") - starts = _convert_to_linear_expr_series_and_validate_index(starts, index) - sizes = _convert_to_integral_series_and_validate_index(sizes, index) + starts = _convert_to_series_and_validate_index(starts, index) + sizes = _convert_to_series_and_validate_index(sizes, index) interval_array = [] for i in index: interval_array.append( @@ -1647,31 +1226,12 @@ class CpModel: Returns: An `IntervalVar` object. """ - - # Creates the IntervalConstraintProto object. - is_present_index = self.get_or_make_boolean_index(is_present) - start_expr = self.parse_linear_expression(start) - size_expr = self.parse_linear_expression(size) - end_expr = self.parse_linear_expression(end) - if len(start_expr.vars) > 1: - raise TypeError( - "cp_model.new_interval_var: start must be affine or constant." - ) - if len(size_expr.vars) > 1: - raise TypeError( - "cp_model.new_interval_var: size must be affine or constant." - ) - if len(end_expr.vars) > 1: - raise TypeError( - "cp_model.new_interval_var: end must be affine or constant." - ) - return IntervalVar( - self.__model, - start_expr, - size_expr, - end_expr, - is_present_index, + return self._new_interval_var( name, + start, + size, + end, + [is_present], ) def new_optional_interval_var_series( @@ -1715,10 +1275,10 @@ class CpModel: if not name.isidentifier(): raise ValueError(f"name={name!r} is not a valid identifier") - starts = _convert_to_linear_expr_series_and_validate_index(starts, index) - sizes = _convert_to_linear_expr_series_and_validate_index(sizes, index) - ends = _convert_to_linear_expr_series_and_validate_index(ends, index) - are_present = _convert_to_literal_series_and_validate_index(are_present, index) + starts = _convert_to_series_and_validate_index(starts, index) + sizes = _convert_to_series_and_validate_index(sizes, index) + ends = _convert_to_series_and_validate_index(ends, index) + are_present = _convert_to_series_and_validate_index(are_present, index) interval_array = [] for i in index: @@ -1755,21 +1315,12 @@ class CpModel: Returns: An `IntervalVar` object. """ - start_expr = self.parse_linear_expression(start) - size_expr = self.parse_linear_expression(size) - end_expr = self.parse_linear_expression(start + size) - if len(start_expr.vars) > 1: - raise TypeError( - "cp_model.new_interval_var: start must be affine or constant." - ) - is_present_index = self.get_or_make_boolean_index(is_present) - return IntervalVar( - self.__model, - start_expr, - size_expr, - end_expr, - is_present_index, + return self._new_interval_var( name, + start, + size, + start + size, + [is_present], ) def new_optional_fixed_size_interval_var_series( @@ -1809,9 +1360,9 @@ class CpModel: if not name.isidentifier(): raise ValueError(f"name={name!r} is not a valid identifier") - starts = _convert_to_linear_expr_series_and_validate_index(starts, index) - sizes = _convert_to_integral_series_and_validate_index(sizes, index) - are_present = _convert_to_literal_series_and_validate_index(are_present, index) + starts = _convert_to_series_and_validate_index(starts, index) + sizes = _convert_to_series_and_validate_index(sizes, index) + are_present = _convert_to_series_and_validate_index(are_present, index) interval_array = [] for i in index: interval_array.append( @@ -1824,24 +1375,19 @@ class CpModel: ) return pd.Series(index=index, data=interval_array) - def add_no_overlap(self, interval_vars: Iterable[IntervalVar]) -> Constraint: + def add_no_overlap(self, intervals: Iterable[IntervalVar]) -> Constraint: """Adds NoOverlap(interval_vars). A NoOverlap constraint ensures that all present intervals do not overlap in time. Args: - interval_vars: The list of interval variables to constrain. + intervals: The list of interval variables to constrain. Returns: An instance of the `Constraint` class. """ - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.no_overlap.intervals.extend( - [self.get_interval_index(x) for x in interval_vars] - ) - return ct + return self._add_no_overlap(intervals) def add_no_overlap_2d( self, @@ -1864,15 +1410,7 @@ class CpModel: Returns: An instance of the `Constraint` class. """ - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.no_overlap_2d.x_intervals.extend( - [self.get_interval_index(x) for x in x_intervals] - ) - model_ct.no_overlap_2d.y_intervals.extend( - [self.get_interval_index(x) for x in y_intervals] - ) - return ct + return self._add_no_overlap_2d(x_intervals, y_intervals) def add_cumulative( self, @@ -1899,15 +1437,7 @@ class CpModel: Returns: An instance of the `Constraint` class. """ - cumulative = Constraint(self) - model_ct = self.__model.constraints[cumulative.index] - model_ct.cumulative.intervals.extend( - [self.get_interval_index(x) for x in intervals] - ) - for d in demands: - model_ct.cumulative.demands.append(self.parse_linear_expression(d)) - model_ct.cumulative.capacity.copy_from(self.parse_linear_expression(capacity)) - return cumulative + return self._add_cumulative(intervals, demands, capacity) # Support for model cloning. def clone(self) -> "CpModel": @@ -1917,19 +1447,19 @@ class CpModel: clone.rebuild_constant_map() return clone - def rebuild_constant_map(self): - """Internal method used during model cloning.""" - for i, var in enumerate(self.__model.variables): - if len(var.domain) == 2 and var.domain[0] == var.domain[1]: - self.__constant_map[var.domain[0]] = i + def __copy__(self): + return CpModel(self.model_proto) + + def __deepcopy__(self, memo): + return CpModel(copy.deepcopy(self.model_proto, memo)) def get_bool_var_from_proto_index(self, index: int) -> IntVar: """Returns an already created Boolean variable from its index.""" - if index < 0 or index >= len(self.__model.variables): + if index < 0 or index >= len(self.model_proto.variables): raise ValueError( f"get_bool_var_from_proto_index: out of bound index {index}" ) - result = IntVar(self.__model, index) + result = IntVar(self.model_proto, index) if not result.is_boolean: raise TypeError( f"get_bool_var_from_proto_index: index {index} does not reference a" @@ -1939,133 +1469,67 @@ class CpModel: def get_int_var_from_proto_index(self, index: int) -> IntVar: """Returns an already created integer variable from its index.""" - if index < 0 or index >= len(self.__model.variables): + if index < 0 or index >= len(self.model_proto.variables): raise ValueError( f"get_int_var_from_proto_index: out of bound index {index}" ) - return IntVar(self.__model, index) + return IntVar(self.model_proto, index) def get_interval_var_from_proto_index(self, index: int) -> IntervalVar: """Returns an already created interval variable from its index.""" - if index < 0 or index >= len(self.__model.constraints): + if index < 0 or index >= len(self.model_proto.constraints): raise ValueError( f"get_interval_var_from_proto_index: out of bound index {index}" ) - ct = self.__model.constraints[index] + ct = self.model_proto.constraints[index] if not ct.has_interval(): raise ValueError( f"get_interval_var_from_proto_index: index {index} does not" " reference an" + " interval variable" ) - return IntervalVar(self.__model, index, None, None, None, None) - - # Helpers. + return IntervalVar(self.model_proto, index) def __str__(self) -> str: - return str(self.__model) + return str(self.model_proto) @property def proto(self) -> cmh.CpModelProto: """Returns the underlying CpModelProto.""" - return self.__model + return self.model_proto def negated(self, index: int) -> int: return -index - 1 - def get_or_make_index(self, arg: VariableT) -> int: - """Returns the index of a variable, its negation, or a number.""" - if isinstance(arg, IntVar): - return arg.index - if isinstance(arg, IntegralTypes): - return self.get_or_make_index_from_constant(arg) - raise TypeError( - f"NotSupported: model.get_or_make_index({type(arg).__name__!r})" - ) - - def get_or_make_boolean_index(self, arg: LiteralT) -> int: - """Returns an index from a boolean expression.""" - if isinstance(arg, IntVar): - self.assert_is_boolean_variable(arg) - return arg.index - if isinstance(arg, cmh.NotBooleanVariable): - self.assert_is_boolean_variable(arg.negated()) - return arg.index - if isinstance(arg, IntegralTypes): - if arg == ~int(False): - return self.get_or_make_index_from_constant(1) - if arg == ~int(True): - return self.get_or_make_index_from_constant(0) - arg_as_int: int = int(arg) - if arg_as_int < 0 or arg_as_int > 1: - raise TypeError(f"Not a boolean: {arg}") - return self.get_or_make_index_from_constant(arg_as_int) - if arg_is_boolean(arg): - return self.get_or_make_index_from_constant(int(arg)) - raise TypeError( - "not supported:" f" model.get_or_make_boolean_index({type(arg).__name__!r})" - ) - - def get_interval_index(self, arg: IntervalVar) -> int: - if not isinstance(arg, IntervalVar): - raise TypeError( - f"NotSupported: model.get_interval_index({type(arg).__name__!r})" - ) - return arg.index - - def get_or_make_index_from_constant(self, value: IntegralT) -> int: - if value in self.__constant_map: - return self.__constant_map[value] - constant_var = self.new_int_var(value, value, "") - self.__constant_map[value] = constant_var.index - return constant_var.index - - def parse_linear_expression( - self, linear_expr: LinearExprT, negate: bool = False - ) -> cmh.LinearExpressionProto: - """Returns a LinearExpressionProto built from a LinearExpr instance.""" - result: cmh.LinearExpressionProto = cmh.LinearExpressionProto() - mult = -1 if negate else 1 - if isinstance(linear_expr, IntegralTypes): - result.offset = int(linear_expr) * mult - return result - - # Raises TypeError if linear_expr is not an integer. - flat_expr = cmh.FlatIntExpr(linear_expr) - result.offset = flat_expr.offset * mult - for var in flat_expr.vars: - result.vars.append(var.index) - for coeff in flat_expr.coeffs: - result.coeffs.append(coeff * mult) - return result - - def _set_objective(self, obj: ObjLinearExprT, minimize: bool): + def _set_objective(self, obj: ObjLinearExprT, maximize: bool): """Sets the objective of the model.""" self.clear_objective() if isinstance(obj, IntegralTypes): - self.__model.objective.offset = int(obj) - self.__model.objective.scaling_factor = 1.0 + self.model_proto.objective.offset = int(obj) + self.model_proto.objective.scaling_factor = 1.0 elif isinstance(obj, LinearExpr): if obj.is_integer(): int_obj = cmh.FlatIntExpr(obj) for var in int_obj.vars: - self.__model.objective.vars.append(var.index) - if minimize: - self.__model.objective.scaling_factor = 1.0 - self.__model.objective.offset = int_obj.offset - self.__model.objective.coeffs.extend(int_obj.coeffs) - else: - self.__model.objective.scaling_factor = -1.0 - self.__model.objective.offset = -int_obj.offset + self.model_proto.objective.vars.append(var.index) + if maximize: + self.model_proto.objective.scaling_factor = -1.0 + self.model_proto.objective.offset = -int_obj.offset for c in int_obj.coeffs: - self.__model.objective.coeffs.append(-c) + self.model_proto.objective.coeffs.append(-c) + else: + self.model_proto.objective.scaling_factor = 1.0 + self.model_proto.objective.offset = int_obj.offset + self.model_proto.objective.coeffs.extend(int_obj.coeffs) else: float_obj = cmh.FlatFloatExpr(obj) for var in float_obj.vars: - self.__model.floating_point_objective.vars.append(var.index) - self.__model.floating_point_objective.coeffs.extend(float_obj.coeffs) - self.__model.floating_point_objective.maximize = not minimize - self.__model.floating_point_objective.offset = float_obj.offset + self.model_proto.floating_point_objective.vars.append(var.index) + self.model_proto.floating_point_objective.coeffs.extend( + float_obj.coeffs + ) + self.model_proto.floating_point_objective.maximize = maximize + self.model_proto.floating_point_objective.offset = float_obj.offset else: raise TypeError( f"TypeError: {type(obj).__name__!r} is not a valid objective" @@ -2073,20 +1537,21 @@ class CpModel: def minimize(self, obj: ObjLinearExprT): """Sets the objective of the model to minimize(obj).""" - self._set_objective(obj, minimize=True) + self._set_objective(obj, maximize=False) def maximize(self, obj: ObjLinearExprT): """Sets the objective of the model to maximize(obj).""" - self._set_objective(obj, minimize=False) + self._set_objective(obj, maximize=True) def has_objective(self) -> bool: return ( - self.__model.has_objective() or self.__model.has_floating_point_objective() + self.model_proto.has_objective() + or self.model_proto.has_floating_point_objective() ) def clear_objective(self): - self.__model.clear_objective() - self.__model.clear_floating_point_objective() + self.model_proto.clear_objective() + self.model_proto.clear_floating_point_objective() def add_decision_strategy( self, @@ -2105,7 +1570,7 @@ class CpModel: solve() will fail. """ - strategy: cmh.DecisionStrategyProto = self.__model.search_strategy.add() + strategy: cmh.DecisionStrategyProto = self.model_proto.search_strategy.add() for v in variables: expr = strategy.exprs.add() if v.index >= 0: @@ -2121,11 +1586,11 @@ class CpModel: def model_stats(self) -> str: """Returns a string containing some model statistics.""" - return cmh.CpSatHelper.model_stats(self.__model) + return cmh.CpSatHelper.model_stats(self.model_proto) def validate(self) -> str: """Returns a string indicating that the model is invalid.""" - return cmh.CpSatHelper.validate_model(self.__model) + return cmh.CpSatHelper.validate_model(self.model_proto) def export_to_file(self, file: str) -> bool: """Write the model as a protocol buffer to 'file'. @@ -2138,14 +1603,14 @@ class CpModel: Returns: True if the model was correctly written. """ - return cmh.CpSatHelper.write_model_to_file(self.__model, file) + return cmh.CpSatHelper.write_model_to_file(self.model_proto, file) def remove_all_names(self) -> None: """Removes all names from the model.""" - self.__model.clear_name() - for v in self.__model.variables: + self.model_proto.clear_name() + for v in self.model_proto.variables: v.clear_name() - for c in self.__model.constraints: + for c in self.model_proto.constraints: c.clear_name() @overload @@ -2157,19 +1622,19 @@ class CpModel: def add_hint(self, var, value) -> None: """Adds 'var == value' as a hint to the solver.""" if var.index >= 0: - self.__model.solution_hint.vars.append(self.get_or_make_index(var)) - self.__model.solution_hint.values.append(int(value)) + self.model_proto.solution_hint.vars.append(var.index) + self.model_proto.solution_hint.values.append(int(value)) else: - self.__model.solution_hint.vars.append(self.negated(var.index)) - self.__model.solution_hint.values.append(int(not value)) + self.model_proto.solution_hint.vars.append(self.negated(var.index)) + self.model_proto.solution_hint.values.append(int(not value)) def clear_hints(self): """Removes any solution hint from the model.""" - self.__model.clear_solution_hint() + self.model_proto.clear_solution_hint() def add_assumption(self, lit: LiteralT) -> None: """Adds the literal to the model as assumptions.""" - self.__model.assumptions.append(self.get_or_make_boolean_index(lit)) + self.model_proto.assumptions.append(self.get_or_make_boolean_index(lit)) def add_assumptions(self, literals: Iterable[LiteralT]) -> None: """Adds the literals to the model as assumptions.""" @@ -2178,41 +1643,7 @@ class CpModel: def clear_assumptions(self) -> None: """Removes all assumptions from the model.""" - self.__model.clear_assumptions() - - # Helpers. - def assert_is_boolean_variable(self, x: LiteralT) -> None: - if isinstance(x, IntVar): - var = self.__model.variables[x.index] - if len(var.domain) != 2 or var.domain[0] < 0 or var.domain[1] > 1: - raise TypeError( - f"TypeError: {type(x).__name__!r} is not a boolean variable" - ) - elif not isinstance(x, cmh.NotBooleanVariable): - raise TypeError( - f"TypeError: {type(x).__name__!r} is not a boolean variable" - ) - - def expand_literals_generator_or_tuple( - self, args: Union[Tuple[LiteralT, ...], Iterable[LiteralT]] - ): - if hasattr(args, "__len__"): # Tuple - if len(args) != 1: - return args - if isinstance(args[0], (NumberTypes, cmh.Literal)): - return args - # Generator - return args[0] - - def expand_literals_to_index_list( - self, - literals: Union[Tuple[LiteralT, ...], Iterable[LiteralT]], - ) -> list[int]: - """Expands a tuple or generator of literals to a list of indices.""" - return [ - self.get_or_make_boolean_index(lit) - for lit in self.expand_literals_generator_or_tuple(literals) - ] + self.model_proto.clear_assumptions() # Compatibility with pre PEP8 # pylint: disable=invalid-name @@ -2226,48 +1657,7 @@ class CpModel: def Proto(self) -> cmh.CpModelProto: return self.proto - NewIntVar = new_int_var - NewIntVarFromDomain = new_int_var_from_domain - NewBoolVar = new_bool_var - NewConstant = new_constant - NewIntVarSeries = new_int_var_series - NewBoolVarSeries = new_bool_var_series - AddLinearConstraint = add_linear_constraint - AddLinearExpressionInDomain = add_linear_expression_in_domain Add = add - AddAllDifferent = add_all_different - AddElement = add_element - AddCircuit = add_circuit - AddMultipleCircuit = add_multiple_circuit - AddAllowedAssignments = add_allowed_assignments - AddForbiddenAssignments = add_forbidden_assignments - AddAutomaton = add_automaton - AddInverse = add_inverse - AddReservoirConstraint = add_reservoir_constraint - AddReservoirConstraintWithActive = add_reservoir_constraint_with_active - AddImplication = add_implication - AddBoolOr = add_bool_or - AddAtLeastOne = add_at_least_one - AddAtMostOne = add_at_most_one - AddExactlyOne = add_exactly_one - AddBoolAnd = add_bool_and - AddBoolXOr = add_bool_xor - AddMinEquality = add_min_equality - AddMaxEquality = add_max_equality - AddDivisionEquality = add_division_equality - AddAbsEquality = add_abs_equality - AddModuloEquality = add_modulo_equality - AddMultiplicationEquality = add_multiplication_equality - NewIntervalVar = new_interval_var - NewIntervalVarSeries = new_interval_var_series - NewFixedSizeIntervalVar = new_fixed_size_interval_var - NewOptionalIntervalVar = new_optional_interval_var - NewOptionalIntervalVarSeries = new_optional_interval_var_series - NewOptionalFixedSizeIntervalVar = new_optional_fixed_size_interval_var - NewOptionalFixedSizeIntervalVarSeries = new_optional_fixed_size_interval_var_series - AddNoOverlap = add_no_overlap - AddNoOverlap2D = add_no_overlap_2d - AddCumulative = add_cumulative Clone = clone GetBoolVarFromProtoIndex = get_bool_var_from_proto_index GetIntVarFromProtoIndex = get_int_var_from_proto_index @@ -2275,32 +1665,16 @@ class CpModel: Minimize = minimize Maximize = maximize HasObjective = has_objective - ClearObjective = clear_objective - AddDecisionStrategy = add_decision_strategy ModelStats = model_stats Validate = validate ExportToFile = export_to_file - AddHint = add_hint - ClearHints = clear_hints - AddAssumption = add_assumption - AddAssumptions = add_assumptions - ClearAssumptions = clear_assumptions + + # add_XXX, new_XXX, and clear_XXX methods are already duplicated + # automatically. # pylint: enable=invalid-name -def expand_exprs_generator_or_tuple( - expressions: Union[Tuple[LinearExprT, ...], Iterable[LinearExprT]], -) -> Union[Iterable[LinearExprT], LinearExprT]: - if hasattr(expressions, "__len__"): # Tuple - if len(expressions) != 1: - return expressions - if isinstance(expressions[0], (NumberTypes, LinearExpr)): - return expressions - # Generator - return expressions[0] - - class CpSolver: """Main solver class. @@ -2313,7 +1687,7 @@ class CpSolver: """ def __init__(self) -> None: - self.__response_wrapper: Optional[cmh.ResponseWrapper] = None + self.__response: Optional[cmh.CpSolverResponse] = None self.parameters: cmh.SatParameters = cmh.SatParameters() self.log_callback: Optional[Callable[[str], None]] = None self.best_bound_callback: Optional[Callable[[float], None]] = None @@ -2339,9 +1713,7 @@ class CpSolver: if self.best_bound_callback is not None: self.__solve_wrapper.add_best_bound_callback(self.best_bound_callback) - self.__response_wrapper = ( - self.__solve_wrapper.solve_and_return_response_wrapper(model.proto) - ) + self.__response = self.__solve_wrapper.solve(model.proto) if solution_callback is not None: self.__solve_wrapper.clear_solution_callback(solution_callback) @@ -2349,7 +1721,7 @@ class CpSolver: with self.__lock: self.__solve_wrapper = None - return self.__response_wrapper.status() + return self.__response.status def stop_search(self) -> None: """Stops the current search asynchronously.""" @@ -2359,7 +1731,7 @@ class CpSolver: def value(self, expression: LinearExprT) -> int: """Returns the value of a linear expression after solve.""" - return self._checked_response.value(expression) + return cmh.ResponseHelper.value(self._checked_response, expression) def values(self, variables: _IndexOrSeries) -> pd.Series: """Returns the values of the input variables. @@ -2379,16 +1751,15 @@ class CpSolver: Raises: RuntimeError: if solve() has not been called. """ - if self.__response_wrapper is None: - raise RuntimeError("solve() has not been called.") + response: cmh.CpSolverResponse = self._checked_response return pd.Series( - data=[self.__response_wrapper.value(var) for var in variables], + data=[cmh.ResponseHelper.value(response, var) for var in variables], index=_get_index(variables), ) def float_value(self, expression: LinearExprT) -> float: """Returns the value of a linear expression after solve.""" - return self._checked_response.float_value(expression) + return cmh.ResponseHelper.float_value(self._checked_response, expression) def float_values(self, expressions: _IndexOrSeries) -> pd.Series: """Returns the float values of the input linear expressions. @@ -2408,16 +1779,17 @@ class CpSolver: Raises: RuntimeError: if solve() has not been called. """ - if self.__response_wrapper is None: - raise RuntimeError("solve() has not been called.") + response: cmh.CpSolverResponse = self._checked_response return pd.Series( - data=[self.__response_wrapper.float_value(expr) for expr in expressions], + data=[ + cmh.ResponseHelper.float_value(response, expr) for expr in expressions + ], index=_get_index(expressions), ) def boolean_value(self, literal: LiteralT) -> bool: """Returns the boolean value of a literal after solve.""" - return self._checked_response.boolean_value(literal) + return cmh.ResponseHelper.boolean_value(self._checked_response, literal) def boolean_values(self, variables: _IndexOrSeries) -> pd.Series: """Returns the values of the input variables. @@ -2437,11 +1809,11 @@ class CpSolver: Raises: RuntimeError: if solve() has not been called. """ - if self.__response_wrapper is None: - raise RuntimeError("solve() has not been called.") + response: cmh.CpSolverResponse = self._checked_response return pd.Series( data=[ - self.__response_wrapper.boolean_value(literal) for literal in variables + cmh.ResponseHelper.boolean_value(response, literal) + for literal in variables ], index=_get_index(variables), ) @@ -2449,65 +1821,67 @@ class CpSolver: @property def objective_value(self) -> float: """Returns the value of the objective after solve.""" - return self._checked_response.objective_value() + return self._checked_response.objective_value @property def best_objective_bound(self) -> float: """Returns the best lower (upper) bound found when min(max)imizing.""" - return self._checked_response.best_objective_bound() + return self._checked_response.best_objective_bound @property def num_booleans(self) -> int: """Returns the number of boolean variables managed by the SAT solver.""" - return self._checked_response.num_booleans() + return self._checked_response.num_booleans @property def num_conflicts(self) -> int: """Returns the number of conflicts since the creation of the solver.""" - return self._checked_response.num_conflicts() + return self._checked_response.num_conflicts @property def num_branches(self) -> int: """Returns the number of search branches explored by the solver.""" - return self._checked_response.num_branches() + return self._checked_response.num_branches @property def num_boolean_propagations(self) -> int: """Returns the number of Boolean propagations done by the solver.""" - return self._checked_response.num_boolean_propagations() + return self._checked_response.num_boolean_propagations @property def num_integer_propagations(self) -> int: """Returns the number of integer propagations done by the solver.""" - return self._checked_response.num_integer_propagations() + return self._checked_response.num_integer_propagations @property def deterministic_time(self) -> float: """Returns the deterministic time in seconds since the creation of the solver.""" - return self._checked_response.deterministic_time() + return self._checked_response.deterministic_time @property def wall_time(self) -> float: """Returns the wall time in seconds since the creation of the solver.""" - return self._checked_response.wall_time() + return self._checked_response.wall_time @property def user_time(self) -> float: """Returns the user time in seconds since the creation of the solver.""" - return self._checked_response.user_time() + return self._checked_response.user_time @property def response_proto(self) -> cmh.CpSolverResponse: """Returns the response object.""" - return self._checked_response.response() + return self._checked_response def response_stats(self) -> str: """Returns some statistics on the solution found as a string.""" - return self._checked_response.response_stats() + return cmh.CpSatHelper.solver_response_stats(self._checked_response) def sufficient_assumptions_for_infeasibility(self) -> Sequence[int]: """Returns the indices of the infeasible assumptions.""" - return self._checked_response.sufficient_assumptions_for_infeasibility() + return cmh.ResponseHelper.sufficient_assumptions_for_infeasibility( + self._checked_response + ) def status_name(self, status: Optional[Any] = None) -> str: """Returns the name of the status returned by solve().""" @@ -2524,14 +1898,14 @@ class CpSolver: Raises: RuntimeError: if solve() has not been called. """ - return self._checked_response.solution_info() + return self._checked_response.solution_info @property - def _checked_response(self) -> cmh.ResponseWrapper: + def _checked_response(self) -> cmh.CpSolverResponse: """Checks solve() has been called, and returns a response wrapper.""" - if self.__response_wrapper is None: + if self.__response is None: raise RuntimeError("solve() has not been called.") - return self.__response_wrapper + return self.__response # Compatibility with pre PEP8 # pylint: disable=invalid-name @@ -2539,11 +1913,8 @@ class CpSolver: def BestObjectiveBound(self) -> float: return self.best_objective_bound - def BooleanValue(self, literal: LiteralT) -> bool: - return self.boolean_value(literal) - - def BooleanValues(self, variables: _IndexOrSeries) -> pd.Series: - return self.boolean_values(variables) + BooleanValue = boolean_value + BooleanValues = boolean_values def NumBooleans(self) -> int: return self.num_booleans @@ -2560,36 +1931,18 @@ class CpSolver: def ResponseProto(self) -> cmh.CpSolverResponse: return self.response_proto - def ResponseStats(self) -> str: - return self.response_stats() - - def Solve( - self, - model: CpModel, - solution_callback: Optional["CpSolverSolutionCallback"] = None, - ) -> cmh.CpSolverStatus: - return self.solve(model, solution_callback) - - def SolutionInfo(self) -> str: - return self.solution_info() - - def StatusName(self, status: Optional[Any] = None) -> str: - return self.status_name(status) - - def StopSearch(self) -> None: - self.stop_search() - - def SufficientAssumptionsForInfeasibility(self) -> Sequence[int]: - return self.sufficient_assumptions_for_infeasibility() + ResponseStats = response_stats + Solve = solve + SolutionInfo = solution_info + StatusName = status_name + StopSearch = stop_search + SufficientAssumptionsForInfeasibility = sufficient_assumptions_for_infeasibility def UserTime(self) -> float: return self.user_time - def Value(self, expression: LinearExprT) -> int: - return self.value(expression) - - def Values(self, variables: _IndexOrSeries) -> pd.Series: - return self.values(variables) + Value = value + Values = values def WallTime(self) -> float: return self.wall_time @@ -2668,10 +2021,13 @@ class CpSolverSolutionCallback(cmh.SolutionCallback): def __init__(self) -> None: cmh.SolutionCallback.__init__(self) + # pylint: disable=invalid-name def OnSolutionCallback(self) -> None: """Proxy for the same method in snake case.""" self.on_solution_callback() + # pylint: enable=invalid-name + def boolean_value(self, lit: LiteralT) -> bool: """Returns the boolean value of a boolean literal. @@ -2886,91 +2242,3 @@ class VarArraySolutionPrinter(CpSolverSolutionCallback): def solution_count(self) -> int: """Returns the number of solutions found.""" return self.__solution_count - - -def _get_index(obj: _IndexOrSeries) -> pd.Index: - """Returns the indices of `obj` as a `pd.Index`.""" - if isinstance(obj, pd.Series): - return obj.index - return obj - - -def _convert_to_integral_series_and_validate_index( - value_or_series: Union[IntegralT, pd.Series], index: pd.Index -) -> pd.Series: - """Returns a pd.Series of the given index with the corresponding values. - - Args: - value_or_series: the values to be converted (if applicable). - index: the index of the resulting pd.Series. - - Returns: - pd.Series: The set of values with the given index. - - Raises: - TypeError: If the type of `value_or_series` is not recognized. - ValueError: If the index does not match. - """ - if isinstance(value_or_series, IntegralTypes): - return pd.Series(data=value_or_series, index=index) - elif isinstance(value_or_series, pd.Series): - if value_or_series.index.equals(index): - return value_or_series - else: - raise ValueError("index does not match") - else: - raise TypeError(f"invalid type={type(value_or_series).__name__!r}") - - -def _convert_to_linear_expr_series_and_validate_index( - value_or_series: Union[LinearExprT, pd.Series], index: pd.Index -) -> pd.Series: - """Returns a pd.Series of the given index with the corresponding values. - - Args: - value_or_series: the values to be converted (if applicable). - index: the index of the resulting pd.Series. - - Returns: - pd.Series: The set of values with the given index. - - Raises: - TypeError: If the type of `value_or_series` is not recognized. - ValueError: If the index does not match. - """ - if isinstance(value_or_series, IntegralTypes): - return pd.Series(data=value_or_series, index=index) - elif isinstance(value_or_series, pd.Series): - if value_or_series.index.equals(index): - return value_or_series - else: - raise ValueError("index does not match") - else: - raise TypeError(f"invalid type={type(value_or_series).__name__!r}") - - -def _convert_to_literal_series_and_validate_index( - value_or_series: Union[LiteralT, pd.Series], index: pd.Index -) -> pd.Series: - """Returns a pd.Series of the given index with the corresponding values. - - Args: - value_or_series: the values to be converted (if applicable). - index: the index of the resulting pd.Series. - - Returns: - pd.Series: The set of values with the given index. - - Raises: - TypeError: If the type of `value_or_series` is not recognized. - ValueError: If the index does not match. - """ - if isinstance(value_or_series, IntegralTypes): - return pd.Series(data=value_or_series, index=index) - elif isinstance(value_or_series, pd.Series): - if value_or_series.index.equals(index): - return value_or_series - else: - raise ValueError("index does not match") - else: - raise TypeError(f"invalid type={type(value_or_series).__name__!r}") diff --git a/ortools/sat/python/cp_model_helper.cc b/ortools/sat/python/cp_model_helper.cc index 50f5bc2752..6675d7486e 100644 --- a/ortools/sat/python/cp_model_helper.cc +++ b/ortools/sat/python/cp_model_helper.cc @@ -18,19 +18,25 @@ #include #include #include +#include #include +#include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" +#include "absl/log/check.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "ortools/base/string_view_migration.h" #include "ortools/port/proto_utils.h" // IWYU: keep #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/python/linear_expr.h" #include "ortools/sat/python/linear_expr_doc.h" +#include "ortools/sat/sat_parameters.pb.h" // IWYU: keep #include "ortools/sat/swig_helper.h" #include "ortools/util/saturated_arithmetic.h" #include "ortools/util/sorted_interval_list.h" @@ -85,79 +91,48 @@ class PySolutionCallback : public SolutionCallback { } }; -// A class to wrap a C++ CpSolverResponse in a Python object, avoid the proto -// conversion back to python. -class ResponseWrapper { +class ResponseHelper { public: - explicit ResponseWrapper(const CpSolverResponse& response) - : response_(response) {} - - double BestObjectiveBound() const { return response_.best_objective_bound(); } - - bool BooleanValue(std::shared_ptr lit) const { + static bool BooleanValue(std::shared_ptr response, + std::shared_ptr lit) { const int index = lit->index(); if (index >= 0) { - return response_.solution(index) != 0; + return response->solution(index) != 0; } else { - return response_.solution(NegatedRef(index)) == 0; + return response->solution(NegatedRef(index)) == 0; } } - bool FixedBooleanValue(bool lit) const { return lit; } - - double DeterministicTime() const { return response_.deterministic_time(); } - - int64_t NumBinaryPropagations() const { - return response_.num_binary_propagations(); + static bool FixedBooleanValue(std::shared_ptr response, + bool lit) { + return lit; } - int64_t NumBooleans() const { return response_.num_booleans(); } - - int64_t NumBranches() const { return response_.num_branches(); } - - int64_t NumConflicts() const { return response_.num_conflicts(); } - - int64_t NumIntegerPropagations() const { - return response_.num_integer_propagations(); - } - - int64_t NumRestarts() const { return response_.num_restarts(); } - - double ObjectiveValue() const { return response_.objective_value(); } - - const CpSolverResponse& Response() const { return response_; } - - std::string ResponseStats() const { - return CpSatHelper::SolverResponseStats(response_); - } - - std::string SolutionInfo() const { - return google::protobuf::StringCopy(response_.solution_info()); - } - - std::vector SufficientAssumptionsForInfeasibility() const { + static std::vector SufficientAssumptionsForInfeasibility( + std::shared_ptr response) { return std::vector( - response_.sufficient_assumptions_for_infeasibility().begin(), - response_.sufficient_assumptions_for_infeasibility().end()); + response->sufficient_assumptions_for_infeasibility().begin(), + response->sufficient_assumptions_for_infeasibility().end()); } - CpSolverStatus Status() const { return response_.status(); } - - double UserTime() const { return response_.user_time(); } - - double FloatValue(std::shared_ptr expr) const { + static double FloatValue(std::shared_ptr response, + std::shared_ptr expr) { FloatExprVisitor visitor; visitor.AddToProcess(expr, 1); - return visitor.Evaluate(response_); + return visitor.Evaluate(*response); } - double FixedFloatValue(double value) const { return value; } + static double FixedFloatValue(std::shared_ptr response, + double value) { + return value; + } - int64_t Value(std::shared_ptr expr) const { + static int64_t Value(std::shared_ptr response, + std::shared_ptr expr) { int64_t value; IntExprVisitor visitor; visitor.AddToProcess(expr, 1); - if (!visitor.Evaluate(response_, &value)) { + if (!visitor.Evaluate(*response, &value)) { ThrowError(PyExc_ValueError, absl::StrCat("Failed to evaluate linear expression: ", expr->DebugString())); @@ -165,12 +140,10 @@ class ResponseWrapper { return value; } - int64_t FixedValue(int64_t value) const { return value; } - - double WallTime() const { return response_.wall_time(); } - - private: - const CpSolverResponse response_; + static int64_t FixedValue(std::shared_ptr response, + int64_t value) { + return value; + } }; // Checks that the result is not null and throws an error if it is. @@ -439,19 +412,320 @@ void LinearExprToProto(const py::handle& arg, int64_t multiplier, } else if (py::isinstance(arg)) { int64_t value = arg.cast(); proto->set_offset(value * multiplier); + } else if (hasattr(arg, "dtype") && hasattr(arg, "is_integer") && + getattr(arg, "is_integer")().cast()) { + int64_t value = arg.cast(); + proto->set_offset(value * multiplier); } else { py::type objtype = py::type::of(arg); const std::string type_name = objtype.attr("__name__").cast(); + py::print(arg); ThrowError(PyExc_TypeError, absl::StrCat("Cannot convert '", absl::CEscape(type_name), "' to a linear expression.")); } } -int AddBoundedLinearExpressionToModel( - BoundedLinearExpression* ble, std::shared_ptr model_proto) { - const int index = model_proto->constraints_size(); - ConstraintProto* ct = model_proto->add_constraints(); +class Constraint; +class IntervalVar; + +enum class BoolArgumentConstraint { + kAtMostOne, + kBoolAnd, + kBoolOr, + kBoolXor, + kExactlyOne, +}; + +enum class LinearArgumentConstraint { + kDiv, + kMax, + kMin, + kMod, + kProd, +}; + +class CpBaseModel : public std::enable_shared_from_this { + public: + explicit CpBaseModel(std::shared_ptr model_proto) + : model_proto_(model_proto == nullptr ? std::make_shared() + : model_proto), + numpy_bool_type_(py::dtype::of().attr("type").cast()) { + if (model_proto != nullptr) RebuildConstantMap(); + } + + std::shared_ptr model_proto() const { return model_proto_; } + + int GetOrMakeIndexFromConstant(int64_t value) { + auto it = cache_.find(value); + if (it != cache_.end()) return it->second; + const int index = model_proto_->variables_size(); + IntegerVariableProto* const_var = model_proto_->add_variables(); + const_var->add_domain(value); + const_var->add_domain(value); + cache_[value] = index; + return index; + } + + void RebuildConstantMap() { + cache_.clear(); + for (int i = 0; i < model_proto_->variables_size(); ++i) { + const IntegerVariableProto& var = model_proto_->variables(i); + if (var.domain_size() == 2 && var.domain(0) == var.domain(1) && + var.name().empty()) { // Constants do not have names. + cache_[var.domain(0)] = i; + } + } + } + + int GetOrMakeBooleanIndex(py::handle literal) { + if (py::isinstance(literal)) { + std::shared_ptr var = literal.cast>(); + AssertVariableIsBoolean(var); + return var->index(); + } else if (py::isinstance(literal)) { + std::shared_ptr not_var = + literal.cast>(); + AssertVariableIsBoolean(not_var); + return not_var->index(); + } else if (IsBooleanValue(literal)) { + const bool value = literal.cast(); + if (value) { + return GetOrMakeIndexFromConstant(1); + } else { + return GetOrMakeIndexFromConstant(0); + } + } else if (py::isinstance(literal)) { + const int64_t value = literal.cast(); + if (value == 1 || value == -1) { // -1 = ~False. + return GetOrMakeIndexFromConstant(1); + } + if (value == 0 || value == -2) { // -2 = ~True. + return GetOrMakeIndexFromConstant(0); + } + ThrowError(PyExc_TypeError, absl::StrCat("Invalid literal: ", value)); + } else { + py::type objtype = py::type::of(literal); + const std::string type_name = + objtype.attr("__name__").cast(); + ThrowError(PyExc_TypeError, absl::StrCat("Invalid boolean literal: '", + absl::CEscape(type_name), "'")); + } + return 0; // Unreachable. + } + + int GetOrMakeVariableIndex(py::handle arg) { + if (py::isinstance(arg)) { + std::shared_ptr var = arg.cast>(); + return var->index(); + } else if (py::isinstance(arg)) { + return GetOrMakeIndexFromConstant(arg.cast()); + } else if (hasattr(arg, "dtype") && hasattr(arg, "is_integer") && + getattr(arg, "is_integer")().cast()) { + return GetOrMakeIndexFromConstant(arg.cast()); + } else { + py::type objtype = py::type::of(arg); + const std::string type_name = + objtype.attr("__name__").cast(); + ThrowError(PyExc_TypeError, + absl::StrCat("GetOrMakeVariableIndex() only accept integer " + "variables or constants as argument: '", + absl::CEscape(type_name), "'")); + } + return 0; // Unreachable. + } + + void AssertVariableIsBoolean(std::shared_ptr literal) { + IntegerVariableProto* var = + model_proto_->mutable_variables(PositiveRef(literal->index())); + if (var->domain_size() != 2 || var->domain(0) < 0 || var->domain(1) > 1) { + ThrowError(PyExc_TypeError, absl::StrCat("Invalid boolean literal: ", + literal->ToString())); + } + } + + bool IsBooleanValue(py::handle value) { + return py::isinstance(value) || + py::isinstance(value, numpy_bool_type_); + } + + std::shared_ptr AddAllDifferentInternal(py::args exprs); + + std::shared_ptr AddAutomatonInternal( + py::sequence transition_expressions, int64_t starting_state, + const std::vector& final_states, + const std::vector>& transition_triples); + + std::shared_ptr AddBoolArgumentConstraintInternal( + BoolArgumentConstraint type, py::args literals); + + std::shared_ptr AddBoundedLinearExpressionInternal( + BoundedLinearExpression* ble); + + std::shared_ptr AddElementInternal(const py::handle& index, + py::sequence exprs, + const py::handle& target); + + std::shared_ptr AddInverseInternal(py::sequence direct, + py::sequence inverse); + + std::shared_ptr AddLinearArgumentConstraintInternal( + LinearArgumentConstraint type, const py::handle& target, py::args exprs); + + std::shared_ptr AddReservoirInternal(py::sequence times, + py::sequence level_changes, + py::sequence actives, + int64_t min_level, + int64_t max_level); + + std::shared_ptr AddTableInternal( + py::sequence exprs, const std::vector>& tuples, + bool negated); + + std::shared_ptr NewIntervalVarInternal(const std::string& name, + const py::handle& start, + const py::handle& size, + const py::handle& end, + py::sequence literals); + + std::shared_ptr AddNoOverlapInternal( + const std::vector>& intervals); + + std::shared_ptr AddNoOverlap2DInternal( + const std::vector>& x_intervals, + const std::vector>& y_intervals); + + std::shared_ptr AddCumulativeInternal( + const std::vector>& intervals, + py::sequence demands, const py::handle& capacity); + + std::shared_ptr AddCircuitInternal( + const std::vector>& arcs); + + std::shared_ptr AddRoutesInternal( + const std::vector>& arcs); + + private: + std::shared_ptr model_proto_; + absl::flat_hash_map cache_; + py::type numpy_bool_type_; +}; + +class Constraint { + public: + // We need to store the CpBaseModel to convert enforcement literals to + // indices. + Constraint(std::shared_ptr model, int index) + : model_(model), index_(index) {} + + int index() const { return index_; } + + std::shared_ptr model_proto() const { + return model_->model_proto(); + } + + ConstraintProto* proto() const { + return model_->model_proto()->mutable_constraints(index_); + } + + std::shared_ptr model() const { return model_; } + + std::string name() const { return proto()->name(); } + void SetName(const std::string& name) { proto()->set_name(name); } + void ClearName() { proto()->clear_name(); } + + std::string ToString() const { + return absl::StrCat("Constraint(index=", index_, ", ", + ProtobufDebugString(*proto()), ")"); + } + + private: + std::shared_ptr model_; + int index_; +}; + +std::shared_ptr CpBaseModel::AddAllDifferentInternal( + py::args exprs) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + if (exprs.size() == 1 && py::isinstance(exprs[0])) { + for (const auto& expr : exprs[0]) { + LinearExprToProto(expr, 1, ct->mutable_all_diff()->add_exprs()); + } + } else { + for (const auto& expr : exprs) { + LinearExprToProto(expr, 1, ct->mutable_all_diff()->add_exprs()); + } + } + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddAutomatonInternal( + py::sequence transition_expressions, int64_t starting_state, + const std::vector& final_states, + const std::vector>& transition_triples) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + for (const auto& expr : transition_expressions) { + LinearExprToProto(expr, 1, ct->mutable_automaton()->add_exprs()); + } + ct->mutable_automaton()->set_starting_state(starting_state); + ct->mutable_automaton()->mutable_final_states()->Add(final_states.begin(), + final_states.end()); + for (const auto& tuple : transition_triples) { + if (tuple.size() != 3) { + ThrowError(PyExc_ValueError, + absl::StrCat("transition (", absl::StrJoin(tuple, ","), + ") has the wrong arity != 3")); + } + ct->mutable_automaton()->add_transition_tail(tuple[0]); + ct->mutable_automaton()->add_transition_label(tuple[1]); + ct->mutable_automaton()->add_transition_head(tuple[2]); + } + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddBoolArgumentConstraintInternal( + BoolArgumentConstraint type, py::args literals) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + BoolArgumentProto* proto = nullptr; + switch (type) { + case BoolArgumentConstraint::kAtMostOne: + proto = ct->mutable_at_most_one(); + break; + case BoolArgumentConstraint::kBoolAnd: + proto = ct->mutable_bool_and(); + break; + case BoolArgumentConstraint::kBoolOr: + proto = ct->mutable_bool_or(); + break; + case BoolArgumentConstraint::kBoolXor: + proto = ct->mutable_bool_xor(); + break; + case BoolArgumentConstraint::kExactlyOne: + proto = ct->mutable_exactly_one(); + break; + default: + ThrowError(PyExc_ValueError, + absl::StrCat("Unknown boolean argument constraint: ", type)); + } + if (literals.size() == 1 && py::isinstance(literals[0])) { + for (const auto& literal : literals[0]) { + proto->add_literals(GetOrMakeBooleanIndex(literal)); + } + } else { + for (const auto& literal : literals) { + proto->add_literals(GetOrMakeBooleanIndex(literal)); + } + } + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddBoundedLinearExpressionInternal( + BoundedLinearExpression* ble) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); for (const auto& var : ble->vars()) { ct->mutable_linear()->add_vars(var->index()); } @@ -468,115 +742,353 @@ int AddBoundedLinearExpressionToModel( ct->mutable_linear()->add_domain(CapSub(bound, offset)); } } - return index; + return std::make_shared(shared_from_this(), ct_index); } -int AddBoolOr(const std::vector& literals, - std::shared_ptr model_proto) { - const int index = model_proto->constraints_size(); - ConstraintProto* ct = model_proto->add_constraints(); - ct->mutable_bool_or()->mutable_literals()->Add(literals.begin(), - literals.end()); - return index; -} - -int AddBoolAnd(const std::vector& literals, - std::shared_ptr model_proto) { - const int index = model_proto->constraints_size(); - ConstraintProto* ct = model_proto->add_constraints(); - ct->mutable_bool_and()->mutable_literals()->Add(literals.begin(), - literals.end()); - return index; -} - -int AddBoolXOr(const std::vector& literals, - std::shared_ptr model_proto) { - const int index = model_proto->constraints_size(); - ConstraintProto* ct = model_proto->add_constraints(); - ct->mutable_bool_xor()->mutable_literals()->Add(literals.begin(), - literals.end()); - return index; -} - -int AddAtMostOne(const std::vector& literals, - std::shared_ptr model_proto) { - const int index = model_proto->constraints_size(); - ConstraintProto* ct = model_proto->add_constraints(); - ct->mutable_at_most_one()->mutable_literals()->Add(literals.begin(), - literals.end()); - return index; -} - -int AddExactlyOne(const std::vector& literals, - std::shared_ptr model_proto) { - const int index = model_proto->constraints_size(); - ConstraintProto* ct = model_proto->add_constraints(); - ct->mutable_exactly_one()->mutable_literals()->Add(literals.begin(), - literals.end()); - return index; -} - -int AddElement(const py::handle& index, py::sequence exprs, - const py::handle& target, - std::shared_ptr model_proto) { - const int ct_index = model_proto->constraints_size(); - ConstraintProto* ct = model_proto->add_constraints(); +std::shared_ptr CpBaseModel::AddElementInternal( + const py::handle& index, py::sequence exprs, const py::handle& target) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); LinearExprToProto(index, 1, ct->mutable_element()->mutable_linear_index()); for (const auto& expr : exprs) { LinearExprToProto(expr, 1, ct->mutable_element()->add_exprs()); } LinearExprToProto(target, 1, ct->mutable_element()->mutable_linear_target()); - return ct_index; + return std::make_shared(shared_from_this(), ct_index); } -int AddLinearArgumentConstraint(const std::string& name, - const py::handle& target, py::sequence exprs, - std::shared_ptr model_proto) { - const int ct_index = model_proto->constraints_size(); - ConstraintProto* ct = model_proto->add_constraints(); +std::shared_ptr CpBaseModel::AddInverseInternal( + py::sequence direct, py::sequence inverse) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + for (const auto& var : direct) { + ct->mutable_inverse()->add_f_direct(GetOrMakeVariableIndex(var)); + } + for (const auto& var : inverse) { + ct->mutable_inverse()->add_f_inverse(GetOrMakeVariableIndex(var)); + } + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddLinearArgumentConstraintInternal( + LinearArgumentConstraint type, const py::handle& target, py::args exprs) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); LinearArgumentProto* proto; int64_t multiplier = 1; - if (name == "min") { - proto = ct->mutable_lin_max(); - multiplier = -1; - } else if (name == "max") { - proto = ct->mutable_lin_max(); - } else if (name == "prod") { - proto = ct->mutable_int_prod(); - } else if (name == "div") { - proto = ct->mutable_int_div(); - } else if (name == "mod") { - proto = ct->mutable_int_mod(); - } else { - ThrowError(PyExc_ValueError, - absl::StrCat("Unknown integer argument constraint: ", name)); + switch (type) { + case LinearArgumentConstraint::kDiv: + proto = ct->mutable_int_div(); + break; + case LinearArgumentConstraint::kMax: + proto = ct->mutable_lin_max(); + break; + case LinearArgumentConstraint::kMin: + proto = ct->mutable_lin_max(); + multiplier = -1; + break; + case LinearArgumentConstraint::kMod: + proto = ct->mutable_int_mod(); + break; + case LinearArgumentConstraint::kProd: + proto = ct->mutable_int_prod(); + break; + default: + ThrowError(PyExc_ValueError, + absl::StrCat("Unknown integer argument constraint: ", type)); } LinearExprToProto(target, multiplier, proto->mutable_target()); - for (const auto& expr : exprs) { - LinearExprToProto(expr, multiplier, proto->add_exprs()); + + if (exprs.size() == 1 && py::isinstance(exprs[0])) { + for (const auto& expr : exprs[0]) { + LinearExprToProto(expr, multiplier, proto->add_exprs()); + } + } else { + for (const auto& expr : exprs) { + LinearExprToProto(expr, multiplier, proto->add_exprs()); + } } - return ct_index; + return std::make_shared(shared_from_this(), ct_index); } -void AddEnforcementLiterals(int index, const std::vector& literals, - std::shared_ptr model_proto) { - ConstraintProto* ct = model_proto->mutable_constraints(index); - ct->mutable_enforcement_literal()->Add(literals.begin(), literals.end()); +std::shared_ptr CpBaseModel::AddReservoirInternal( + py::sequence times, py::sequence level_changes, py::sequence actives, + int64_t min_level, int64_t max_level) { + const int ct_index = model_proto_->constraints_size(); + ReservoirConstraintProto* proto = + model_proto_->add_constraints()->mutable_reservoir(); + for (const auto& time : times) { + LinearExprToProto(time, 1, proto->add_time_exprs()); + } + for (const auto& change : level_changes) { + LinearExprToProto(change, 1, proto->add_level_changes()); + } + for (const auto& active : actives) { + proto->add_active_literals(GetOrMakeBooleanIndex(active)); + } + proto->set_min_level(min_level); + proto->set_max_level(max_level); + return std::make_shared(shared_from_this(), ct_index); } -void SetCtName(int index, const std::string& name, - std::shared_ptr model_proto) { - model_proto->mutable_constraints(index)->set_name(name); +std::shared_ptr CpBaseModel::AddTableInternal( + py::sequence exprs, const std::vector>& tuples, + bool negated) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + const int num_exprs = exprs.size(); + for (const auto& expr : exprs) { + LinearExprToProto(expr, 1, ct->mutable_table()->add_exprs()); + } + for (const auto& tuple : tuples) { + if (tuple.size() != num_exprs) { + ThrowError(PyExc_ValueError, + absl::StrCat("Tuple (", absl::StrJoin(tuple, ","), + ") has the wrong arity != ", num_exprs)); + } + ct->mutable_table()->mutable_values()->Add(tuple.begin(), tuple.end()); + } + ct->mutable_table()->set_negated(negated); + return std::make_shared(shared_from_this(), ct_index); } -std::string GetCtName(int index, std::shared_ptr model_proto) { - return model_proto->constraints(index).name(); +std::string ShortName(int literal, std::shared_ptr model_proto) { + const int var = PositiveRef(literal); + const IntegerVariableProto& var_proto = model_proto->variables(var); + const std::string& var_name = + var_proto.name().empty() ? absl::StrCat("i", var) : var_proto.name(); + if (literal < 0) { + return absl::StrCat("not(", var_name, ")"); + } else { + return var_name; + } } -void ClearCtName(int index, std::shared_ptr model_proto) { - model_proto->mutable_constraints(index)->clear_name(); +std::string ShortExprName(const LinearExpressionProto& expr, + std::shared_ptr model_proto) { + if (expr.vars().empty()) { + return absl::StrCat(expr.offset()); + } else { + const IntegerVariableProto& var_proto = + model_proto->variables(expr.vars(0)); + const std::string& var_name = var_proto.name().empty() + ? absl::StrCat("i", expr.vars(0)) + : var_proto.name(); + const int64_t coeff = expr.coeffs(0); + std::string result; + if (coeff == 1) { + result = var_name; + } else if (coeff == -1) { + result = absl::StrCat("-", var_name); + } else if (coeff != 0) { + result = absl::StrCat(coeff, " * ", var_name); + } + if (expr.offset() > 0) { + absl::StrAppend(&result, " + ", expr.offset()); + } else if (expr.offset() < 0) { + absl::StrAppend(&result, " - ", -expr.offset()); + } + return result; + } +} + +std::shared_ptr RebuildFromLinearExpressionProto( + const LinearExpressionProto& proto, + std::shared_ptr model_proto) { + if (proto.vars().empty()) { + return LinearExpr::ConstantInt(proto.offset()); + } else if (proto.vars_size() == 1) { + return LinearExpr::AffineInt( + std::make_shared(model_proto, proto.vars(0)), proto.coeffs(0), + proto.offset()); + } else { + std::vector> vars; + vars.reserve(proto.vars_size()); + for (const int var : proto.vars()) { + vars.push_back(std::make_shared(model_proto, var)); + } + return std::make_shared(vars, proto.coeffs(), + proto.offset()); + } +} + +class IntervalVar { + public: + IntervalVar(std::shared_ptr model_proto, int index) + : model_proto_(model_proto), index_(index) { + DCHECK_GE(index, 0); + } + + int index() const { return index_; } + + std::shared_ptr model_proto() const { return model_proto_; } + + ConstraintProto* proto() const { + return model_proto_->mutable_constraints(index_); + } + + std::string ToString() const { + const std::string name = proto()->name(); + if (name.empty()) { + return absl::StrCat("iv", index_); + } else { + return name; + } + } + + std::string DebugString() const { + if (proto()->enforcement_literal().empty()) { + return absl::StrCat( + name(), "(start = ", + ShortExprName(proto()->interval().start(), model_proto()), + ", size = ", ShortExprName(proto()->interval().size(), model_proto()), + ", end = ", ShortExprName(proto()->interval().end(), model_proto()), + ")"); + } else { + return absl::StrCat( + name(), "(start = ", + ShortExprName(proto()->interval().start(), model_proto()), + ", size = ", ShortExprName(proto()->interval().size(), model_proto()), + ", end = ", ShortExprName(proto()->interval().end(), model_proto()), + ", is_present = ", + ShortName(proto()->enforcement_literal(0), model_proto()), ")"); + } + } + + std::string name() const { return proto()->name(); } + + void SetName(const std::string& name) { proto()->set_name(name); } + + std::shared_ptr StartExpr() const { + return RebuildFromLinearExpressionProto(proto()->interval().start(), + model_proto_); + } + std::shared_ptr SizeExpr() const { + return RebuildFromLinearExpressionProto(proto()->interval().size(), + model_proto_); + } + std::shared_ptr EndExpr() const { + return RebuildFromLinearExpressionProto(proto()->interval().end(), + model_proto_); + } + + private: + std::shared_ptr model_proto_; + int index_; +}; + +std::shared_ptr CpBaseModel::NewIntervalVarInternal( + const std::string& name, const py::handle& start, const py::handle& size, + const py::handle& end, py::sequence literals) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + if (!name.empty()) ct->set_name(name); + LinearExprToProto(start, 1, ct->mutable_interval()->mutable_start()); + LinearExprToProto(size, 1, ct->mutable_interval()->mutable_size()); + LinearExprToProto(end, 1, ct->mutable_interval()->mutable_end()); + for (const auto& lit : literals) { + ct->add_enforcement_literal(GetOrMakeBooleanIndex(lit)); + } + const std::string method = literals.empty() + ? "cp_model.new_interval_var" + : "cp_model.new_optional_interval_var"; + if (ct->interval().start().vars_size() > 1) { + ThrowError(PyExc_TypeError, + absl::StrCat(method, ": start must be affine or constant.")); + } + if (ct->interval().size().vars_size() > 1) { + ThrowError(PyExc_TypeError, + absl::StrCat(method, ": size must be affine or constant.")); + } + if (ct->interval().end().vars_size() > 1) { + ThrowError(PyExc_TypeError, + absl::StrCat(method, ": end must be affine or constant.")); + } + return std::make_shared(model_proto_, ct_index); +} + +std::shared_ptr CpBaseModel::AddNoOverlapInternal( + const std::vector>& intervals) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + ct->mutable_no_overlap()->mutable_intervals()->Reserve(intervals.size()); + for (const std::shared_ptr& interval : intervals) { + ct->mutable_no_overlap()->add_intervals(interval->index()); + } + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddNoOverlap2DInternal( + const std::vector>& x_intervals, + const std::vector>& y_intervals) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + ct->mutable_no_overlap_2d()->mutable_x_intervals()->Reserve( + x_intervals.size()); + for (const std::shared_ptr& x_interval : x_intervals) { + ct->mutable_no_overlap_2d()->add_x_intervals(x_interval->index()); + } + ct->mutable_no_overlap_2d()->mutable_y_intervals()->Reserve( + y_intervals.size()); + for (const std::shared_ptr& y_interval : y_intervals) { + ct->mutable_no_overlap_2d()->add_y_intervals(y_interval->index()); + } + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddCumulativeInternal( + const std::vector>& intervals, + const py::sequence demands, const py::handle& capacity) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + CumulativeConstraintProto* proto = ct->mutable_cumulative(); + + proto->mutable_intervals()->Reserve(intervals.size()); + for (const std::shared_ptr& interval : intervals) { + proto->add_intervals(interval->index()); + } + + proto->mutable_demands()->Reserve(demands.size()); + for (const auto& expr : demands) { + LinearExprToProto(expr, 1, proto->add_demands()); + } + + LinearExprToProto(capacity, 1, proto->mutable_capacity()); + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddCircuitInternal( + const std::vector>& arcs) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + CircuitConstraintProto* proto = ct->mutable_circuit(); + proto->mutable_tails()->Reserve(arcs.size()); + proto->mutable_heads()->Reserve(arcs.size()); + proto->mutable_literals()->Reserve(arcs.size()); + for (const auto& [tail, head, lit] : arcs) { + proto->add_tails(tail); + proto->add_heads(head); + proto->add_literals(GetOrMakeBooleanIndex(lit)); + } + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddRoutesInternal( + const std::vector>& arcs) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + RoutesConstraintProto* proto = ct->mutable_routes(); + proto->mutable_tails()->Reserve(arcs.size()); + proto->mutable_heads()->Reserve(arcs.size()); + proto->mutable_literals()->Reserve(arcs.size()); + for (const auto& [tail, head, lit] : arcs) { + proto->add_tails(tail); + proto->add_heads(head); + proto->add_literals(GetOrMakeBooleanIndex(lit)); + } + return std::make_shared(shared_from_this(), ct_index); } PYBIND11_MODULE(cp_model_helper, m) { @@ -596,7 +1108,7 @@ PYBIND11_MODULE(cp_model_helper, m) { .def("NumConflicts", &SolutionCallback::NumConflicts) .def("NumIntegerPropagations", &SolutionCallback::NumIntegerPropagations) .def("ObjectiveValue", &SolutionCallback::ObjectiveValue) - .def("Response", &SolutionCallback::Response) + .def("Response", &SolutionCallback::SharedResponse) .def("SolutionBooleanValue", &SolutionCallback::SolutionBooleanValue, py::arg("index")) .def("SolutionIntegerValue", &SolutionCallback::SolutionIntegerValue, @@ -606,17 +1118,8 @@ PYBIND11_MODULE(cp_model_helper, m) { .def("WallTime", &SolutionCallback::WallTime) .def( "Value", - [](const SolutionCallback& callback, - std::shared_ptr expr) { - int64_t value; - IntExprVisitor visitor; - visitor.AddToProcess(expr, 1); - if (!visitor.Evaluate(callback.Response(), &value)) { - ThrowError(PyExc_ValueError, - absl::StrCat("Failed to evaluate linear expression: ", - expr->DebugString())); - } - return value; + [](const SolutionCallback& self, std::shared_ptr expr) { + return ResponseHelper::Value(self.SharedResponse(), expr); }, "Returns the value of a linear expression after solve.") .def( @@ -624,11 +1127,8 @@ PYBIND11_MODULE(cp_model_helper, m) { "Returns the value of a linear expression after solve.") .def( "FloatValue", - [](const SolutionCallback& callback, - std::shared_ptr expr) { - FloatExprVisitor visitor; - visitor.AddToProcess(expr, 1.0); - return visitor.Evaluate(callback.Response()); + [](const SolutionCallback& self, std::shared_ptr expr) { + return ResponseHelper::FloatValue(self.SharedResponse(), expr); }, "Returns the value of a floating point linear expression after " "solve.") @@ -639,42 +1139,31 @@ PYBIND11_MODULE(cp_model_helper, m) { "solve.") .def( "BooleanValue", - [](const SolutionCallback& callback, std::shared_ptr lit) { - return callback.SolutionBooleanValue(lit->index()); + [](const SolutionCallback& self, std::shared_ptr lit) { + return ResponseHelper::BooleanValue(self.SharedResponse(), lit); }, "Returns the Boolean value of a literal after solve.") .def( "BooleanValue", [](const SolutionCallback&, bool lit) { return lit; }, "Returns the Boolean value of a literal after solve."); - py::class_(m, "ResponseWrapper") - .def("best_objective_bound", &ResponseWrapper::BestObjectiveBound) - .def("boolean_value", &ResponseWrapper::BooleanValue, - py::arg("lit").none(false)) - .def("boolean_value", &ResponseWrapper::FixedBooleanValue, - py::arg("lit").none(false)) - .def("deterministic_time", &ResponseWrapper::DeterministicTime) - .def("num_binary_propagations", &ResponseWrapper::NumBinaryPropagations) - .def("num_booleans", &ResponseWrapper::NumBooleans) - .def("num_branches", &ResponseWrapper::NumBranches) - .def("num_conflicts", &ResponseWrapper::NumConflicts) - .def("num_integer_propagations", &ResponseWrapper::NumIntegerPropagations) - .def("num_restarts", &ResponseWrapper::NumRestarts) - .def("objective_value", &ResponseWrapper::ObjectiveValue) - .def("response", &ResponseWrapper::Response) - .def("response_stats", &ResponseWrapper::ResponseStats) - .def("solution_info", &ResponseWrapper::SolutionInfo) - .def("status", &ResponseWrapper::Status) - .def("sufficient_assumptions_for_infeasibility", - &ResponseWrapper::SufficientAssumptionsForInfeasibility) - .def("user_time", &ResponseWrapper::UserTime) - .def("float_value", &ResponseWrapper::FloatValue, - py::arg("expr").none(false)) - .def("float_value", &ResponseWrapper::FixedFloatValue, - py::arg("value").none(false)) - .def("value", &ResponseWrapper::Value, py::arg("expr").none(false)) - .def("value", &ResponseWrapper::FixedValue, py::arg("value").none(false)) - .def("wall_time", &ResponseWrapper::WallTime); + py::class_(m, "ResponseHelper") + .def_static("boolean_value", &ResponseHelper::BooleanValue, + py::arg("response").none(false), py::arg("lit").none(false)) + .def_static("boolean_value", &ResponseHelper::FixedBooleanValue, + py::arg("response").none(false), py::arg("lit").none(false)) + .def_static("float_value", &ResponseHelper::FloatValue, + py::arg("response").none(false), py::arg("expr").none(false)) + .def_static("float_value", &ResponseHelper::FixedFloatValue, + py::arg("response").none(false), py::arg("value").none(false)) + .def_static("sufficient_assumptions_for_infeasibility", + &ResponseHelper::SufficientAssumptionsForInfeasibility, + py::arg("response").none(false)) + .def_static("value", &ResponseHelper::Value, + py::arg("response").none(false), py::arg("expr").none(false)) + .def_static("value", &ResponseHelper::FixedValue, + py::arg("response").none(false), + py::arg("value").none(false)); py::class_(m, "SolveWrapper") .def(py::init<>()) @@ -744,20 +1233,6 @@ PYBIND11_MODULE(cp_model_helper, m) { return result; }, py::arg("model_proto").none(false)) - .def("solve_and_return_response_wrapper", - [](ExtSolveWrapper* solve_wrapper, - std::shared_ptr model_proto) -> ResponseWrapper { - const auto result = [=]() -> ResponseWrapper { - ::py::gil_scoped_release release; - return ResponseWrapper(solve_wrapper->Solve(*model_proto)); - }(); - if (solve_wrapper->local_error_already_set_.has_value()) { - solve_wrapper->local_error_already_set_->restore(); - solve_wrapper->local_error_already_set_.reset(); - throw py::error_already_set(); - } - return result; - }) .def("stop_search", &SolveWrapper::StopSearch); py::class_(m, "CpSatHelper") @@ -767,39 +1242,8 @@ PYBIND11_MODULE(cp_model_helper, m) { py::arg("response")) .def_static("validate_model", &CpSatHelper::ValidateModel, py::arg("model_proto")) - .def_static("variable_domain", &CpSatHelper::VariableDomain, - py::arg("variable_proto")) .def_static("write_model_to_file", &CpSatHelper::WriteModelToFile, - py::arg("model_proto"), py::arg("filename")) - .def_static("set_ct_name", &SetCtName, py::arg("index"), py::arg("name"), - py::arg("model_proto")) - .def_static("ct_name", &GetCtName, py::arg("index"), - py::arg("model_proto")) - .def_static("clear_ct_name", &ClearCtName, py::arg("index"), - py::arg("model_proto")) - .def_static("add_bool_or", &AddBoolOr, py::arg("literals"), - py::arg("model_proto").none(false)) - .def_static("add_bool_and", &AddBoolAnd, py::arg("literals"), - py::arg("model_proto").none(false)) - .def_static("add_bool_xor", &AddBoolXOr, py::arg("literals"), - py::arg("model_proto").none(false)) - .def_static("add_at_most_one", &AddAtMostOne, py::arg("literals"), - py::arg("model_proto").none(false)) - .def_static("add_exactly_one", &AddExactlyOne, py::arg("literals"), - py::arg("model_proto").none(false)) - .def_static("add_element", &AddElement, py::arg("index").none(false), - py::arg("expressions"), py::arg("target").none(false), - py::arg("model_proto").none(false)) - .def_static("add_linear_argument_constraint", - &AddLinearArgumentConstraint, py::arg("name").none(false), - py::arg("target").none(false), py::arg("exprs"), - py::arg("model_proto").none(false)) - .def_static("add_enforcement_literals", &AddEnforcementLiterals, - py::arg("index"), py::arg("literals"), - py::arg("model_proto").none(false)) - .def_static("add_bounded_linear_expression_to_model", - &AddBoundedLinearExpressionToModel, py::arg("ble"), - py::arg("model_proto")); + py::arg("model_proto"), py::arg("filename")); py::class_>( m, "LinearExpr", DOC(operations_research, sat, python, LinearExpr)) @@ -1233,7 +1677,7 @@ PYBIND11_MODULE(cp_model_helper, m) { "not supported."); }) .def("__hash__", &Literal::Hash) - // PEP8 Compatibility. + // Pre PEP8 compatibility layer. .def("Not", &Literal::negated) .def("Index", &Literal::index); @@ -1244,42 +1688,36 @@ PYBIND11_MODULE(cp_model_helper, m) { .def(py::init>()) // new variable. .def_property_readonly( "proto", &IntVar::proto, py::return_value_policy::reference, - py::keep_alive<1, 0>() - // DOC(operations_research, sat, python, IntVar, proto) - ) - .def_property_readonly( - "model_proto", &IntVar::model_proto - // DOC(operations_research, sat, python, IntVar, model_proto) - ) + py::keep_alive<1, 0>(), + "Returns the IntegerVariableProto of this variable.") + .def_property_readonly("model_proto", &IntVar::model_proto, + "Returns the CP model protobuf") .def_property_readonly( "index", &IntVar::index, py::return_value_policy::reference, DOC(operations_research, sat, python, IntVar, index)) .def_property_readonly( "is_boolean", &IntVar::is_boolean, DOC(operations_research, sat, python, IntVar, is_boolean)) - .def_property( - "name", &IntVar::name, &IntVar::SetName //, py::arg("name") - // DOC(operations_research, - // sat, python, IntVar, name) - ) + .def_property("name", &IntVar::name, &IntVar::SetName, + "The name of the variable.") .def( "with_name", [](std::shared_ptr self, const std::string& name) { self->SetName(name); return self; }, - py::arg("name")) - .def_property( - "domain", &IntVar::domain, &IntVar::SetDomain //, py::arg("domain") - // DOC(operations_research, sat, python, IntVar, domain) - ) + py::arg("name"), + "Sets the name of the variable and returns the variable.") + .def_property("domain", &IntVar::domain, &IntVar::SetDomain, + "The domain of the variable.") .def( "with_domain", [](std::shared_ptr self, const Domain& domain) { self->SetDomain(domain); return self; }, - py::arg("domain")) + py::arg("domain"), + "Sets the domain of the variable and returns the variable.") .def("__str__", &IntVar::ToString) .def("__repr__", &IntVar::DebugString) .def( @@ -1318,7 +1756,7 @@ PYBIND11_MODULE(cp_model_helper, m) { return std::make_shared( t[0].cast>(), t[1].cast()); })) - // PEP8 Compatibility. + // Pre PEP8 compatibility layer. .def("Name", &IntVar::name) .def("Proto", &IntVar::proto) .def("Not", @@ -1358,7 +1796,7 @@ PYBIND11_MODULE(cp_model_helper, m) { [](std::shared_ptr not_var) -> std::shared_ptr { return not_var->negated(); }, DOC(operations_research, sat, python, NotBooleanVariable, negated)) - // PEP8 Compatibility. + // Pre PEP8 compatibility layer. .def( "Not", [](std::shared_ptr not_var) @@ -1388,6 +1826,324 @@ PYBIND11_MODULE(cp_model_helper, m) { "not supported.")); return false; }); + + py::enum_(m, "BoolArgumentConstraint") + .value("at_most_one", BoolArgumentConstraint::kAtMostOne) + .value("bool_and", BoolArgumentConstraint::kBoolAnd) + .value("bool_or", BoolArgumentConstraint::kBoolOr) + .value("bool_xor", BoolArgumentConstraint::kBoolXor) + .value("exactly_one", BoolArgumentConstraint::kExactlyOne) + .export_values(); + + py::enum_(m, "LinearArgumentConstraint") + .value("div", LinearArgumentConstraint::kDiv) + .value("max", LinearArgumentConstraint::kMax) + .value("min", LinearArgumentConstraint::kMin) + .value("mod", LinearArgumentConstraint::kMod) + .value("prod", LinearArgumentConstraint::kProd) + .export_values(); + + py::class_>( + m, "CpBaseModel", "Base class for the CP model.") + .def(py::init>()) + .def_property_readonly("model_proto", &CpBaseModel::model_proto, + "Returns the CP model protobuf") + .def("get_or_make_index_from_constant", + &CpBaseModel::GetOrMakeIndexFromConstant, py::arg("value"), + "Returns the index of the given constant value.") + .def("get_or_make_boolean_index", &CpBaseModel::GetOrMakeBooleanIndex, + py::arg("value"), "Returns the index of the given boolean value.") + .def("get_or_make_variable_index", &CpBaseModel::GetOrMakeVariableIndex, + py::arg("arg"), + "Returns the index of the given variable or constant variable.") + .def("is_boolean_value", &CpBaseModel::IsBooleanValue, py::arg("value")) + .def("rebuild_constant_map", &CpBaseModel::RebuildConstantMap) + .def("_add_all_different", &CpBaseModel::AddAllDifferentInternal) + .def("_add_automaton", &CpBaseModel::AddAutomatonInternal, + py::arg("transition_expressions"), py::arg("starting_state"), + py::arg("final_states"), py::arg("transition_triples")) + .def("_add_bool_argument_constraint", + &CpBaseModel::AddBoolArgumentConstraintInternal, py::arg("name")) + .def("_add_bounded_linear_expression", + &CpBaseModel::AddBoundedLinearExpressionInternal, py::arg("ble")) + .def("_add_element", &CpBaseModel::AddElementInternal, + py::arg("index").none(false), py::arg("expressions"), + py::arg("target").none(false)) + .def("_add_linear_argument_constraint", + &CpBaseModel::AddLinearArgumentConstraintInternal, + py::arg("name").none(false), py::arg("target").none(false)) + .def("_add_inverse", &CpBaseModel::AddInverseInternal, py::arg("direct"), + py::arg("inverse")) + .def("_add_reservoir", &CpBaseModel::AddReservoirInternal, + py::arg("times"), py::arg("level_changes"), py::arg("actives"), + py::arg("min_level"), py::arg("max_level")) + .def("_add_table", &CpBaseModel::AddTableInternal, py::arg("expressions"), + py::arg("values"), py::arg("negated")) + // Scheduling support. + .def("_new_interval_var", &CpBaseModel::NewIntervalVarInternal, + py::arg("name"), py::arg("start"), py::arg("size"), py::arg("end"), + py::arg("Literals")) + .def("_add_no_overlap", &CpBaseModel::AddNoOverlapInternal, + py::arg("intervals")) + .def("_add_no_overlap_2d", &CpBaseModel::AddNoOverlap2DInternal, + py::arg("x_intervals"), py::arg("y_intervals")) + .def("_add_cumulative", &CpBaseModel::AddCumulativeInternal, + py::arg("intervals"), py::arg("demands"), py::arg("capacity")) + // Routing support. + .def("_add_circuit", &CpBaseModel::AddCircuitInternal, py::arg("arcs")) + .def("_add_routes", &CpBaseModel::AddRoutesInternal, py::arg("arcs")); + + static const char* kConstraintDoc = R"doc( + Base class for constraints. + + Constraints are built by the CpModel through the add methods. + Once created by the CpModel class, they are automatically added to the model. + The purpose of this class is to allow specification of enforcement literals + for this constraint. + + b = model.new_bool_var('b') + x = model.new_int_var(0, 10, 'x') + y = model.new_int_var(0, 10, 'y') + + model.add(x + 2 * y == 5).only_enforce_if(b.negated()) + )doc"; + + static const char* kConstraintOnlyEnforceIfDoc = R"doc( + Adds one or more enforcement literals to the constraint. + + This method adds one or more literals (that is, a boolean variable or its + negation) as enforcement literals. The conjunction of all these literals + determines whether the constraint is active or not. It acts as an + implication, so if the conjunction is true, it implies that the constraint + must be enforced. If it is false, then the constraint is ignored. + + BoolOr, BoolAnd, and linear constraints all support enforcement literals. + + Args: + *literals: One or more Boolean literals. + + Returns: + self.)doc"; + + py::class_>(m, "Constraint", + kConstraintDoc) + .def(py::init, int>()) + .def_property_readonly( + "index", &Constraint::index, + "Returns the index of the constraint in the model protobuf.") + .def_property_readonly("model_proto", &Constraint::model_proto, + "Returns the model protobuf.") + .def_property_readonly("proto", &Constraint::proto, + py::return_value_policy::reference, + py::keep_alive<1, 0>(), + "Returns the ConstraintProto of this constraint.") + .def_property("name", &Constraint::name, &Constraint::SetName, + "The name of the constraint.") + .def( + "with_name", + [](Constraint* self, const std::string& name) { + if (name.empty()) { + self->ClearName(); + } else { + self->SetName(name); + } + return self; + }, + "Sets the name of the constraint and returns the constraints") + .def( + "only_enforce_if", + [](std::shared_ptr self, + std::shared_ptr literal) { + self->proto()->add_enforcement_literal(literal->index()); + return self; + }, + py::arg("literal"), kConstraintOnlyEnforceIfDoc) + .def( + "only_enforce_if", + [](std::shared_ptr self, bool literal) { + self->proto()->add_enforcement_literal( + self->model()->GetOrMakeIndexFromConstant(literal)); + return self; + }, + py::arg("literal"), kConstraintOnlyEnforceIfDoc) + .def( + "only_enforce_if", + [](std::shared_ptr self, + const std::vector>& literals) { + for (const std::shared_ptr& literal : literals) { + self->proto()->add_enforcement_literal(literal->index()); + } + }, + py::arg("literals"), kConstraintOnlyEnforceIfDoc) + .def( + "only_enforce_if", + [](std::shared_ptr self, py::args literals) { + if (literals.size() == 1 && + py::isinstance(literals[0])) { + py::sequence seq = literals[0].cast(); + for (const auto& literal : seq) { + self->proto()->add_enforcement_literal( + self->model()->GetOrMakeBooleanIndex(literal)); + } + } else { + for (const auto& literal : literals) { + self->proto()->add_enforcement_literal( + self->model()->GetOrMakeBooleanIndex(literal)); + } + } + }, + kConstraintOnlyEnforceIfDoc) + // Pre PEP8 compatibility. + .def("Name", &Constraint::name) + .def("Index", &Constraint::index) + .def("Proto", &Constraint::proto) + .def("WithName", + [](Constraint* self, const std::string& name) { + if (name.empty()) { + self->ClearName(); + } else { + self->SetName(name); + } + return self; + }) + .def("OnlyEnforceIf", [](std::shared_ptr self, + py::args literals) { + if (literals.size() == 1 && py::isinstance(literals[0])) { + py::sequence seq = literals[0].cast(); + for (const auto& literal : seq) { + self->proto()->add_enforcement_literal( + self->model()->GetOrMakeBooleanIndex(literal)); + } + } else { + for (const auto& literal : literals) { + self->proto()->add_enforcement_literal( + self->model()->GetOrMakeBooleanIndex(literal)); + } + } + }); + + static const char* kIntervalVarDoc = R"doc( +Represents an Interval variable. + +An interval variable is both a constraint and a variable. It is defined by +three integer variables: start, size, and end. + +It is a constraint because, internally, it enforces that start + size == end. + +It is also a variable as it can appear in specific scheduling constraints: +NoOverlap, NoOverlap2D, Cumulative. + +Optionally, an enforcement literal can be added to this constraint, in which +case these scheduling constraints will ignore interval variables with +enforcement literals assigned to false. Conversely, these constraints will +also set these enforcement literals to false if they cannot fit these +intervals into the schedule. + +Raises: + ValueError: if start, size, end are not defined, or have the wrong type. +)doc"; + + py::class_>(m, "IntervalVar", + kIntervalVarDoc) + .def(py::init, int>()) + .def_property_readonly("index", &IntervalVar::index, + "Returns the index of the interval variable.") + .def_property_readonly("model_proto", &IntervalVar::model_proto, + "Returns the model protobuf.") + .def_property_readonly( + "proto", &IntervalVar::proto, py::return_value_policy::reference, + py::keep_alive<1, 0>(), "Returns the interval constraint protobuf.") + .def_property("name", &IntervalVar::name, &IntervalVar::SetName, + "The name of the interval variable.") + .def( + "start_expr", + [](std::shared_ptr self) -> py::object { + const IntervalConstraintProto& proto = self->proto()->interval(); + if (proto.start().vars().empty()) { + return py::cast(proto.start().offset()); + } else { + return py::cast(self->StartExpr()); + } + }, + "Returns the start expression of the interval variable.") + .def( + "size_expr", + [](std::shared_ptr self) -> py::object { + const IntervalConstraintProto& proto = self->proto()->interval(); + if (proto.size().vars().empty()) { + return py::cast(proto.size().offset()); + } else { + return py::cast(self->SizeExpr()); + } + }, + "Returns the size expression of the interval variable.") + .def( + "end_expr", + [](std::shared_ptr self) -> py::object { + const IntervalConstraintProto& proto = self->proto()->interval(); + if (proto.end().vars().empty()) { + return py::cast(proto.end().offset()); + } else { + return py::cast(self->EndExpr()); + } + }, + "Returns the end expression of the interval variable.") + .def("__str__", &IntervalVar::ToString) + .def("__repr__", &IntervalVar::DebugString) + .def(py::pickle( + [](std::shared_ptr p) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(p->model_proto(), p->index()); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 2) throw std::runtime_error("Invalid state!"); + + return std::make_shared( + t[0].cast>(), t[1].cast()); + })) + // Pre PEP8 compatibility layer. + .def("Proto", &IntervalVar::proto) + .def("Index", &IntervalVar::index) + .def("Name", &IntervalVar::name) + .def("StartExpr", + [](std::shared_ptr self) -> py::object { + const IntervalConstraintProto& proto = self->proto()->interval(); + if (proto.start().vars().empty()) { + return py::cast(proto.start().offset()); + } else { + return py::cast(self->StartExpr()); + } + }) + .def("SizeExpr", + [](std::shared_ptr self) -> py::object { + const IntervalConstraintProto& proto = self->proto()->interval(); + if (proto.size().vars().empty()) { + return py::cast(proto.size().offset()); + } else { + return py::cast(self->SizeExpr()); + } + }) + .def("EndExpr", [](std::shared_ptr self) -> py::object { + const IntervalConstraintProto& proto = self->proto()->interval(); + if (proto.end().vars().empty()) { + return py::cast(proto.end().offset()); + } else { + return py::cast(self->EndExpr()); + } + }); + + m.def( + "rebuild_from_linear_expression_proto", + [](const LinearExpressionProto& proto, + std::shared_ptr model_proto) -> py::object { + if (proto.vars().empty()) { + return py::cast(proto.offset()); + } else { + return py::cast(RebuildFromLinearExpressionProto(proto, model_proto)); + } + }, + py::arg("proto"), py::arg("model_proto")); + #define IMPORT_PROTO_WRAPPER_CODE #include "ortools/sat/python/proto_builder_pybind11.h" #undef IMPORT_PROTO_WRAPPER_CODE diff --git a/ortools/sat/python/cp_model_helper_test.py b/ortools/sat/python/cp_model_helper_test.py index 71b90845f2..b5b83640e9 100644 --- a/ortools/sat/python/cp_model_helper_test.py +++ b/ortools/sat/python/cp_model_helper_test.py @@ -51,20 +51,6 @@ class CpModelHelperTest(absltest.TestCase): super().tearDown() sys.stdout.flush() - def test_variable_domain(self): - model_string = """ - variables { domain: [ -10, 10 ] } - variables { domain: [ -5, -5, 3, 6 ] } - """ - model = cmh.CpModelProto() - self.assertTrue(model.parse_text_format(model_string)) - - d0 = cmh.CpSatHelper.variable_domain(model.variables[0]) - d1 = cmh.CpSatHelper.variable_domain(model.variables[1]) - - self.assertEqual(d0.flattened_intervals(), [-10, 10]) - self.assertEqual(d1.flattened_intervals(), [-5, -5, 3, 6]) - def test_simple_solve(self): model_string = """ variables { domain: -10 domain: 10 } @@ -101,10 +87,10 @@ class CpModelHelperTest(absltest.TestCase): self.assertTrue(model.parse_text_format(model_string)) solve_wrapper = cmh.SolveWrapper() - response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) + response = solve_wrapper.solve(model) - self.assertEqual(cmh.OPTIMAL, response_wrapper.status()) - self.assertEqual(30.0, response_wrapper.objective_value()) + self.assertEqual(cmh.OPTIMAL, response.status) + self.assertEqual(30.0, response.objective_value) def test_simple_solve_with_core(self): model_string = """ @@ -146,10 +132,10 @@ class CpModelHelperTest(absltest.TestCase): solve_wrapper = cmh.SolveWrapper() solve_wrapper.set_parameters(parameters) - response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) + response = solve_wrapper.solve(model) - self.assertEqual(cmh.OPTIMAL, response_wrapper.status()) - self.assertEqual(30.0, response_wrapper.objective_value()) + self.assertEqual(cmh.OPTIMAL, response.status) + self.assertEqual(30.0, response.objective_value) def test_simple_solve_with_proto_api(self): model = cmh.CpModelProto() @@ -168,14 +154,11 @@ class CpModelHelperTest(absltest.TestCase): model.objective.scaling_factor = -1 solve_wrapper = cmh.SolveWrapper() - response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) + response = solve_wrapper.solve(model) - self.assertEqual(cmh.OPTIMAL, response_wrapper.status()) - self.assertEqual(30.0, response_wrapper.objective_value()) - self.assertEqual(30.0, response_wrapper.best_objective_bound()) - self.assertRaises(TypeError, response_wrapper.value, None) - self.assertRaises(TypeError, response_wrapper.float_value, None) - self.assertRaises(TypeError, response_wrapper.boolean_value, None) + self.assertEqual(cmh.OPTIMAL, response.status) + self.assertEqual(30.0, response.objective_value) + self.assertEqual(30.0, response.best_objective_bound) def test_solution_callback(self): model_string = """ @@ -193,10 +176,10 @@ class CpModelHelperTest(absltest.TestCase): params = cmh.SatParameters() params.enumerate_all_solutions = True solve_wrapper.set_parameters(params) - response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) + response = solve_wrapper.solve(model) self.assertEqual(5, callback.solution_count()) - self.assertEqual(cmh.OPTIMAL, response_wrapper.status()) + self.assertEqual(cmh.OPTIMAL, response.status) def test_best_bound_callback(self): model_string = """ @@ -222,10 +205,10 @@ class CpModelHelperTest(absltest.TestCase): params.linearization_level = 2 params.log_search_progress = True solve_wrapper.set_parameters(params) - response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) + response = solve_wrapper.solve(model) self.assertEqual(2.6, best_bound_callback.best_bound) - self.assertEqual(cmh.OPTIMAL, response_wrapper.status()) + self.assertEqual(cmh.OPTIMAL, response.status) def test_model_stats(self): model_string = """ diff --git a/ortools/sat/python/cp_model_test.py b/ortools/sat/python/cp_model_test.py index c446246ef8..92a807759c 100644 --- a/ortools/sat/python/cp_model_test.py +++ b/ortools/sat/python/cp_model_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections import copy import itertools import sys @@ -184,12 +185,13 @@ class CpModelTest(absltest.TestCase): sys.stdout.flush() def test_is_boolean(self): - self.assertTrue(cp_model.arg_is_boolean(True)) - self.assertTrue(cp_model.arg_is_boolean(False)) - self.assertFalse(cp_model.arg_is_boolean(1)) - self.assertFalse(cp_model.arg_is_boolean(0)) - self.assertTrue(cp_model.arg_is_boolean(np.bool_(1))) - self.assertTrue(cp_model.arg_is_boolean(np.bool_(0))) + model = cp_model.CpModel() + self.assertTrue(model.is_boolean_value(True)) + self.assertTrue(model.is_boolean_value(False)) + self.assertFalse(model.is_boolean_value(1)) + self.assertFalse(model.is_boolean_value(0)) + self.assertTrue(model.is_boolean_value(np.bool_(1))) + self.assertTrue(model.is_boolean_value(np.bool_(0))) def test_create_integer_variable(self) -> None: model = cp_model.CpModel() @@ -220,6 +222,12 @@ class CpModelTest(absltest.TestCase): variables = set() variables.add(var_a) + accumulator = collections.defaultdict(int) + accumulator[var_a] += 1 + self.assertEqual(accumulator[var_a], 1) + accumulator[model.get_int_var_from_proto_index(var_a.index)] += 3 + self.assertEqual(accumulator[var_a], 4) + def test_literal(self) -> None: model = cp_model.CpModel() x = model.new_bool_var("x") @@ -475,6 +483,18 @@ class CpModelTest(absltest.TestCase): self.assertLen(model.proto.constraints, 3) self.assertEqual(-4, model.proto.constraints[2].enforcement_literal[0]) self.assertEqual(2, model.proto.constraints[2].enforcement_literal[1]) + model.add_linear_constraint(x + 4 * y, 0, 10).only_enforce_if([b, c, False]) + self.assertLen(model.proto.constraints, 4) + self.assertEqual(2, model.proto.constraints[3].enforcement_literal[0]) + self.assertEqual(3, model.proto.constraints[3].enforcement_literal[1]) + self.assertEqual(4, model.proto.constraints[3].enforcement_literal[2]) + model.add_linear_constraint(x + 5 * y, 0, 10).only_enforce_if( + c.negated(), b, np.True_ + ) + self.assertLen(model.proto.constraints, 5) + self.assertEqual(-4, model.proto.constraints[4].enforcement_literal[0]) + self.assertEqual(2, model.proto.constraints[4].enforcement_literal[1]) + self.assertEqual(5, model.proto.constraints[4].enforcement_literal[2]) def test_names(self) -> None: model = cp_model.CpModel() @@ -922,7 +942,8 @@ class CpModelTest(absltest.TestCase): self.assertLen(model.proto.constraints, 1) self.assertLen(model.proto.constraints[0].table.exprs, 5) self.assertLen(model.proto.constraints[0].table.values, 15) - with self.assertRaises(TypeError): + self.assertFalse(model.proto.constraints[0].table.negated) + with self.assertRaises(ValueError): model.add_allowed_assignments( x, [(0, 1, 2, 3, 4), (4, 3, 2, 1, 1), (0, 0, 0, 0)], @@ -945,7 +966,7 @@ class CpModelTest(absltest.TestCase): self.assertLen(model.proto.constraints[0].table.values, 15) self.assertTrue(model.proto.constraints[0].table.negated) self.assertRaises( - TypeError, + ValueError, model.add_forbidden_assignments, x, [(0, 1, 2, 3, 4), (4, 3, 2, 1, 1), (0, 0, 0, 0)], @@ -969,7 +990,7 @@ class CpModelTest(absltest.TestCase): self.assertLen(model.proto.constraints[0].automaton.transition_label, 4) self.assertLen(model.proto.constraints[0].automaton.final_states, 2) self.assertEqual(0, model.proto.constraints[0].automaton.starting_state) - with self.assertRaises(TypeError): + with self.assertRaises(ValueError): model.add_automaton( x, 0, @@ -1254,10 +1275,10 @@ class CpModelTest(absltest.TestCase): model.add_implication(x, y) self.assertLen(model.proto.variables, 2) self.assertLen(model.proto.constraints, 1) - self.assertLen(model.proto.constraints[0].bool_or.literals, 1) + self.assertLen(model.proto.constraints[0].bool_and.literals, 1) self.assertLen(model.proto.constraints[0].enforcement_literal, 1) self.assertEqual(x.index, model.proto.constraints[0].enforcement_literal[0]) - self.assertEqual(y.index, model.proto.constraints[0].bool_or.literals[0]) + self.assertEqual(y.index, model.proto.constraints[0].bool_and.literals[0]) def test_bool_or(self) -> None: model = cp_model.CpModel() @@ -1281,12 +1302,14 @@ class CpModelTest(absltest.TestCase): model.add_bool_or(True, x[0], x[2]) model.add_bool_or(False, x[0]) model.add_bool_or(x[i] for i in [0, 2, 3, 4]) + model.add_bool_or(x[3]) self.assertLen(model.proto.variables, 7) - self.assertLen(model.proto.constraints, 4) + self.assertLen(model.proto.constraints, 5) self.assertLen(model.proto.constraints[0].bool_or.literals, 5) self.assertLen(model.proto.constraints[1].bool_or.literals, 3) self.assertLen(model.proto.constraints[2].bool_or.literals, 2) self.assertLen(model.proto.constraints[3].bool_or.literals, 4) + self.assertLen(model.proto.constraints[4].bool_or.literals, 1) def test_at_least_one(self) -> None: model = cp_model.CpModel() @@ -1388,7 +1411,7 @@ class CpModelTest(absltest.TestCase): proto.coeffs.append(1) proto.vars.append(y.index) proto.coeffs.append(2) - expr1 = cp_model.rebuild_from_linear_expression_proto(proto, model.proto) + expr1 = cmh.rebuild_from_linear_expression_proto(proto, model.proto) canonical_expr1 = cmh.FlatIntExpr(expr1) self.assertEqual(canonical_expr1.vars[0], x) self.assertEqual(canonical_expr1.vars[1], y) @@ -1399,7 +1422,7 @@ class CpModelTest(absltest.TestCase): self.assertRaises(TypeError, canonical_expr1.vars[0].negated) proto.offset = 2 - expr2 = cp_model.rebuild_from_linear_expression_proto(proto, model.proto) + expr2 = cmh.rebuild_from_linear_expression_proto(proto, model.proto) canonical_expr2 = cmh.FlatIntExpr(expr2) self.assertEqual(canonical_expr2.vars[0], x) self.assertEqual(canonical_expr2.vars[1], y) @@ -1606,7 +1629,7 @@ class CpModelTest(absltest.TestCase): def test_model_errors(self) -> None: model = cp_model.CpModel() self.assertRaises(TypeError, model.add, "dummy") - self.assertRaises(TypeError, model.get_or_make_index, "dummy") + self.assertRaises(TypeError, model.get_or_make_variable_index, "dummy") self.assertRaises(TypeError, model.minimize, "dummy") def test_solver_errors(self) -> None: @@ -1778,6 +1801,14 @@ class CpModelTest(absltest.TestCase): self.assertTrue(cp_model.object_is_a_false_literal(False)) self.assertFalse(cp_model.object_is_a_true_literal(False)) self.assertFalse(cp_model.object_is_a_false_literal(True)) + self.assertFalse(cp_model.object_is_a_true_literal(~True)) + self.assertFalse(cp_model.object_is_a_false_literal(~False)) + self.assertTrue(cp_model.object_is_a_true_literal(~False)) + self.assertTrue(cp_model.object_is_a_false_literal(~True)) + self.assertTrue(cp_model.object_is_a_true_literal(np.True_)) + self.assertTrue(cp_model.object_is_a_false_literal(np.False_)) + self.assertFalse(cp_model.object_is_a_true_literal(np.False_)) + self.assertFalse(cp_model.object_is_a_false_literal(np.True_)) def test_solve_minimize_with_solution_callback(self) -> None: model = cp_model.CpModel() @@ -2660,6 +2691,24 @@ TRFM""" self.assertEqual(-2, prod.coefficient) self.assertEqual(2, prod.offset) + def test_pre_pep8(self): + model = cp_model.CpModel() + x = [model.NewBoolVar(f"x{i}") for i in range(5)] + model.AddBoolOr(x) + self.assertLen(model.proto.variables, 5) + self.assertLen(model.proto.constraints, 1) + self.assertLen(model.proto.constraints[0].bool_or.literals, 5) + + model_copy = copy.copy(model) + self.assertTrue(hasattr(model_copy, "AddBoolOr")) + self.assertTrue(hasattr(model_copy, "AddBoolXOr")) + self.assertTrue(hasattr(model_copy, "AddNoOverlap2D")) + + model_deepcopy = copy.deepcopy(model) + self.assertTrue(hasattr(model_deepcopy, "AddBoolOr")) + self.assertTrue(hasattr(model_deepcopy, "AddBoolXOr")) + self.assertTrue(hasattr(model_deepcopy, "AddNoOverlap2D")) + if __name__ == "__main__": absltest.main() diff --git a/ortools/sat/python/linear_expr.cc b/ortools/sat/python/linear_expr.cc index 0b87598b67..309c8f8d14 100644 --- a/ortools/sat/python/linear_expr.cc +++ b/ortools/sat/python/linear_expr.cc @@ -25,7 +25,9 @@ #include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "ortools/base/string_view_migration.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/util/fp_roundtrip_conv.h" @@ -798,10 +800,10 @@ std::string IntVar::name() const { if (model_proto_ == nullptr || index_ >= model_proto_->variables_size()) { return ""; } - return model_proto_->variables(index_).name(); + return google::protobuf::StringCopy(model_proto_->variables(index_).name()); } -void IntVar::SetName(const std::string& name) { +void IntVar::SetName(absl::string_view name) { if (model_proto_ == nullptr || index_ >= model_proto_->variables_size()) { return; } diff --git a/ortools/sat/python/linear_expr.h b/ortools/sat/python/linear_expr.h index 3e74f256a5..bd9acb5c04 100644 --- a/ortools/sat/python/linear_expr.h +++ b/ortools/sat/python/linear_expr.h @@ -24,6 +24,7 @@ #include "absl/container/fixed_array.h" #include "absl/log/check.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/util/sorted_interval_list.h" @@ -520,7 +521,7 @@ class IntVar : public Literal { /// Overwrite the name of the variable. If name is empty, this method clears /// the name of the variable. - void SetName(const std::string& name); + void SetName(absl::string_view name); /// Returns a copy of the domain of the variable. Domain domain() const; diff --git a/ortools/sat/samples/cumulative_variable_profile_sample_sat.py b/ortools/sat/samples/cumulative_variable_profile_sample_sat.py index d03bdec0af..bfc4dd740c 100644 --- a/ortools/sat/samples/cumulative_variable_profile_sample_sat.py +++ b/ortools/sat/samples/cumulative_variable_profile_sample_sat.py @@ -22,6 +22,7 @@ import pandas as pd from ortools.sat.python import cp_model +# [START data_model] def create_data_model() -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """Creates the dataframes that describes the model.""" diff --git a/ortools/sat/samples/ranking_circuit_sample_sat.py b/ortools/sat/samples/ranking_circuit_sample_sat.py index 0e91333ba0..69a9cb7833 100644 --- a/ortools/sat/samples/ranking_circuit_sample_sat.py +++ b/ortools/sat/samples/ranking_circuit_sample_sat.py @@ -14,8 +14,8 @@ """Code sample to demonstrates how to rank intervals using a circuit.""" -from typing import List, Sequence +from collections.abc import Sequence from ortools.sat.python import cp_model @@ -56,7 +56,7 @@ def rank_tasks_with_circuit( num_tasks = len(starts) all_tasks = range(num_tasks) - arcs: List[cp_model.ArcT] = [] + arcs: list[cp_model.ArcT] = [] for i in all_tasks: # if node i is first. start_lit = model.new_bool_var(f"start_{i}") diff --git a/ortools/sat/samples/sequences_in_no_overlap_sample_sat.py b/ortools/sat/samples/sequences_in_no_overlap_sample_sat.py index 7e1ff86c73..46ef5034f8 100644 --- a/ortools/sat/samples/sequences_in_no_overlap_sample_sat.py +++ b/ortools/sat/samples/sequences_in_no_overlap_sample_sat.py @@ -14,7 +14,7 @@ """Implements sequence constraints in a no_overlap constraint.""" -from typing import Dict, List, Sequence, Tuple +from collections.abc import Sequence from ortools.sat.python import cp_model @@ -26,9 +26,9 @@ def sequence_constraints_with_circuit( task_types: Sequence[str], lengths: Sequence[cp_model.IntVar], cumuls: Sequence[cp_model.IntVar], - sequence_length_constraints: Dict[str, Tuple[int, int]], - sequence_cumul_constraints: Dict[str, Tuple[int, int, int]], -) -> Sequence[Tuple[cp_model.IntVar, int]]: + sequence_length_constraints: dict[str, tuple[int, int]], + sequence_cumul_constraints: dict[str, tuple[int, int, int]], +) -> Sequence[tuple[cp_model.IntVar, int]]: """This method enforces constraints on sequences of tasks of the same type. This method assumes that all durations are strictly positive. @@ -64,7 +64,7 @@ def sequence_constraints_with_circuit( num_tasks = len(starts) all_tasks = range(num_tasks) - arcs: List[cp_model.ArcT] = [] + arcs: list[cp_model.ArcT] = [] for i in all_tasks: # if node i is first. start_lit = model.new_bool_var(f"start_{i}") diff --git a/ortools/sat/samples/transitions_in_no_overlap_sample_sat.py b/ortools/sat/samples/transitions_in_no_overlap_sample_sat.py index 5cbf236b37..441e549061 100644 --- a/ortools/sat/samples/transitions_in_no_overlap_sample_sat.py +++ b/ortools/sat/samples/transitions_in_no_overlap_sample_sat.py @@ -14,7 +14,8 @@ """Implements transition times and costs in a no_overlap constraint.""" -from typing import Dict, List, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Union from ortools.sat.python import cp_model @@ -24,9 +25,9 @@ def transitive_reduction_with_circuit_delays_and_penalties( starts: Sequence[cp_model.IntVar], durations: Sequence[int], presences: Sequence[Union[cp_model.IntVar, bool]], - penalties: Dict[Tuple[int, int], int], - delays: Dict[Tuple[int, int], int], -) -> Sequence[Tuple[cp_model.IntVar, int]]: + penalties: dict[tuple[int, int], int], + delays: dict[tuple[int, int], int], +) -> Sequence[tuple[cp_model.IntVar, int]]: """This method uses a circuit constraint to rank tasks. This method assumes that all starts are disjoint, meaning that all tasks have @@ -63,7 +64,7 @@ def transitive_reduction_with_circuit_delays_and_penalties( num_tasks = len(starts) all_tasks = range(num_tasks) - arcs: List[cp_model.ArcT] = [] + arcs: list[cp_model.ArcT] = [] penalty_terms = [] for i in all_tasks: # if node i is first. diff --git a/ortools/sat/scheduling_cuts.cc b/ortools/sat/scheduling_cuts.cc index 4c62279d7b..97381e44f2 100644 --- a/ortools/sat/scheduling_cuts.cc +++ b/ortools/sat/scheduling_cuts.cc @@ -1138,8 +1138,8 @@ void CtExhaustiveHelper::BuildPredecessors( predecessors_.clear(); if (events.size() > 100) return; - ReifiedLinear2Bounds* binary_relations = - model->GetOrCreate(); + RootLevelLinear2Bounds* root_level_bounds = + model->GetOrCreate(); std::vector sorted_events(events.begin(), events.end()); std::sort(sorted_events.begin(), sorted_events.end(), @@ -1151,7 +1151,8 @@ void CtExhaustiveHelper::BuildPredecessors( for (const auto& e1 : sorted_events) { for (const auto& e2 : sorted_events) { if (e2.task_index == e1.task_index) continue; - if (binary_relations->GetLevelZeroPrecedenceStatus(e2.end, e1.start) == + const auto [expr, ub] = EncodeDifferenceLowerThan(e2.end, e1.start, 0); + if (root_level_bounds->GetLevelZeroStatus(expr, kMinIntegerValue, ub) == RelationStatus::IS_TRUE) { while (predecessors_.size() <= e1.task_index) predecessors_.Add({}); predecessors_.AppendToLastVector(e2.task_index); diff --git a/ortools/sat/scheduling_helpers.cc b/ortools/sat/scheduling_helpers.cc index 9d25b17a7f..7016ed3280 100644 --- a/ortools/sat/scheduling_helpers.cc +++ b/ortools/sat/scheduling_helpers.cc @@ -482,17 +482,47 @@ void SchedulingConstraintHelper::AddReasonForBeingBeforeAssumingNoOverlap( AddOtherReason(before); AddOtherReason(after); - // Prefer the linear2 explanation as it is more likely this comes from - // level zero or a single enforcement literal. - // We need Start(after) >= End(before) - SizeMin(before). - // we rewrite as "End(before) - Start(after) <= SizeMin(before). - const auto [expr, ub] = - EncodeDifferenceLowerThan(ends_[before], starts_[after], SizeMin(before)); - if (linear2_bounds_->UpperBound(expr) <= ub) { - AddSizeMinReason(before); - linear2_bounds_->AddReasonForUpperBoundLowerThan(expr, ub, &literal_reason_, - &integer_reason_); - return; + // We compute this as an optimization, since for fixed sizes all linear2 + // options are equivalent. + const bool fixed_size = sizes_[before].var == kNoIntegerVariable && + sizes_[after].var == kNoIntegerVariable && + starts_[before].var == ends_[before].var && + starts_[after].var == ends_[after].var; + + // Prefer the straightforward linear2 explanation as it is more likely this + // comes from level zero or a single enforcement literal. Also handle the + // fixed size case. This explains with at most two integer bounds. + { + const auto [expr, ub] = + EncodeDifferenceLowerThan(starts_[before], ends_[after], -1); + if (fixed_size || linear2_bounds_->UpperBound(expr) <= ub) { + linear2_bounds_->AddReasonForUpperBoundLowerThan( + expr, ub, &literal_reason_, &integer_reason_); + return; + } + } + + // Another choice of Linear2. We need Start(before) < End(after). We rewrite + // as: + // End(before) - Size(before) < Start(after) + Size(after) + // + // Note that the generic code below not based on linear2 will not work + // if we can only get a good bound for End(before)-Start(after), so we also + // handle it here even if the linear2 is not known. + // + // TODO(user): check also the linear2 constructed with (end_before, + // end_after) and (start_after, start_before). Or maybe keep a index of pair + // of intervals that have non-trivial linear2 bounds and use that instead. + { + const auto [expr, ub] = EncodeDifferenceLowerThan( + ends_[before], starts_[after], SizeMin(before) + SizeMin(after) - 1); + if (linear2_bounds_->UpperBound(expr) <= ub) { + AddSizeMinReason(before); + AddSizeMinReason(after); + linear2_bounds_->AddReasonForUpperBoundLowerThan( + expr, ub, &literal_reason_, &integer_reason_); + return; + } } // We will explain StartMax(before) < EndMin(after); diff --git a/ortools/sat/swig_helper.cc b/ortools/sat/swig_helper.cc index b235b1d050..6da693027a 100644 --- a/ortools/sat/swig_helper.cc +++ b/ortools/sat/swig_helper.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include "absl/log/check.h" @@ -33,67 +34,75 @@ namespace operations_research { namespace sat { +SolutionCallback::SolutionCallback() { + // We create a dummy response. + response_ = std::make_shared(); +} + SolutionCallback::~SolutionCallback() = default; void SolutionCallback::Run( const operations_research::sat::CpSolverResponse& response) const { - response_ = response; + response_ = std::make_shared(response); has_response_ = true; OnSolutionCallback(); } int64_t SolutionCallback::NumBooleans() const { - return response_.num_booleans(); + return response_->num_booleans(); } int64_t SolutionCallback::NumBranches() const { - return response_.num_branches(); + return response_->num_branches(); } int64_t SolutionCallback::NumConflicts() const { - return response_.num_conflicts(); + return response_->num_conflicts(); } int64_t SolutionCallback::NumBinaryPropagations() const { - return response_.num_binary_propagations(); + return response_->num_binary_propagations(); } int64_t SolutionCallback::NumIntegerPropagations() const { - return response_.num_integer_propagations(); + return response_->num_integer_propagations(); } -double SolutionCallback::WallTime() const { return response_.wall_time(); } +double SolutionCallback::WallTime() const { return response_->wall_time(); } -double SolutionCallback::UserTime() const { return response_.user_time(); } +double SolutionCallback::UserTime() const { return response_->user_time(); } double SolutionCallback::DeterministicTime() const { - return response_.deterministic_time(); + return response_->deterministic_time(); } double SolutionCallback::ObjectiveValue() const { - return response_.objective_value(); + return response_->objective_value(); } double SolutionCallback::BestObjectiveBound() const { - return response_.best_objective_bound(); + return response_->best_objective_bound(); } int64_t SolutionCallback::SolutionIntegerValue(int index) const { - return index >= 0 ? response_.solution(index) - : -response_.solution(-index - 1); + return index >= 0 ? response_->solution(index) + : -response_->solution(-index - 1); } bool SolutionCallback::SolutionBooleanValue(int index) const { - return index >= 0 ? response_.solution(index) != 0 - : response_.solution(-index - 1) == 0; + return index >= 0 ? response_->solution(index) != 0 + : response_->solution(-index - 1) == 0; } void SolutionCallback::StopSearch() const { if (wrapper_ != nullptr) wrapper_->StopSearch(); } -const operations_research::sat::CpSolverResponse& SolutionCallback::Response() - const { +operations_research::sat::CpSolverResponse SolutionCallback::Response() const { + return *response_; +} + +std::shared_ptr SolutionCallback::SharedResponse() const { return response_; } diff --git a/ortools/sat/swig_helper.h b/ortools/sat/swig_helper.h index 8c5890eccd..9bedcb4599 100644 --- a/ortools/sat/swig_helper.h +++ b/ortools/sat/swig_helper.h @@ -16,6 +16,7 @@ #include #include +#include #include #include "ortools/sat/cp_model.pb.h" @@ -33,6 +34,8 @@ class SolveWrapper; // See http://www.swig.org/Doc4.0/SWIGDocumentation.html#CSharp_directors. class SolutionCallback { public: + SolutionCallback(); + virtual ~SolutionCallback(); virtual void OnSolutionCallback() const = 0; @@ -66,7 +69,9 @@ class SolutionCallback { // Stops the search. void StopSearch() const; - const operations_research::sat::CpSolverResponse& Response() const; + operations_research::sat::CpSolverResponse Response() const; + + std::shared_ptr SharedResponse() const; // We use mutable and non const methods to overcome SWIG difficulties. void SetWrapperClass(SolveWrapper* wrapper) const; @@ -76,7 +81,7 @@ class SolutionCallback { bool HasResponse() const; private: - mutable CpSolverResponse response_; + mutable std::shared_ptr response_; mutable bool has_response_ = false; mutable SolveWrapper* wrapper_ = nullptr; }; diff --git a/ortools/sat/synchronization.cc b/ortools/sat/synchronization.cc index 6b4344cc08..10f6c64bb7 100644 --- a/ortools/sat/synchronization.cc +++ b/ortools/sat/synchronization.cc @@ -503,7 +503,7 @@ void SharedResponseManager::UpdateInnerObjectiveBounds( // UNKNOWN -> FEASIBLE -> OPTIMAL // UNKNOWN -> INFEASIBLE void SharedResponseManager::NotifyThatImprovingProblemIsInfeasible( - const std::string& worker_info) { + absl::string_view worker_info) { absl::MutexLock mutex_lock(&mutex_); if (best_status_ == CpSolverStatus::FEASIBLE || best_status_ == CpSolverStatus::OPTIMAL) { diff --git a/ortools/sat/synchronization.h b/ortools/sat/synchronization.h index c6babb7d5f..1eb9590fcc 100644 --- a/ortools/sat/synchronization.h +++ b/ortools/sat/synchronization.h @@ -485,7 +485,7 @@ class SharedResponseManager { // // Note that this shouldn't be called before the solution is actually // reported. We check for this case in NewSolution(). - void NotifyThatImprovingProblemIsInfeasible(const std::string& worker_info); + void NotifyThatImprovingProblemIsInfeasible(absl::string_view worker_info); // Adds to the shared response a subset of assumptions that are enough to // make the problem infeasible. diff --git a/ortools/sat/util.cc b/ortools/sat/util.cc index dd90fbbcc1..8f212eacf8 100644 --- a/ortools/sat/util.cc +++ b/ortools/sat/util.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "absl/algorithm/container.h" @@ -1052,5 +1053,56 @@ std::vector FindMostDiverseSubset(int k, int n, return result; } +std::vector> HeuristicallySplitLongLinear( + absl::Span coeffs) { + std::vector> result; + if (coeffs.empty()) return result; + + // Split an interval [0, size) into num_parts mostly equal parts. + const auto append_splits = [&result](int offset, int size, int num_parts) { + int previous_start = 0; + for (int64_t b = 0; b < num_parts; ++b) { + const int next_start = static_cast(b + 1) * size / num_parts; + result.push_back({offset + previous_start, next_start - previous_start}); + previous_start = next_start; + } + }; + + const int num_terms = coeffs.size(); + const int num_buckets = static_cast(std::round(std::sqrt(num_terms))); + + int num_differents = 1; + for (int i = 1; i < num_terms; ++i) { + if (coeffs[i - 1] != coeffs[i]) ++num_differents; + } + + // If we don't have many different coefficients, we always create parts + // with exactly the same coeffs. We split large part evenly into size / + // expected_part_size. + if (num_differents < 20) { + const int expected_part_size = num_terms / num_buckets; + + for (int i = 0; i < num_terms;) { + int j = i + 1; + for (; j < num_terms; ++j) { + if (coeffs[j] != coeffs[i]) break; + } + + const int local_size = j - i; + const int num_local_buckets = + MathUtil::CeilOfRatio(local_size, expected_part_size); + append_splits(i, local_size, num_local_buckets); + + i = j; + } + + return result; + } + + // Otherwise we just split into num_buckets buckets. + append_splits(0, num_terms, num_buckets); + return result; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/util.h b/ortools/sat/util.h index 3ec89ce7a9..030fc254da 100644 --- a/ortools/sat/util.h +++ b/ortools/sat/util.h @@ -406,6 +406,22 @@ std::vector FindMostDiverseSubset(int k, int n, std::vector& buffer, int always_pick_mask = 0); +// HEURISTIC. Try to "cut" the list into roughly sqrt(size) equally sized parts. +// We try to keep the same coefficients in the same buckets. +// The list is assumed to be sorted. +// Return a list of pair (start, size) for each part. +// +// Context: Currently when we load long linear constraint (more than 100 terms), +// to keep the propagation and reason shorts, we always split them by adding +// intermediate variable corresponding to the sum of a subpart. We just do that +// in the CP-engine, not in the LP though. using sub-part with the same coeff +// seems to help and kind of make sense. +// +// TODO(user): This sounds sub-optimal, we should also try to add variables for +// common part between constraints, like what some of the presolve is doing. +std::vector> HeuristicallySplitLongLinear( + absl::Span coeffs); + // Simple DP to compute the maximum reachable value of a "subset sum" under // a given bound (inclusive). Note that we abort as soon as the computation // become too important. diff --git a/ortools/sat/util_test.cc b/ortools/sat/util_test.cc index 9aaef75c60..76f1dd3407 100644 --- a/ortools/sat/util_test.cc +++ b/ortools/sat/util_test.cc @@ -56,6 +56,7 @@ namespace { using ::testing::ElementsAre; using ::testing::IsEmpty; +using ::testing::Pair; TEST(CompactVectorVectorTest, EmptyCornerCases) { CompactVectorVector storage; @@ -1249,6 +1250,21 @@ TEST(FindMostDiverseSubsetTest, RandomButAlwaysPickZero) { EXPECT_EQ(best_seen, result_value); } +TEST(HeuristicallySplitLongLinearTest, BasicExamples) { + EXPECT_THAT(HeuristicallySplitLongLinear({1, 2, 3}), + ElementsAre(Pair(0, 1), Pair(1, 1), Pair(2, 1))); + EXPECT_THAT(HeuristicallySplitLongLinear({1, 1, 2, 3}), + ElementsAre(Pair(0, 2), Pair(2, 1), Pair(3, 1))); + + // The number of part is not ideal here. + EXPECT_THAT( + HeuristicallySplitLongLinear({1, 1, 1, 1, 1, 2, 3}), + ElementsAre(Pair(0, 1), Pair(1, 2), Pair(3, 2), Pair(5, 1), Pair(6, 1))); + + EXPECT_THAT(HeuristicallySplitLongLinear({1, 1, 1, 1, 3, 3, 3, 3, 3}), + ElementsAre(Pair(0, 2), Pair(2, 2), Pair(4, 2), Pair(6, 3))); +} + } // namespace } // namespace sat } // namespace operations_research diff --git a/ortools/sat/work_assignment.cc b/ortools/sat/work_assignment.cc index 11bb2324bb..010e359323 100644 --- a/ortools/sat/work_assignment.cc +++ b/ortools/sat/work_assignment.cc @@ -50,28 +50,44 @@ namespace operations_research::sat { namespace { -// We restart the shared tree 10 times after (on average) 2 tree assignments per -// worker. -const int kAssignmentsPerWorkerPerRestart = 2; const int kNumInitialRestarts = 10; // If you build a tree by expanding the nodes with minimal depth+discrepancy, -// the number of leaves when all nodes with a given value have been split +// the number of leaves when all nodes less than a given value have been split // follows the fibonacci sequence: -// num_leaves(0) := 2; -// num_leaves(1) := 3; +// num_leaves(0) := 1; +// num_leaves(1) := 2; // num_leaves(n) := num_leaves(n-1) + num_leaves(n-2) // This function returns f(n) := min({i | num_leaves(i) >= n}) int MaxAllowedDiscrepancyPlusDepth(int num_leaves) { int i = 0; int a = 1; int b = 2; - while (b < num_leaves) { + while (a < num_leaves) { std::tie(a, b) = std::make_pair(b, a + b); ++i; } return i; } + +// Returns the maximum depth of any leaf in the shared tree. +// This is an upper bound that can be computed without needing a lock on the +// shared tree. +int MaxPossibleLeafDepth(const SatParameters& params) { + const int num_leaves = params.shared_tree_open_leaves_per_worker() * + params.shared_tree_num_workers(); + switch (params.shared_tree_split_strategy()) { + case SatParameters::SPLIT_STRATEGY_DISCREPANCY: + case SatParameters::SPLIT_STRATEGY_AUTO: + return MaxAllowedDiscrepancyPlusDepth(num_leaves) + + params.shared_tree_balance_tolerance(); + case SatParameters::SPLIT_STRATEGY_BALANCED_TREE: + return std::ceil(std::log2(num_leaves)) + + params.shared_tree_balance_tolerance(); + default: + return num_leaves; + } +} } // namespace Literal ProtoLiteral::Decode(CpModelMapping* mapping, @@ -217,6 +233,7 @@ absl::Span ProtoTrail::Implications(int level) const { SharedTreeManager::SharedTreeManager(Model* model) : params_(*model->GetOrCreate()), num_workers_(params_.shared_tree_num_workers()), + max_path_depth_(MaxPossibleLeafDepth(params_)), shared_response_manager_(model->GetOrCreate()), num_splits_wanted_( num_workers_ * params_.shared_tree_open_leaves_per_worker() - 1), @@ -276,10 +293,9 @@ bool SharedTreeManager::SyncTree(ProtoTrail& path) { return false; } // Restart after processing updates - we might learn a new objective bound. - // Do initial restarts once the tree has been split a reasonable number of - // times. - if (num_leaves_assigned_since_restart_ > - kAssignmentsPerWorkerPerRestart * num_workers_ && + // Do initial restarts once each worker has had the chance to be assigned a + // leaf. + if (num_leaves_assigned_since_restart_ >= num_workers_ && num_restarts_ < kNumInitialRestarts) { RestartLockHeld(); path.Clear(); @@ -290,30 +306,39 @@ bool SharedTreeManager::SyncTree(ProtoTrail& path) { return true; } -void SharedTreeManager::ProposeSplit(ProtoTrail& path, ProtoLiteral decision) { - absl::MutexLock mutex_lock(&mu_); - if (!IsValid(path)) return; +int SharedTreeManager::TrySplitTree(absl::Span decisions, + ProtoTrail& path) { + decisions = decisions.subspan(0, max_path_depth_ - path.MaxLevel()); + if (decisions.empty()) return 0; + absl::MutexLock l(&mu_); + for (int i = 0; i < decisions.size(); ++i) { + if (!TrySplitTreeLockHeld(decisions[i], path)) return i; + } + return decisions.size(); +} + +bool SharedTreeManager::TrySplitTreeLockHeld(ProtoLiteral decision, + ProtoTrail& path) { + if (!IsValid(path)) return false; std::vector> nodes = GetAssignedNodes(path); if (nodes.back().first->closed) { VLOG(2) << "Cannot split closed node"; - return; + return false; } if (nodes.back().first->children[0] != nullptr) { LOG_IF(WARNING, nodes.size() > 1) << "Cannot resplit previously split node @ " << nodes.back().second << "/" << nodes.size(); - return; + return false; } if (nodes_.size() + 2 > max_nodes_) { VLOG(2) << "Too many nodes to accept split"; - return; + return false; } if (num_splits_wanted_ <= 0) { VLOG(2) << "Enough splits for now"; - return; + return false; } - const int num_desired_leaves = - params_.shared_tree_open_leaves_per_worker() * num_workers_; if (params_.shared_tree_split_strategy() == SatParameters::SPLIT_STRATEGY_DISCREPANCY || params_.shared_tree_split_strategy() == @@ -328,11 +353,9 @@ void SharedTreeManager::ProposeSplit(ProtoTrail& path, ProtoLiteral decision) { } // TODO(user): Need to write up the shape this creates. // This rule will allow twice as many leaves in the preferred subtree. - if (discrepancy + path.MaxLevel() > - MaxAllowedDiscrepancyPlusDepth(num_desired_leaves) + - params_.shared_tree_balance_tolerance()) { + if (discrepancy + path.MaxLevel() >= max_path_depth_) { VLOG(2) << "Too high discrepancy to accept split"; - return; + return false; } } else if (params_.shared_tree_split_strategy() == SatParameters::SPLIT_STRATEGY_OBJECTIVE_LB) { @@ -340,14 +363,7 @@ void SharedTreeManager::ProposeSplit(ProtoTrail& path, ProtoLiteral decision) { VLOG(2) << "Can only split nodes with minimum objective lb, " << nodes.back().first->objective_lb << " > " << nodes.front().first->objective_lb; - return; - } - } else if (params_.shared_tree_split_strategy() == - SatParameters::SPLIT_STRATEGY_BALANCED_TREE) { - if (path.MaxLevel() + 1 > - log2(num_desired_leaves) + params_.shared_tree_balance_tolerance()) { - VLOG(2) << "Tree too unbalanced to accept split"; - return; + return false; } } VLOG_EVERY_N(2, 10) << unassigned_leaves_.size() << " unassigned leaves, " @@ -356,6 +372,7 @@ void SharedTreeManager::ProposeSplit(ProtoTrail& path, ProtoLiteral decision) { Split(nodes, decision); auto [new_leaf, level] = nodes.back(); path.PushLevel(new_leaf->literal, new_leaf->objective_lb, new_leaf->id); + return true; } void SharedTreeManager::ReplaceTree(ProtoTrail& path) { @@ -727,9 +744,6 @@ bool SharedTreeWorker::NextDecision(LiteralIndex* decision_index) { const auto& decision_policy = heuristics_->decision_policies[heuristics_->policy_index]; const int next_level = sat_solver_->CurrentDecisionLevel() + 1; - if (next_level == assigned_tree_.MaxLevel() + 1) { - new_split_available_ = true; - } CHECK_EQ(assigned_tree_literals_.size(), assigned_tree_.MaxLevel()); if (next_level <= assigned_tree_.MaxLevel()) { VLOG(2) << "Following shared trail depth=" << next_level << " " @@ -744,26 +758,26 @@ bool SharedTreeWorker::NextDecision(LiteralIndex* decision_index) { return helper_->GetDecision(decision_policy, decision_index); } -void SharedTreeWorker::MaybeProposeSplit() { - if (!new_split_available_ || - sat_solver_->CurrentDecisionLevel() < assigned_tree_.MaxLevel() + 1 || - time_limit_->GetElapsedDeterministicTime() < next_split_dtime_) { +void SharedTreeWorker::MaybeProposeSplits() { + if (time_limit_->GetElapsedDeterministicTime() <= next_split_dtime_) { return; } - new_split_available_ = false; - const Literal split_decision = - sat_solver_->Decisions()[assigned_tree_.MaxLevel()].literal; - const std::optional encoded = EncodeDecision(split_decision); - if (encoded.has_value()) { - next_split_dtime_ = time_limit_->GetElapsedDeterministicTime() + - parameters_->shared_tree_split_min_dtime(); - CHECK_EQ(assigned_tree_literals_.size(), assigned_tree_.MaxLevel()); - manager_->ProposeSplit(assigned_tree_, *encoded); - if (assigned_tree_.MaxLevel() > assigned_tree_literals_.size()) { - assigned_tree_literals_.push_back(split_decision); - assigned_tree_implications_.push_back({}); - } - CHECK_EQ(assigned_tree_literals_.size(), assigned_tree_.MaxLevel()); + next_split_dtime_ = time_limit_->GetElapsedDeterministicTime() + + parameters_->shared_tree_split_min_dtime(); + tmp_splits_.clear(); + const int max_split_level = + std::min(trail_->CurrentDecisionLevel(), manager_->MaxPathDepth()); + for (int i = assigned_tree_.MaxLevel(); i < max_split_level; ++i) { + const Literal split_decision = sat_solver_->Decisions()[i].literal; + const std::optional encoded = EncodeDecision(split_decision); + if (!encoded.has_value()) break; + tmp_splits_.push_back(*encoded); + } + const int splits_accepted = + manager_->TrySplitTree(tmp_splits_, assigned_tree_); + for (int i = 0; i < splits_accepted; ++i) { + assigned_tree_literals_.push_back(DecodeDecision(tmp_splits_[i])); + assigned_tree_implications_.push_back({}); } } @@ -810,6 +824,10 @@ bool SharedTreeWorker::SyncWithSharedTree() { assigned_tree_lbds_.Add(restart_policy_->LbdAverageSinceReset()); restart_policy_->Reset(); earliest_replacement_dtime_ = 0; + if (assigned_tree_.MaxLevel() > 0) { + next_split_dtime_ = time_limit_->GetElapsedDeterministicTime() + + parameters_->shared_tree_split_min_dtime(); + } if (parameters_->shared_tree_worker_enable_phase_sharing()) { VLOG(2) << "Importing phase of length: " << assigned_tree_.TargetPhase().size(); @@ -893,7 +911,7 @@ SatSolver::Status SharedTreeWorker::Search( if (!helper_->TakeDecision(decision)) { return sat_solver_->UnsatStatus(); } - MaybeProposeSplit(); + MaybeProposeSplits(); } return SatSolver::LIMIT_REACHED; diff --git a/ortools/sat/work_assignment.h b/ortools/sat/work_assignment.h index d2f7463f8b..580aa4196a 100644 --- a/ortools/sat/work_assignment.h +++ b/ortools/sat/work_assignment.h @@ -208,6 +208,7 @@ class SharedTreeManager { int NumWorkers() const { return num_workers_; } int NumNodes() const ABSL_LOCKS_EXCLUDED(mu_); + int MaxPathDepth() const { return max_path_depth_; } // Syncs the state of path with the shared search tree. // Clears `path` and returns false if the assigned subtree is closed or a @@ -221,9 +222,12 @@ class SharedTreeManager { // solutions. Clears path. void CloseTree(ProtoTrail& path, int level); - // Called by workers in order to split the shared tree. - // `path` may or may not be extended by one level, branching on `decision`. - void ProposeSplit(ProtoTrail& path, ProtoLiteral decision); + // Attempts to split the tree repeatedly with the given decisions. + // `path` will be extended with the accepted splits, the opposite branches + // will be added as unassigned leaves. + // Returns the number of splits accepted. + int TrySplitTree(absl::Span decisions, ProtoTrail& path) + ABSL_LOCKS_EXCLUDED(mu_); void Restart() { absl::MutexLock l(&mu_); @@ -259,6 +263,8 @@ class SharedTreeManager { // Returns the NodeTrailInfo for `node` or it's closest non-closed, // non-implied ancestor. `node` must be valid, never returns nullptr. NodeTrailInfo* GetTrailInfo(Node* node); + bool TrySplitTreeLockHeld(ProtoLiteral decision, ProtoTrail& path) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); void Split(std::vector>& nodes, ProtoLiteral lit) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); Node* MakeSubtree(Node* parent, ProtoLiteral literal) @@ -274,6 +280,7 @@ class SharedTreeManager { mutable absl::Mutex mu_; const SatParameters& params_; const int num_workers_; + const int max_path_depth_; SharedResponseManager* const shared_response_manager_; // Stores the node id of the root, this is used to handle global restarts. @@ -320,7 +327,7 @@ class SharedTreeWorker { Literal DecodeDecision(ProtoLiteral literal); std::optional EncodeDecision(Literal decision); bool NextDecision(LiteralIndex* decision_index); - void MaybeProposeSplit(); + void MaybeProposeSplits(); bool ShouldReplaceSubtree(); bool FinishedMinRestarts() const { return assigned_tree_.MaxLevel() > 0 && @@ -360,12 +367,7 @@ class SharedTreeWorker { std::vector> assigned_tree_implications_; double next_split_dtime_ = 0; - // True if the last decision may split the assigned tree and has not yet been - // proposed to the SharedTreeManager. - // We propagate the decision before sharing with the SharedTreeManager so we - // don't share any decision that immediately leads to conflict. - bool new_split_available_ = false; - + std::vector tmp_splits_; std::vector reason_; // Stores the average LBD of learned clauses for each tree assigned since it // was assigned. diff --git a/ortools/sat/work_assignment_test.cc b/ortools/sat/work_assignment_test.cc index 90a8c287d4..d1c36acfd6 100644 --- a/ortools/sat/work_assignment_test.cc +++ b/ortools/sat/work_assignment_test.cc @@ -266,7 +266,7 @@ TEST(SharedTreeManagerTest, SplitTest) { auto* shared_tree_manager = model.GetOrCreate(); ProtoTrail shared_trail; - shared_tree_manager->ProposeSplit(shared_trail, {-1, 0}); + shared_tree_manager->TrySplitTree({{-1, 0}}, shared_trail); EXPECT_EQ(shared_trail.MaxLevel(), 1); } @@ -287,7 +287,7 @@ TEST(SharedTreeManagerTest, RestartTest) { auto* shared_tree_manager = model.GetOrCreate(); ProtoTrail shared_trail; - shared_tree_manager->ProposeSplit(shared_trail, {-1, 0}); + shared_tree_manager->TrySplitTree({{-1, 0}}, shared_trail); shared_tree_manager->Restart(); shared_tree_manager->SyncTree(shared_trail); @@ -310,7 +310,7 @@ TEST(SharedTreeManagerTest, RestartTestWithLevelZeroImplications) { auto* shared_tree_manager = model.GetOrCreate(); ProtoTrail shared_trail; - shared_tree_manager->ProposeSplit(shared_trail, {-1, 0}); + shared_tree_manager->TrySplitTree({{-1, 0}}, shared_trail); shared_tree_manager->CloseTree(shared_trail, 1); shared_tree_manager->SyncTree(shared_trail); shared_tree_manager->ReplaceTree(shared_trail); @@ -337,7 +337,7 @@ TEST(SharedTreeManagerTest, SharedBranchingTest) { auto* shared_tree_manager = model.GetOrCreate(); ProtoTrail trail1, trail2; - shared_tree_manager->ProposeSplit(trail1, {-1, 0}); + shared_tree_manager->TrySplitTree({{-1, 0}}, trail1); shared_tree_manager->ReplaceTree(trail2); EXPECT_EQ(trail1.MaxLevel(), 1); @@ -365,7 +365,7 @@ TEST(SharedTreeManagerTest, ObjectiveLbSplitTest) { auto* shared_tree_manager = model.GetOrCreate(); ProtoTrail trail1, trail2; - shared_tree_manager->ProposeSplit(trail1, {-1, 0}); + shared_tree_manager->TrySplitTree({{-1, 0}}, trail1); ASSERT_EQ(trail1.MaxLevel(), 1); trail1.SetObjectiveLb(1, 2); shared_tree_manager->SyncTree(trail1); @@ -374,7 +374,8 @@ TEST(SharedTreeManagerTest, ObjectiveLbSplitTest) { trail2.SetObjectiveLb(1, 1); shared_tree_manager->SyncTree(trail2); // Reject this split because it is not at the global lower bound. - shared_tree_manager->ProposeSplit(trail1, {int_var.index(), 3}); + ASSERT_FALSE( + shared_tree_manager->TrySplitTree({{int_var.index(), 3}}, trail1)); EXPECT_EQ(response_manager->GetInnerObjectiveLowerBound(), 1); EXPECT_EQ(shared_tree_manager->NumNodes(), 3); @@ -388,9 +389,10 @@ TEST(SharedTreeManagerTest, DiscrepancySplitTestOneLeafPerWorker) { model_builder.Maximize(int_var); Model model; SatParameters params; - params.set_num_workers(4); - params.set_shared_tree_num_workers(4); + params.set_num_workers(5); + params.set_shared_tree_num_workers(5); params.set_shared_tree_open_leaves_per_worker(1); + params.set_shared_tree_balance_tolerance(0); params.set_cp_model_presolve(false); params.set_shared_tree_split_strategy( SatParameters::SPLIT_STRATEGY_DISCREPANCY); @@ -401,19 +403,23 @@ TEST(SharedTreeManagerTest, DiscrepancySplitTestOneLeafPerWorker) { auto* shared_tree_manager = model.GetOrCreate(); ProtoTrail trail1, trail2; - shared_tree_manager->ProposeSplit(trail1, {-1, 0}); - shared_tree_manager->SyncTree(trail1); + // Reject the last split: splitting at 3 depth + 0 discrepancy is not minimal. + EXPECT_EQ(shared_tree_manager->TrySplitTree({{-1, 0}, + {int_var.index(), 3}, + {int_var.index(), 4}, + {int_var.index(), 5}}, + trail1), + 3); shared_tree_manager->ReplaceTree(trail2); - shared_tree_manager->ProposeSplit(trail2, {int_var.index(), 3}); - shared_tree_manager->ProposeSplit(trail1, {int_var.index(), 3}); - // Reject this split: 2 depth + 1 discrepancy is not minimal. - shared_tree_manager->ProposeSplit(trail2, {int_var.index(), 5}); - // Reject this split: 2 depth + 0 discrepancy is not minimal. - shared_tree_manager->ProposeSplit(trail1, {int_var.index(), 5}); + // Reject the 2nd split: 2 depth + 1 discrepancy is not minimal. + EXPECT_EQ(shared_tree_manager->TrySplitTree( + {{int_var.index(), 3}, {int_var.index(), 5}}, trail2), + 1); - EXPECT_EQ(trail1.MaxLevel(), 2); + EXPECT_EQ(shared_tree_manager->MaxPathDepth(), 3); + EXPECT_EQ(trail1.MaxLevel(), 3); EXPECT_EQ(trail2.MaxLevel(), 2); - EXPECT_EQ(shared_tree_manager->NumNodes(), 7); + EXPECT_EQ(shared_tree_manager->NumNodes(), 9); } TEST(SharedTreeManagerTest, DiscrepancySplitTest) { @@ -426,10 +432,11 @@ TEST(SharedTreeManagerTest, DiscrepancySplitTest) { SatParameters params; params.set_num_workers(2); params.set_shared_tree_num_workers(2); - params.set_shared_tree_open_leaves_per_worker(2); + params.set_shared_tree_open_leaves_per_worker(2.5); params.set_cp_model_presolve(false); params.set_shared_tree_split_strategy( SatParameters::SPLIT_STRATEGY_DISCREPANCY); + params.set_shared_tree_balance_tolerance(0); model.Add(NewSatParameters(params)); LoadVariables(model_builder.Build(), false, &model); auto* response_manager = model.GetOrCreate(); @@ -437,19 +444,18 @@ TEST(SharedTreeManagerTest, DiscrepancySplitTest) { auto* shared_tree_manager = model.GetOrCreate(); ProtoTrail trail1, trail2; - shared_tree_manager->ProposeSplit(trail1, {-1, 0}); - shared_tree_manager->SyncTree(trail1); + EXPECT_EQ(shared_tree_manager->TrySplitTree( + {{-1, 0}, {int_var.index(), 3}, {int_var.index(), 5}}, trail1), + 3); shared_tree_manager->ReplaceTree(trail2); - shared_tree_manager->ProposeSplit(trail2, {int_var.index(), 3}); - shared_tree_manager->ProposeSplit(trail1, {int_var.index(), 3}); - // Reject this split: 2 depth + 1 discrepancy is not minimal. - shared_tree_manager->ProposeSplit(trail2, {int_var.index(), 5}); - // Reject this split: 2 depth + 0 discrepancy is not minimal. - shared_tree_manager->ProposeSplit(trail1, {int_var.index(), 5}); + EXPECT_EQ(shared_tree_manager->TrySplitTree( + {{int_var.index(), 3}, {int_var.index(), 5}}, trail2), + 1); - EXPECT_EQ(trail1.MaxLevel(), 2); + EXPECT_EQ(shared_tree_manager->MaxPathDepth(), 3); + EXPECT_EQ(trail1.MaxLevel(), 3); EXPECT_EQ(trail2.MaxLevel(), 2); - EXPECT_EQ(shared_tree_manager->NumNodes(), 7); + EXPECT_EQ(shared_tree_manager->NumNodes(), 9); } TEST(SharedTreeManagerTest, BalancedSplitTestOneLeafPerWorker) { @@ -474,17 +480,24 @@ TEST(SharedTreeManagerTest, BalancedSplitTestOneLeafPerWorker) { auto* shared_tree_manager = model.GetOrCreate(); ProtoTrail trail1, trail2; - shared_tree_manager->ProposeSplit(trail1, {-1, 0}); + EXPECT_EQ(shared_tree_manager->TrySplitTree({{int_var.index(), 3}, + {int_var.index(), 2}, + {int_var.index(), 1}, + {int_var.index(), 0}}, + trail1), + 3); shared_tree_manager->SyncTree(trail1); + // Trees are assigned in FIFO order, so this will be the subtree at depth 1. shared_tree_manager->ReplaceTree(trail2); - shared_tree_manager->SyncTree(trail2); - shared_tree_manager->ProposeSplit(trail1, {int_var.index(), 3}); - // Reject this split because it creates an unbalanced tree - shared_tree_manager->ProposeSplit(trail1, {int_var.index(), 5}); - shared_tree_manager->ProposeSplit(trail2, {int_var.index(), 3}); + // Reject the final split because there are too many leaves, even though the + // depth is ok. + EXPECT_EQ(shared_tree_manager->TrySplitTree( + {{int_var.index(), 5}, {int_var.index(), 4}}, trail2), + 1); - EXPECT_EQ(shared_tree_manager->NumNodes(), 7); - EXPECT_EQ(trail1.MaxLevel(), 2); + EXPECT_EQ(shared_tree_manager->MaxPathDepth(), 3); + EXPECT_EQ(shared_tree_manager->NumNodes(), 9); + EXPECT_EQ(trail1.MaxLevel(), 3); EXPECT_EQ(trail2.MaxLevel(), 2); } @@ -496,8 +509,8 @@ TEST(SharedTreeManagerTest, BalancedSplitTest) { model_builder.Maximize(int_var); Model model; SatParameters params; - params.set_num_workers(3); - params.set_shared_tree_num_workers(3); + params.set_num_workers(4); + params.set_shared_tree_num_workers(4); params.set_shared_tree_open_leaves_per_worker(2); params.set_cp_model_presolve(false); params.set_shared_tree_split_strategy( @@ -510,18 +523,24 @@ TEST(SharedTreeManagerTest, BalancedSplitTest) { auto* shared_tree_manager = model.GetOrCreate(); ProtoTrail trail1, trail2; - shared_tree_manager->ProposeSplit(trail1, {-1, 0}); - shared_tree_manager->SyncTree(trail1); + EXPECT_EQ(shared_tree_manager->TrySplitTree({{int_var.index(), 3}, + {int_var.index(), 2}, + {int_var.index(), 1}, + {int_var.index(), 0}}, + trail1), + 3); shared_tree_manager->ReplaceTree(trail2); - shared_tree_manager->SyncTree(trail2); - shared_tree_manager->ProposeSplit(trail1, {int_var.index(), 3}); - // Reject this split because it creates an unbalanced tree - shared_tree_manager->ProposeSplit(trail1, {int_var.index(), 5}); - shared_tree_manager->ProposeSplit(trail2, {int_var.index(), 3}); + EXPECT_EQ(shared_tree_manager->TrySplitTree({{int_var.index(), 6}, + {int_var.index(), 5}, + {int_var.index(), 4}, + {int_var.index(), 3}}, + trail2), + 2); - EXPECT_EQ(shared_tree_manager->NumNodes(), 7); - EXPECT_EQ(trail1.MaxLevel(), 2); - EXPECT_EQ(trail2.MaxLevel(), 2); + EXPECT_EQ(shared_tree_manager->MaxPathDepth(), 3); + EXPECT_EQ(shared_tree_manager->NumNodes(), 11); + EXPECT_EQ(trail1.MaxLevel(), 3); + EXPECT_EQ(trail2.MaxLevel(), 3); } TEST(SharedTreeManagerTest, CloseTreeTest) { @@ -538,10 +557,9 @@ TEST(SharedTreeManagerTest, CloseTreeTest) { model.Add(NewSatParameters(params)); LoadVariables(model_builder.Build(), false, &model); auto* shared_tree_manager = model.GetOrCreate(); - ProtoTrail trail1, trail2, trail3; - shared_tree_manager->ProposeSplit(trail1, {-1, 0}); + ProtoTrail trail1, trail2; + EXPECT_EQ(shared_tree_manager->TrySplitTree({{-1, 0}, {1, 0}}, trail1), 2); shared_tree_manager->ReplaceTree(trail2); - shared_tree_manager->ProposeSplit(trail1, {1, 0}); shared_tree_manager->CloseTree(trail1, 1); shared_tree_manager->ReplaceTree(trail1); @@ -568,7 +586,7 @@ TEST(SharedTreeManagerTest, TrailSharing) { auto* shared_tree_manager = model.GetOrCreate(); ProtoTrail trail1, trail2; - shared_tree_manager->ProposeSplit(trail1, ProtoLiteral(0, 1)); + shared_tree_manager->TrySplitTree({ProtoLiteral(0, 1)}, trail1); trail1.AddImplication(1, ProtoLiteral(1, 1)); trail1.AddImplication(1, ProtoLiteral(1, 3)); shared_tree_manager->SyncTree(trail1);