From 4151254eba725b2f8cd7fb895bbd70018a6f3648 Mon Sep 17 00:00:00 2001 From: Corentin Le Molgat Date: Mon, 21 Jul 2025 17:25:33 +0200 Subject: [PATCH] sat: backport from main --- ortools/sat/2d_distances_propagator.cc | 211 +-- ortools/sat/2d_distances_propagator.h | 50 +- ortools/sat/BUILD.bazel | 25 +- ortools/sat/combine_solutions.cc | 21 +- ortools/sat/combine_solutions.h | 3 +- ortools/sat/constraint_violation.cc | 4 +- ortools/sat/constraint_violation.h | 4 +- ortools/sat/cp_model_lns.cc | 2 +- ortools/sat/cp_model_lns_test.cc | 8 +- ortools/sat/cp_model_loader.cc | 7 +- ortools/sat/cp_model_mapping.h | 1 - ortools/sat/cp_model_presolve.cc | 25 +- ortools/sat/cp_model_search.cc | 4 +- ortools/sat/cp_model_solver.cc | 95 +- ortools/sat/cp_model_solver_helpers.cc | 134 +- ortools/sat/cp_model_solver_helpers.h | 8 + ortools/sat/cp_model_solver_test.cc | 12 +- ortools/sat/cp_model_symmetries.cc | 4 + ortools/sat/cumulative.cc | 4 +- ortools/sat/cumulative_energy_test.cc | 10 +- ortools/sat/diffn.cc | 2 +- ortools/sat/disjunctive.cc | 104 +- ortools/sat/disjunctive.h | 12 +- ortools/sat/disjunctive_test.cc | 30 +- ortools/sat/feasibility_jump.cc | 23 +- ortools/sat/feasibility_jump.h | 2 +- ortools/sat/feasibility_pump.cc | 4 +- ortools/sat/flaky_models_test.cc | 2 +- ortools/sat/go/cpmodel/cp_model.go | 1 + ortools/sat/go/cpmodel/cp_model_test.go | 3 +- ortools/sat/integer.cc | 30 +- ortools/sat/integer.h | 49 +- ortools/sat/integer_base.cc | 121 +- ortools/sat/integer_base.h | 109 +- ortools/sat/integer_base_test.cc | 21 +- ortools/sat/integer_search.cc | 21 +- ortools/sat/intervals.cc | 24 +- ortools/sat/intervals.h | 3 +- ortools/sat/java/BUILD.bazel | 2 +- ortools/sat/java/CMakeLists.txt | 10 +- ortools/sat/java/{sat.i => sat.swig} | 4 +- ortools/sat/linear_propagation.cc | 11 +- ortools/sat/linear_propagation.h | 4 +- ortools/sat/linear_relaxation.cc | 4 +- ortools/sat/no_overlap_2d_helper.cc | 8 +- ortools/sat/opb_reader.h | 21 +- ortools/sat/parameters_validation.cc | 2 + ortools/sat/precedences.cc | 1129 ++++++++-------- ortools/sat/precedences.h | 1137 ++++++++++------- ortools/sat/precedences_test.cc | 486 ++++--- ortools/sat/primary_variables.cc | 15 + ortools/sat/python/BUILD.bazel | 1 + ortools/sat/python/CMakeLists.txt | 56 +- ortools/sat/python/cp_model.py | 818 +++++------- ortools/sat/python/cp_model_helper_test.py | 163 ++- ortools/sat/python/cp_model_numbers.py | 67 - ortools/sat/python/cp_model_numbers_test.py | 63 - ortools/sat/python/cp_model_test.py | 167 ++- .../sat/python/gen_proto_builder_pybind11.cc | 49 + ortools/sat/python/linear_expr.cc | 162 ++- ortools/sat/python/linear_expr.h | 142 +- ortools/sat/python/linear_expr_doc.h | 46 +- ortools/sat/python/wrappers.cc | 450 +++++++ ortools/sat/python/wrappers.h | 31 + ortools/sat/rins.cc | 7 +- ortools/sat/rins_test.cc | 16 +- ortools/sat/routing_cuts.cc | 143 ++- ortools/sat/routing_cuts.h | 1 + ortools/sat/routing_cuts_test.cc | 175 +-- ortools/sat/samples/assumptions_sample_sat.go | 1 + .../sat/samples/boolean_product_sample_sat.go | 3 +- ortools/sat/samples/channeling_sample_sat.go | 3 +- .../earliness_tardiness_cost_sample_sat.go | 3 +- ortools/sat/samples/no_overlap_sample_sat.go | 1 + .../sat/samples/rabbits_and_pheasants_sat.go | 1 + ortools/sat/samples/ranking_sample_sat.go | 1 + .../search_for_all_solutions_sample_sat.go | 3 +- ortools/sat/samples/simple_sat_program.go | 1 + .../samples/solution_hinting_sample_sat.go | 1 + ...print_intermediate_solutions_sample_sat.go | 3 +- .../solve_with_time_limit_sample_sat.go | 3 +- .../sat/samples/step_function_sample_sat.go | 3 +- ortools/sat/sat_decision.cc | 23 + ortools/sat/sat_decision.h | 24 +- ortools/sat/sat_decision_test.cc | 69 + ortools/sat/sat_parameters.proto | 28 +- ortools/sat/sat_solver.cc | 2 +- ortools/sat/scheduling_cuts.cc | 181 +-- ortools/sat/scheduling_cuts.h | 34 + ortools/sat/scheduling_cuts_test.cc | 33 +- ortools/sat/scheduling_helpers.cc | 86 +- ortools/sat/scheduling_helpers.h | 46 +- ortools/sat/shaving_solver.cc | 6 +- ortools/sat/solution_crush.cc | 34 + ortools/sat/solution_crush.h | 7 + ortools/sat/synchronization.cc | 316 ++++- ortools/sat/synchronization.h | 437 ++++++- ortools/sat/synchronization_test.cc | 20 +- ortools/sat/util.cc | 44 + ortools/sat/util.h | 16 +- ortools/sat/util_test.cc | 89 ++ ortools/sat/work_assignment.cc | 49 +- ortools/sat/work_assignment.h | 20 +- ortools/sat/work_assignment_test.cc | 12 + 104 files changed, 5372 insertions(+), 2814 deletions(-) rename ortools/sat/java/{sat.i => sat.swig} (98%) delete mode 100644 ortools/sat/python/cp_model_numbers.py delete mode 100644 ortools/sat/python/cp_model_numbers_test.py create mode 100644 ortools/sat/python/gen_proto_builder_pybind11.cc create mode 100644 ortools/sat/python/wrappers.cc create mode 100644 ortools/sat/python/wrappers.h diff --git a/ortools/sat/2d_distances_propagator.cc b/ortools/sat/2d_distances_propagator.cc index 71b44cabc3..a5e614e921 100644 --- a/ortools/sat/2d_distances_propagator.cc +++ b/ortools/sat/2d_distances_propagator.cc @@ -13,12 +13,16 @@ #include "ortools/sat/2d_distances_propagator.h" +#include #include #include #include +#include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/types/span.h" @@ -30,6 +34,7 @@ #include "ortools/sat/model.h" #include "ortools/sat/no_overlap_2d_helper.h" #include "ortools/sat/precedences.h" +#include "ortools/sat/sat_base.h" #include "ortools/sat/scheduling_helpers.h" #include "ortools/sat/synchronization.h" @@ -39,21 +44,17 @@ namespace sat { Precedences2DPropagator::Precedences2DPropagator( NoOverlap2DConstraintHelper* helper, Model* model) : helper_(*helper), - binary_relations_maps_(model->GetOrCreate()), - shared_stats_(model->GetOrCreate()) { + linear2_bounds_(model->GetOrCreate()), + linear2_watcher_(model->GetOrCreate()), + shared_stats_(model->GetOrCreate()), + lin2_indices_(model->GetOrCreate()), + trail_(model->GetOrCreate()), + integer_trail_(model->GetOrCreate()) { model->GetOrCreate()->SetPushAffineUbForBinaryRelation(); } -void Precedences2DPropagator::CollectPairsOfBoxesWithNonTrivialDistance() { - helper_.SynchronizeAndSetDirection(); - non_trivial_pairs_.clear(); - - struct VarUsage { - // boxes[0=x, 1=y][0=start, 1=end] - std::vector boxes[2][2]; - }; - absl::flat_hash_map var_to_box_and_coeffs; - +void Precedences2DPropagator::UpdateVarLookups() { + var_to_box_and_coeffs_.clear(); for (int dim = 0; dim < 2; ++dim) { const SchedulingConstraintHelper& dim_helper = dim == 0 ? helper_.x_helper() : helper_.y_helper(); @@ -61,101 +62,137 @@ void Precedences2DPropagator::CollectPairsOfBoxesWithNonTrivialDistance() { const absl::Span interval_points = j == 0 ? dim_helper.Starts() : dim_helper.Ends(); for (int i = 0; i < helper_.NumBoxes(); ++i) { - if (interval_points[i].var != kNoIntegerVariable) { - var_to_box_and_coeffs[PositiveVariable(interval_points[i].var)] - .boxes[dim][j] - .push_back(i); + const IntegerVariable var = interval_points[i].var; + if (var != kNoIntegerVariable) { + var_to_box_and_coeffs_[PositiveVariable(var)].boxes[dim][j].push_back( + i); } } } } +} - VLOG(2) << "CollectPairsOfBoxesWithNonTrivialDistance called, num_exprs: " - << binary_relations_maps_->GetAllExpressionsWithAffineBounds().size(); - for (const LinearExpression2& expr : - binary_relations_maps_->GetAllExpressionsWithAffineBounds()) { - auto it1 = var_to_box_and_coeffs.find(PositiveVariable(expr.vars[0])); - auto it2 = var_to_box_and_coeffs.find(PositiveVariable(expr.vars[1])); - if (it1 == var_to_box_and_coeffs.end() || - it2 == var_to_box_and_coeffs.end()) { - continue; +void Precedences2DPropagator::AddOrUpdateDataForPairOfBoxes(int box1, + int box2) { + if (box1 > box2) std::swap(box1, box2); + const auto [it, inserted] = non_trivial_pairs_index_.insert( + {std::make_pair(box1, box2), static_cast(pair_data_.size())}); + absl::InlinedVector presence_literals; + for (int dim = 0; dim < 2; ++dim) { + const SchedulingConstraintHelper& dim_helper = + dim == 0 ? helper_.x_helper() : helper_.y_helper(); + for (const int box : {box1, box2}) { + if (dim_helper.IsOptional(box)) { + presence_literals.push_back(dim_helper.PresenceLiteral(box)); + } } + } + gtl::STLSortAndRemoveDuplicates(&presence_literals); + if (inserted) { + pair_data_.emplace_back( + PairData{.pair_presence_literals = {presence_literals.begin(), + presence_literals.end()}, + .box1 = box1, + .box2 = box2}); + } + PairData& pair_data = pair_data_[it->second]; + for (int dim = 0; dim < 2; ++dim) { + const SchedulingConstraintHelper& dim_helper = + dim == 0 ? helper_.x_helper() : helper_.y_helper(); + for (int j = 0; j < 2; ++j) { + int b1 = j == 0 ? box1 : box2; + int b2 = j == 0 ? box2 : box1; + auto [start_minus_end_expr, start_minus_end_ub] = + EncodeDifferenceLowerThan(dim_helper.Starts()[b1], + dim_helper.Ends()[b2], 0); + const LinearExpression2Index start_minus_end_index = + lin2_indices_->GetIndex(start_minus_end_expr); + pair_data.start_before_end[dim][j].ub = start_minus_end_ub; + if (start_minus_end_index != kNoLinearExpression2Index) { + pair_data.start_before_end[dim][j].linear2 = start_minus_end_index; + } else { + pair_data.start_before_end[dim][j].linear2 = start_minus_end_expr; + } + } + } +} - const VarUsage& usage1 = it1->second; - const VarUsage& usage2 = it2->second; - for (int dim = 0; dim < 2; ++dim) { - const SchedulingConstraintHelper& dim_helper = - dim == 0 ? helper_.x_helper() : helper_.y_helper(); - for (const int box1 : usage1.boxes[dim][0 /* start */]) { - for (const int box2 : usage2.boxes[dim][1 /* end */]) { - if (box1 == box2) continue; - const AffineExpression& start = dim_helper.Starts()[box1]; - const AffineExpression& end = dim_helper.Ends()[box2]; - LinearExpression2 expr2; - expr2.vars[0] = start.var; - expr2.vars[1] = end.var; - expr2.coeffs[0] = start.coeff; - expr2.coeffs[1] = -end.coeff; - expr2.SimpleCanonicalization(); - expr2.DivideByGcd(); - if (expr == expr2) { - if (box1 < box2) { - non_trivial_pairs_.push_back({box1, box2}); - } else { - non_trivial_pairs_.push_back({box2, box1}); - } +void Precedences2DPropagator::CollectNewPairsOfBoxesWithNonTrivialDistance() { + const absl::Span exprs = + lin2_indices_->GetStoredLinear2Indices(); + if (exprs.size() == num_known_linear2_) { + return; + } + VLOG(2) << "CollectPairsOfBoxesWithNonTrivialDistance called, num_exprs: " + << exprs.size(); + for (; num_known_linear2_ < exprs.size(); ++num_known_linear2_) { + const LinearExpression2& positive_expr = exprs[num_known_linear2_]; + LinearExpression2 negated_expr = positive_expr; + negated_expr.Negate(); + for (const LinearExpression2& expr : {positive_expr, negated_expr}) { + auto it1 = var_to_box_and_coeffs_.find(PositiveVariable(expr.vars[0])); + auto it2 = var_to_box_and_coeffs_.find(PositiveVariable(expr.vars[1])); + if (it1 == var_to_box_and_coeffs_.end()) { + continue; + } + if (it2 == var_to_box_and_coeffs_.end()) { + continue; + } + + const VarUsage& usage1 = it1->second; + const VarUsage& usage2 = it2->second; + for (int dim = 0; dim < 2; ++dim) { + for (const int box1 : usage1.boxes[dim][0 /* start */]) { + for (const int box2 : usage2.boxes[dim][1 /* end */]) { + if (box1 == box2) continue; + AddOrUpdateDataForPairOfBoxes(box1, box2); } } } } } +} - gtl::STLSortAndRemoveDuplicates(&non_trivial_pairs_); +IntegerValue Precedences2DPropagator::UpperBound( + std::variant linear2) const { + if (std::holds_alternative(linear2)) { + return linear2_bounds_->UpperBound( + std::get(linear2)); + } else { + return integer_trail_->UpperBound(std::get(linear2)); + } } bool Precedences2DPropagator::Propagate() { - if (!helper_.SynchronizeAndSetDirection()) return false; - if (last_helper_inprocessing_count_ != helper_.InProcessingCount() || - helper_.x_helper().CurrentDecisionLevel() == 0 || - last_num_expressions_ != - binary_relations_maps_->NumExpressionsWithAffineBounds()) { + if (last_helper_inprocessing_count_ != helper_.InProcessingCount()) { + if (!helper_.SynchronizeAndSetDirection()) return false; last_helper_inprocessing_count_ = helper_.InProcessingCount(); - last_num_expressions_ = - binary_relations_maps_->NumExpressionsWithAffineBounds(); - CollectPairsOfBoxesWithNonTrivialDistance(); + UpdateVarLookups(); + num_known_linear2_ = 0; + non_trivial_pairs_index_.clear(); + pair_data_.clear(); } + CollectNewPairsOfBoxesWithNonTrivialDistance(); num_calls_++; SchedulingConstraintHelper* helpers[2] = {&helper_.x_helper(), &helper_.y_helper()}; - for (const auto& [box1, box2] : non_trivial_pairs_) { - DCHECK(box1 < helper_.NumBoxes()); - DCHECK(box2 < helper_.NumBoxes()); - DCHECK_NE(box1, box2); - if (!helper_.IsPresent(box1) || !helper_.IsPresent(box2)) { + for (const PairData& pair_data : pair_data_) { + if (!absl::c_all_of(pair_data.pair_presence_literals, + [this](const Literal& literal) { + return trail_->Assignment().LiteralIsTrue(literal); + })) { continue; } bool is_unfeasible = true; for (int dim = 0; dim < 2; dim++) { - const SchedulingConstraintHelper* helper = helpers[dim]; for (int j = 0; j < 2; j++) { - int b1 = box1; - int b2 = box2; - if (j == 1) { - std::swap(b1, b2); - } - LinearExpression2 expr; - expr.vars[0] = helper->Starts()[b1].var; - expr.vars[1] = helper->Ends()[b2].var; - expr.coeffs[0] = helper->Starts()[b1].coeff; - expr.coeffs[1] = -helper->Ends()[b2].coeff; - const IntegerValue ub_of_start_minus_end_value = - binary_relations_maps_->UpperBound(expr) + - helper->Starts()[b1].constant - helper->Ends()[b2].constant; - if (ub_of_start_minus_end_value >= 0) { + const PairData::Condition& start_before_end = + pair_data.start_before_end[dim][j]; + if (UpperBound(start_before_end.linear2) >= start_before_end.ub) { is_unfeasible = false; break; } @@ -165,7 +202,10 @@ bool Precedences2DPropagator::Propagate() { if (!is_unfeasible) continue; // We have a mandatory overlap on both x and y! Explain and propagate. + if (!helper_.SynchronizeAndSetDirection()) return false; + const int box1 = pair_data.box1; + const int box2 = pair_data.box2; helper_.ClearReason(); num_conflicts_++; @@ -177,15 +217,10 @@ bool Precedences2DPropagator::Propagate() { if (j == 1) { std::swap(b1, b2); } - LinearExpression2 expr; - expr.vars[0] = helper->Starts()[b1].var; - expr.vars[1] = helper->Ends()[b2].var; - expr.coeffs[0] = helper->Starts()[b1].coeff; - expr.coeffs[1] = -helper->Ends()[b2].coeff; - binary_relations_maps_->AddReasonForUpperBoundLowerThan( - expr, - -(helper->Starts()[b1].constant - helper->Ends()[b2].constant) - 1, - helper_.x_helper().MutableLiteralReason(), + const auto [expr, ub] = EncodeDifferenceLowerThan( + helper->Starts()[b1], helper->Ends()[b2], -1); + linear2_bounds_->AddReasonForUpperBoundLowerThan( + expr, ub, helper_.x_helper().MutableLiteralReason(), helper_.x_helper().MutableIntegerReason()); } } @@ -199,7 +234,7 @@ bool Precedences2DPropagator::Propagate() { int Precedences2DPropagator::RegisterWith(GenericLiteralWatcher* watcher) { const int id = watcher->Register(this); helper_.WatchAllBoxes(id); - binary_relations_maps_->WatchAllLinearExpressions2(id); + linear2_watcher_->WatchAllLinearExpressions2(id); return id; } @@ -208,7 +243,7 @@ Precedences2DPropagator::~Precedences2DPropagator() { std::vector> stats; stats.push_back({"Precedences2DPropagator/called", num_calls_}); stats.push_back({"Precedences2DPropagator/conflicts", num_conflicts_}); - stats.push_back({"Precedences2DPropagator/pairs", non_trivial_pairs_.size()}); + stats.push_back({"Precedences2DPropagator/pairs", pair_data_.size()}); shared_stats_->AddStats(stats); } diff --git a/ortools/sat/2d_distances_propagator.h b/ortools/sat/2d_distances_propagator.h index b05e6b1a3d..d418340914 100644 --- a/ortools/sat/2d_distances_propagator.h +++ b/ortools/sat/2d_distances_propagator.h @@ -16,22 +16,27 @@ #include #include +#include #include +#include "absl/container/fixed_array.h" +#include "absl/container/flat_hash_map.h" #include "ortools/sat/integer.h" +#include "ortools/sat/integer_base.h" #include "ortools/sat/model.h" #include "ortools/sat/no_overlap_2d_helper.h" #include "ortools/sat/precedences.h" +#include "ortools/sat/sat_base.h" #include "ortools/sat/synchronization.h" namespace operations_research { namespace sat { // This class implements a propagator for non_overlap_2d constraints that uses -// the BinaryRelationsMaps to detect precedences between pairs of boxes and +// the Linear2Bounds to detect precedences between pairs of boxes and // detect a conflict if the precedences implies an overlap between the two -// boxes. For doing this efficiently, it keep track of pairs of boxes that have -// non-fixed precedences in the BinaryRelationsMaps and only check those in the +// boxes. For doing this efficiently, it keeps track of pairs of boxes that have +// non-fixed precedences in the Linear2Bounds and only check those in the // propagation. class Precedences2DPropagator : public PropagatorInterface { public: @@ -43,16 +48,47 @@ class Precedences2DPropagator : public PropagatorInterface { int RegisterWith(GenericLiteralWatcher* watcher); private: - void CollectPairsOfBoxesWithNonTrivialDistance(); + void CollectNewPairsOfBoxesWithNonTrivialDistance(); + void UpdateVarLookups(); + IntegerValue UpperBound( + std::variant linear2) const; + void AddOrUpdateDataForPairOfBoxes(int box1, int box2); - std::vector> non_trivial_pairs_; + struct PairData { + // The condition must be true if ub(linear2) < ub. + struct Condition { + // If the expression is in the Linear2Indices it is represented by its + // index, otherwise it is represented by the expression itself. + std::variant linear2; + IntegerValue ub; + }; + + absl::FixedArray pair_presence_literals; + int box1; + int box2; + // start1_before_end2[0==x, 1==y][0=start_1_end_2, 1=start_2_end_1] + Condition start_before_end[2][2]; + }; + absl::flat_hash_map, int> non_trivial_pairs_index_; + std::vector pair_data_; + struct VarUsage { + // boxes[0=x, 1=y][0=start, 1=end] + std::vector boxes[2][2]; + }; + + absl::flat_hash_map var_to_box_and_coeffs_; NoOverlap2DConstraintHelper& helper_; - BinaryRelationsMaps* binary_relations_maps_; + Linear2Bounds* linear2_bounds_; + Linear2Watcher* linear2_watcher_; SharedStatistics* shared_stats_; + Linear2Indices* lin2_indices_; + Trail* trail_; + IntegerTrail* integer_trail_; int last_helper_inprocessing_count_ = -1; - int last_num_expressions_ = -1; + int num_known_linear2_ = 0; + int64_t last_linear2_timestamp_ = -1; int64_t num_conflicts_ = 0; int64_t num_calls_ = 0; diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index 58a506a26a..7acad72904 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -131,7 +131,10 @@ cc_library( ":scheduling_helpers", ":synchronization", "//ortools/base:stl_util", + "@abseil-cpp//absl/algorithm:container", + "@abseil-cpp//absl/container:fixed_array", "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/container:inlined_vector", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/types:span", @@ -314,8 +317,10 @@ cc_library( "@abseil-cpp//absl/hash", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/numeric:int128", "@abseil-cpp//absl/random", "@abseil-cpp//absl/random:bit_gen_ref", + "@abseil-cpp//absl/random:distributions", "@abseil-cpp//absl/status", "@abseil-cpp//absl/strings", "@abseil-cpp//absl/strings:str_format", @@ -815,7 +820,6 @@ cc_library( deps = [ ":cp_model_cc_proto", ":cp_model_utils", - ":integer", ":integer_base", ":linear_constraint", ":model", @@ -1096,6 +1100,7 @@ cc_library( ":integer", ":integer_base", ":model", + ":precedences", ":presolve_context", ":presolve_util", ":probing", @@ -1133,7 +1138,6 @@ cc_library( "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/log:vlog_is_on", - "@abseil-cpp//absl/meta:type_traits", "@abseil-cpp//absl/numeric:int128", "@abseil-cpp//absl/random:distributions", "@abseil-cpp//absl/status:statusor", @@ -1501,6 +1505,7 @@ cc_library( "//ortools/util:bitset", "//ortools/util:integer_pq", "//ortools/util:strong_integers", + "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/types:span", ], @@ -1751,6 +1756,7 @@ cc_library( ":sat_base", "//ortools/base", "//ortools/base:strong_vector", + "//ortools/util:bitset", "//ortools/util:saturated_arithmetic", "//ortools/util:sorted_interval_list", "//ortools/util:strong_integers", @@ -2055,6 +2061,7 @@ cc_library( deps = [ ":clause", ":cp_constraints", + ":cp_model_mapping", ":integer", ":integer_base", ":model", @@ -2081,6 +2088,7 @@ cc_library( "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/log:vlog_is_on", + "@abseil-cpp//absl/strings:str_format", "@abseil-cpp//absl/types:span", ], ) @@ -2104,6 +2112,7 @@ cc_test( "//ortools/base:parse_test_proto", "//ortools/util:sorted_interval_list", "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/types:span", ], ) @@ -2288,6 +2297,7 @@ cc_library( ":integer", ":integer_base", ":intervals", + ":linear_propagation", ":model", ":precedences", ":sat_base", @@ -2315,6 +2325,7 @@ cc_test( ":disjunctive", ":integer", ":integer_base", + ":integer_expr", ":integer_search", ":intervals", ":model", @@ -2569,7 +2580,6 @@ cc_library( "//ortools/base:stl_util", "//ortools/base:strong_vector", "//ortools/util:logging", - "//ortools/util:saturated_arithmetic", "//ortools/util:sorted_interval_list", "//ortools/util:strong_integers", "@abseil-cpp//absl/base:core_headers", @@ -3173,6 +3183,7 @@ cc_library( "@abseil-cpp//absl/container:inlined_vector", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/log:log_streamer", + "@abseil-cpp//absl/numeric:bits", "@abseil-cpp//absl/numeric:int128", "@abseil-cpp//absl/random", "@abseil-cpp//absl/random:bit_gen_ref", @@ -3202,6 +3213,7 @@ cc_test( "@abseil-cpp//absl/container:btree", "@abseil-cpp//absl/container:flat_hash_set", "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/numeric:bits", "@abseil-cpp//absl/numeric:int128", "@abseil-cpp//absl/random", "@abseil-cpp//absl/strings", @@ -3571,7 +3583,6 @@ cc_library( ":synchronization", ":timetable", ":util", - "//ortools/base:stl_util", "//ortools/util:bitset", "//ortools/util:saturated_arithmetic", "//ortools/util:strong_integers", @@ -3850,6 +3861,7 @@ cc_test( "//ortools/base:gmock_main", "//ortools/base:parse_test_proto", "//ortools/util:random_engine", + "@abseil-cpp//absl/strings", "@abseil-cpp//absl/types:span", ], ) @@ -4017,13 +4029,17 @@ cc_binary( "//ortools/base:path", "//ortools/util:file_util", "//ortools/util:logging", + "//ortools/util:sigint", "//ortools/util:sorted_interval_list", + "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/flags:flag", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/log:flags", "@abseil-cpp//absl/strings", "@abseil-cpp//absl/strings:str_format", + "@abseil-cpp//absl/synchronization", + "@abseil-cpp//absl/types:span", "@protobuf", ], ) @@ -4268,6 +4284,7 @@ cc_library( "//ortools/util:running_stat", "//ortools/util:strong_integers", "//ortools/util:time_limit", + "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:flat_hash_set", diff --git a/ortools/sat/combine_solutions.cc b/ortools/sat/combine_solutions.cc index 679074675e..872a9e4b4f 100644 --- a/ortools/sat/combine_solutions.cc +++ b/ortools/sat/combine_solutions.cc @@ -53,7 +53,9 @@ std::optional> FindCombinedSolution( CHECK_EQ(new_solution.size(), base_solution.size()); const std::vector< std::shared_ptr::Solution>> - solutions = response_manager->SolutionsRepository().GetBestNSolutions(10); + solutions = + response_manager->SolutionPool().BestSolutions().GetBestNSolutions( + 10); for (int sol_idx = 0; sol_idx < solutions.size(); ++sol_idx) { std::shared_ptr::Solution> s = @@ -79,18 +81,21 @@ std::optional> FindCombinedSolution( PushedSolutionPointers PushAndMaybeCombineSolution( SharedResponseManager* response_manager, const CpModelProto& model_proto, absl::Span new_solution, const std::string& solution_info, - absl::Span base_solution, Model* model) { + std::shared_ptr::Solution> + base_solution) { PushedSolutionPointers result = {nullptr, nullptr}; - result.pushed_solution = - response_manager->NewSolution(new_solution, solution_info, model); - if (!base_solution.empty()) { + result.pushed_solution = response_manager->NewSolution( + new_solution, solution_info, nullptr, + base_solution == nullptr ? -1 : base_solution->source_id); + if (base_solution != nullptr) { std::string combined_solution_info = solution_info; std::optional> combined_solution = - FindCombinedSolution(model_proto, new_solution, base_solution, - response_manager, &combined_solution_info); + FindCombinedSolution(model_proto, new_solution, + base_solution->variable_values, response_manager, + &combined_solution_info); if (combined_solution.has_value()) { result.improved_solution = response_manager->NewSolution( - combined_solution.value(), combined_solution_info, model); + combined_solution.value(), combined_solution_info); } } return result; diff --git a/ortools/sat/combine_solutions.h b/ortools/sat/combine_solutions.h index 7106e9939e..259e1cbca9 100644 --- a/ortools/sat/combine_solutions.h +++ b/ortools/sat/combine_solutions.h @@ -49,7 +49,8 @@ struct PushedSolutionPointers { PushedSolutionPointers PushAndMaybeCombineSolution( SharedResponseManager* response_manager, const CpModelProto& model_proto, absl::Span new_solution, const std::string& solution_info, - absl::Span base_solution = {}, Model* model = nullptr); + std::shared_ptr::Solution> + base_solution); } // namespace sat } // namespace operations_research diff --git a/ortools/sat/constraint_violation.cc b/ortools/sat/constraint_violation.cc index 0cd5f80d10..f59bff555f 100644 --- a/ortools/sat/constraint_violation.cc +++ b/ortools/sat/constraint_violation.cc @@ -1500,7 +1500,7 @@ LsEvaluator::LsEvaluator(const CpModelProto& cp_model, LsEvaluator::LsEvaluator( const CpModelProto& cp_model, const SatParameters& params, const std::vector& ignored_constraints, - const std::vector& additional_constraints, + absl::Span additional_constraints, TimeLimit* time_limit) : cp_model_(cp_model), params_(params), time_limit_(time_limit) { var_to_constraints_.resize(cp_model_.variables_size()); @@ -1830,7 +1830,7 @@ void LsEvaluator::CompileOneConstraint(const ConstraintProto& ct) { void LsEvaluator::CompileConstraintsAndObjective( const std::vector& ignored_constraints, - const std::vector& additional_constraints) { + absl::Span additional_constraints) { constraints_.clear(); // The first compiled constraint is always the objective if present. diff --git a/ortools/sat/constraint_violation.h b/ortools/sat/constraint_violation.h index 54880e9f5a..cc09718d24 100644 --- a/ortools/sat/constraint_violation.h +++ b/ortools/sat/constraint_violation.h @@ -313,7 +313,7 @@ class LsEvaluator { TimeLimit* time_limit); LsEvaluator(const CpModelProto& cp_model, const SatParameters& params, const std::vector& ignored_constraints, - const std::vector& additional_constraints, + absl::Span additional_constraints, TimeLimit* time_limit); // Intersects the domain of the objective with [lb..ub]. @@ -434,7 +434,7 @@ class LsEvaluator { private: void CompileConstraintsAndObjective( const std::vector& ignored_constraints, - const std::vector& additional_constraints); + absl::Span additional_constraints); void CompileOneConstraint(const ConstraintProto& ct_proto); void BuildVarConstraintGraph(); diff --git a/ortools/sat/cp_model_lns.cc b/ortools/sat/cp_model_lns.cc index 5281a662a4..36990c16a9 100644 --- a/ortools/sat/cp_model_lns.cc +++ b/ortools/sat/cp_model_lns.cc @@ -1237,7 +1237,7 @@ CpModelProto NeighborhoodGeneratorHelper::UpdatedModelProtoCopy() const { } bool NeighborhoodGenerator::ReadyToGenerate() const { - return (helper_.shared_response().SolutionsRepository().NumSolutions() > 0); + return helper_.shared_response().HasFeasibleSolution(); } double NeighborhoodGenerator::GetUCBScore(int64_t total_num_calls) const { diff --git a/ortools/sat/cp_model_lns_test.cc b/ortools/sat/cp_model_lns_test.cc index ac44009b93..68dc701a14 100644 --- a/ortools/sat/cp_model_lns_test.cc +++ b/ortools/sat/cp_model_lns_test.cc @@ -201,8 +201,10 @@ TYPED_TEST(GeneratorTest, ReadyToGenerate) { EXPECT_FALSE(generator.ReadyToGenerate()); shared_response_manager->NewSolution(solution.solution(), solution.solution_info(), &model); - shared_response_manager->MutableSolutionsRepository()->Synchronize(); - EXPECT_EQ(1, shared_response_manager->SolutionsRepository().NumSolutions()); + shared_response_manager->Synchronize(); + EXPECT_EQ( + 1, + shared_response_manager->SolutionPool().BestSolutions().NumSolutions()); EXPECT_TRUE(generator.ReadyToGenerate()); } @@ -301,7 +303,7 @@ TEST(RelaxationInducedNeighborhoodGeneratorTest, NoNeighborhoodGeneratedRINS) { solution.add_solution(0); shared_response_manager->NewSolution(solution.solution(), solution.solution_info(), &model); - shared_response_manager->MutableSolutionsRepository()->Synchronize(); + shared_response_manager->Synchronize(); lp_solutions.NewLPSolution({0.0}); lp_solutions.Synchronize(); diff --git a/ortools/sat/cp_model_loader.cc b/ortools/sat/cp_model_loader.cc index f053299657..3a691a1ddf 100644 --- a/ortools/sat/cp_model_loader.cc +++ b/ortools/sat/cp_model_loader.cc @@ -1261,7 +1261,7 @@ void LoadLinearConstraint(const ConstraintProto& ct, Model* m) { // Load precedences. if (!HasEnforcementLiteral(ct)) { - auto* precedences = m->GetOrCreate(); + auto* root_level_lin2_bounds = m->GetOrCreate(); // To avoid overflow in the code below, we tighten the bounds. // Note that we detect and do not add trivial relation. @@ -1272,7 +1272,7 @@ void LoadLinearConstraint(const ConstraintProto& ct, Model* m) { if (vars.size() == 2) { LinearExpression2 expr(vars[0], vars[1], coeffs[0], coeffs[1]); - precedences->AddBounds(expr, rhs_min, rhs_max); + root_level_lin2_bounds->Add(expr, rhs_min, rhs_max); } else if (vars.size() == 3) { // TODO(user): This is a weaker duplication of the logic of // BinaryRelationsMaps, but is is useful for the transitive closure in @@ -1293,7 +1293,8 @@ void LoadLinearConstraint(const ConstraintProto& ct, Model* m) { ? coeff * integer_trail->UpperBound(vars[other]).value() : coeff * integer_trail->LowerBound(vars[other]).value(); LinearExpression2 expr(vars[i], vars[j], coeffs[i], coeffs[j]); - precedences->AddBounds(expr, rhs_min - other_ub, rhs_max - other_lb); + root_level_lin2_bounds->Add(expr, rhs_min - other_ub, + rhs_max - other_lb); } } } diff --git a/ortools/sat/cp_model_mapping.h b/ortools/sat/cp_model_mapping.h index 1a82e4263e..5cf63e3e2f 100644 --- a/ortools/sat/cp_model_mapping.h +++ b/ortools/sat/cp_model_mapping.h @@ -24,7 +24,6 @@ #include "ortools/base/strong_vector.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" -#include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/linear_constraint.h" #include "ortools/sat/model.h" diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index d835856bf1..c17f6da241 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -40,7 +40,6 @@ #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/log/vlog_is_on.h" -#include "absl/meta/type_traits.h" #include "absl/numeric/int128.h" #include "absl/random/distributions.h" #include "absl/status/statusor.h" @@ -74,6 +73,7 @@ #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/model.h" +#include "ortools/sat/precedences.h" #include "ortools/sat/presolve_context.h" #include "ortools/sat/presolve_util.h" #include "ortools/sat/probing.h" @@ -7886,6 +7886,28 @@ void CpModelPresolver::Probe() { prober->ProbeBooleanVariables( context_->params().probing_deterministic_time_limit()); + for (const auto& [expr, ub] : model.GetOrCreate() + ->GetSortedNonTrivialUpperBounds()) { + if (expr.vars[0] == kNoIntegerVariable || + expr.vars[1] == kNoIntegerVariable) { + continue; + } + const IntegerVariable var0 = PositiveVariable(expr.vars[0]); + const IntegerVariable var1 = PositiveVariable(expr.vars[1]); + const int proto_var0 = mapping->GetProtoVariableFromIntegerVariable(var0); + const int proto_var1 = mapping->GetProtoVariableFromIntegerVariable(var1); + if (proto_var0 < 0 || proto_var1 < 0) continue; + const int64_t coeff0 = VariableIsPositive(expr.vars[0]) + ? expr.coeffs[0].value() + : -expr.coeffs[0].value(); + const int64_t coeff1 = VariableIsPositive(expr.vars[1]) + ? expr.coeffs[1].value() + : -expr.coeffs[1].value(); + known_linear2_.Add( + GetLinearExpression2FromProto(proto_var0, coeff0, proto_var1, coeff1), + kMinIntegerValue, ub); + } + probing_timer->AddCounter("probed", prober->num_decisions()); probing_timer->AddToWork( model.GetOrCreate()->GetElapsedDeterministicTime()); @@ -8798,6 +8820,7 @@ void CpModelPresolver::ExpandObjective() { } void CpModelPresolver::MergeNoOverlapConstraints() { + PresolveTimer timer("MergeNoOverlap", logger_, time_limit_); if (context_->ModelIsUnsat()) return; if (time_limit_->LimitReached()) return; diff --git a/ortools/sat/cp_model_search.cc b/ortools/sat/cp_model_search.cc index ff253b4041..f95d2c7c0e 100644 --- a/ortools/sat/cp_model_search.cc +++ b/ortools/sat/cp_model_search.cc @@ -726,7 +726,6 @@ absl::flat_hash_map GetNamedParameters( SatParameters new_params = base_params; new_params.set_use_shared_tree_search(true); new_params.set_search_branching(SatParameters::AUTOMATIC_SEARCH); - new_params.set_linearization_level(0); // These settings don't make sense with shared tree search, turn them off as // they can break things. @@ -758,7 +757,8 @@ absl::flat_hash_map GetNamedParameters( lns_params.set_log_search_progress(false); lns_params.set_debug_crash_on_bad_hint(false); // Can happen in lns. - lns_params.set_solution_pool_size(1); // Keep the best solution found. + lns_params.set_solution_pool_size(1); // Keep the best solution found. + lns_params.set_alternative_pool_size(0); // Disable. strategies["lns"] = lns_params; // Note that we only do this for the derived parameters. The strategy "lns" diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index 8aef798e1f..e4c60ab3de 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -793,40 +793,6 @@ void LogSubsolverNames(absl::Span> subsolvers, SOLVER_LOG(logger, ""); } -void LogFinalStatistics(SharedClasses* shared) { - if (!shared->logger->LoggingIsEnabled()) return; - - shared->logger->FlushPendingThrottledLogs(/*ignore_rates=*/true); - SOLVER_LOG(shared->logger, ""); - - shared->stat_tables->Display(shared->logger); - shared->response->DisplayImprovementStatistics(); - - std::vector> table; - table.push_back({"Solution repositories", "Added", "Queried", "Synchro"}); - table.push_back(shared->response->SolutionsRepository().TableLineStats()); - table.push_back(shared->ls_hints->TableLineStats()); - if (shared->lp_solutions != nullptr) { - table.push_back(shared->lp_solutions->TableLineStats()); - } - if (shared->incomplete_solutions != nullptr) { - table.push_back(shared->incomplete_solutions->TableLineStats()); - } - SOLVER_LOG(shared->logger, FormatTable(table)); - - if (shared->bounds) { - shared->bounds->LogStatistics(shared->logger); - } - - if (shared->clauses) { - shared->clauses->LogStatistics(shared->logger); - } - - // Extra logging if needed. Note that these are mainly activated on - // --vmodule *some_file*=1 and are here for development. - shared->stats->Log(shared->logger); -} - void LaunchSubsolvers(const SatParameters& params, SharedClasses* shared, std::vector>& subsolvers, absl::Span ignored) { @@ -868,7 +834,7 @@ void LaunchSubsolvers(const SatParameters& params, SharedClasses* shared, for (int i = 0; i < subsolvers.size(); ++i) { subsolvers[i].reset(); } - LogFinalStatistics(shared); + shared->LogFinalStatistics(); } bool VarIsFixed(const CpModelProto& model_proto, int i) { @@ -1124,13 +1090,18 @@ class FullProblemSolver : public SubSolver { shared_->model_proto, shared_->bounds.get(), &local_model_); } + if (shared_->linear2_bounds != nullptr) { + RegisterLinear2BoundsImport(shared_->linear2_bounds.get(), + &local_model_); + } + // Note that this is done after the loading, so we will never export // problem clauses. if (shared_->clauses != nullptr) { const int id = shared_->clauses->RegisterNewId( + local_model_.Name(), /*may_terminate_early=*/stop_at_first_solution_ && - local_model_.GetOrCreate()->has_objective()); - shared_->clauses->SetWorkerNameForId(id, local_model_.Name()); + local_model_.GetOrCreate()->has_objective()); RegisterClausesLevelZeroImport(id, shared_->clauses.get(), &local_model_); @@ -1348,36 +1319,29 @@ class LnsSolver : public SubSolver { data.task_id = task_id; data.difficulty = generator_->difficulty(); data.deterministic_limit = generator_->deterministic_limit(); + data.initial_best_objective = + shared_->response->GetBestSolutionObjective(); // Choose a base solution for this neighborhood. + const auto base_solution = + shared_->response->SolutionPool().GetSolutionToImprove(random); CpSolverResponse base_response; - { - const SharedSolutionRepository& repo = - shared_->response->SolutionsRepository(); - if (repo.NumSolutions() > 0) { - base_response.set_status(CpSolverStatus::FEASIBLE); - std::shared_ptr::Solution> - solution = repo.GetRandomBiasedSolution(random); - base_response.mutable_solution()->Assign( - solution->variable_values.begin(), - solution->variable_values.end()); + if (base_solution != nullptr) { + base_response.set_status(CpSolverStatus::FEASIBLE); + base_response.mutable_solution()->Assign( + base_solution->variable_values.begin(), + base_solution->variable_values.end()); - // Note: We assume that the solution rank is the solution internal - // objective. - data.initial_best_objective = repo.GetSolution(0)->rank; - data.base_objective = solution->rank; - } else { - base_response.set_status(CpSolverStatus::UNKNOWN); + // Note: We assume that the solution rank is the solution internal + // objective. + data.base_objective = base_solution->rank; + } else { + base_response.set_status(CpSolverStatus::UNKNOWN); - // If we do not have a solution, we use the current objective upper - // bound so that our code that compute an "objective" improvement - // works. - // - // TODO(user): this is non-deterministic. Fix. - data.initial_best_objective = - shared_->response->GetInnerObjectiveUpperBound(); - data.base_objective = data.initial_best_objective; - } + // If we do not have a solution, we use the current objective upper + // bound so that our code that compute an "objective" improvement + // works. + data.base_objective = data.initial_best_objective; } Neighborhood neighborhood = @@ -1667,9 +1631,9 @@ class LnsSolver : public SubSolver { if (absl::MakeSpan(solution_values) != absl::MakeSpan(base_response.solution())) { new_solution = true; - PushAndMaybeCombineSolution( - shared_->response, shared_->model_proto, solution_values, - solution_info, base_response.solution(), /*model=*/nullptr); + PushAndMaybeCombineSolution(shared_->response, shared_->model_proto, + solution_values, solution_info, + base_solution); } } if (!neighborhood.is_reduced && @@ -1782,7 +1746,6 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) { subsolvers.push_back(std::make_unique( "synchronization_agent", [shared]() { shared->response->Synchronize(); - shared->response->MutableSolutionsRepository()->Synchronize(); shared->ls_hints->Synchronize(); if (shared->bounds != nullptr) { shared->bounds->Synchronize(); diff --git a/ortools/sat/cp_model_solver_helpers.cc b/ortools/sat/cp_model_solver_helpers.cc index 030cf124b5..e64e37ae16 100644 --- a/ortools/sat/cp_model_solver_helpers.cc +++ b/ortools/sat/cp_model_solver_helpers.cc @@ -847,6 +847,63 @@ void RegisterVariableBoundsLevelZeroImport( import_level_zero_bounds); } +void RegisterLinear2BoundsImport(SharedLinear2Bounds* shared_linear2_bounds, + Model* model) { + CHECK(shared_linear2_bounds != nullptr); + auto* cp_model_mapping = model->GetOrCreate(); + auto* root_linear2 = model->GetOrCreate(); + auto* sat_solver = model->GetOrCreate(); + const int import_id = + shared_linear2_bounds->RegisterNewImportId(model->Name()); + const auto& import_function = [import_id, shared_linear2_bounds, root_linear2, + cp_model_mapping, sat_solver, model]() { + const auto new_bounds = + shared_linear2_bounds->NewlyUpdatedBounds(import_id); + int num_imported = 0; + for (const auto& [proto_expr, bounds] : new_bounds) { + // Lets create the corresponding LinearExpression2. + LinearExpression2 expr; + if (!cp_model_mapping->IsInteger(proto_expr.vars[0]) || + !cp_model_mapping->IsInteger(proto_expr.vars[1])) { + continue; + } + for (const int i : {0, 1}) { + expr.vars[i] = cp_model_mapping->Integer(proto_expr.vars[i]); + expr.coeffs[i] = proto_expr.coeffs[i]; + } + const auto [lb, ub] = bounds; + const auto [lb_added, ub_added] = root_linear2->Add(expr, lb, ub); + if (!lb_added && !ub_added) continue; + ++num_imported; + + // TODO(user): Is it a good idea to add the linear constraint ? + // We might have many redundant linear2 relations that don't need + // propagation when we have chains of precedences. The root_linear2 should + // be up-to-date with transitive closure to avoid adding such relations + // (recompute it at level zero before this?). + // + // TODO(user): use IntegerValure directly in + // AddWeightedSumGreaterOrEqual() or use a lower-level API. + const std::vector coeffs = {expr.coeffs[0].value(), + expr.coeffs[1].value()}; + if (lb_added) { + AddWeightedSumGreaterOrEqual({}, absl::MakeSpan(expr.vars, 2), coeffs, + lb.value(), model); + if (sat_solver->ModelIsUnsat()) return false; + } + if (ub_added) { + AddWeightedSumLowerOrEqual({}, absl::MakeSpan(expr.vars, 2), coeffs, + ub.value(), model); + if (sat_solver->ModelIsUnsat()) return false; + } + } + shared_linear2_bounds->NotifyNumImported(import_id, num_imported); + return true; + }; + model->GetOrCreate()->callbacks.push_back( + import_function); +} + // Registers a callback that will report improving objective best bound. // It will be called each time new objective bound are propagated at level zero. void RegisterObjectiveBestBoundExport( @@ -1073,7 +1130,7 @@ void FillBinaryRelationRepository(const CpModelProto& model_proto, auto* encoder = model->GetOrCreate(); auto* mapping = model->GetOrCreate(); auto* repository = model->GetOrCreate(); - auto* relations_maps = model->GetOrCreate(); + auto* root_level_lin2_bounds = model->GetOrCreate(); for (const ConstraintProto& ct : model_proto.constraints()) { // Load conditional precedences and always true binary relations. @@ -1095,13 +1152,15 @@ void FillBinaryRelationRepository(const CpModelProto& model_proto, // var1_min <= var1 - delta.var2 <= var1_max, which is equivalent to // the default bounds if var2 = 0, and gives implied_lb <= var1 <= // var1_max + delta otherwise. - repository->Add(enforcement_literal, {var1, 1}, {var2, -delta}, + repository->Add(enforcement_literal, + LinearExpression2(var1, var2, 1, -delta), var1_domain.Min(), var1_domain.Max()); } else if (negated_var2 != kNoIntegerVariable) { // var1_min + delta <= var1 + delta.neg_var2 <= var1_max + delta, // which is equivalent to the default bounds if neg_var2 = 1, and // gives implied_lb <= var1 <= var1_max + delta otherwise. - repository->Add(enforcement_literal, {var1, 1}, {negated_var2, delta}, + repository->Add(enforcement_literal, + LinearExpression2(var1, negated_var2, 1, delta), var1_domain.Min() + delta, var1_domain.Max() + delta); } }; @@ -1137,23 +1196,19 @@ void FillBinaryRelationRepository(const CpModelProto& model_proto, if (ct.enforcement_literal().empty()) { if (vars.size() == 2) { - repository->Add(Literal(kNoLiteralIndex), {vars[0], coeffs[0]}, - {vars[1], coeffs[1]}, rhs_min, rhs_max); - - LinearExpression2 expr; - expr.vars[0] = vars[0]; - expr.vars[1] = vars[1]; - expr.coeffs[0] = coeffs[0]; - expr.coeffs[1] = coeffs[1]; - relations_maps->AddRelationBounds(expr, rhs_min, rhs_max); + const LinearExpression2 expr(vars[0], vars[1], coeffs[0], coeffs[1]); + root_level_lin2_bounds->Add(expr, rhs_min, rhs_max); } } else { const Literal lit = mapping->Literal(ct.enforcement_literal(0)); if (vars.size() == 1) { - repository->Add(lit, {vars[0], coeffs[0]}, {}, rhs_min, rhs_max); + repository->Add( + lit, LinearExpression2(vars[0], kNoIntegerVariable, coeffs[0], 0), + rhs_min, rhs_max); } else if (vars.size() == 2) { - repository->Add(lit, {vars[0], coeffs[0]}, {vars[1], coeffs[1]}, - rhs_min, rhs_max); + repository->Add( + lit, LinearExpression2(vars[0], vars[1], coeffs[0], coeffs[1]), + rhs_min, rhs_max); } } } @@ -1216,10 +1271,6 @@ void LoadBaseModel(const CpModelProto& model_proto, Model* model) { AddFullEncodingFromSearchBranching(model_proto, model); if (sat_solver->ModelIsUnsat()) return unsat(); - // Reserve space for the precedence relations. - model->GetOrCreate()->Resize( - model->GetOrCreate()->NumIntegerVariables().value()); - FillBinaryRelationRepository(model_proto, model); if (time_limit->LimitReached()) return; @@ -1292,7 +1343,7 @@ void LoadBaseModel(const CpModelProto& model_proto, Model* model) { model->GetOrCreate()->ProcessImplicationGraph( model->GetOrCreate()); - model->GetOrCreate()->Build(); + model->GetOrCreate()->Build(); } void LoadFeasibilityPump(const CpModelProto& model_proto, Model* model) { @@ -1794,7 +1845,7 @@ void QuickSolveWithHint(const CpModelProto& model_proto, Model* model) { // Tricky: We can only test that if we don't already have a feasible solution // like we do if the hint is complete. if (parameters->debug_crash_on_bad_hint() && - shared_response_manager->SolutionsRepository().NumSolutions() == 0 && + shared_response_manager->HasFeasibleSolution() && !model->GetOrCreate()->LimitReached() && status != SatSolver::Status::FEASIBLE) { LOG(FATAL) << "QuickSolveWithHint() didn't find a feasible solution." @@ -2092,6 +2143,10 @@ SharedClasses::SharedClasses(const CpModelProto* proto, Model* global_model) bounds->LoadDebugSolution(response->DebugSolution()); } + if (params.share_linear2_bounds()) { + linear2_bounds = std::make_unique(); + } + // Create extra shared classes if needed. Note that while these parameters // are true by default, we disable them if we don't have enough workers for // them in AdaptGlobalParameters(). @@ -2126,7 +2181,7 @@ void SharedClasses::RegisterSharedClassesInLocalModel(Model* local_model) { local_model->Register(stat_tables); // TODO(user): Use parameters and not the presence/absence of these class - // to decide when to use them. + // to decide when to use them? this is not clear. if (lp_solutions != nullptr) { local_model->Register(lp_solutions.get()); } @@ -2140,6 +2195,9 @@ void SharedClasses::RegisterSharedClassesInLocalModel(Model* local_model) { if (clauses != nullptr) { local_model->Register(clauses.get()); } + if (linear2_bounds != nullptr) { + local_model->Register(linear2_bounds.get()); + } } bool SharedClasses::SearchIsDone() { @@ -2152,5 +2210,37 @@ bool SharedClasses::SearchIsDone() { return false; } +void SharedClasses::LogFinalStatistics() { + if (!logger->LoggingIsEnabled()) return; + + logger->FlushPendingThrottledLogs(/*ignore_rates=*/true); + SOLVER_LOG(logger, ""); + + stat_tables->Display(logger); + response->DisplayImprovementStatistics(); + + std::vector> table; + table.push_back({"Solution repositories", "Added", "Queried", "Synchro"}); + response->SolutionPool().AddTableStats(&table); + table.push_back(ls_hints->TableLineStats()); + if (lp_solutions != nullptr) { + table.push_back(lp_solutions->TableLineStats()); + } + if (incomplete_solutions != nullptr) { + table.push_back(incomplete_solutions->TableLineStats()); + } + SOLVER_LOG(logger, FormatTable(table)); + + // TODO(user): we can combine the "bounds table" into one for shorter logs. + if (bounds != nullptr) bounds->LogStatistics(logger); + if (linear2_bounds != nullptr) linear2_bounds->LogStatistics(logger); + + if (clauses != nullptr) clauses->LogStatistics(logger); + + // Extra logging if needed. Note that these are mainly activated on + // --vmodule *some_file*=1 and are here for development. + stats->Log(logger); +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/cp_model_solver_helpers.h b/ortools/sat/cp_model_solver_helpers.h index 1f46f77495..af00cb3213 100644 --- a/ortools/sat/cp_model_solver_helpers.h +++ b/ortools/sat/cp_model_solver_helpers.h @@ -60,12 +60,15 @@ struct SharedClasses { std::unique_ptr lp_solutions; std::unique_ptr incomplete_solutions; std::unique_ptr clauses; + std::unique_ptr linear2_bounds; // call local_model->Register() on most of the class here, this allow to // more easily depends on one of the shared class deep within the solver. void RegisterSharedClassesInLocalModel(Model* local_model); bool SearchIsDone(); + + void LogFinalStatistics(); }; // Loads a CpModelProto inside the given model. @@ -119,6 +122,11 @@ int RegisterClausesLevelZeroImport(int id, SharedClausesManager* shared_clauses_manager, Model* model); +// This will register a level zero callback to imports new linear2 from the +// SharedLinear2Bounds. +void RegisterLinear2BoundsImport(SharedLinear2Bounds* shared_linear2_bounds, + Model* model); + void PostsolveResponseWrapper(const SatParameters& params, int num_variable_in_original_model, const CpModelProto& mapping_proto, diff --git a/ortools/sat/cp_model_solver_test.cc b/ortools/sat/cp_model_solver_test.cc index 63ab2fae0f..e3d719b400 100644 --- a/ortools/sat/cp_model_solver_test.cc +++ b/ortools/sat/cp_model_solver_test.cc @@ -109,7 +109,7 @@ TEST(StopAfterFirstSolutionTest, BooleanLinearOptimizationProblem) { Model model; SatParameters params; - params.set_num_search_workers(8); + params.set_num_workers(8); params.set_stop_after_first_solution(true); int num_solutions = 0; @@ -1070,10 +1070,12 @@ TEST(SolveCpModelTest, SolutionHintMinimizeL1DistanceTest) { // TODO(user): Instead, we might change the presolve to always try to keep the // given hint feasible. Model model; - model.Add( - NewSatParameters("repair_hint:true, stop_after_first_solution:true, " - "keep_all_feasible_solutions_in_presolve:true " - "num_workers:1")); + SatParameters params; + params.set_repair_hint(true); + params.set_stop_after_first_solution(true); + params.set_keep_all_feasible_solutions_in_presolve(true); + params.set_num_workers(1); + model.Add(NewSatParameters(params)); const CpSolverResponse response = SolveCpModel(model_proto, &model); EXPECT_THAT(response.status(), AnyOf(Eq(CpSolverStatus::OPTIMAL), Eq(CpSolverStatus::FEASIBLE))); diff --git a/ortools/sat/cp_model_symmetries.cc b/ortools/sat/cp_model_symmetries.cc index 98a1ca6ab0..a93b81c067 100644 --- a/ortools/sat/cp_model_symmetries.cc +++ b/ortools/sat/cp_model_symmetries.cc @@ -1453,6 +1453,8 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { if (row_has_at_most_one_true[row]) { context->UpdateRuleStats( "symmetry: fixed all but one to false in orbitope row"); + context->solution_crush().MaybeSwapOrbitopeColumns( + orbitope, row, num_processed_rows - 1, true); for (int j = num_processed_rows; j < num_cols; ++j) { if (!context->SetLiteralToFalse(orbitope[row][j])) return false; } @@ -1460,6 +1462,8 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { CHECK(row_has_at_most_one_false[row]); context->UpdateRuleStats( "symmetry: fixed all but one to true in orbitope row"); + context->solution_crush().MaybeSwapOrbitopeColumns( + orbitope, row, num_processed_rows - 1, false); for (int j = num_processed_rows; j < num_cols; ++j) { if (!context->SetLiteralToTrue(orbitope[row][j])) return false; } diff --git a/ortools/sat/cumulative.cc b/ortools/sat/cumulative.cc index 3d910e779d..8cde3428c0 100644 --- a/ortools/sat/cumulative.cc +++ b/ortools/sat/cumulative.cc @@ -212,8 +212,8 @@ std::function Cumulative( // having two independent constraint doing the same propagation. std::vector full_precedences; if (parameters.exploit_all_precedences()) { - model->GetOrCreate()->ComputeFullPrecedences( - index_to_end_vars, &full_precedences); + model->GetOrCreate() + ->ComputeFullPrecedences(index_to_end_vars, &full_precedences); } for (const FullIntegerPrecedence& data : full_precedences) { const int size = data.indices.size(); diff --git a/ortools/sat/cumulative_energy_test.cc b/ortools/sat/cumulative_energy_test.cc index a8b56e9905..27f4cb929c 100644 --- a/ortools/sat/cumulative_energy_test.cc +++ b/ortools/sat/cumulative_energy_test.cc @@ -176,8 +176,8 @@ bool SolveUsingNaiveModel(const EnergyInstance& instance) { std::vector intervals; std::vector consumptions; IntegerVariable one = model.Add(ConstantIntegerVariable(1)); - IntervalsRepository* intervals_repository = - model.GetOrCreate(); + auto* intervals_repository = model.GetOrCreate(); + auto* precedences = model.GetOrCreate(); for (const auto& task : instance.tasks) { if (task.is_optional) { @@ -207,7 +207,7 @@ bool SolveUsingNaiveModel(const EnergyInstance& instance) { CHECK_NE(start_expr.var, kNoIntegerVariable); const IntegerVariable start = start_expr.var; if (previous_start != kNoIntegerVariable) { - model.Add(LowerOrEqual(previous_start, start)); + precedences->AddPrecedence(previous_start, start); } else { first_start = start; } @@ -215,8 +215,8 @@ bool SolveUsingNaiveModel(const EnergyInstance& instance) { } // start[last] <= start[0] + duration_max - 1 if (previous_start != kNoIntegerVariable) { - model.Add(LowerOrEqualWithOffset(previous_start, first_start, - -task.duration_max + 1)); + precedences->AddPrecedenceWithOffset(previous_start, first_start, + -task.duration_max + 1); } } } diff --git a/ortools/sat/diffn.cc b/ortools/sat/diffn.cc index b848740606..fde8ee6a46 100644 --- a/ortools/sat/diffn.cc +++ b/ortools/sat/diffn.cc @@ -307,7 +307,7 @@ void AddNonOverlappingRectangles(const std::vector& x, return; } - // At least one of the 4 options is true. + // At least one of the 4 options is true if all boxes are present. std::vector clause = {x_ij, x_ji, y_ij, y_ji}; if (repository->IsOptional(x[i])) { clause.push_back(repository->PresenceLiteral(x[i]).Negated()); diff --git a/ortools/sat/disjunctive.cc b/ortools/sat/disjunctive.cc index 0faa84e760..c475f64cc4 100644 --- a/ortools/sat/disjunctive.cc +++ b/ortools/sat/disjunctive.cc @@ -25,6 +25,7 @@ #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/intervals.h" +#include "ortools/sat/linear_propagation.h" #include "ortools/sat/model.h" #include "ortools/sat/precedences.h" #include "ortools/sat/sat_base.h" @@ -143,6 +144,9 @@ void AddDisjunctive(const std::vector& intervals, // using the fact that they are in disjunction. if (params.use_precedences_in_disjunctive_constraint() && !params.use_combined_no_overlap()) { + // Lets try to exploit linear3 too. + model->GetOrCreate()->SetPushAffineUbForBinaryRelation(); + for (const bool time_direction : {true, false}) { DisjunctivePrecedences* precedences = new DisjunctivePrecedences(time_direction, helper, model); @@ -276,8 +280,8 @@ bool DisjunctiveWithTwoItems::Propagate() { helper_->ClearReason(); helper_->AddPresenceReason(task_before); helper_->AddPresenceReason(task_after); - helper_->AddReasonForBeingBefore(task_before, task_after); - helper_->AddReasonForBeingBefore(task_after, task_before); + helper_->AddReasonForBeingBeforeAssumingNoOverlap(task_before, task_after); + helper_->AddReasonForBeingBeforeAssumingNoOverlap(task_after, task_before); return helper_->ReportConflict(); } @@ -295,7 +299,8 @@ bool DisjunctiveWithTwoItems::Propagate() { if (helper_->StartMin(task_after) < end_min_before) { // Reason for precedences if both present. helper_->ClearReason(); - helper_->AddReasonForBeingBefore(task_before, task_after); + helper_->AddReasonForBeingBeforeAssumingNoOverlap(task_before, + task_after); // Reason for the bound push. helper_->AddPresenceReason(task_before); @@ -311,7 +316,8 @@ bool DisjunctiveWithTwoItems::Propagate() { if (helper_->EndMax(task_before) > start_max_after) { // Reason for precedences if both present. helper_->ClearReason(); - helper_->AddReasonForBeingBefore(task_before, task_after); + helper_->AddReasonForBeingBeforeAssumingNoOverlap(task_before, + task_after); // Reason for the bound push. helper_->AddPresenceReason(task_after); @@ -527,7 +533,7 @@ bool DisjunctiveOverloadChecker::Propagate() { const int to_push = task_with_max_end_min.task_index; helper_->ClearReason(); helper_->AddPresenceReason(task); - helper_->AddReasonForBeingBefore(task, to_push); + helper_->AddReasonForBeingBeforeAssumingNoOverlap(task, to_push); helper_->AddEndMinReason(task, end_min); if (!helper_->IncreaseStartMin(to_push, end_min)) { @@ -750,14 +756,19 @@ bool DisjunctiveSimplePrecedences::Propagate() { bool DisjunctiveSimplePrecedences::Push(TaskTime before, int t) { const int t_before = before.task_index; + DCHECK_NE(t_before, t); helper_->ClearReason(); helper_->AddPresenceReason(t_before); - helper_->AddReasonForBeingBefore(t_before, t); + helper_->AddReasonForBeingBeforeAssumingNoOverlap(t_before, t); helper_->AddEndMinReason(t_before, before.time); if (!helper_->IncreaseStartMin(t, before.time)) { return false; } + if (helper_->CurrentDecisionLevel() == 0 && helper_->IsPresent(t_before) && + helper_->IsPresent(t)) { + if (!helper_->NotifyLevelZeroPrecedence(t_before, t)) return false; + } ++stats_.num_propagations; return true; } @@ -818,8 +829,8 @@ bool DisjunctiveSimplePrecedences::PropagateOneDirection() { helper_->ClearReason(); helper_->AddPresenceReason(blocking_task); helper_->AddPresenceReason(t); - helper_->AddReasonForBeingBefore(blocking_task, t); - helper_->AddReasonForBeingBefore(t, blocking_task); + helper_->AddReasonForBeingBeforeAssumingNoOverlap(blocking_task, t); + helper_->AddReasonForBeingBeforeAssumingNoOverlap(t, blocking_task); return helper_->ReportConflict(); } else if (end_min > best_task_before.time) { best_task_before = {t, end_min}; @@ -927,9 +938,13 @@ bool DisjunctiveDetectablePrecedences::Push(IntegerValue task_set_end_min, // Heuristic, if some tasks are known to be after the first one, // we just add the min-size as a reason. + // + // TODO(user): ideally we don't want to do that if we don't have a level + // zero precedence... if (i > critical_index && helper_->GetCurrentMinDistanceBetweenTasks( - sorted_tasks[critical_index].task, ct, - /*add_reason_if_after=*/true) >= 0) { + sorted_tasks[critical_index].task, ct) >= 0) { + helper_->AddReasonForBeingBeforeAssumingNoOverlap( + sorted_tasks[critical_index].task, ct); helper_->AddSizeMinReason(ct); } else { helper_->AddEnergyAfterReason(ct, sorted_tasks[i].size_min, window_start); @@ -937,9 +952,9 @@ bool DisjunctiveDetectablePrecedences::Push(IntegerValue task_set_end_min, // We only need the reason for being before if we don't already have // a static precedence between the tasks. - const IntegerValue dist = helper_->GetCurrentMinDistanceBetweenTasks( - ct, t, /*add_reason_if_after=*/true); + const IntegerValue dist = helper_->GetCurrentMinDistanceBetweenTasks(ct, t); if (dist >= 0) { + helper_->AddReasonForBeingBeforeAssumingNoOverlap(ct, t); energy_of_task_before += sorted_tasks[i].size_min; min_slack = std::min(min_slack, dist); } else { @@ -969,7 +984,7 @@ bool DisjunctiveDetectablePrecedences::Push(IntegerValue task_set_end_min, // Process detected precedence. if (helper_->CurrentDecisionLevel() == 0 && helper_->IsPresent(t)) { for (int i = critical_index; i < sorted_tasks.size(); ++i) { - if (!helper_->PropagatePrecedence(sorted_tasks[i].task, t)) { + if (!helper_->NotifyLevelZeroPrecedence(sorted_tasks[i].task, t)) { return false; } } @@ -1047,8 +1062,8 @@ bool DisjunctiveDetectablePrecedences::PropagateWithRanks() { helper_->ClearReason(); helper_->AddPresenceReason(blocking_task); helper_->AddPresenceReason(t); - helper_->AddReasonForBeingBefore(blocking_task, t); - helper_->AddReasonForBeingBefore(t, blocking_task); + helper_->AddReasonForBeingBeforeAssumingNoOverlap(blocking_task, t); + helper_->AddReasonForBeingBeforeAssumingNoOverlap(t, blocking_task); return helper_->ReportConflict(); } else { if (!some_propagation && rank > highest_rank) { @@ -1207,18 +1222,16 @@ bool DisjunctivePrecedences::PropagateSubwindow() { // Note that like in Propagate() we split this set of task into critical // subpart as there is no point considering them together. // - // TODO(user): we should probably change the api to return a Span. - // // TODO(user): If more than one set of task push the same variable, we - // probabaly only want to keep the best push? Maybe we want to process them + // probably only want to keep the best push? Maybe we want to process them // in reverse order of what we do here? indices_before_.clear(); IntegerValue local_start; IntegerValue local_end; for (; global_i < size; ++global_i) { - const PrecedenceRelations::PrecedenceData& data = before_[global_i]; + const EnforcedLinear2Bounds::PrecedenceData& data = before_[global_i]; if (data.var != var) break; - const int index = data.index; + const int index = data.var_index; const auto [t, start_of_t] = window_[index]; if (global_i == global_start_i) { // First loop. local_start = start_of_t; @@ -1227,7 +1240,7 @@ bool DisjunctivePrecedences::PropagateSubwindow() { if (start_of_t >= local_end) break; local_end += helper_->SizeMin(t); } - indices_before_.push_back(index); + indices_before_.push_back({index, data.lin2_index}); } // No need to consider if we don't have at least two tasks before var. @@ -1249,18 +1262,18 @@ bool DisjunctivePrecedences::PropagateSubwindow() { int best_index = -1; const IntegerValue current_var_lb = integer_trail_->LowerBound(var); IntegerValue best_new_lb = current_var_lb; + IntegerValue min_offset_at_best = kMinIntegerValue; IntegerValue min_offset = kMaxIntegerValue; IntegerValue sum_of_duration = 0; for (int i = num_before; --i >= 0;) { - const TaskTime task_time = window_[indices_before_[i]]; + const auto [task_index, lin2_index] = indices_before_[i]; + const TaskTime task_time = window_[task_index]; const AffineExpression& end_exp = helper_->Ends()[task_time.task_index]; - // TODO(user): The hash lookup here is a bit slow, so we avoid fetching - // the offset as much as possible. Note that the alternative of storing it - // in PrecedenceData is not necessarily better and harder to update as we - // dive/backtrack. - const IntegerValue inner_offset = -precedence_relations_->UpperBound( - LinearExpression2::Difference(end_exp.var, var)); + // TODO(user): The lookup here is a bit slow, so we avoid fetching + // the offset as much as possible. + const IntegerValue inner_offset = + -linear2_bounds_->NonTrivialUpperBound(lin2_index); DCHECK_NE(inner_offset, kMinIntegerValue); // We have var >= end_exp.var + inner_offset, so @@ -1275,7 +1288,7 @@ bool DisjunctivePrecedences::PropagateSubwindow() { // This is true if we skipped all task so far in this block. if (min_offset == kMaxIntegerValue) { // If only one task is left, we can abort. - // This avoid a GetConditionalOffset() lookup. + // This avoid a linear2_bounds_ lookup. if (i == 1) break; // Lower the end_min_when_all_present for better filtering later. @@ -1292,6 +1305,7 @@ bool DisjunctivePrecedences::PropagateSubwindow() { const IntegerValue start = task_time.time; const IntegerValue new_lb = start + sum_of_duration + min_offset; if (new_lb > best_new_lb) { + min_offset_at_best = min_offset; best_new_lb = new_lb; best_index = i; } @@ -1302,21 +1316,20 @@ bool DisjunctivePrecedences::PropagateSubwindow() { DCHECK_NE(best_index, -1); helper_->ClearReason(); const IntegerValue window_start = - window_[indices_before_[best_index]].time; + window_[indices_before_[best_index].first].time; for (int i = best_index; i < num_before; ++i) { if (skip_[i]) continue; - const int ct = window_[indices_before_[i]].task_index; + const int ct = window_[indices_before_[i].first].task_index; helper_->AddPresenceReason(ct); helper_->AddEnergyAfterReason(ct, helper_->SizeMin(ct), window_start); - // Fetch the explanation. + // Fetch the explanation of (var - end) >= min_offset // This is okay if a bit slow since we only do that when we push. - const AffineExpression& end_exp = helper_->Ends()[ct]; - const LinearExpression2 expr = - LinearExpression2::Difference(end_exp.var, var); - precedence_relations_->AddReasonForUpperBoundLowerThan( - expr, precedence_relations_->UpperBound(expr), - helper_->MutableLiteralReason(), helper_->MutableIntegerReason()); + const auto [expr, ub] = EncodeDifferenceLowerThan( + helper_->Ends()[ct], var, -min_offset_at_best); + linear2_bounds_->AddReasonForUpperBoundLowerThan( + expr, ub, helper_->MutableLiteralReason(), + helper_->MutableIntegerReason()); } ++stats_.num_propagations; if (!helper_->PushIntegerLiteral( @@ -1516,8 +1529,9 @@ bool DisjunctiveNotLast::PropagateSubwindow() { helper_->AddPresenceReason(ct); helper_->AddEnergyAfterReason(ct, sorted_tasks[i].size_min, window_start); - if (helper_->GetCurrentMinDistanceBetweenTasks( - ct, t, /*add_reason_if_after=*/true) < 0) { + if (helper_->GetCurrentMinDistanceBetweenTasks(ct, t) >= 0) { + helper_->AddReasonForBeingBeforeAssumingNoOverlap(ct, t); + } else { helper_->AddStartMaxReason(ct, largest_ct_start_max); } } @@ -1763,9 +1777,11 @@ bool DisjunctiveEdgeFinding::PropagateSubwindow(IntegerValue window_end_min) { task, event_size_[event], event >= second_event ? second_start : first_start); - const IntegerValue dist = helper_->GetCurrentMinDistanceBetweenTasks( - task, gray_task, /*add_reason_if_after=*/true); - if (dist < 0) { + const IntegerValue dist = + helper_->GetCurrentMinDistanceBetweenTasks(task, gray_task); + if (dist >= 0) { + helper_->AddReasonForBeingBeforeAssumingNoOverlap(task, gray_task); + } else { all_before = false; helper_->AddEndMaxReason(task, window_end); } @@ -1788,7 +1804,7 @@ bool DisjunctiveEdgeFinding::PropagateSubwindow(IntegerValue window_end_min) { for (int i = first_event; i < window_size; ++i) { const int task = window_[i].task_index; if (!is_gray_[task]) { - if (!helper_->PropagatePrecedence(task, gray_task)) { + if (!helper_->NotifyLevelZeroPrecedence(task, gray_task)) { return false; } } diff --git a/ortools/sat/disjunctive.h b/ortools/sat/disjunctive.h index bb77c41b12..e50d022975 100644 --- a/ortools/sat/disjunctive.h +++ b/ortools/sat/disjunctive.h @@ -353,7 +353,8 @@ class DisjunctivePrecedences : public PropagatorInterface { : time_direction_(time_direction), helper_(helper), integer_trail_(model->GetOrCreate()), - precedence_relations_(model->GetOrCreate()), + precedence_relations_(model->GetOrCreate()), + linear2_bounds_(model->GetOrCreate()), stats_("DisjunctivePrecedences", model) { window_.ClearAndReserve(helper->NumTasks()); index_to_end_vars_.ClearAndReserve(helper->NumTasks()); @@ -369,20 +370,21 @@ class DisjunctivePrecedences : public PropagatorInterface { const bool time_direction_; SchedulingConstraintHelper* helper_; IntegerTrail* integer_trail_; - PrecedenceRelations* precedence_relations_; + EnforcedLinear2Bounds* precedence_relations_; + Linear2Bounds* linear2_bounds_; FixedCapacityVector window_; FixedCapacityVector index_to_end_vars_; - FixedCapacityVector indices_before_; + FixedCapacityVector> indices_before_; std::vector skip_; - std::vector before_; + std::vector before_; PropagationStatistics stats_; }; // This is an optimization for the case when we have a big number of such -// pairwise constraints. This should be roughtly equivalent to what the general +// pairwise constraints. This should be roughly equivalent to what the general // disjunctive case is doing, but it dealt with variable size better and has a // lot less overhead. class DisjunctiveWithTwoItems : public PropagatorInterface { diff --git a/ortools/sat/disjunctive_test.cc b/ortools/sat/disjunctive_test.cc index 85c1db3204..5fa0063cf8 100644 --- a/ortools/sat/disjunctive_test.cc +++ b/ortools/sat/disjunctive_test.cc @@ -31,6 +31,7 @@ #include "ortools/base/logging.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" +#include "ortools/sat/integer_expr.h" #include "ortools/sat/integer_search.h" #include "ortools/sat/intervals.h" #include "ortools/sat/model.h" @@ -238,8 +239,8 @@ TEST(DisjunctiveConstraintTest, Precedences) { Trail* trail = model.GetOrCreate(); IntegerTrail* integer_trail = model.GetOrCreate(); auto* precedences = model.GetOrCreate(); - auto* relations = model.GetOrCreate(); auto* intervals = model.GetOrCreate(); + auto* lin2_bounds = model.GetOrCreate(); const auto add_affine_coeff_one_precedence = [&](const AffineExpression e1, const AffineExpression& e2) { @@ -249,8 +250,8 @@ TEST(DisjunctiveConstraintTest, Precedences) { CHECK_EQ(e2.coeff, 1); precedences->AddPrecedenceWithOffset(e1.var, e2.var, e1.constant - e2.constant); - relations->AddUpperBound(LinearExpression2::Difference(e1.var, e2.var), - e2.constant - e1.constant); + lin2_bounds->AddUpperBound(LinearExpression2::Difference(e1.var, e2.var), + e2.constant - e1.constant); }; const int kStart(0); @@ -483,6 +484,22 @@ TEST(DisjunctiveTest, TwoIntervalsTest) { EXPECT_EQ(12, CountAllSolutions(instance, AddDisjunctive)); } +namespace { + +void AddLowerOrEqualWithOffset(AffineExpression a, IntegerVariable b, + int64_t offset, Model* model) { + const int64_t rhs = -a.constant.value() - offset; + std::vector vars = {a.var, b}; + std::vector coeffs = {a.coeff.value(), -1}; + AddWeightedSumLowerOrEqual({}, vars, coeffs, rhs, model); + + // We also need to register them. + model->GetOrCreate()->AddUpperBound( + LinearExpression2::Difference(a.var, b), rhs); +} + +} // namespace + TEST(DisjunctiveTest, Precedences) { Model model; @@ -493,10 +510,9 @@ TEST(DisjunctiveTest, Precedences) { const IntegerVariable var = model.Add(NewIntegerVariable(0, 10)); IntervalsRepository* intervals = model.GetOrCreate(); - model.Add( - AffineCoeffOneLowerOrEqualWithOffset(intervals->End(ids[0]), var, 5)); - model.Add( - AffineCoeffOneLowerOrEqualWithOffset(intervals->End(ids[1]), var, 4)); + + AddLowerOrEqualWithOffset(intervals->End(ids[0]), var, 5, &model); + AddLowerOrEqualWithOffset(intervals->End(ids[1]), var, 4, &model); EXPECT_TRUE(model.GetOrCreate()->Propagate()); EXPECT_EQ(model.Get(LowerBound(var)), (3 + 2) + std::min(4, 5)); diff --git a/ortools/sat/feasibility_jump.cc b/ortools/sat/feasibility_jump.cc index 3d57b0ff5b..64ae336a88 100644 --- a/ortools/sat/feasibility_jump.cc +++ b/ortools/sat/feasibility_jump.cc @@ -364,8 +364,7 @@ std::function FeasibilityJumpSolver::GenerateTask(int64_t /*task_id*/) { // still finish each batch though). We will also reset the luby sequence. bool new_best_solution_was_found = false; if (type() == SubSolver::INCOMPLETE) { - const int64_t best = - shared_response_->SolutionsRepository().GetBestRank(); + const int64_t best = shared_response_->GetBestSolutionObjective().value(); if (best < state_->last_solution_rank) { states_->ResetLubyCounter(); new_best_solution_was_found = true; @@ -394,11 +393,9 @@ std::function FeasibilityJumpSolver::GenerateTask(int64_t /*task_id*/) { new_best_solution_was_found) { if (type() == SubSolver::INCOMPLETE) { // Choose a base solution for this neighborhood. - std::shared_ptr::Solution> - solution = shared_response_->SolutionsRepository() - .GetRandomBiasedSolution(random_); - state_->solution = solution->variable_values; - state_->base_solution = solution; + state_->base_solution = + shared_response_->SolutionPool().GetSolutionToImprove(random_); + state_->solution = state_->base_solution->variable_values; ++state_->num_solutions_imported; } else { if (!first_time) { @@ -427,6 +424,10 @@ std::function FeasibilityJumpSolver::GenerateTask(int64_t /*task_id*/) { } // Between chunk, we synchronize bounds. + // + // TODO(user): This do not play well with optimizing solution whose + // objective lag behind... Basically, we can run LS on old solution but will + // only consider it feasible if it improve the best known solution. bool recompute_compound_weights = false; if (linear_model_->model_proto().has_objective()) { const IntegerValue lb = shared_response_->GetInnerObjectiveLowerBound(); @@ -500,15 +501,15 @@ std::function FeasibilityJumpSolver::GenerateTask(int64_t /*task_id*/) { ++state_->counters.num_batches; if (DoSomeLinearIterations() && DoSomeGeneralIterations()) { // Checks for infeasibility induced by the non supported constraints. + // + // TODO(user): Checking the objective is faster and we could avoid to + // check feasibility if we are not going to keep the solution anyway. if (SolutionIsFeasible(linear_model_->model_proto(), state_->solution)) { auto pointers = PushAndMaybeCombineSolution( shared_response_, linear_model_->model_proto(), state_->solution, absl::StrCat(name(), "_", state_->options.name(), "(", OneLineStats(), ")"), - state_->base_solution == nullptr - ? absl::Span() - : state_->base_solution->variable_values, - /*model=*/nullptr); + state_->base_solution); // If we pushed a new solution, we use it as a new "base" so that we // will have a smaller delta on the next solution we find. state_->base_solution = pointers.pushed_solution; diff --git a/ortools/sat/feasibility_jump.h b/ortools/sat/feasibility_jump.h index d1975161c7..e40f5d42b9 100644 --- a/ortools/sat/feasibility_jump.h +++ b/ortools/sat/feasibility_jump.h @@ -510,7 +510,7 @@ class FeasibilityJumpSolver : public SubSolver { if (shared_response_->ProblemIsSolved()) return false; if (shared_time_limit_->LimitReached()) return false; - return (shared_response_->SolutionsRepository().NumSolutions() > 0) == + return shared_response_->HasFeasibleSolution() == (type() == SubSolver::INCOMPLETE); } diff --git a/ortools/sat/feasibility_pump.cc b/ortools/sat/feasibility_pump.cc index 8eb2f49b19..6f2d9a1dc1 100644 --- a/ortools/sat/feasibility_pump.cc +++ b/ortools/sat/feasibility_pump.cc @@ -681,7 +681,9 @@ bool FeasibilityPump::PropagationRounding() { } if (!sat_solver_->FinishPropagation()) return false; - sat_solver_->EnqueueDecisionAndBacktrackOnConflict(to_enqueue); + const SatSolver::Status decision_status = + sat_solver_->EnqueueDecisionAndBacktrackOnConflict(to_enqueue); + if (decision_status != SatSolver::Status::FEASIBLE) return false; if (sat_solver_->ModelIsUnsat()) return false; } integer_solution_is_set_ = true; diff --git a/ortools/sat/flaky_models_test.cc b/ortools/sat/flaky_models_test.cc index c233ab9070..e00ebee981 100644 --- a/ortools/sat/flaky_models_test.cc +++ b/ortools/sat/flaky_models_test.cc @@ -90,7 +90,7 @@ TEST(FlakyTest, Issue3108) { SatParameters parameters; parameters.set_log_search_progress(true); parameters.set_cp_model_probing_level(0); - parameters.set_num_search_workers(1); + parameters.set_num_workers(1); const CpSolverResponse response = SolveWithParameters(model_proto, parameters); EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL); diff --git a/ortools/sat/go/cpmodel/cp_model.go b/ortools/sat/go/cpmodel/cp_model.go index 99dbf687ca..a535284af8 100644 --- a/ortools/sat/go/cpmodel/cp_model.go +++ b/ortools/sat/go/cpmodel/cp_model.go @@ -29,6 +29,7 @@ import ( "sort" log "github.com/golang/glog" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) diff --git a/ortools/sat/go/cpmodel/cp_model_test.go b/ortools/sat/go/cpmodel/cp_model_test.go index 4d8ecf7e86..d2ed3d821a 100644 --- a/ortools/sat/go/cpmodel/cp_model_test.go +++ b/ortools/sat/go/cpmodel/cp_model_test.go @@ -22,8 +22,9 @@ import ( log "github.com/golang/glog" "github.com/google/go-cmp/cmp" - cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" "google.golang.org/protobuf/testing/protocmp" + + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) func Example() { diff --git a/ortools/sat/integer.cc b/ortools/sat/integer.cc index 3adcb4d8ba..5ecc0455a8 100644 --- a/ortools/sat/integer.cc +++ b/ortools/sat/integer.cc @@ -983,7 +983,8 @@ int IntegerTrail::FindTrailIndexOfVarBefore(IntegerVariable var, int IntegerTrail::FindLowestTrailIndexThatExplainBound( IntegerLiteral i_lit) const { DCHECK_LE(i_lit.bound, var_lbs_[i_lit.var]); - if (i_lit.bound <= LevelZeroLowerBound(i_lit.var)) return -1; + DCHECK(!IsTrueAtLevelZero(i_lit)); + int trail_index = var_trail_index_[i_lit.var]; // Check the validity of the cached index and use it if possible. This caching @@ -1003,6 +1004,7 @@ int IntegerTrail::FindLowestTrailIndexThatExplainBound( int prev_trail_index = trail_index; while (true) { + ++work_done_in_explain_lower_than_; if (trail_index >= var_trail_index_cache_threshold_) { var_trail_index_cache_[i_lit.var] = trail_index; } @@ -1171,10 +1173,9 @@ std::vector* IntegerTrail::InitializeConflict( lazy_reasons_.back().Explain(conflict, &tmp_queue_); } else { conflict->assign(literals_reason.begin(), literals_reason.end()); - const int num_vars = var_lbs_.size(); for (const IntegerLiteral& literal : bounds_reason) { - const int trail_index = FindLowestTrailIndexThatExplainBound(literal); - if (trail_index >= num_vars) tmp_queue_.push_back(trail_index); + if (IsTrueAtLevelZero(literal)) continue; + tmp_queue_.push_back(FindLowestTrailIndexThatExplainBound(literal)); } } return conflict; @@ -1553,9 +1554,8 @@ bool IntegerTrail::EnqueueInternal( // efficiency and a potential smaller reason. auto* conflict = InitializeConflict(i_lit, use_lazy_reason, literal_reason, integer_reason); - { - const int trail_index = FindLowestTrailIndexThatExplainBound(ub_reason); - if (trail_index >= 0) tmp_queue_.push_back(trail_index); + if (!IsTrueAtLevelZero(ub_reason)) { + tmp_queue_.push_back(FindLowestTrailIndexThatExplainBound(ub_reason)); } MergeReasonIntoInternal(conflict, NextConflictId()); return false; @@ -1771,12 +1771,10 @@ absl::Span IntegerTrail::Dependencies(int reason_index) const { int new_size = 0; int* data = trail_index_reason_buffer_.data() + start; - const int num_vars = var_lbs_.size(); for (int i = start; i < end; ++i) { - const int dep = - FindLowestTrailIndexThatExplainBound(bounds_reason_buffer_[i]); - if (dep >= num_vars) { - data[new_size++] = dep; + const IntegerLiteral to_explain = bounds_reason_buffer_[i]; + if (!IsTrueAtLevelZero(to_explain)) { + data[new_size++] = FindLowestTrailIndexThatExplainBound(to_explain); } } cached_sizes_[reason_index] = new_size; @@ -1818,14 +1816,10 @@ std::vector IntegerTrail::ReasonFor(IntegerLiteral literal) const { void IntegerTrail::MergeReasonInto(absl::Span literals, std::vector* output) const { DCHECK(tmp_queue_.empty()); - const int num_vars = var_lbs_.size(); for (const IntegerLiteral& literal : literals) { if (literal.IsAlwaysTrue()) continue; - const int trail_index = FindLowestTrailIndexThatExplainBound(literal); - - // Any indices lower than that means that there is no reason needed. - // Note that it is important for size to be signed because of -1 indices. - if (trail_index >= num_vars) tmp_queue_.push_back(trail_index); + if (IsTrueAtLevelZero(literal)) continue; + tmp_queue_.push_back(FindLowestTrailIndexThatExplainBound(literal)); } return MergeReasonIntoInternal(output, -1); } diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index 1c8cbf1438..1b926092e7 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -505,6 +505,7 @@ class IntegerTrail final : public SatPropagator { // Same as above for an affine expression. IntegerValue LowerBound(AffineExpression expr) const; IntegerValue UpperBound(AffineExpression expr) const; + IntegerValue UpperBound(LinearExpression2 expr) const; bool IsFixed(AffineExpression expr) const; IntegerValue FixedValue(AffineExpression expr) const; @@ -522,6 +523,7 @@ class IntegerTrail final : public SatPropagator { // Returns the current value (if known) of an IntegerLiteral. bool IntegerLiteralIsTrue(IntegerLiteral l) const; bool IntegerLiteralIsFalse(IntegerLiteral l) const; + bool IsTrueAtLevelZero(IntegerLiteral l) const; // Returns globally valid lower/upper bound on the given integer variable. IntegerValue LevelZeroLowerBound(IntegerVariable var) const; @@ -795,39 +797,38 @@ class IntegerTrail final : public SatPropagator { void AddAllGreaterThanConstantReason(absl::Span exprs, IntegerValue target_min, std::vector* indices) const { - int64_t num_processed = 0; + constexpr int64_t check_period = 1e6; + int64_t limit_check = work_done_in_explain_lower_than_ + check_period; for (const AffineExpression& expr : exprs) { if (expr.IsConstant()) { DCHECK_GE(expr.constant, target_min); continue; } DCHECK_NE(expr.var, kNoIntegerVariable); + const IntegerLiteral to_explain = expr.GreaterOrEqual(target_min); + if (IsTrueAtLevelZero(to_explain)) continue; // On large routing problems, we can spend a lot of time in this loop. - // We check the time limit every 5 processed expressions. - if (++num_processed % 5 == 0 && time_limit_->LimitReached()) return; + if (work_done_in_explain_lower_than_ > limit_check) { + limit_check = work_done_in_explain_lower_than_ + check_period; + if (time_limit_->LimitReached()) return; + } // Skip if we already have an explanation for expr >= target_min. Note // that we already do that while processing the returned indices, so this // mainly save a FindLowestTrailIndexThatExplainBound() call per skipped // indices, which can still be costly. { - const int index = tmp_var_to_trail_index_in_queue_[expr.var]; + const int index = tmp_var_to_trail_index_in_queue_[to_explain.var]; if (index == std::numeric_limits::max()) continue; - if (index > 0 && - expr.ValueAt(integer_trail_[index].bound) >= target_min) { + if (index > 0 && integer_trail_[index].bound >= to_explain.bound) { has_dependency_ = true; continue; } } // We need to find the index that explain the bound. - // Note that this will skip if the condition is true at level zero. - const int index = - FindLowestTrailIndexThatExplainBound(expr.GreaterOrEqual(target_min)); - if (index >= 0) { - indices->push_back(index); - } + indices->push_back(FindLowestTrailIndexThatExplainBound(to_explain)); } } @@ -884,8 +885,8 @@ class IntegerTrail final : public SatPropagator { int64_t conflict_id) const; // Returns the lowest trail index of a TrailEntry that can be used to explain - // the given IntegerLiteral. The literal must be currently true (CHECKed). - // Returns -1 if the explanation is trivial. + // the given IntegerLiteral. The literal must be currently true but not true + // at level zero (DCHECKed). int FindLowestTrailIndexThatExplainBound(IntegerLiteral i_lit) const; // This must be called before Dependencies() or AppendLiteralsReason(). @@ -1032,6 +1033,8 @@ class IntegerTrail final : public SatPropagator { std::vector*> watchers_; std::vector reversible_classes_; + mutable int64_t work_done_in_explain_lower_than_ = 0; + mutable Domain temp_domain_; DelayedRootLevelDeduction* delayed_to_fix_; IntegerDomains* domains_; @@ -1375,6 +1378,20 @@ inline IntegerValue IntegerTrail::UpperBound(AffineExpression expr) const { return UpperBound(expr.var) * expr.coeff + expr.constant; } +inline IntegerValue IntegerTrail::UpperBound(LinearExpression2 expr) const { + IntegerValue result = 0; + for (int i = 0; i < 2; ++i) { + if (expr.coeffs[i] == 0) { + continue; + } else if (expr.coeffs[i] > 0) { + result += expr.coeffs[i] * UpperBound(expr.vars[i]); + } else { + result += expr.coeffs[i] * LowerBound(expr.vars[i]); + } + } + return result; +} + inline bool IntegerTrail::IsFixed(AffineExpression expr) const { if (expr.var == kNoIntegerVariable) return true; return IsFixed(expr.var); @@ -1405,6 +1422,10 @@ inline bool IntegerTrail::IntegerLiteralIsFalse(IntegerLiteral l) const { return l.bound > UpperBound(l.var); } +inline bool IntegerTrail::IsTrueAtLevelZero(IntegerLiteral l) const { + return l.bound <= LevelZeroLowerBound(l.var); +} + // The level zero bounds are stored at the beginning of the trail and they also // serves as sentinels. Their index match the variables index. inline IntegerValue IntegerTrail::LevelZeroLowerBound( diff --git a/ortools/sat/integer_base.cc b/ortools/sat/integer_base.cc index 8af3a695ca..29a7d8d186 100644 --- a/ortools/sat/integer_base.cc +++ b/ortools/sat/integer_base.cc @@ -13,11 +13,15 @@ #include "ortools/sat/integer_base.h" +#include #include #include +#include #include +#include #include "absl/log/check.h" +#include "ortools/util/bitset.h" namespace operations_research::sat { @@ -79,22 +83,19 @@ bool LinearExpression2::NegateForCanonicalization() { return negate; } -void LinearExpression2::CanonicalizeAndUpdateBounds(IntegerValue& lb, - IntegerValue& ub, - bool allow_negation) { +bool LinearExpression2::CanonicalizeAndUpdateBounds(IntegerValue& lb, + IntegerValue& ub) { SimpleCanonicalization(); - if (coeffs[0] == 0 || coeffs[1] == 0) return; // abort. + if (coeffs[0] == 0 || coeffs[1] == 0) return false; // abort. - if (allow_negation) { - const bool negated = NegateForCanonicalization(); - if (negated) { - // We need to be able to negate without overflow. - CHECK_GE(lb, kMinIntegerValue); - CHECK_LE(ub, kMaxIntegerValue); - std::swap(lb, ub); - lb = -lb; - ub = -ub; - } + const bool negated = NegateForCanonicalization(); + if (negated) { + // We need to be able to negate without overflow. + CHECK_GE(lb, kMinIntegerValue); + CHECK_LE(ub, kMaxIntegerValue); + std::swap(lb, ub); + lb = -lb; + ub = -ub; } // Do gcd division. @@ -108,27 +109,71 @@ void LinearExpression2::CanonicalizeAndUpdateBounds(IntegerValue& lb, CHECK(coeffs[0] != 0 || vars[0] == kNoIntegerVariable); CHECK(coeffs[1] != 0 || vars[1] == kNoIntegerVariable); + + return negated; } -bool BestBinaryRelationBounds::Add(LinearExpression2 expr, IntegerValue lb, - IntegerValue ub) { - expr.CanonicalizeAndUpdateBounds(lb, ub); - if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) return false; +bool LinearExpression2::IsCanonicalized() const { + for (int i : {0, 1}) { + if ((vars[i] == kNoIntegerVariable) != (coeffs[i] == 0)) { + return false; + } + } + if (vars[0] >= vars[1]) return false; + + if (vars[0] == kNoIntegerVariable) return true; + + return coeffs[0] > 0 && coeffs[1] > 0; +} + +void LinearExpression2::MakeVariablesPositive() { + SimpleCanonicalization(); + for (int i = 0; i < 2; ++i) { + if (vars[i] != kNoIntegerVariable && !VariableIsPositive(vars[i])) { + coeffs[i] = -coeffs[i]; + vars[i] = NegationOf(vars[i]); + } + } +} + +std::pair +BestBinaryRelationBounds::Add(LinearExpression2 expr, IntegerValue lb, + IntegerValue ub) { + const bool negated = expr.CanonicalizeAndUpdateBounds(lb, ub); + + // We only store proper linear2. + if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) { + return {AddResult::INVALID, AddResult::INVALID}; + } auto [it, inserted] = best_bounds_.insert({expr, {lb, ub}}); - if (inserted) return true; + if (inserted) { + std::pair result = { + lb > kMinIntegerValue ? AddResult::ADDED : AddResult::INVALID, + ub < kMaxIntegerValue ? AddResult::ADDED : AddResult::INVALID}; + if (negated) std::swap(result.first, result.second); + return result; + } const auto [known_lb, known_ub] = it->second; - bool restricted = false; + + std::pair result = { + lb > kMinIntegerValue ? AddResult::NOT_BETTER : AddResult::INVALID, + ub < kMaxIntegerValue ? AddResult::NOT_BETTER : AddResult::INVALID}; if (lb > known_lb) { + result.first = (it->second.first == kMinIntegerValue) ? AddResult::ADDED + : AddResult::UPDATED; it->second.first = lb; - restricted = true; } if (ub < known_ub) { + result.second = (it->second.second == kMaxIntegerValue) + ? AddResult::ADDED + : AddResult::UPDATED; it->second.second = ub; - restricted = true; } - return restricted; + if (negated) std::swap(result.first, result.second); + return result; } RelationStatus BestBinaryRelationBounds::GetStatus(LinearExpression2 expr, @@ -165,4 +210,34 @@ IntegerValue BestBinaryRelationBounds::GetUpperBound( return kMaxIntegerValue; } +std::vector> +BestBinaryRelationBounds::GetSortedNonTrivialUpperBounds() const { + std::vector> root_relations_sorted; + root_relations_sorted.reserve(2 * best_bounds_.size()); + for (const auto& [expr, bounds] : best_bounds_) { + if (bounds.first != kMinIntegerValue) { + LinearExpression2 negated_expr = expr; + negated_expr.Negate(); + root_relations_sorted.push_back({negated_expr, -bounds.first}); + } + if (bounds.second != kMaxIntegerValue) { + root_relations_sorted.push_back({expr, bounds.second}); + } + } + std::sort(root_relations_sorted.begin(), root_relations_sorted.end()); + return root_relations_sorted; +} + +std::vector> +BestBinaryRelationBounds::GetSortedNonTrivialBounds() const { + std::vector> + root_relations_sorted; + root_relations_sorted.reserve(best_bounds_.size()); + for (const auto& [expr, bounds] : best_bounds_) { + root_relations_sorted.push_back({expr, bounds.first, bounds.second}); + } + std::sort(root_relations_sorted.begin(), root_relations_sorted.end()); + return root_relations_sorted; +} + } // namespace operations_research::sat diff --git a/ortools/sat/integer_base.h b/ortools/sat/integer_base.h index a86d15eb07..14aa492ced 100644 --- a/ortools/sat/integer_base.h +++ b/ortools/sat/integer_base.h @@ -41,6 +41,11 @@ namespace sat { // Callbacks that will be called when the search goes back to level 0. // Callbacks should return false if the propagation fails. +// +// We will call this after propagation has reached a fixed point. Note however +// that if any callbacks "propagate" something, the callbacks following it might +// not see a state where the propagation have been called again. +// TODO(user): maybe we should re-propagate before calling the next callback. struct LevelZeroCallbackHelper { std::vector> callbacks; }; @@ -95,6 +100,13 @@ inline IntegerValue FloorRatio(IntegerValue dividend, return result - adjust; } +// When the case positive_divisor == 1 is frequent, this is faster. +inline IntegerValue FloorRatioWithTest(IntegerValue dividend, + IntegerValue positive_divisor) { + if (positive_divisor == 1) return dividend; + return FloorRatio(dividend, positive_divisor); +} + // Overflows and saturated arithmetic. inline IntegerValue CapProdI(IntegerValue a, IntegerValue b) { @@ -369,26 +381,30 @@ struct LinearExpression2 { // This will not change any bounds on the LinearExpression2. // That is we will not potentially Negate() the expression like // CanonicalizeAndUpdateBounds() might do. - // Note that since kNoIntegerVariable=-1 and we sort the variables, if we any + // Note that since kNoIntegerVariable=-1 and we sort the variables, if we have // one zero and one non-zero we will always have the zero first. void SimpleCanonicalization(); // Fully canonicalizes the expression and updates the given bounds // accordingly. This is the same as SimpleCanonicalization(), DivideByGcd() // and the NegateForCanonicalization() with a proper updates of the bounds. - void CanonicalizeAndUpdateBounds(IntegerValue& lb, IntegerValue& ub, - bool allow_negation = false); + // Returns whether the expression was negated. + bool CanonicalizeAndUpdateBounds(IntegerValue& lb, IntegerValue& ub); // Divides the expression by the gcd of both coefficients, and returns it. // Note that we always return something >= 1 even if both coefficients are // zero. IntegerValue DivideByGcd(); + bool IsCanonicalized() const; + // Makes sure expr and -expr have the same canonical representation by // negating the expression of it is in the non-canonical form. Returns true if // the expression was negated. bool NegateForCanonicalization(); + void MakeVariablesPositive(); + absl::Span non_zero_vars() const { const int first = coeffs[0] == 0 ? 1 : 0; const int last = coeffs[1] == 0 ? 0 : 1; @@ -413,13 +429,31 @@ struct LinearExpression2 { IntegerValue coeffs[2]; IntegerVariable vars[2]; + + template + friend void AbslStringify(Sink& sink, const LinearExpression2& expr) { + absl::Format(&sink, "%d X%d + %d X%d", expr.coeffs[0].value(), + expr.vars[0].value(), expr.coeffs[1].value(), + expr.vars[1].value()); + } }; -inline std::ostream& operator<<(std::ostream& os, - const LinearExpression2& expr) { - os << absl::StrCat(expr.coeffs[0], " X", expr.vars[0], " + ", expr.coeffs[1], - " X", expr.vars[1]); - return os; +// Encodes (a - b <= ub) in (linear2 <= ub) format. +// Note that the returned expression is canonicalized and divided by its GCD. +inline std::pair EncodeDifferenceLowerThan( + AffineExpression a, AffineExpression b, IntegerValue ub) { + LinearExpression2 expr; + expr.vars[0] = a.var; + expr.coeffs[0] = a.coeff; + expr.vars[1] = b.var; + expr.coeffs[1] = -b.coeff; + IntegerValue rhs = ub + b.constant - a.constant; + + // Canonicalize. + expr.SimpleCanonicalization(); + const IntegerValue gcd = expr.DivideByGcd(); + rhs = FloorRatio(rhs, gcd); + return {std::move(expr), rhs}; } template @@ -437,9 +471,21 @@ class BestBinaryRelationBounds { public: // Register the fact that expr \in [lb, ub] is true. // - // Returns true if this fact is new, that is if the bounds are tighter than - // the current ones. - bool Add(LinearExpression2 expr, IntegerValue lb, IntegerValue ub); + // If lb==kMinIntegerValue it only register that expr <= ub (and symmetrically + // for ub==kMaxIntegerValue). + // + // Returns for each of the bound if it was restricted (added/updated), if it + // was ignored because a better or equal bound was already present, or if it + // was rejected because it was invalid (e.g. the expression was a degenerate + // linear2 or the bound was a min/max value). + enum class AddResult { + ADDED, + UPDATED, + NOT_BETTER, + INVALID, + }; + std::pair Add(LinearExpression2 expr, IntegerValue lb, + IntegerValue ub); // Returns the known status of expr <= bound. RelationStatus GetStatus(LinearExpression2 expr, IntegerValue lb, @@ -450,6 +496,18 @@ class BestBinaryRelationBounds { // entry in the hash-map. IntegerValue GetUpperBound(LinearExpression2 expr) const; + // Same as GetUpperBound() but assume the expression is already canonicalized. + // This is slightly faster. + IntegerValue UpperBoundWhenCanonicalized(LinearExpression2 expr) const; + + int64_t num_bounds() const { return best_bounds_.size(); } + + std::vector> + GetSortedNonTrivialUpperBounds() const; + + std::vector> + GetSortedNonTrivialBounds() const; + private: // The best bound on the given "canonicalized" expression. absl::flat_hash_map> @@ -512,6 +570,28 @@ std::ostream& operator<<(std::ostream& os, const ValueLiteralPair& p); DEFINE_STRONG_INDEX_TYPE(IntervalVariable); const IntervalVariable kNoIntervalVariable(-1); +// This functions appears in hot spot, and so it is important to inline it. +// +// TODO(user): Maybe introduce a CanonicalizedLinear2 class so we automatically +// get the better function, and it documents when we have canonicalized +// expression. +inline IntegerValue BestBinaryRelationBounds::UpperBoundWhenCanonicalized( + LinearExpression2 expr) const { + DCHECK_EQ(expr.DivideByGcd(), 1); + DCHECK(expr.IsCanonicalized()); + const bool negated = expr.NegateForCanonicalization(); + const auto it = best_bounds_.find(expr); + if (it != best_bounds_.end()) { + const auto [known_lb, known_ub] = it->second; + if (negated) { + return -known_lb; + } else { + return known_ub; + } + } + return kMaxIntegerValue; +} + // ============================================================================ // Implementation. // ============================================================================ @@ -552,8 +632,8 @@ inline IntegerLiteral AffineExpression::GreaterOrEqual( : IntegerLiteral::FalseLiteral(); } DCHECK_GT(coeff, 0); - return IntegerLiteral::GreaterOrEqual(var, - CeilRatio(bound - constant, coeff)); + return IntegerLiteral::GreaterOrEqual( + var, coeff == 1 ? bound - constant : CeilRatio(bound - constant, coeff)); } // var * coeff + constant <= bound. @@ -563,7 +643,8 @@ inline IntegerLiteral AffineExpression::LowerOrEqual(IntegerValue bound) const { : IntegerLiteral::FalseLiteral(); } DCHECK_GT(coeff, 0); - return IntegerLiteral::LowerOrEqual(var, FloorRatio(bound - constant, coeff)); + return IntegerLiteral::LowerOrEqual( + var, coeff == 1 ? bound - constant : FloorRatio(bound - constant, coeff)); } } // namespace sat diff --git a/ortools/sat/integer_base_test.cc b/ortools/sat/integer_base_test.cc index e3b069cd02..10774a554a 100644 --- a/ortools/sat/integer_base_test.cc +++ b/ortools/sat/integer_base_test.cc @@ -13,6 +13,8 @@ #include "ortools/sat/integer_base.h" +#include + #include "gtest/gtest.h" namespace operations_research::sat { @@ -59,12 +61,17 @@ TEST(BestBinaryRelationBoundsTest, Basic) { expr.coeffs[0] = IntegerValue(1); expr.coeffs[1] = IntegerValue(-1); + using AddResult = BestBinaryRelationBounds::AddResult; + BestBinaryRelationBounds best_bounds; - EXPECT_TRUE(best_bounds.Add(expr, IntegerValue(0), IntegerValue(5))); - EXPECT_TRUE(best_bounds.Add(expr, IntegerValue(3), IntegerValue(8))); - EXPECT_TRUE(best_bounds.Add(expr, IntegerValue(-1), IntegerValue(4))); - EXPECT_FALSE( - best_bounds.Add(expr, IntegerValue(3), IntegerValue(4))); // best + EXPECT_EQ(best_bounds.Add(expr, IntegerValue(0), IntegerValue(5)), + std::make_pair(AddResult::ADDED, AddResult::ADDED)); + EXPECT_EQ(best_bounds.Add(expr, IntegerValue(3), IntegerValue(8)), + std::make_pair(AddResult::UPDATED, AddResult::NOT_BETTER)); + EXPECT_EQ(best_bounds.Add(expr, IntegerValue(-1), IntegerValue(4)), + std::make_pair(AddResult::NOT_BETTER, AddResult::UPDATED)); + EXPECT_EQ(best_bounds.Add(expr, IntegerValue(3), IntegerValue(4)), // best + std::make_pair(AddResult::NOT_BETTER, AddResult::NOT_BETTER)); EXPECT_EQ(RelationStatus::IS_TRUE, best_bounds.GetStatus(expr, IntegerValue(-10), IntegerValue(4))); @@ -85,8 +92,10 @@ TEST(BestBinaryRelationBoundsTest, UpperBound) { expr.coeffs[0] = IntegerValue(1); expr.coeffs[1] = IntegerValue(-1); + using AddResult = BestBinaryRelationBounds::AddResult; BestBinaryRelationBounds best_bounds; - EXPECT_TRUE(best_bounds.Add(expr, IntegerValue(0), IntegerValue(5))); + EXPECT_EQ(best_bounds.Add(expr, IntegerValue(0), IntegerValue(5)), + std::make_pair(AddResult::ADDED, AddResult::ADDED)); EXPECT_EQ(best_bounds.GetUpperBound(expr), IntegerValue(5)); diff --git a/ortools/sat/integer_search.cc b/ortools/sat/integer_search.cc index 7e095a2eba..f98b2935b6 100644 --- a/ortools/sat/integer_search.cc +++ b/ortools/sat/integer_search.cc @@ -394,7 +394,7 @@ std::function IntegerValueSelectionHeuristic( value_selection_heuristics.push_back( [model, response_manager](IntegerVariable var) { return SplitUsingBestSolutionValueInRepository( - var, response_manager->SolutionsRepository(), model); + var, response_manager->SolutionPool().BestSolutions(), model); }); } } @@ -872,7 +872,6 @@ std::function CumulativePrecedenceSearchHeuristic( // TODO(user): Add heuristic ordering for creating interesting precedence // first. bool found_precedence_to_add = false; - std::vector conflict; helper->ClearReason(); for (const int s : open_tasks) { for (const int t : open_tasks) { @@ -897,13 +896,13 @@ std::function CumulativePrecedenceSearchHeuristic( // fixed all literal, but if it is not, we can just return this // decision. if (trail->Assignment().LiteralIsFalse(Literal(existing))) { - conflict.push_back(Literal(existing)); + helper->MutableLiteralReason()->push_back(Literal(existing)); continue; } } else { // Make sure s could be before t. if (helper->EndMin(s) > helper->StartMax(t)) { - helper->AddReasonForBeingBefore(t, s); + helper->AddReasonForBeingBeforeAssumingNoOverlap(t, s); continue; } @@ -929,24 +928,24 @@ std::function CumulativePrecedenceSearchHeuristic( // // TODO(user): We need to add the reason for demand_min and capacity_max. // TODO(user): unfortunately we can't report it from here. - std::vector integer_reason = - *helper->MutableIntegerReason(); if (!h.capacity.IsConstant()) { - integer_reason.push_back( + helper->MutableIntegerReason()->push_back( integer_trail->UpperBoundAsLiteral(h.capacity)); } const auto& demands = h.demand_helper->Demands(); for (const int t : open_tasks) { if (helper->IsOptional(t)) { CHECK(trail->Assignment().LiteralIsTrue(helper->PresenceLiteral(t))); - conflict.push_back(helper->PresenceLiteral(t).Negated()); + helper->MutableLiteralReason()->push_back( + helper->PresenceLiteral(t).Negated()); } const AffineExpression d = demands[t]; if (!d.IsConstant()) { - integer_reason.push_back(integer_trail->LowerBoundAsLiteral(d)); + helper->MutableIntegerReason()->push_back( + integer_trail->LowerBoundAsLiteral(d)); } } - integer_trail->ReportConflict(conflict, integer_reason); + (void)helper->ReportConflict(); search_helper->NotifyThatConflictWasFoundDuringGetDecision(); if (VLOG_IS_ON(2)) { LOG(INFO) << "Conflict between precedences !"; @@ -1026,7 +1025,7 @@ std::function RandomizeOnRestartHeuristic( value_selection_heuristics.push_back( [model, response_manager](IntegerVariable var) { return SplitUsingBestSolutionValueInRepository( - var, response_manager->SolutionsRepository(), model); + var, response_manager->SolutionPool().BestSolutions(), model); }); value_selection_weight.push_back(5); } diff --git a/ortools/sat/intervals.cc b/ortools/sat/intervals.cc index c50429e71f..113ad4e5d9 100644 --- a/ortools/sat/intervals.cc +++ b/ortools/sat/intervals.cc @@ -43,7 +43,7 @@ IntervalsRepository::IntervalsRepository(Model* model) sat_solver_(model->GetOrCreate()), implications_(model->GetOrCreate()), integer_trail_(model->GetOrCreate()), - relations_maps_(model->GetOrCreate()) {} + reified_precedences_(model->GetOrCreate()) {} IntervalVariable IntervalsRepository::CreateInterval(IntegerVariable start, IntegerVariable end, @@ -155,9 +155,9 @@ IntervalsRepository::GetOrCreateDisjunctivePrecedenceLiteralIfNonTrivial( } // Abort if the relation is already known. - if (relations_maps_->GetLevelZeroPrecedenceStatus(a.end, b.start) == + if (reified_precedences_->GetLevelZeroPrecedenceStatus(a.end, b.start) == RelationStatus::IS_TRUE || - relations_maps_->GetLevelZeroPrecedenceStatus(b.end, a.start) == + reified_precedences_->GetLevelZeroPrecedenceStatus(b.end, a.start) == RelationStatus::IS_TRUE) { return kNoLiteralIndex; } @@ -181,10 +181,10 @@ IntervalsRepository::GetOrCreateDisjunctivePrecedenceLiteralIfNonTrivial( // Also insert it in precedences. if (enforcement_literals.empty()) { - relations_maps_->AddReifiedPrecedenceIfNonTrivial(a_before_b, a.end, - b.start); - relations_maps_->AddReifiedPrecedenceIfNonTrivial(a_before_b.Negated(), - b.end, a.start); + reified_precedences_->AddReifiedPrecedenceIfNonTrivial(a_before_b, a.end, + b.start); + reified_precedences_->AddReifiedPrecedenceIfNonTrivial(a_before_b.Negated(), + b.end, a.start); } enforcement_literals.push_back(a_before_b); @@ -212,12 +212,12 @@ IntervalsRepository::GetOrCreateDisjunctivePrecedenceLiteralIfNonTrivial( bool IntervalsRepository::CreatePrecedenceLiteralIfNonTrivial( AffineExpression x, AffineExpression y) { - const LiteralIndex index = relations_maps_->GetReifiedPrecedence(x, y); + const LiteralIndex index = reified_precedences_->GetReifiedPrecedence(x, y); if (index != kNoLiteralIndex) return false; // We want l => x <= y and not(l) => x > y <=> y + 1 <= x // Do not create l if the relation is always true or false. - if (relations_maps_->GetLevelZeroPrecedenceStatus(x, y) != + if (reified_precedences_->GetLevelZeroPrecedenceStatus(x, y) != RelationStatus::IS_UNKNOWN) { return false; } @@ -225,7 +225,7 @@ bool IntervalsRepository::CreatePrecedenceLiteralIfNonTrivial( // Create a new literal. const BooleanVariable boolean_var = sat_solver_->NewBooleanVariable(); const Literal x_before_y = Literal(boolean_var, true); - relations_maps_->AddReifiedPrecedenceIfNonTrivial(x_before_y, x, y); + reified_precedences_->AddReifiedPrecedenceIfNonTrivial(x_before_y, x, y); AffineExpression y_plus_one = y; y_plus_one.constant += 1; @@ -236,7 +236,7 @@ bool IntervalsRepository::CreatePrecedenceLiteralIfNonTrivial( LiteralIndex IntervalsRepository::GetPrecedenceLiteral( AffineExpression x, AffineExpression y) const { - return relations_maps_->GetReifiedPrecedence(x, y); + return reified_precedences_->GetReifiedPrecedence(x, y); } Literal IntervalsRepository::GetOrCreatePrecedenceLiteral(AffineExpression x, @@ -247,7 +247,7 @@ Literal IntervalsRepository::GetOrCreatePrecedenceLiteral(AffineExpression x, } CHECK(CreatePrecedenceLiteralIfNonTrivial(x, y)); - const LiteralIndex index = relations_maps_->GetReifiedPrecedence(x, y); + const LiteralIndex index = reified_precedences_->GetReifiedPrecedence(x, y); CHECK_NE(index, kNoLiteralIndex); return Literal(index); } diff --git a/ortools/sat/intervals.h b/ortools/sat/intervals.h index 8b36fd47d2..fe4f0fde0b 100644 --- a/ortools/sat/intervals.h +++ b/ortools/sat/intervals.h @@ -28,6 +28,7 @@ #include "ortools/sat/integer_base.h" #include "ortools/sat/model.h" #include "ortools/sat/no_overlap_2d_helper.h" +#include "ortools/sat/precedences.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_solver.h" #include "ortools/sat/scheduling_helpers.h" @@ -189,7 +190,7 @@ class IntervalsRepository { SatSolver* sat_solver_; BinaryImplicationGraph* implications_; IntegerTrail* integer_trail_; - BinaryRelationsMaps* relations_maps_; + ReifiedLinear2Bounds* reified_precedences_; // Literal indicating if the tasks is executed. Tasks that are always executed // will have a kNoLiteralIndex entry in this vector. diff --git a/ortools/sat/java/BUILD.bazel b/ortools/sat/java/BUILD.bazel index dc1589eb1e..76bc1a8521 100644 --- a/ortools/sat/java/BUILD.bazel +++ b/ortools/sat/java/BUILD.bazel @@ -19,7 +19,7 @@ load("//bazel:swig_java.bzl", "ortools_java_wrap_cc") ortools_java_wrap_cc( name = "sat", - src = "sat.i", + src = "sat.swig", java_deps = [ "//ortools/sat:cp_model_java_proto", "//ortools/sat:sat_parameters_java_proto", diff --git a/ortools/sat/java/CMakeLists.txt b/ortools/sat/java/CMakeLists.txt index ea6aeec7bf..6f7243292b 100644 --- a/ortools/sat/java/CMakeLists.txt +++ b/ortools/sat/java/CMakeLists.txt @@ -11,17 +11,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -set_property(SOURCE sat.i PROPERTY CPLUSPLUS ON) -set_property(SOURCE sat.i PROPERTY SWIG_MODULE_NAME main) -set_property(SOURCE sat.i PROPERTY COMPILE_DEFINITIONS +set_property(SOURCE sat.swig PROPERTY CPLUSPLUS ON) +set_property(SOURCE sat.swig PROPERTY SWIG_MODULE_NAME main) +set_property(SOURCE sat.swig PROPERTY COMPILE_DEFINITIONS ${OR_TOOLS_COMPILE_DEFINITIONS} ABSL_MUST_USE_RESULT=) -set_property(SOURCE sat.i PROPERTY COMPILE_OPTIONS +set_property(SOURCE sat.swig PROPERTY COMPILE_OPTIONS -package ${JAVA_PACKAGE}.sat) swig_add_library(jnisat TYPE OBJECT LANGUAGE java OUTPUT_DIR ${JAVA_PROJECT_DIR}/${JAVA_SRC_PATH}/sat - SOURCES sat.i) + SOURCES sat.swig) target_include_directories(jnisat PRIVATE ${JNI_INCLUDE_DIRS}) set_target_properties(jnisat PROPERTIES diff --git a/ortools/sat/java/sat.i b/ortools/sat/java/sat.swig similarity index 98% rename from ortools/sat/java/sat.i rename to ortools/sat/java/sat.swig index af1b325e2b..f2def937a9 100644 --- a/ortools/sat/java/sat.i +++ b/ortools/sat/java/sat.swig @@ -17,7 +17,7 @@ %include "ortools/util/java/proto.i" -%import "ortools/util/java/sorted_interval_list.i" +%import "ortools/util/java/sorted_interval_list.swig" %{ #include "ortools/sat/cp_model.pb.h" @@ -76,7 +76,7 @@ PROTO2_RETURN(operations_research::sat::CpSolverResponse, // The only difference is that the argument is not a basic type, and needs // processing to be passed to the std::function. // -// TODO(user): cleanup java/functions.i and move the code there. +// TODO(user): cleanup java/functions.swig and move the code there. %{ #include // std::make_shared %} diff --git a/ortools/sat/linear_propagation.cc b/ortools/sat/linear_propagation.cc index c77ff22752..330483c928 100644 --- a/ortools/sat/linear_propagation.cc +++ b/ortools/sat/linear_propagation.cc @@ -384,8 +384,8 @@ LinearPropagator::LinearPropagator(Model* model) rev_int_repository_(model->GetOrCreate()), rev_integer_value_repository_( model->GetOrCreate()), - precedences_(model->GetOrCreate()), - binary_relations_(model->GetOrCreate()), + precedences_(model->GetOrCreate()), + linear3_bounds_(model->GetOrCreate()), random_(model->GetOrCreate()), shared_stats_(model->GetOrCreate()), watcher_id_(watcher_->Register(this)), @@ -538,7 +538,8 @@ bool LinearPropagator::Propagate() { // - Z + Y >= 6 ==> Z >= 1 // - (1) again to push T <= 10 and reach the propagation fixed point. Bitset64::View in_queue = in_queue_.view(); - const bool push_affine_ub = push_affine_ub_for_binary_relations_; + const bool push_affine_ub = push_affine_ub_for_binary_relations_ || + trail_->CurrentDecisionLevel() == 0; while (true) { // We always process the whole queue in FIFO order. // Note that the order really only matter for infeasible constraint so it @@ -612,7 +613,7 @@ bool LinearPropagator::Propagate() { // The rev_rhs was updated to: initial_rhs - lb(vars[2]) * coeffs[2]. const IntegerValue initial_rhs = info.rev_rhs + coeffs[2] * integer_trail_->LowerBound(vars[2]); - binary_relations_->AddAffineUpperBound( + linear3_bounds_->AddAffineUpperBound( expr, AffineExpression(vars[2], -coeffs[2], initial_rhs)); } else if (info.rev_size == 3) { for (int i = 0; i < 3; ++i) { @@ -623,7 +624,7 @@ bool LinearPropagator::Propagate() { expr.vars[1] = vars[b]; expr.coeffs[0] = coeffs[a]; expr.coeffs[1] = coeffs[b]; - binary_relations_->AddAffineUpperBound( + linear3_bounds_->AddAffineUpperBound( expr, AffineExpression(vars[i], -coeffs[i], info.rev_rhs)); } } diff --git a/ortools/sat/linear_propagation.h b/ortools/sat/linear_propagation.h index b98f46711e..ab4027b665 100644 --- a/ortools/sat/linear_propagation.h +++ b/ortools/sat/linear_propagation.h @@ -421,8 +421,8 @@ class LinearPropagator : public PropagatorInterface, TimeLimit* time_limit_; RevIntRepository* rev_int_repository_; RevIntegerValueRepository* rev_integer_value_repository_; - PrecedenceRelations* precedences_; - BinaryRelationsMaps* binary_relations_; + EnforcedLinear2Bounds* precedences_; + Linear2BoundsFromLinear3* linear3_bounds_; ModelRandomGenerator* random_; SharedStatistics* shared_stats_ = nullptr; const int watcher_id_; diff --git a/ortools/sat/linear_relaxation.cc b/ortools/sat/linear_relaxation.cc index 2a96fdb1f7..0bd15c832a 100644 --- a/ortools/sat/linear_relaxation.cc +++ b/ortools/sat/linear_relaxation.cc @@ -699,8 +699,8 @@ std::optional DetectMakespanFromPrecedences( } std::vector output; - auto* precedences = model->GetOrCreate(); - precedences->ComputeFullPrecedences(end_vars, &output); + auto* evaluator = model->GetOrCreate(); + evaluator->ComputeFullPrecedences(end_vars, &output); for (const auto& p : output) { // TODO(user): What if we have more than one candidate makespan ? if (p.indices.size() != ends.size()) continue; diff --git a/ortools/sat/no_overlap_2d_helper.cc b/ortools/sat/no_overlap_2d_helper.cc index 9fee042fff..94484b160e 100644 --- a/ortools/sat/no_overlap_2d_helper.cc +++ b/ortools/sat/no_overlap_2d_helper.cc @@ -97,8 +97,8 @@ void ClearAndAddMandatoryOverlapReason(int box1, int box2, y->ClearReason(); y->AddPresenceReason(box1); y->AddPresenceReason(box2); - y->AddReasonForBeingBefore(box1, box2); - y->AddReasonForBeingBefore(box2, box1); + y->AddReasonForBeingBeforeAssumingNoOverlap(box1, box2); + y->AddReasonForBeingBeforeAssumingNoOverlap(box2, box1); } } // namespace @@ -162,7 +162,7 @@ bool LeftBoxBeforeRightBoxOnFirstDimension(int left, int right, x->ClearReason(); x->AddPresenceReason(left); x->AddPresenceReason(right); - x->AddReasonForBeingBefore(left, right); + x->AddReasonForBeingBeforeAssumingNoOverlap(left, right); x->AddEndMinReason(left, left_end_min); // left and right must overlap on y. ClearAndAddMandatoryOverlapReason(left, right, y); @@ -177,7 +177,7 @@ bool LeftBoxBeforeRightBoxOnFirstDimension(int left, int right, x->ClearReason(); x->AddPresenceReason(left); x->AddPresenceReason(right); - x->AddReasonForBeingBefore(left, right); + x->AddReasonForBeingBeforeAssumingNoOverlap(left, right); x->AddStartMaxReason(right, right_start_max); // left and right must overlap on y. ClearAndAddMandatoryOverlapReason(left, right, y); diff --git a/ortools/sat/opb_reader.h b/ortools/sat/opb_reader.h index cb6a36d500..ef452c2e7f 100644 --- a/ortools/sat/opb_reader.h +++ b/ortools/sat/opb_reader.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -87,6 +88,7 @@ class OpbReader { LOG(INFO) << "#variables: " << num_variables_; LOG(INFO) << "#constraints: " << constraints_.size(); LOG(INFO) << "#objective: " << objective_.size(); + if (top_cost_.has_value()) LOG(INFO) << "top_cost: " << top_cost_.value(); const std::string error_message = ValidateModel(); if (!error_message.empty()) { @@ -134,14 +136,16 @@ class OpbReader { void ProcessNewLine(const std::string& line) { const std::vector words = absl::StrSplit(line, absl::ByAnyChar(" ;"), absl::SkipEmpty()); - if (words.empty() || words[0].empty() || words[0][0] == '*') { - // TODO(user): Parse comments. + if (words.empty() || words[0].empty() || words[0][0] == '*') return; + + if (words[0] == "soft:") { + if (words.size() == 1) return; + int64_t top_cost; + if (!ParseInt64Into(words[1], &top_cost)) return; + top_cost_ = top_cost; return; } - // We ignore the number of soft constraints. - if (words[0] == "soft:") return; - if (words[0] == "min:") { for (int i = 1; i < words.size(); ++i) { const std::string& word = words[i]; @@ -364,6 +368,12 @@ class OpbReader { obj->add_coeffs(term.coeff); } } + + if (top_cost_.has_value()) { + CpObjectiveProto* obj = model->mutable_objective(); + obj->add_domain(std::numeric_limits::min()); + obj->add_domain(top_cost_.value()); + } } int num_variables_; @@ -371,6 +381,7 @@ class OpbReader { std::vector constraints_; absl::flat_hash_map, int> product_to_var_; bool model_is_supported_ = true; + std::optional top_cost_; }; } // namespace sat diff --git a/ortools/sat/parameters_validation.cc b/ortools/sat/parameters_validation.cc index 36af16ffab..7c0de26f54 100644 --- a/ortools/sat/parameters_validation.cc +++ b/ortools/sat/parameters_validation.cc @@ -141,6 +141,8 @@ std::string ValidateParameters(const SatParameters& params) { TEST_POSITIVE(glucose_decay_increment_period); TEST_POSITIVE(shared_tree_max_nodes_per_worker); TEST_POSITIVE(shared_tree_open_leaves_per_worker); + TEST_NON_NEGATIVE(shared_tree_split_min_dtime); + TEST_IS_FINITE(shared_tree_split_min_dtime); TEST_POSITIVE(mip_var_scaling); // Test LP tolerances. diff --git a/ortools/sat/precedences.cc b/ortools/sat/precedences.cc index 562795dd3c..339c6ece5e 100644 --- a/ortools/sat/precedences.cc +++ b/ortools/sat/precedences.cc @@ -17,7 +17,9 @@ #include #include +#include #include +#include #include #include @@ -53,60 +55,287 @@ namespace operations_research { namespace sat { -bool PrecedenceRelations::AddBounds(LinearExpression2 expr, IntegerValue lb, - IntegerValue ub) { - expr.CanonicalizeAndUpdateBounds(lb, ub); +LinearExpression2Index Linear2Indices::AddOrGet( + LinearExpression2 original_expr) { + LinearExpression2 expr = original_expr; + DCHECK(expr.IsCanonicalized()); + DCHECK_EQ(expr.DivideByGcd(), 1); + DCHECK_NE(expr.coeffs[0], 0); + DCHECK_NE(expr.coeffs[1], 0); + const bool negated = expr.NegateForCanonicalization(); + auto [it, inserted] = expr_to_index_.insert({expr, exprs_.size()}); + if (inserted) { + CHECK_LT(2 * exprs_.size() + 1, + std::numeric_limits::max()); + exprs_.push_back(expr); + } + const LinearExpression2Index result = + negated ? NegationOf(LinearExpression2Index(2 * it->second)) + : LinearExpression2Index(2 * it->second); - if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) { - // This class handles only binary relationships, let something else handle - // the case where there is actually a single variable. + if (!inserted) return result; + + // Update our per-variable and per-pair lookup tables. + IntegerVariable var1 = PositiveVariable(expr.vars[0]); + IntegerVariable var2 = PositiveVariable(expr.vars[1]); + if (var1 > var2) std::swap(var1, var2); + var_pair_to_bounds_[{var1, var2}].push_back(result); + var_to_bounds_[var1].push_back(result); + var_to_bounds_[var2].push_back(result); + + return result; +} + +void Linear2Watcher::NotifyBoundChanged(LinearExpression2 expr) { + DCHECK(expr.IsCanonicalized()); + DCHECK_EQ(expr.DivideByGcd(), 1); + ++timestamp_; + for (const int id : propagator_ids_) { + watcher_->CallOnNextPropagate(id); + } + for (IntegerVariable var : expr.non_zero_vars()) { + var = PositiveVariable(var); // TODO(user): Be more precise? + if (var >= var_timestamp_.size()) { + var_timestamp_.resize(var + 1, 0); + } + var_timestamp_[var]++; + } +} + +int64_t Linear2Watcher::VarTimestamp(IntegerVariable var) { + var = PositiveVariable(var); + return var < var_timestamp_.size() ? var_timestamp_[var] : 0; +} + +bool RootLevelLinear2Bounds::AddUpperBound(LinearExpression2Index index, + IntegerValue ub) { + const LinearExpression2 expr = lin2_indices_->GetExpression(index); + const IntegerValue zero_level_ub = integer_trail_->LevelZeroUpperBound(expr); + if (ub >= zero_level_ub) { return false; } + if (best_upper_bounds_.size() <= index) { + best_upper_bounds_.resize(index.value() + 1, kMaxIntegerValue); + } + if (ub >= best_upper_bounds_[index]) { + return false; + } + best_upper_bounds_[index] = ub; - // Add to root_relations_. + ++num_updates_; + linear2_watcher_->NotifyBoundChanged(expr); + + // Simple relations. // - // TODO(user): AddInternal() only returns true if this is the first relation - // between head and tail. But we can still avoid an extra lookup. - const bool add_ub = ub < LevelZeroUpperBound(expr); - LinearExpression2 expr_for_lb = expr; - expr_for_lb.Negate(); - const bool add_lb = lb > -LevelZeroUpperBound(expr_for_lb); - if (!add_ub && !add_lb) { - return false; + // TODO(user): Remove them each time we go back to level zero and they become + // trivially true ? + if (IntTypeAbs(expr.coeffs[0]) == 1 && IntTypeAbs(expr.coeffs[1]) == 1) { + if (index >= in_coeff_one_lookup_.size()) { + in_coeff_one_lookup_.resize(index + 1, false); + } + if (!in_coeff_one_lookup_[index]) { + const IntegerVariable a = + expr.coeffs[0] > 0 ? expr.vars[0] : NegationOf(expr.vars[0]); + const IntegerVariable b = + expr.coeffs[1] > 0 ? expr.vars[1] : NegationOf(expr.vars[1]); + + coeff_one_var_lookup_.resize(integer_trail_->NumIntegerVariables()); + in_coeff_one_lookup_[index] = true; + coeff_one_var_lookup_[a].push_back({b, index}); + coeff_one_var_lookup_[b].push_back({a, index}); + } } - if (add_ub) { - AddInternal(expr, ub); - } - if (add_lb) { - AddInternal(expr_for_lb, -lb); - } - - // If we are not built, make sure there is enough room in the graph. - // TODO(user): Alternatively, force caller to do a Resize(). - const int max_node = - std::max(PositiveVariable(expr.vars[0]), PositiveVariable(expr.vars[1])) - .value() + - 1; - if (!is_built_ && max_node >= graph_.num_nodes()) { - graph_.AddNode(max_node); + // Share. + // + // TODO(user): It seems we could change the canonicalization to only use + // positive variable? that would simplify a bit the code here and not make it + // worse elsewhere? + if (shared_linear2_bounds_ != nullptr) { + const IntegerValue lb = -LevelZeroUpperBound(NegationOf(index)); + const int proto_var0 = + cp_model_mapping_->GetProtoVariableFromIntegerVariable( + PositiveVariable(expr.vars[0])); + const int proto_var1 = + cp_model_mapping_->GetProtoVariableFromIntegerVariable( + PositiveVariable(expr.vars[1])); + if (proto_var0 >= 0 && proto_var1 >= 0) { + // This is also a relation between cp_model proto variable. Share it! + // Note that since expr is canonicalized, this one should too. + SharedLinear2Bounds::Key key; + key.vars[0] = proto_var0; + key.coeffs[0] = + VariableIsPositive(expr.vars[0]) ? expr.coeffs[0] : -expr.coeffs[0]; + key.vars[1] = proto_var1; + key.coeffs[1] = + VariableIsPositive(expr.vars[1]) ? expr.coeffs[1] : -expr.coeffs[1]; + shared_linear2_bounds_->Add(shared_linear2_bounds_id_, key, lb, ub); + } } return true; } -bool PrecedenceRelations::AddUpperBound(LinearExpression2 expr, - IntegerValue ub) { - return AddBounds(expr, kMinIntegerValue, ub); +// TODO(user): If we add an indexing for "coeff * var" this is kind of +// easy to generalize to affine relations, not just "simple one". +int RootLevelLinear2Bounds::AugmentSimpleRelations(IntegerVariable var, + int work_limit) { + var = PositiveVariable(var); + if (var >= coeff_one_var_lookup_.size()) return 0; + if (NegationOf(var) >= coeff_one_var_lookup_.size()) return 0; + + // Note that this never touches in_coeff_one_lookup_[var/NegationOf(var)], + // so it should be safe to iterate on it. + int work_done = 0; + for (const auto [a, a_index] : coeff_one_var_lookup_[var]) { + CHECK_NE(PositiveVariable(a), var); + const IntegerValue a_ub = best_upper_bounds_[a_index]; + for (const auto [b, b_index] : coeff_one_var_lookup_[NegationOf(var)]) { + if (PositiveVariable(b) == PositiveVariable(a)) continue; + CHECK_NE(PositiveVariable(b), var); + if (++work_done > work_limit) return work_done; + + const LinearExpression2 candidate{a, b, 1, 1}; + AddUpperBound(candidate, a_ub + best_upper_bounds_[b_index]); + } + } + return work_done; } -void PrecedenceRelations::PushConditionalRelation( - absl::Span enforcements, LinearExpression2 expr, - IntegerValue rhs) { +RootLevelLinear2Bounds::~RootLevelLinear2Bounds() { + if (!VLOG_IS_ON(1)) return; + std::vector> stats; + stats.push_back({"RootLevelLinear2Bounds/num_updates", num_updates_}); + shared_stats_->AddStats(stats); +} + +RelationStatus RootLevelLinear2Bounds::GetLevelZeroStatus( + LinearExpression2 expr, IntegerValue lb, IntegerValue ub) const { expr.SimpleCanonicalization(); if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) { - return; + return RelationStatus::IS_UNKNOWN; } + const IntegerValue known_ub = LevelZeroUpperBound(expr); + expr.Negate(); + const IntegerValue known_lb = -LevelZeroUpperBound(expr); + if (lb <= known_lb && ub >= known_ub) return RelationStatus::IS_TRUE; + if (lb > known_ub || ub < known_lb) return RelationStatus::IS_FALSE; + return RelationStatus::IS_UNKNOWN; +} + +IntegerValue RootLevelLinear2Bounds::GetUpperBoundNoTrail( + LinearExpression2Index index) const { + if (best_upper_bounds_.size() <= index) { + return kMaxIntegerValue; + } + return best_upper_bounds_[index]; +} + +std::vector> +RootLevelLinear2Bounds::GetSortedNonTrivialUpperBounds() const { + std::vector> result; + for (LinearExpression2Index index = LinearExpression2Index{0}; + index < best_upper_bounds_.size(); ++index) { + const IntegerValue ub = best_upper_bounds_[index]; + if (ub == kMaxIntegerValue) continue; + const LinearExpression2 expr = lin2_indices_->GetExpression(index); + if (ub < integer_trail_->LevelZeroUpperBound(expr)) { + result.push_back({expr, ub}); + } + } + std::sort(result.begin(), result.end()); + return result; +} + +std::vector> +RootLevelLinear2Bounds::GetAllBoundsContainingVariable( + IntegerVariable var) const { + std::vector> result; + for (const LinearExpression2Index index : + lin2_indices_->GetAllLinear2ContainingVariable(var)) { + const IntegerValue lb = -GetUpperBoundNoTrail(NegationOf(index)); + const IntegerValue ub = GetUpperBoundNoTrail(index); + const LinearExpression2 expr = lin2_indices_->GetExpression(index); + const IntegerValue trail_lb = integer_trail_->LevelZeroLowerBound(expr); + const IntegerValue trail_ub = integer_trail_->LevelZeroUpperBound(expr); + if (lb <= trail_lb && ub >= trail_ub) continue; + LinearExpression2 explicit_vars_expr = expr; + if (explicit_vars_expr.vars[0] == NegationOf(var)) { + explicit_vars_expr.vars[0] = NegationOf(explicit_vars_expr.vars[0]); + explicit_vars_expr.coeffs[0] = -explicit_vars_expr.coeffs[0]; + } + if (explicit_vars_expr.vars[1] == NegationOf(var)) { + explicit_vars_expr.vars[1] = NegationOf(explicit_vars_expr.vars[1]); + explicit_vars_expr.coeffs[1] = -explicit_vars_expr.coeffs[1]; + } + if (explicit_vars_expr.vars[1] == var) { + std::swap(explicit_vars_expr.vars[0], explicit_vars_expr.vars[1]); + std::swap(explicit_vars_expr.coeffs[0], explicit_vars_expr.coeffs[1]); + } + DCHECK(explicit_vars_expr.vars[0] == var); + result.push_back( + {explicit_vars_expr, std::max(lb, trail_lb), std::min(ub, trail_ub)}); + } + return result; +} + +// Return a list of (lb <= expr <= ub), with expr.vars = {var1, var2}, where +// at least one of the bounds is non-trivial and the potential other +// non-trivial bound is tight. +std::vector> +RootLevelLinear2Bounds::GetAllBoundsContainingVariables( + IntegerVariable var1, IntegerVariable var2) const { + std::vector> result; + for (const LinearExpression2Index index : + lin2_indices_->GetAllLinear2ContainingVariables(var1, var2)) { + const IntegerValue lb = -GetUpperBoundNoTrail(NegationOf(index)); + const IntegerValue ub = GetUpperBoundNoTrail(index); + const LinearExpression2 expr = lin2_indices_->GetExpression(index); + const IntegerValue trail_lb = integer_trail_->LevelZeroLowerBound(expr); + const IntegerValue trail_ub = integer_trail_->LevelZeroUpperBound(expr); + if (lb <= trail_lb && ub >= trail_ub) continue; + + LinearExpression2 explicit_vars_expr = expr; + if (explicit_vars_expr.vars[0] == NegationOf(var1) || + explicit_vars_expr.vars[0] == NegationOf(var2)) { + explicit_vars_expr.vars[0] = NegationOf(explicit_vars_expr.vars[0]); + explicit_vars_expr.coeffs[0] = -explicit_vars_expr.coeffs[0]; + } + if (explicit_vars_expr.vars[1] == NegationOf(var1) || + explicit_vars_expr.vars[1] == NegationOf(var2)) { + explicit_vars_expr.vars[1] = NegationOf(explicit_vars_expr.vars[1]); + explicit_vars_expr.coeffs[1] = -explicit_vars_expr.coeffs[1]; + } + if (explicit_vars_expr.vars[1] == var1) { + std::swap(explicit_vars_expr.vars[0], explicit_vars_expr.vars[1]); + std::swap(explicit_vars_expr.coeffs[0], explicit_vars_expr.coeffs[1]); + } + DCHECK(explicit_vars_expr.vars[0] == var1 && + explicit_vars_expr.vars[1] == var2); + result.push_back( + {explicit_vars_expr, std::max(lb, trail_lb), std::min(ub, trail_ub)}); + } + return result; +} + +absl::Span> +RootLevelLinear2Bounds::GetVariablesInSimpleRelation( + IntegerVariable var) const { + if (var >= coeff_one_var_lookup_.size()) return {}; + return coeff_one_var_lookup_[var]; +} + +EnforcedLinear2Bounds::~EnforcedLinear2Bounds() { + if (!VLOG_IS_ON(1)) return; + std::vector> stats; + stats.push_back({"EnforcedLinear2Bounds/num_conditional_relation_updates", + num_conditional_relation_updates_}); + shared_stats_->AddStats(stats); +} + +void EnforcedLinear2Bounds::PushConditionalRelation( + absl::Span enforcements, LinearExpression2Index lin2_index, + IntegerValue rhs) { // This must be currently true. if (DEBUG_MODE) { for (const Literal l : enforcements) { @@ -115,51 +344,49 @@ void PrecedenceRelations::PushConditionalRelation( } if (enforcements.empty() || trail_->CurrentDecisionLevel() == 0) { - AddUpperBound(expr, rhs); + root_level_bounds_->AddUpperBound(lin2_index, rhs); return; } - const IntegerValue gcd = expr.DivideByGcd(); - rhs = FloorRatio(rhs, gcd); + if (rhs >= root_level_bounds_->LevelZeroUpperBound(lin2_index)) return; + const LinearExpression2 expr = lin2_indices_->GetExpression(lin2_index); - // Ignore if no better than best_relations, otherwise increase it. - { - const auto [it, inserted] = best_relations_.insert({expr, rhs}); - if (!inserted) { - if (rhs >= it->second) return; // Ignore. - it->second = rhs; - } - } + linear2_watcher_->NotifyBoundChanged(expr); + ++num_conditional_relation_updates_; const int new_index = conditional_stack_.size(); - const auto [it, inserted] = conditional_relations_.insert({expr, new_index}); - if (inserted) { + if (conditional_relations_.size() <= lin2_index) { + conditional_relations_.resize(lin2_index.value() + 1, -1); + } + if (conditional_relations_[lin2_index] == -1) { + conditional_relations_[lin2_index] = new_index; CreateLevelEntryIfNeeded(); - conditional_stack_.emplace_back(/*prev_entry=*/-1, rhs, expr, enforcements); + conditional_stack_.emplace_back(/*prev_entry=*/-1, rhs, lin2_index, + enforcements); if (expr.coeffs[0] == 1 && expr.coeffs[1] == 1) { const int new_size = std::max(expr.vars[0].value(), expr.vars[1].value()) + 1; - if (new_size > conditional_after_.size()) { - conditional_after_.resize(new_size); + if (new_size > conditional_var_lookup_.size()) { + conditional_var_lookup_.resize(new_size); } - conditional_after_[expr.vars[0]].push_back(NegationOf(expr.vars[1])); - conditional_after_[expr.vars[1]].push_back(NegationOf(expr.vars[0])); + conditional_var_lookup_[expr.vars[0]].push_back( + {expr.vars[1], lin2_index}); + conditional_var_lookup_[expr.vars[1]].push_back( + {expr.vars[0], lin2_index}); } } else { - // We should only decrease because we ignored entry worse than the one in - // best_relations_. - const int prev_entry = it->second; - DCHECK_LT(rhs, conditional_stack_[prev_entry].rhs); + const int prev_entry = conditional_relations_[lin2_index]; + if (rhs >= conditional_stack_[prev_entry].rhs) return; // Update. - it->second = new_index; + conditional_relations_[lin2_index] = new_index; CreateLevelEntryIfNeeded(); - conditional_stack_.emplace_back(prev_entry, rhs, expr, enforcements); + conditional_stack_.emplace_back(prev_entry, rhs, lin2_index, enforcements); } } -void PrecedenceRelations::CreateLevelEntryIfNeeded() { +void EnforcedLinear2Bounds::CreateLevelEntryIfNeeded() { const int current = trail_->CurrentDecisionLevel(); if (!level_to_stack_size_.empty() && level_to_stack_size_.back().first == current) @@ -168,7 +395,7 @@ void PrecedenceRelations::CreateLevelEntryIfNeeded() { } // We only pop what is needed. -void PrecedenceRelations::SetLevel(int level) { +void EnforcedLinear2Bounds::SetLevel(int level) { while (!level_to_stack_size_.empty() && level_to_stack_size_.back().first > level) { const int target = level_to_stack_size_.back().second; @@ -177,18 +404,17 @@ void PrecedenceRelations::SetLevel(int level) { const ConditionalEntry& back = conditional_stack_.back(); if (back.prev_entry != -1) { conditional_relations_[back.key] = back.prev_entry; - UpdateBestRelation(back.key, conditional_stack_[back.prev_entry].rhs); } else { - UpdateBestRelation(back.key, kMaxIntegerValue); - conditional_relations_.erase(back.key); + conditional_relations_[back.key] = -1; + const LinearExpression2 expr = lin2_indices_->GetExpression(back.key); - if (back.key.coeffs[0] == 1 && back.key.coeffs[1] == 1) { - DCHECK_EQ(conditional_after_[back.key.vars[0]].back(), - NegationOf(back.key.vars[1])); - DCHECK_EQ(conditional_after_[back.key.vars[1]].back(), - NegationOf(back.key.vars[0])); - conditional_after_[back.key.vars[0]].pop_back(); - conditional_after_[back.key.vars[1]].pop_back(); + if (expr.coeffs[0] == 1 && expr.coeffs[1] == 1) { + DCHECK_EQ(conditional_var_lookup_[expr.vars[0]].back().first, + expr.vars[1]); + DCHECK_EQ(conditional_var_lookup_[expr.vars[1]].back().first, + expr.vars[0]); + conditional_var_lookup_[expr.vars[0]].pop_back(); + conditional_var_lookup_[expr.vars[1]].pop_back(); } } conditional_stack_.pop_back(); @@ -197,116 +423,95 @@ void PrecedenceRelations::SetLevel(int level) { } } -IntegerValue PrecedenceRelations::LevelZeroUpperBound( - LinearExpression2 expr) const { - expr.SimpleCanonicalization(); - const IntegerValue gcd = expr.DivideByGcd(); - const auto it = root_relations_.find(expr); - if (it != root_relations_.end()) { - return CapProdI(it->second, gcd); - } - return kMaxIntegerValue; -} - -void PrecedenceRelations::AddReasonForUpperBoundLowerThan( - LinearExpression2 expr, IntegerValue ub, +void EnforcedLinear2Bounds::AddReasonForUpperBoundLowerThan( + LinearExpression2Index index, IntegerValue ub, std::vector* literal_reason, std::vector* /*unused*/) const { - expr.SimpleCanonicalization(); - if (ub >= LevelZeroUpperBound(expr)) return; - const IntegerValue gcd = expr.DivideByGcd(); - const auto it = conditional_relations_.find(expr); - DCHECK(it != conditional_relations_.end()); + if (ub >= root_level_bounds_->LevelZeroUpperBound(index)) return; + DCHECK_LT(index, conditional_relations_.size()); + const int entry_index = conditional_relations_[index]; + DCHECK_NE(entry_index, -1); - const ConditionalEntry& entry = conditional_stack_[it->second]; + const ConditionalEntry& entry = conditional_stack_[entry_index]; if (DEBUG_MODE) { for (const Literal l : entry.enforcements) { CHECK(trail_->Assignment().LiteralIsTrue(l)); } } - DCHECK_EQ(CapProdI(gcd, entry.rhs), UpperBound(expr)); - DCHECK_LE(CapProdI(gcd, entry.rhs), ub); + DCHECK_LE(entry.rhs, ub); for (const Literal l : entry.enforcements) { literal_reason->push_back(l.Negated()); } } -IntegerValue PrecedenceRelations::UpperBound(LinearExpression2 expr) const { - expr.SimpleCanonicalization(); - const IntegerValue gcd = expr.DivideByGcd(); - const auto it = best_relations_.find(expr); - if (it != best_relations_.end()) { - return CapProdI(gcd, it->second); +IntegerValue EnforcedLinear2Bounds::GetUpperBoundFromEnforced( + LinearExpression2Index index) const { + if (index >= conditional_relations_.size()) { + return kMaxIntegerValue; + } + const int entry_index = conditional_relations_[index]; + if (entry_index == -1) { + return kMaxIntegerValue; + } else { + const ConditionalEntry& entry = conditional_stack_[entry_index]; + if (DEBUG_MODE) { + for (const Literal l : entry.enforcements) { + CHECK(trail_->Assignment().LiteralIsTrue(l)); + } + } + DCHECK_LT(entry.rhs, root_level_bounds_->LevelZeroUpperBound(index)); + return entry.rhs; } - DCHECK(!root_relations_.contains(expr)); - DCHECK(!conditional_relations_.contains(expr)); - return kMaxIntegerValue; } -void PrecedenceRelations::Build() { - if (is_built_) return; - is_built_ = true; +bool TransitivePrecedencesEvaluator::Build() { + const int64_t in_timestamp = root_level_bounds_->num_updates(); + if (in_timestamp <= build_timestamp_) return true; + build_timestamp_ = in_timestamp; - const int num_nodes = graph_.num_nodes(); - util_intops::StrongVector> - before(num_nodes); + const std::vector> + root_relations_sorted = + root_level_bounds_->GetSortedNonTrivialUpperBounds(); + int max_node = 0; + for (const auto [expr, _] : root_relations_sorted) { + max_node = std::max(max_node, PositiveVariable(expr.vars[0]).value()); + max_node = std::max(max_node, PositiveVariable(expr.vars[1]).value()); + } + max_node++; // For negation. + + // Is it a DAG? + // Get a topological order of the DAG formed by all the arcs that are present. + // + // TODO(user): This can fail if we don't have a DAG. But in the end we + // don't really need a topological order, just something that is close to + // one so that we can compute an approximated transitive closure in O(n^2) and + // not O(n^3). We could use an heuristic instead, like as long as there is + // node with an in-degree of zero, add them to the order and update the + // in-degree of the other (by removing outgoing arcs). If there is a cycle + // (i.e. no node with no incoming arc), pick one with a small in-degree + // randomly. + DenseIntStableTopologicalSorter sorter(max_node); + for (const auto [expr, negated_offset] : root_relations_sorted) { + // Coefficients should be positive. + DCHECK_GT(expr.coeffs[0], 0); + DCHECK_GT(expr.coeffs[1], 0); - // We will construct a graph with the current relation from all_relations_. - // And use this to compute the "closure". - CHECK(arc_offsets_.empty()); - graph_.ReserveArcs(2 * root_relations_.size()); - std::vector> root_relations_sorted( - root_relations_.begin(), root_relations_.end()); - std::sort(root_relations_sorted.begin(), root_relations_sorted.end()); - for (const auto [var_pair, negated_offset] : root_relations_sorted) { // TODO(user): Support negative offset? // // Note that if we only have >= 0 ones, if we do have a cycle, we could - // make sure all variales are the same, and otherwise, we have a DAG or a + // make sure all variables are the same, and otherwise, we have a DAG or a // conflict. const IntegerValue offset = -negated_offset; if (offset < 0) continue; - if (var_pair.coeffs[0] != 1 || var_pair.coeffs[1] != 1) { + if (expr.coeffs[0] != 1 || expr.coeffs[1] != 1) { // TODO(user): Support non-1 coefficients. continue; } // We have two arcs. - { - const IntegerVariable tail = var_pair.vars[0]; - const IntegerVariable head = NegationOf(var_pair.vars[1]); - graph_.AddArc(tail.value(), head.value()); - arc_offsets_.push_back(offset); - CHECK_LT(var_pair.vars[1], before.size()); - before[head].push_back(tail); - } - { - const IntegerVariable tail = var_pair.vars[1]; - const IntegerVariable head = NegationOf(var_pair.vars[0]); - graph_.AddArc(tail.value(), head.value()); - arc_offsets_.push_back(offset); - CHECK_LT(var_pair.vars[1], before.size()); - before[head].push_back(tail); - } - } - - std::vector permutation; - graph_.Build(&permutation); - util::Permute(permutation, &arc_offsets_); - - // Is it a DAG? - // Get a topological order of the DAG formed by all the arcs that are present. - // - // TODO(user): This can fail if we don't have a DAG. We could just skip Bad - // edges instead, and have a sub-DAG as an heuristic. Or analyze the arc - // weight and make sure cycle are not an issue. We can also start with arcs - // with strictly positive weight. - // - // TODO(user): Only explore the sub-graph reachable from "vars". - DenseIntStableTopologicalSorter sorter(num_nodes); - for (int arc = 0; arc < graph_.num_arcs(); ++arc) { - sorter.AddEdge(graph_.Tail(arc), graph_.Head(arc)); + sorter.AddEdge(expr.vars[0].value(), NegationOf(expr.vars[1]).value()); + sorter.AddEdge(expr.vars[1].value(), NegationOf(expr.vars[0]).value()); } int next; bool graph_has_cycle = false; @@ -315,58 +520,43 @@ void PrecedenceRelations::Build() { topological_order_.push_back(IntegerVariable(next)); if (graph_has_cycle) { is_dag_ = false; - return; + return true; } } is_dag_ = !graph_has_cycle; - // Lets build full precedences if we don't have too many of them. - // TODO(user): Also do that if we don't have a DAG? - if (!is_dag_) return; - - int work = 0; - const int kWorkLimit = 1e6; - for (const IntegerVariable tail_var : topological_order_) { - if (++work > kWorkLimit) break; - for (const int arc : graph_.OutgoingArcs(tail_var.value())) { - DCHECK_EQ(tail_var.value(), graph_.Tail(arc)); - const IntegerVariable head_var = IntegerVariable(graph_.Head(arc)); - const IntegerValue arc_offset = arc_offsets_[arc]; - - if (++work > kWorkLimit) break; - if (AddInternal(LinearExpression2::Difference(tail_var, head_var), - -arc_offset)) { - before[head_var].push_back(tail_var); - } - - for (const IntegerVariable before_var : before[tail_var]) { - if (++work > kWorkLimit) break; - LinearExpression2 expr_for_key(before_var, tail_var, 1, -1); - expr_for_key.SimpleCanonicalization(); - const IntegerValue offset = - -root_relations_.at(expr_for_key) + arc_offset; - if (AddInternal(LinearExpression2::Difference(before_var, head_var), - -offset)) { - before[head_var].push_back(before_var); - } - } + // Lets get the transitive closure if it is cheap. This is also a way not to + // add too many relations per call. + int total_work = 0; + const int kWorkLimit = params_->transitive_precedences_work_limit(); + if (kWorkLimit > 0) { + for (const IntegerVariable var : topological_order_) { + const int work = root_level_bounds_->AugmentSimpleRelations( + var, kWorkLimit - total_work); + total_work += work; + if (total_work >= kWorkLimit) break; } } - VLOG(2) << "Full precedences. Work=" << work - << " Relations=" << root_relations_.size(); + build_timestamp_ = root_level_bounds_->num_updates(); + VLOG(2) << "Full precedences. Work=" << total_work + << " Relations=" << root_relations_sorted.size() + << " num_added=" << build_timestamp_ - in_timestamp; + return true; } -void PrecedenceRelations::ComputeFullPrecedences( +// TODO(user): There is probably little need for that function. For small +// problem, we already augment root_level_bounds_ will all the relation obtained +// by transitive closure, so this algo only need to look at direct dependency in +// root_level_bounds_->GetVariablesInSimpleRelation(). And for large graph, we +// probably do not want this. +void TransitivePrecedencesEvaluator::ComputeFullPrecedences( absl::Span vars, std::vector* output) { output->clear(); - if (!is_built_) Build(); + Build(); // Will do nothing if we are up to date. if (!is_dag_) return; - VLOG(2) << "num_nodes: " << graph_.num_nodes() - << " num_arcs: " << graph_.num_arcs() << " is_dag: " << is_dag_; - // Compute all precedences. // We loop over the node in topological order, and we maintain for all // variable we encounter, the list of "to_consider" variables that are before. @@ -394,10 +584,12 @@ void PrecedenceRelations::ComputeFullPrecedences( } } - for (const int arc : graph_.OutgoingArcs(tail_var.value())) { - CHECK_EQ(tail_var.value(), graph_.Tail(arc)); - const IntegerVariable head_var = IntegerVariable(graph_.Head(arc)); - const IntegerValue arc_offset = arc_offsets_[arc]; + // We look for tail_var + offset <= head_var. + for (const auto [neg_head_var, index] : + root_level_bounds_->GetVariablesInSimpleRelation(tail_var)) { + const IntegerVariable head_var = NegationOf(neg_head_var); + const IntegerValue arc_offset = + -root_level_bounds_->GetUpperBoundNoTrail(index); // No need to create an empty entry in this case. if (tail_map.empty() && !to_consider.contains(tail_var)) continue; @@ -451,12 +643,10 @@ void PrecedenceRelations::ComputeFullPrecedences( } } -void PrecedenceRelations::CollectPrecedences( +void EnforcedLinear2Bounds::CollectPrecedences( absl::Span vars, std::vector* output) { - // +1 for the negation. - const int needed_size = - std::max(after_.size(), conditional_after_.size()) + 1; + const int needed_size = integer_trail_->NumIntegerVariables().value(); var_to_degree_.resize(needed_size); var_to_last_index_.resize(needed_size); var_with_positive_degree_.resize(needed_size); @@ -468,29 +658,31 @@ void PrecedenceRelations::CollectPrecedences( IntegerVariable* var_with_positive_degree = var_with_positive_degree_.data(); int* var_to_degree = var_to_degree_.data(); int* var_to_last_index = var_to_last_index_.data(); - const auto process = [&](int index, absl::Span v) { - for (const IntegerVariable after : v) { - DCHECK_LT(after, needed_size); - if (var_to_degree[after.value()] == 0) { - var_with_positive_degree[num_relevants++] = after; - } else { - // We do not want duplicates. - if (var_to_last_index[after.value()] == index) continue; - } + const auto process = + [&](int var_index, + absl::Span> + v) { + for (const auto [other, lin2_index] : v) { + const IntegerVariable after = NegationOf(other); + DCHECK_LT(after, needed_size); + if (var_to_degree[after.value()] == 0) { + var_with_positive_degree[num_relevants++] = after; + } else { + // We do not want duplicates. + if (var_to_last_index[after.value()] == var_index) continue; + } - tmp_precedences_.push_back({after, index}); - var_to_degree[after.value()]++; - var_to_last_index[after.value()] = index; - } - }; + tmp_precedences_.push_back({after, var_index, lin2_index}); + var_to_degree[after.value()]++; + var_to_last_index[after.value()] = var_index; + } + }; - for (int index = 0; index < vars.size(); ++index) { - const IntegerVariable var = vars[index]; - if (var < after_.size()) { - process(index, after_[var]); - } - if (var < conditional_after_.size()) { - process(index, conditional_after_[var]); + for (int var_index = 0; var_index < vars.size(); ++var_index) { + const IntegerVariable var = vars[var_index]; + process(var_index, root_level_bounds_->GetVariablesInSimpleRelation(var)); + if (var < conditional_var_lookup_.size()) { + process(var_index, conditional_var_lookup_[var]); } } @@ -498,8 +690,9 @@ void PrecedenceRelations::CollectPrecedences( // For that we transform var_to_degree to point to the first position of // each lbvar in the output vector. int start = 0; - for (int i = 0; i < num_relevants; ++i) { - const IntegerVariable var = var_with_positive_degree[i]; + const absl::Span relevant_variables = + absl::MakeSpan(var_with_positive_degree, num_relevants); + for (const IntegerVariable var : relevant_variables) { const int degree = var_to_degree[var.value()]; if (degree > 1) { var_to_degree[var.value()] = start; @@ -520,8 +713,7 @@ void PrecedenceRelations::CollectPrecedences( // Cleanup var_to_degree, note that we don't need to clean // var_to_last_index_. - for (int i = 0; i < num_relevants; ++i) { - const IntegerVariable var = var_with_positive_degree[i]; + for (const IntegerVariable var : relevant_variables) { var_to_degree[var.value()] = 0; } } @@ -1149,37 +1341,16 @@ bool PrecedencesPropagator::BellmanFordTarjan(Trail* trail) { return true; } -void BinaryRelationRepository::Add(Literal lit, LinearTerm a, LinearTerm b, +void BinaryRelationRepository::Add(Literal lit, LinearExpression2 expr, IntegerValue lhs, IntegerValue rhs) { - if (lit.Index() != kNoLiteralIndex) { - num_enforced_relations_++; - DCHECK(a.coeff == 0 || a.var != kNoIntegerVariable); - DCHECK(b.coeff == 0 || b.var != kNoIntegerVariable); - } else { - DCHECK_NE(a.coeff, 0); - DCHECK_NE(b.coeff, 0); - DCHECK_NE(a.var, kNoIntegerVariable); - DCHECK_NE(b.var, kNoIntegerVariable); - } + expr.MakeVariablesPositive(); + CHECK_NE(lit.Index(), kNoLiteralIndex); + num_enforced_relations_++; + DCHECK(expr.coeffs[0] == 0 || expr.vars[0] != kNoIntegerVariable); + DCHECK(expr.coeffs[1] == 0 || expr.vars[1] != kNoIntegerVariable); - Relation r; - r.enforcement = lit; - r.a = a; - r.b = b; - r.lhs = lhs; - r.rhs = rhs; - - // We shall only consider positive variable here. - if (r.a.var != kNoIntegerVariable && !VariableIsPositive(r.a.var)) { - r.a.var = NegationOf(r.a.var); - r.a.coeff = -r.a.coeff; - } - if (r.b.var != kNoIntegerVariable && !VariableIsPositive(r.b.var)) { - r.b.var = NegationOf(r.b.var); - r.b.coeff = -r.b.coeff; - } - - relations_.push_back(std::move(r)); + relations_.push_back( + {.enforcement = lit, .expr = expr, .lhs = lhs, .rhs = rhs}); } void BinaryRelationRepository::AddPartialRelation(Literal lit, @@ -1188,37 +1359,25 @@ void BinaryRelationRepository::AddPartialRelation(Literal lit, DCHECK_NE(a, kNoIntegerVariable); DCHECK_NE(b, kNoIntegerVariable); DCHECK_NE(a, b); - Add(lit, LinearTerm(a, 1), LinearTerm(b, 1), 0, 0); + Add(lit, LinearExpression2(a, b, 1, 1), 0, 0); } void BinaryRelationRepository::Build() { DCHECK(!is_built_); is_built_ = true; std::vector> literal_key_values; - std::vector> var_key_values; const int num_relations = relations_.size(); literal_key_values.reserve(num_enforced_relations_); - var_key_values.reserve(num_relations - num_enforced_relations_); for (int i = 0; i < num_relations; ++i) { const Relation& r = relations_[i]; - if (r.enforcement.Index() == kNoLiteralIndex) { - var_key_values.emplace_back(r.a.var, i); - var_key_values.emplace_back(r.b.var, i); - std::pair key(r.a.var, r.b.var); - if (relations_[i].a.var > relations_[i].b.var) { - std::swap(key.first, key.second); - } - var_pair_to_relations_[key].push_back(i); - } else { - literal_key_values.emplace_back(r.enforcement.Index(), i); - } + literal_key_values.emplace_back(r.enforcement.Index(), i); } lit_to_relations_.ResetFromPairs(literal_key_values); - var_to_relations_.ResetFromPairs(var_key_values); } bool BinaryRelationRepository::PropagateLocalBounds( - const IntegerTrail& integer_trail, Literal lit, + const IntegerTrail& integer_trail, + const RootLevelLinear2Bounds& root_level_bounds, Literal lit, const absl::flat_hash_map& input, absl::flat_hash_map* output) const { DCHECK_NE(lit.Index(), kNoLiteralIndex); @@ -1241,24 +1400,28 @@ bool BinaryRelationRepository::PropagateLocalBounds( auto update_upper_bound_by_var = [&](IntegerVariable var, IntegerValue ub) { update_lower_bound_by_var(NegationOf(var), -ub); }; - auto update_var_bounds = [&](const LinearTerm& a, const LinearTerm& b, - IntegerValue lhs, IntegerValue rhs) { - if (a.coeff == 0) return; + auto update_var_bounds = [&](const LinearExpression2& expr, IntegerValue lhs, + IntegerValue rhs) { + if (expr.coeffs[0] == 0) return; // lb(b.y) <= b.y <= ub(b.y) and lhs <= a.x + b.y <= rhs imply // ceil((lhs - ub(b.y)) / a) <= x <= floor((rhs - lb(b.y)) / a) - if (b.coeff != 0) { - lhs = lhs - b.coeff * get_upper_bound(b.var); - rhs = rhs - b.coeff * get_lower_bound(b.var); + if (expr.coeffs[1] != 0) { + lhs = lhs - expr.coeffs[1] * get_upper_bound(expr.vars[1]); + rhs = rhs - expr.coeffs[1] * get_lower_bound(expr.vars[1]); } - update_lower_bound_by_var(a.var, MathUtil::CeilOfRatio(lhs, a.coeff)); - update_upper_bound_by_var(a.var, MathUtil::FloorOfRatio(rhs, a.coeff)); + update_lower_bound_by_var(expr.vars[0], + MathUtil::CeilOfRatio(lhs, expr.coeffs[0])); + update_upper_bound_by_var(expr.vars[0], + MathUtil::FloorOfRatio(rhs, expr.coeffs[0])); }; auto update_var_bounds_from_relation = [&](Relation r) { - r.a.MakeCoeffPositive(); - r.b.MakeCoeffPositive(); - update_var_bounds(r.a, r.b, r.lhs, r.rhs); - update_var_bounds(r.b, r.a, r.lhs, r.rhs); + r.expr.SimpleCanonicalization(); + + update_var_bounds(r.expr, r.lhs, r.rhs); + std::swap(r.expr.vars[0], r.expr.vars[1]); + std::swap(r.expr.coeffs[0], r.expr.coeffs[1]); + update_var_bounds(r.expr, r.lhs, r.rhs); }; if (lit.Index() < lit_to_relations_.size()) { for (const int relation_index : lit_to_relations_[lit]) { @@ -1266,9 +1429,10 @@ bool BinaryRelationRepository::PropagateLocalBounds( } } for (const auto& [var, _] : input) { - if (var >= var_to_relations_.size()) continue; - for (const int relation_index : var_to_relations_[var]) { - update_var_bounds_from_relation(relations_[relation_index]); + for (const auto& [expr, lb, ub] : + root_level_bounds.GetAllBoundsContainingVariable(var)) { + update_var_bounds_from_relation( + Relation{Literal(kNoLiteralIndex), expr, lb, ub}); } } @@ -1291,17 +1455,22 @@ bool GreaterThanAtLeastOneOfDetector::AddRelationFromIndices( const IntegerValue var_lb = integer_trail->LevelZeroLowerBound(var); for (const int index : indices) { Relation r = repository_.relation(index); - if (r.a.var != PositiveVariable(var)) std::swap(r.a, r.b); - CHECK_EQ(r.a.var, PositiveVariable(var)); + if (r.expr.vars[0] != PositiveVariable(var)) { + std::swap(r.expr.vars[0], r.expr.vars[1]); + std::swap(r.expr.coeffs[0], r.expr.coeffs[1]); + } + CHECK_EQ(r.expr.vars[0], PositiveVariable(var)); - if ((r.a.coeff == 1) == VariableIsPositive(var)) { + if ((r.expr.coeffs[0] == 1) == VariableIsPositive(var)) { // a + b >= lhs if (r.lhs <= kMinIntegerValue) continue; - exprs.push_back(AffineExpression(r.b.var, -r.b.coeff, r.lhs)); + exprs.push_back( + AffineExpression(r.expr.vars[1], -r.expr.coeffs[1], r.lhs)); } else { // -a + b <= rhs. if (r.rhs >= kMaxIntegerValue) continue; - exprs.push_back(AffineExpression(r.b.var, r.b.coeff, -r.rhs)); + exprs.push_back( + AffineExpression(r.expr.vars[1], r.expr.coeffs[1], -r.rhs)); } // Ignore this entry if it is always true. @@ -1349,11 +1518,13 @@ int GreaterThanAtLeastOneOfDetector:: for (const int index : repository_.IndicesOfRelationsEnforcedBy(l.Index())) { const Relation& r = repository_.relation(index); - if (r.a.var != kNoIntegerVariable && IntTypeAbs(r.a.coeff) == 1) { - infos.push_back({r.a.var, index}); + if (r.expr.vars[0] != kNoIntegerVariable && + IntTypeAbs(r.expr.coeffs[0]) == 1) { + infos.push_back({r.expr.vars[0], index}); } - if (r.b.var != kNoIntegerVariable && IntTypeAbs(r.b.coeff) == 1) { - infos.push_back({r.b.var, index}); + if (r.expr.vars[1] != kNoIntegerVariable && + IntTypeAbs(r.expr.coeffs[1]) == 1) { + infos.push_back({r.expr.vars[1], index}); } } } @@ -1403,17 +1574,19 @@ int GreaterThanAtLeastOneOfDetector:: for (int index = 0; index < repository_.size(); ++index) { const Relation& r = repository_.relation(index); if (r.enforcement.Index() == kNoLiteralIndex) continue; - if (r.a.var != kNoIntegerVariable && IntTypeAbs(r.a.coeff) == 1) { - if (r.a.var >= var_to_relations.size()) { - var_to_relations.resize(r.a.var + 1); + if (r.expr.vars[0] != kNoIntegerVariable && + IntTypeAbs(r.expr.coeffs[0]) == 1) { + if (r.expr.vars[0] >= var_to_relations.size()) { + var_to_relations.resize(r.expr.vars[0] + 1); } - var_to_relations[r.a.var].push_back(index); + var_to_relations[r.expr.vars[0]].push_back(index); } - if (r.b.var != kNoIntegerVariable && IntTypeAbs(r.b.coeff) == 1) { - if (r.b.var >= var_to_relations.size()) { - var_to_relations.resize(r.b.var + 1); + if (r.expr.vars[1] != kNoIntegerVariable && + IntTypeAbs(r.expr.coeffs[1]) == 1) { + if (r.expr.vars[1] >= var_to_relations.size()) { + var_to_relations.resize(r.expr.vars[1] + 1); } - var_to_relations[r.b.var].push_back(index); + var_to_relations[r.expr.vars[1]].push_back(index); } } @@ -1531,11 +1704,9 @@ int GreaterThanAtLeastOneOfDetector::AddGreaterThanAtLeastOneOfConstraints( return num_added_constraints; } -BinaryRelationsMaps::BinaryRelationsMaps(Model* model) - : integer_trail_(model->GetOrCreate()), - integer_encoder_(model->GetOrCreate()), - watcher_(model->GetOrCreate()), - shared_stats_(model->GetOrCreate()) { +ReifiedLinear2Bounds::ReifiedLinear2Bounds(Model* model) + : integer_encoder_(model->GetOrCreate()), + best_root_level_bounds_(model->GetOrCreate()) { int index = 0; model->GetOrCreate()->callbacks.push_back( [index = index, trail = model->GetOrCreate(), this]() mutable { @@ -1552,11 +1723,11 @@ BinaryRelationsMaps::BinaryRelationsMaps(Model* model) // Linear scan. for (const auto [l, expr, ub] : all_reified_relations_) { if (relevant_true_literals.contains(l)) { - AddRelationBounds(expr, kMinIntegerValue, ub); + best_root_level_bounds_->Add(expr, kMinIntegerValue, ub); VLOG(2) << "New fixed precedence: " << expr << " <= " << ub << " (was reified by " << l << ")"; } else if (relevant_true_literals.contains(l.Negated())) { - AddRelationBounds(expr, ub + 1, kMaxIntegerValue); + best_root_level_bounds_->Add(expr, ub + 1, kMaxIntegerValue); VLOG(2) << "New fixed precedence: " << expr << " > " << ub << " (was reified by not(" << l << "))"; } @@ -1565,112 +1736,26 @@ BinaryRelationsMaps::BinaryRelationsMaps(Model* model) }); } -BinaryRelationsMaps::~BinaryRelationsMaps() { +Linear2BoundsFromLinear3::~Linear2BoundsFromLinear3() { if (!VLOG_IS_ON(1)) return; std::vector> stats; - stats.push_back({"BinaryRelationsMaps/num_relations", num_updates_}); stats.push_back( - {"BinaryRelationsMaps/num_affine_updates", num_affine_updates_}); + {"Linear2BoundsFromLinear3/num_affine_updates", num_affine_updates_}); shared_stats_->AddStats(stats); } -IntegerValue BinaryRelationsMaps::GetImpliedUpperBound( - const LinearExpression2& expr) const { - DCHECK_GE(expr.coeffs[0], 0); - DCHECK_GE(expr.coeffs[1], 0); - IntegerValue implied_ub = 0; - for (const int i : {0, 1}) { - if (expr.coeffs[i] > 0) { - implied_ub += expr.coeffs[i] * integer_trail_->UpperBound(expr.vars[i]); - } - } - return implied_ub; -} - -std::pair -BinaryRelationsMaps::GetImpliedLevelZeroBounds( - const LinearExpression2& expr) const { - // Compute the implied bounds on the expression. - IntegerValue implied_lb = 0; - IntegerValue implied_ub = 0; - if (expr.coeffs[0] != 0) { - CHECK_GE(expr.vars[0], 0); - implied_lb += - expr.coeffs[0] * integer_trail_->LevelZeroLowerBound(expr.vars[0]); - implied_ub += - expr.coeffs[0] * integer_trail_->LevelZeroUpperBound(expr.vars[0]); - } - if (expr.coeffs[1] != 0) { - CHECK_GE(expr.vars[1], 0); - implied_lb += - expr.coeffs[1] * integer_trail_->LevelZeroLowerBound(expr.vars[1]); - implied_ub += - expr.coeffs[1] * integer_trail_->LevelZeroUpperBound(expr.vars[1]); - } - - return {implied_lb, implied_ub}; -} - -void BinaryRelationsMaps::AddRelationBounds(LinearExpression2 expr, - IntegerValue lb, IntegerValue ub) { - expr.CanonicalizeAndUpdateBounds(lb, ub); - const auto [implied_lb, implied_ub] = GetImpliedLevelZeroBounds(expr); - lb = std::max(lb, implied_lb); - ub = std::min(ub, implied_ub); - - if (lb > ub) return; // unsat ?? - if (lb == implied_lb && ub == implied_ub) return; // trivially true. - - if (best_root_level_bounds_.Add(expr, lb, ub)) { - // TODO(user): Also push them to a global shared repository after - // remapping IntegerVariable to proto indices. - ++num_updates_; - } -} - -RelationStatus BinaryRelationsMaps::GetLevelZeroStatus(LinearExpression2 expr, - IntegerValue lb, - IntegerValue ub) const { - expr.CanonicalizeAndUpdateBounds(lb, ub); - const auto [implied_lb, implied_ub] = GetImpliedLevelZeroBounds(expr); - lb = std::max(lb, implied_lb); - ub = std::min(ub, implied_ub); - - // Returns directly if the status can be derived from the implied bounds. - if (lb > ub) return RelationStatus::IS_FALSE; - if (lb == implied_lb && ub == implied_ub) return RelationStatus::IS_TRUE; - - // Relax as best_root_level_bounds_.GetStatus() might have older bounds. - if (lb == implied_lb) lb = kMinIntegerValue; - if (ub == implied_ub) ub = kMaxIntegerValue; - - return best_root_level_bounds_.GetStatus(expr, lb, ub); -} - -std::pair BinaryRelationsMaps::FromDifference( - const AffineExpression& a, const AffineExpression& b) const { - LinearExpression2 expr; - expr.vars[0] = a.var; - expr.vars[1] = b.var; - expr.coeffs[0] = a.coeff; - expr.coeffs[1] = -b.coeff; - IntegerValue lb = kMinIntegerValue; // unused. - IntegerValue ub = b.constant - a.constant; - expr.CanonicalizeAndUpdateBounds(lb, ub, /*allow_negation=*/false); - return {std::move(expr), ub}; -} - -RelationStatus BinaryRelationsMaps::GetLevelZeroPrecedenceStatus( +RelationStatus ReifiedLinear2Bounds::GetLevelZeroPrecedenceStatus( AffineExpression a, AffineExpression b) const { - const auto [expr, ub] = FromDifference(a, b); - return GetLevelZeroStatus(expr, kMinIntegerValue, ub); + const auto [expr, ub] = EncodeDifferenceLowerThan(a, b, 0); + return best_root_level_bounds_->GetLevelZeroStatus(expr, kMinIntegerValue, + ub); } -void BinaryRelationsMaps::AddReifiedPrecedenceIfNonTrivial(Literal l, - AffineExpression a, - AffineExpression b) { - const auto [expr, ub] = FromDifference(a, b); - const RelationStatus status = GetLevelZeroStatus(expr, kMinIntegerValue, ub); +void ReifiedLinear2Bounds::AddReifiedPrecedenceIfNonTrivial( + Literal l, AffineExpression a, AffineExpression b) { + const auto [expr, ub] = EncodeDifferenceLowerThan(a, b, 0); + const RelationStatus status = + best_root_level_bounds_->GetLevelZeroStatus(expr, kMinIntegerValue, ub); if (status != RelationStatus::IS_UNKNOWN) return; relation_to_lit_.insert({{expr, ub}, l}); @@ -1679,10 +1764,11 @@ void BinaryRelationsMaps::AddReifiedPrecedenceIfNonTrivial(Literal l, all_reified_relations_.push_back({l, expr, ub}); } -LiteralIndex BinaryRelationsMaps::GetReifiedPrecedence(AffineExpression a, - AffineExpression b) { - const auto [expr, ub] = FromDifference(a, b); - const RelationStatus status = GetLevelZeroStatus(expr, kMinIntegerValue, ub); +LiteralIndex ReifiedLinear2Bounds::GetReifiedPrecedence(AffineExpression a, + AffineExpression b) { + const auto [expr, ub] = EncodeDifferenceLowerThan(a, b, 0); + const RelationStatus status = + best_root_level_bounds_->GetLevelZeroStatus(expr, kMinIntegerValue, ub); if (status == RelationStatus::IS_TRUE) { return integer_encoder_->GetTrueLiteral().Index(); } @@ -1695,115 +1781,152 @@ LiteralIndex BinaryRelationsMaps::GetReifiedPrecedence(AffineExpression a, return it->second; } -bool BinaryRelationsMaps::AddAffineUpperBound(LinearExpression2 expr, - AffineExpression affine_ub) { - const IntegerValue new_ub = integer_trail_->UpperBound(affine_ub); - expr.SimpleCanonicalization(); +Linear2BoundsFromLinear3::Linear2BoundsFromLinear3(Model* model) + : integer_trail_(model->GetOrCreate()), + trail_(model->GetOrCreate()), + linear2_watcher_(model->GetOrCreate()), + watcher_(model->GetOrCreate()), + shared_stats_(model->GetOrCreate()), + root_level_bounds_(model->GetOrCreate()), + lin2_indices_(model->GetOrCreate()) {} - // Not better than trivial upper bound. - if (GetImpliedUpperBound(expr) <= new_ub) return false; - - // Not better than the root level upper bound. - if (best_root_level_bounds_.GetUpperBound(expr) <= new_ub) return false; - - const IntegerValue gcd = expr.DivideByGcd(); - - const auto it = best_affine_ub_.find(expr); - if (it != best_affine_ub_.end()) { - const auto [old_affine_ub, old_gcd] = it->second; - // We have an affine bound for this expr in the map. Can be exactly the - // same, a better one or a worse one. - if (old_affine_ub == affine_ub && old_gcd == gcd) { - // The affine bound is already in the map. - NotifyWatchingPropagators(); // The affine bound was updated. - return false; - } - const IntegerValue old_ub = - FloorRatio(integer_trail_->UpperBound(old_affine_ub), old_gcd); - if (old_ub <= new_ub) return false; // old bound is better. +// Note that for speed we do not compare to the trivial or root level bounds. +// +// It is okay to still store it in the hash-map, since at worst we will have no +// more entries than 3 * number_of_linear3_in_the_problem. +bool Linear2BoundsFromLinear3::AddAffineUpperBound( + LinearExpression2Index lin2_index, IntegerValue lin_expr_gcd, + AffineExpression affine_ub) { + // At level zero, just add it to root_level_bounds_. + if (trail_->CurrentDecisionLevel() == 0 || affine_ub.IsConstant()) { + root_level_bounds_->AddUpperBound( + lin2_index, FloorRatio(integer_trail_->LevelZeroUpperBound(affine_ub), + lin_expr_gcd)); + return false; // Not important. + } + + // We have gcd * canonical_expr <= affine_ub, + // so we do need to store a "divisor". + if (lin2_index >= best_affine_ub_.size()) { + best_affine_ub_.resize(lin2_index.value() + 1, {AffineExpression(), 0}); + } + auto& [old_affine_ub, old_divisor] = best_affine_ub_[lin2_index]; + if (old_divisor != 0) { + // We have an affine bound for this expr in the map. Can be exactly the + // same, a better one or a worse one. + // + // Note that we expect exactly the same most of the time as it should be + // rare to have many linear3 "competing" for the same linear2 bound. + if (old_affine_ub == affine_ub && old_divisor == lin_expr_gcd) { + linear2_watcher_->NotifyBoundChanged( + lin2_indices_->GetExpression(lin2_index)); + return false; + } + + const IntegerValue new_ub = + FloorRatioWithTest(integer_trail_->UpperBound(affine_ub), lin_expr_gcd); + const IntegerValue old_ub = FloorRatioWithTest( + integer_trail_->UpperBound(old_affine_ub), old_divisor); + if (old_ub <= new_ub) return false; // old bound is better. + + best_affine_ub_[lin2_index] = {affine_ub, lin_expr_gcd}; // Overwrite. + } else { + // Note that this should almost never happen (only once per lin2). + best_affine_ub_[lin2_index] = {affine_ub, lin_expr_gcd}; } - // We have gcd * canonical_expr <= affine_ub, so we do need to store a - // "divisor". ++num_affine_updates_; - best_affine_ub_[expr] = {affine_ub, gcd}; - NotifyWatchingPropagators(); + linear2_watcher_->NotifyBoundChanged( + lin2_indices_->GetExpression(lin2_index)); return true; } -void BinaryRelationsMaps::NotifyWatchingPropagators() const { - for (const int id : propagator_ids_) { - watcher_->CallOnNextPropagate(id); - } +IntegerValue Linear2BoundsFromLinear3::GetUpperBoundFromLinear3( + LinearExpression2Index lin2_index) const { + if (lin2_index >= best_affine_ub_.size()) return kMaxIntegerValue; + auto [affine, divisor] = best_affine_ub_[lin2_index]; + if (divisor == 0) return kMaxIntegerValue; + return FloorRatio(integer_trail_->UpperBound(affine), divisor); } -IntegerValue BinaryRelationsMaps::UpperBound(LinearExpression2 expr) const { - expr.SimpleCanonicalization(); - - const IntegerValue trivial_ub = GetImpliedUpperBound(expr); - const IntegerValue root_level_ub = - best_root_level_bounds_.GetUpperBound(expr); - const IntegerValue best_ub = std::min(root_level_ub, trivial_ub); - - const IntegerValue gcd = expr.DivideByGcd(); - const auto it = best_affine_ub_.find(expr); - if (it == best_affine_ub_.end()) { - return best_ub; - } else { - const auto [affine, divisor] = it->second; - const IntegerValue canonical_ub = - FloorRatio(integer_trail_->UpperBound(affine), divisor); - return std::min(best_ub, CapProdI(gcd, canonical_ub)); - } -} - -// TODO(user): If the trivial bound is better, its explanation is different... -void BinaryRelationsMaps::AddReasonForUpperBoundLowerThan( - LinearExpression2 expr, IntegerValue ub, +void Linear2BoundsFromLinear3::AddReasonForUpperBoundLowerThan( + LinearExpression2Index lin2_index, IntegerValue ub, std::vector* /*literal_reason*/, std::vector* integer_reason) const { - expr.SimpleCanonicalization(); + DCHECK_LE(GetUpperBoundFromLinear3(lin2_index), ub); + DCHECK_LT(lin2_index, best_affine_ub_.size()); - if (expr.coeffs[0] == 0 && expr.coeffs[1] == 0) return; // trivially zero - - // Starts by simple bounds. - if (best_root_level_bounds_.GetUpperBound(expr) <= ub) return; - - // Add explanation if it is a trivial bound. - const IntegerValue implied_ub = GetImpliedUpperBound(expr); - if (implied_ub <= ub) { - const IntegerValue slack = ub - implied_ub; - expr.Negate(); // AppendRelaxedLinearReason() explains a lower bound. - absl::Span vars = expr.non_zero_vars(); - absl::Span coeffs = expr.non_zero_coeffs(); - integer_trail_->AppendRelaxedLinearReason(slack, coeffs, vars, - integer_reason); - return; - } - - // None of the bound above are enough, try the affine one. Note that gcd * - // expr <= ub, is the same as asking why expr <= FloorRatio(ub, gcd). - const IntegerValue gcd = expr.DivideByGcd(); - const auto it = best_affine_ub_.find(expr); - if (it == best_affine_ub_.end()) return; - - // We want the reason for "expr <= ub", that is the reason for - // - "gcd * canonical_expr <= ub" - // - "canonical_expr <= FloorRatio(ub, gcd); - // - // knowing that canonical_expr <= affine_ub / divisor. - const auto [affine, divisor] = it->second; - integer_reason->push_back( - affine.LowerOrEqual(CapProdI(FloorRatio(ub, gcd) + 1, divisor) - 1)); + // We want the reason for "expr <= ub" + // knowing that expr <= affine / divisor. + const auto [affine, divisor] = best_affine_ub_[lin2_index]; + DCHECK_NE(divisor, 0); + integer_reason->push_back(affine.LowerOrEqual(CapProdI(ub + 1, divisor) - 1)); } -std::vector -BinaryRelationsMaps::GetAllExpressionsWithAffineBounds() const { - std::vector result; - for (const auto [expr, info] : best_affine_ub_) { - result.push_back(expr); +IntegerValue Linear2Bounds::UpperBound( + LinearExpression2Index lin2_index) const { + return std::min( + NonTrivialUpperBound(lin2_index), + integer_trail_->UpperBound(lin2_indices_->GetExpression(lin2_index))); +} + +IntegerValue Linear2Bounds::UpperBound(LinearExpression2 expr) const { + expr.SimpleCanonicalization(); + if (expr.coeffs[0] == 0) { + return integer_trail_->UpperBound(expr); } - return result; + DCHECK_NE(expr.coeffs[1], 0); + const IntegerValue gcd = expr.DivideByGcd(); + IntegerValue ub = integer_trail_->UpperBound(expr); + const LinearExpression2Index index = lin2_indices_->GetIndex(expr); + if (index != kNoLinearExpression2Index) { + ub = std::min(ub, root_level_bounds_->GetUpperBoundNoTrail(index)); + ub = std::min(ub, enforced_bounds_->GetUpperBoundFromEnforced(index)); + ub = std::min(ub, linear3_bounds_->GetUpperBoundFromLinear3(index)); + } + return CapProdI(gcd, ub); +} + +void Linear2Bounds::AddReasonForUpperBoundLowerThan( + LinearExpression2 expr, IntegerValue ub, + std::vector* literal_reason, + std::vector* integer_reason) const { + DCHECK_LE(UpperBound(expr), ub); + + // Explanation are by order of preference, with no reason needed first. + if (integer_trail_->LevelZeroUpperBound(expr) <= ub) { + return; + } + expr.SimpleCanonicalization(); + const IntegerValue gcd = expr.DivideByGcd(); + ub = FloorRatio(ub, gcd); + const LinearExpression2Index index = lin2_indices_->GetIndex(expr); + if (index != kNoLinearExpression2Index) { + // No reason. + if (root_level_bounds_->GetUpperBoundNoTrail(index) <= ub) { + return; + } + // This one is a single literal. + if (enforced_bounds_->GetUpperBoundFromEnforced(index) <= ub) { + return enforced_bounds_->AddReasonForUpperBoundLowerThan( + index, ub, literal_reason, integer_reason); + } + // This one is a single var upper bound. + if (linear3_bounds_->GetUpperBoundFromLinear3(index) <= ub) { + return linear3_bounds_->AddReasonForUpperBoundLowerThan( + index, ub, literal_reason, integer_reason); + } + } + + // Trivial linear2 bounds from its variables. + const IntegerValue implied_ub = integer_trail_->UpperBound(expr); + const IntegerValue slack = ub - implied_ub; + DCHECK_GE(slack, 0); + expr.Negate(); // AppendRelaxedLinearReason() explains a lower bound. + absl::Span vars = expr.non_zero_vars(); + absl::Span coeffs = expr.non_zero_coeffs(); + integer_trail_->AppendRelaxedLinearReason(slack, coeffs, vars, + integer_reason); } } // namespace sat diff --git a/ortools/sat/precedences.h b/ortools/sat/precedences.h index 4ded3dbc4d..59b0d0f064 100644 --- a/ortools/sat/precedences.h +++ b/ortools/sat/precedences.h @@ -18,15 +18,20 @@ #include #include #include +#include #include #include +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" #include "ortools/base/strong_vector.h" #include "ortools/graph/graph.h" +#include "ortools/sat/cp_model_mapping.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/model.h" @@ -41,58 +46,265 @@ namespace operations_research { namespace sat { +DEFINE_STRONG_INDEX_TYPE(LinearExpression2Index); +const LinearExpression2Index kNoLinearExpression2Index(-1); +inline LinearExpression2Index NegationOf(LinearExpression2Index i) { + return LinearExpression2Index(i.value() ^ 1); +} + +inline bool Linear2IsPositive(LinearExpression2Index i) { + return (i.value() & 1) == 0; +} + +inline LinearExpression2Index PositiveLinear2(LinearExpression2Index i) { + return LinearExpression2Index(i.value() & (~1)); +} + +// Class to hold a list of LinearExpression2 that have (potentially) non-trivial +// bounds. This class is overzealous, in the sense that if a linear2 is in the +// list, it does not necessarily mean that it has a non-trivial bound, but the +// converse is true: if a linear2 is not in the list, +// Linear2Bounds::GetUpperBound() will return a trivial bound. +class Linear2Indices { + public: + Linear2Indices() = default; + + // Returns a never-changing index for the given linear expression. + // The expression must already be canonicalized and divided by its GCD. + LinearExpression2Index AddOrGet(LinearExpression2 expr); + + // Returns a never-changing index for the given linear expression if it is + // potentially non-trivial, otherwise returns kNoLinearExpression2Index. The + // expression must already be canonicalized and divided by its GCD. + LinearExpression2Index GetIndex(LinearExpression2 expr) const; + + LinearExpression2 GetExpression(LinearExpression2Index index) const; + + // Return all positive linear2 expressions that have a potentially non-trivial + // bound. When calling this code it is often a good idea to check both the + // expression on the span and its negation. The order is fixed forever and + // this span can only grow by appending new expressions. + absl::Span GetStoredLinear2Indices() const { + return exprs_; + } + + // Return a list of all potentially non-trivial LinearExpression2Indexes + // containing a given variable. + absl::Span GetAllLinear2ContainingVariable( + IntegerVariable var) const; + + // Return a list of all potentially non-trivial LinearExpression2Indexes + // containing a given pair of variables. + absl::Span GetAllLinear2ContainingVariables( + IntegerVariable var1, IntegerVariable var2) const; + + private: + std::vector exprs_; + absl::flat_hash_map expr_to_index_; + + // Map to implement GetAllBoundsContainingVariable(). + absl::flat_hash_map> + var_to_bounds_; + // Map to implement GetAllBoundsContainingVariables(). + absl::flat_hash_map, + absl::InlinedVector> + var_pair_to_bounds_; +}; + +// Simple "watcher" class that will be notified if a linear2 bound changed. It +// can also be queried to see if LinearExpression2 involving a specific variable +// changed since last time. +class Linear2Watcher { + public: + explicit Linear2Watcher(Model* model) + : watcher_(model->GetOrCreate()) {} + + // This assumes `expr` is canonicalized and divided by its gcd. + void NotifyBoundChanged(LinearExpression2 expr); + + // Register a GenericLiteralWatcher() id so that propagation is called as + // soon as a bound on a linear2 changed. + void WatchAllLinearExpressions2(int id) { propagator_ids_.insert(id); } + + // Allow to know if some bounds changed since last query. + int64_t Timestamp() const { return timestamp_; } + int64_t VarTimestamp(IntegerVariable var); + + private: + GenericLiteralWatcher* watcher_; + + int64_t timestamp_ = 0; + util_intops::StrongVector var_timestamp_; + absl::btree_set propagator_ids_; +}; + +// This holds all the relation lhs <= linear2 <= rhs that are true at level +// zero. It is the source of truth across all the solver for such bounds. +class RootLevelLinear2Bounds { + public: + explicit RootLevelLinear2Bounds(Model* model) + : integer_trail_(model->GetOrCreate()), + linear2_watcher_(model->GetOrCreate()), + shared_stats_(model->GetOrCreate()), + lin2_indices_(model->GetOrCreate()), + cp_model_mapping_(model->GetOrCreate()), + shared_linear2_bounds_(model->Mutable()), + shared_linear2_bounds_id_( + shared_linear2_bounds_ == nullptr + ? 0 + : shared_linear2_bounds_->RegisterNewId(model->Name())) {} + + ~RootLevelLinear2Bounds(); + + // Add a relation lb <= expr <= ub. If expr is not a proper linear2 expression + // (e.g. 0*x + y, y + y, y - y) it will be ignored. + // Returns a pair saying whether the lower/upper bounds for this expr became + // more restricted than what was currently stored. + std::pair Add(LinearExpression2 expr, IntegerValue lb, + IntegerValue ub) { + if (integer_trail_->LevelZeroUpperBound(expr) <= ub && + integer_trail_->LevelZeroLowerBound(expr) >= lb) { + return {false, false}; + } + const bool negated = expr.CanonicalizeAndUpdateBounds(lb, ub); + if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) return {false, false}; + const LinearExpression2Index index = lin2_indices_->AddOrGet(expr); + bool ub_changed = AddUpperBound(index, ub); + bool lb_changed = AddUpperBound(NegationOf(index), -lb); + if (negated) { + std::swap(lb_changed, ub_changed); + } + return {lb_changed, ub_changed}; + } + + // Same as above, but only update the upper bound. + bool AddUpperBound(LinearExpression2 expr, IntegerValue ub) { + if (integer_trail_->LevelZeroUpperBound(expr) <= ub) return false; + expr.SimpleCanonicalization(); + if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) return false; + const IntegerValue gcd = expr.DivideByGcd(); + ub = FloorRatio(ub, gcd); + return AddUpperBound(lin2_indices_->AddOrGet(expr), ub); + } + + // All modifications go through this function. + bool AddUpperBound(LinearExpression2Index index, IntegerValue ub); + + IntegerValue LevelZeroUpperBound(LinearExpression2 expr) const { + expr.SimpleCanonicalization(); + if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) { + return integer_trail_->LevelZeroUpperBound(expr); + } + const IntegerValue gcd = expr.DivideByGcd(); + const LinearExpression2Index index = lin2_indices_->GetIndex(expr); + if (index == kNoLinearExpression2Index) { + return integer_trail_->LevelZeroUpperBound(expr); + } + return CapProdI(gcd, LevelZeroUpperBound(index)); + } + + IntegerValue LevelZeroUpperBound(LinearExpression2Index index) const { + const LinearExpression2 expr = lin2_indices_->GetExpression(index); + // TODO(user): Remove the expression from the root_level_relations_ if + // the zero-level bound got more restrictive. + return std::min(integer_trail_->LevelZeroUpperBound(expr), + GetUpperBoundNoTrail(index)); + } + + // Return a list of (expr <= ub) sorted by expr. They are guaranteed to be + // better than the trivial upper bound. + std::vector> + GetSortedNonTrivialUpperBounds() const; + + // Return a list of (lb <= expr <= ub), with expr.vars[0] = var, where at + // least one of the bounds is non-trivial and the potential other non-trivial + // bound is tight. + // + // As the class name indicates, all bounds are level zero ones. + std::vector> + GetAllBoundsContainingVariable(IntegerVariable var) const; + + // Return a list of (lb <= expr <= ub), with expr.vars = {var1, var2}, where + // at least one of the bounds is non-trivial and the potential other + // non-trivial bound is tight. + // + // As the class name indicates, all bounds are level zero ones. + std::vector> + GetAllBoundsContainingVariables(IntegerVariable var1, + IntegerVariable var2) const; + + // For a given variable `var`, return all variables `other` so that + // LinearExpression2(var, other, 1, 1) has a non trivial upper bound. + // Note that using negation one can also recover x + y >= lb and x - y <= ub. + absl::Span> + GetVariablesInSimpleRelation(IntegerVariable var) const; + + // For all pairs of relation 'a + var <= x' and 'neg(var) + b <= y' try to add + // 'a + b <= x + y' if that relation is better. + // + // This can be quadratic. Returns the amount of "work" done, and abort if + // we reach the limit. This uses GetVariablesInSimpleRelation(). + int AugmentSimpleRelations(IntegerVariable var, int work_limit); + + RelationStatus GetLevelZeroStatus(LinearExpression2 expr, IntegerValue lb, + IntegerValue ub) const; + + // Low-level function that returns the zero-level upper bound if it is + // non-trivial. Otherwise returns kMaxIntegerValue. This is a different + // behavior from LevelZeroUpperBound() that would return the implied + // zero-level bound from the trail for trivial ones. `expr` must be + // canonicalized and gcd-reduced. + IntegerValue GetUpperBoundNoTrail(LinearExpression2Index index) const; + + int64_t num_updates() const { return num_updates_; } + + private: + IntegerTrail* integer_trail_; + Linear2Watcher* linear2_watcher_; + SharedStatistics* shared_stats_; + Linear2Indices* lin2_indices_; + CpModelMapping* cp_model_mapping_; + SharedLinear2Bounds* shared_linear2_bounds_; // Might be nullptr. + + const int shared_linear2_bounds_id_; + + util_intops::StrongVector + best_upper_bounds_; + + // coeff_one_var_lookup_[var] contains all the other_var such that we have a + // linear2 relation var + other_var <= ub. We also store that relation index. + util_intops::StrongVector in_coeff_one_lookup_; + util_intops::StrongVector< + IntegerVariable, + std::vector>> + coeff_one_var_lookup_; + + int64_t num_updates_ = 0; +}; + struct FullIntegerPrecedence { IntegerVariable var; std::vector indices; std::vector offsets; }; -// Stores all the precedences relation of the form "a*x + b*y <= ub" -// that we could extract from the linear constraint of the model. These are -// stored in a directed graph. +// This class is used to compute the transitive closure of the level-zero +// precedence relations. // -// TODO(user): Support conditional relation. // TODO(user): Support non-DAG like graph. -// TODO(user): Support variable offset that can be updated as search progress. -class PrecedenceRelations : public ReversibleInterface { +class TransitivePrecedencesEvaluator { public: - explicit PrecedenceRelations(Model* model) - : params_(*model->GetOrCreate()), - trail_(model->GetOrCreate()), - integer_trail_(model->GetOrCreate()) { - integer_trail_->RegisterReversibleClass(this); + explicit TransitivePrecedencesEvaluator(Model* model) + : params_(model->GetOrCreate()), + integer_trail_(model->GetOrCreate()), + shared_stats_(model->GetOrCreate()), + root_level_bounds_(model->GetOrCreate()) { + // Call Build() each time we go back to level zero. + model->GetOrCreate()->callbacks.push_back( + [this]() { return Build(); }); } - void Resize(int num_variables) { - graph_.ReserveNodes(num_variables); - graph_.AddNode(num_variables - 1); - } - - // Add a relation lb <= expr <= ub. If expr is not a proper linear2 expression - // (e.g. 0*x + y, y + y, y - y) it will be ignored. Returns true if it was - // added and is considered "new". - bool AddBounds(LinearExpression2 expr, IntegerValue lb, IntegerValue ub); - - // Same as above, but only for the upper bound. - bool AddUpperBound(LinearExpression2 expr, IntegerValue ub); - - // Adds add relation (enf => expr <= rhs) that is assumed to be true at - // the current level. - // - // It will be automatically reverted via the SetLevel() functions that is - // called before any integer propagations trigger. - // - // This is assumed to be called when a relation becomes true (enforcement are - // assigned) and when it becomes false in reverse order (CHECKed). - // - // If expr is not a proper linear2 expression (e.g. 0*x + y, y + y, y - y) it - // will be ignored. - void PushConditionalRelation(absl::Span enforcements, - LinearExpression2 expr, IntegerValue rhs); - - // Called each time we change decision level. - void SetLevel(int level) final; - // Returns a set of relations var >= max_i(vars[index[i]] + offsets[i]). // // This currently only works if the precedence relation form a DAG. @@ -112,121 +324,127 @@ class PrecedenceRelations : public ReversibleInterface { void ComputeFullPrecedences(absl::Span vars, std::vector* output); - // Returns a set of precedences (var, index) such that var is after - // vars[index]. All entries for the same variable will be contiguous and - // sorted by index. We only list variable with at least two entries. The - // offset can be retrieved via UpperBound(vars[index], var). + // The current code requires the internal data to be processed once all + // root-level relations are loaded. // - // For more efficiency, this method ignores all linear2 expressions with any - // coefficient different from 1. + // If we don't have too many variable, we compute the full transitive closure + // and then push back to RootLevelLinear2Bounds if there is a relation between + // two variables. This can be used to optimize some scheduling propagation and + // reasons. + // + // Warning: If there are too many, this will NOT push all relations. + bool Build(); + + private: + SatParameters* params_; + IntegerTrail* integer_trail_; + SharedStatistics* shared_stats_; + RootLevelLinear2Bounds* root_level_bounds_; + + int64_t build_timestamp_ = -1; + bool is_dag_ = false; + std::vector topological_order_; +}; + +// Stores all the precedences relation of the form "{lits} => a*x + b*y <= ub" +// that we could extract from the model. +class EnforcedLinear2Bounds : public ReversibleInterface { + public: + explicit EnforcedLinear2Bounds(Model* model) + : params_(*model->GetOrCreate()), + trail_(model->GetOrCreate()), + integer_trail_(model->GetOrCreate()), + linear2_watcher_(model->GetOrCreate()), + root_level_bounds_(model->GetOrCreate()), + shared_stats_(model->GetOrCreate()), + lin2_indices_(model->GetOrCreate()) { + integer_trail_->RegisterReversibleClass(this); + } + + ~EnforcedLinear2Bounds() override; + + // Adds add relation (enf => expr <= rhs) that is assumed to be true at + // the current level. + // + // It will be automatically reverted via the SetLevel() functions that is + // called before any integer propagations trigger. + // + // This is assumed to be called when a relation becomes true (enforcement are + // assigned) and when it becomes false in reverse order (CHECKed). + // + // If expr is not a proper linear2 expression (e.g. 0*x + y, y + y, y - y) it + // will be ignored. + void PushConditionalRelation(absl::Span enforcements, + LinearExpression2Index index, IntegerValue rhs); + + void PushConditionalRelation(absl::Span enforcements, + LinearExpression2 expr, IntegerValue rhs) { + expr.SimpleCanonicalization(); + if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) return; + const IntegerValue gcd = expr.DivideByGcd(); + rhs = FloorRatio(rhs, gcd); + return PushConditionalRelation(enforcements, lin2_indices_->AddOrGet(expr), + rhs); + } + + // Called each time we change decision level. + void SetLevel(int level) final; + + // Returns a set of precedences such that we have a relation + // of the form vars[index] <= var + offset. + // + // All entries for the same variable will be contiguous and sorted by index. + // We only list variable with at least two entries. The up to date offset can + // be retrieved later via Linear2Bounds::UpperBound(lin2_index). + // + // This method currently ignores all linear2 expressions with any coefficient + // different from 1. + // + // TODO(user): Ideally this should be moved to a new class and maybe augmented + // with other kind of precedences. struct PrecedenceData { IntegerVariable var; - int index; + int var_index; + LinearExpression2Index lin2_index; }; void CollectPrecedences(absl::Span vars, std::vector* output); - // If we don't have too many variable, we compute the full transitive closure - // and can query in O(1) if there is a relation between two variables. - // This can be used to optimize some scheduling propagation and reasons. - // - // Warning: If there are too many, this will NOT contain all relations. - // - // Returns kMaxIntegerValue if there are none, otherwise return an upper bound - // such that expr <= ub. - IntegerValue LevelZeroUpperBound(LinearExpression2 expr) const; - - // Returns the maximum value for expr, and the reason for it (all - // true). Note that we always check LevelZeroUpperBound() so if it is better, - // the returned literal reason will be empty. - // - // We separate the two because usually the reason is only needed when we push, - // which happen less often, so we don't mind doing two hash lookups, and we - // really want to optimize the UpperBound() instead. - // - // Important: This doesn't contains the transitive closure. - // Important: The span is only valid in a narrow scope. - IntegerValue UpperBound(LinearExpression2 expr) const; + // Low-level function that returns the upper bound if there is some enforced + // relations only. Otherwise always returns kMaxIntegerValue. + // `expr` must be canonicalized and gcd-reduced. + IntegerValue GetUpperBoundFromEnforced(LinearExpression2Index index) const; void AddReasonForUpperBoundLowerThan( - LinearExpression2 expr, IntegerValue ub, + LinearExpression2Index index, IntegerValue ub, std::vector* literal_reason, std::vector* integer_reason) const; - // The current code requires the internal data to be processed once all - // relations are loaded. - // - // TODO(user): Be more dynamic as we start to add relations during search. - void Build(); - private: void CreateLevelEntryIfNeeded(); - // expr <= ub. - bool AddInternal(LinearExpression2 expr, IntegerValue ub) { - expr.SimpleCanonicalization(); - if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) { - return false; - } - const auto [it, inserted] = root_relations_.insert({expr, ub}); - UpdateBestRelationIfBetter(expr, ub); - if (inserted) { - if (expr.coeffs[0] != 1 || expr.coeffs[1] != 1) { - return true; - } - const int new_size = - std::max(expr.vars[0].value(), expr.vars[1].value()) + 1; - if (new_size > after_.size()) after_.resize(new_size); - after_[expr.vars[0]].push_back(NegationOf(expr.vars[1])); - after_[expr.vars[1]].push_back(NegationOf(expr.vars[0])); - return true; - } - it->second = std::min(it->second, ub); - return false; - } - - void UpdateBestRelationIfBetter(LinearExpression2 expr, IntegerValue rhs) { - const auto [it, inserted] = best_relations_.insert({expr, rhs}); - if (!inserted) { - it->second = std::min(it->second, rhs); - } - } - - void UpdateBestRelation(LinearExpression2 expr, IntegerValue rhs) { - const auto it = root_relations_.find(expr); - if (it != root_relations_.end()) { - rhs = std::min(rhs, it->second); - } - if (rhs == kMaxIntegerValue) { - best_relations_.erase(expr); - } else { - best_relations_[expr] = rhs; - } - } - const SatParameters& params_; Trail* trail_; IntegerTrail* integer_trail_; + Linear2Watcher* linear2_watcher_; + RootLevelLinear2Bounds* root_level_bounds_; + SharedStatistics* shared_stats_; + Linear2Indices* lin2_indices_; - util::StaticGraph<> graph_; - std::vector arc_offsets_; - - bool is_built_ = false; - bool is_dag_ = false; - std::vector topological_order_; + int64_t num_conditional_relation_updates_ = 0; // Conditional stack for push/pop of conditional relations. // // TODO(user): this kind of reversible hash_map is already implemented in // other part of the code. Consolidate. struct ConditionalEntry { - ConditionalEntry(int p, IntegerValue r, LinearExpression2 k, + ConditionalEntry(int p, IntegerValue r, LinearExpression2Index k, absl::Span e) : prev_entry(p), rhs(r), key(k), enforcements(e.begin(), e.end()) {} int prev_entry; IntegerValue rhs; - LinearExpression2 key; + LinearExpression2Index key; absl::InlinedVector enforcements; }; std::vector conditional_stack_; @@ -234,21 +452,15 @@ class PrecedenceRelations : public ReversibleInterface { // This is always stored in the form (expr <= rhs). // The conditional relations contains indices in the conditional_stack_. - absl::flat_hash_map root_relations_; - absl::flat_hash_map conditional_relations_; - - // Contains std::min() of the offset from root_relations_ and - // conditional_relations_. - absl::flat_hash_map best_relations_; + util_intops::StrongVector conditional_relations_; // Store for each variable x, the variables y that appears alongside it in - // LevelZeroUpperBound(expr) or UpperBound(expr). That is the variable - // that are after x with an offset. Note that conditional_after_ is updated on + // lit => x + y <= ub. Note that conditional_var_lookup_ is updated on // dive/backtrack. - util_intops::StrongVector> - after_; - util_intops::StrongVector> - conditional_after_; + util_intops::StrongVector< + IntegerVariable, + std::vector>> + conditional_var_lookup_; // Temp data for CollectPrecedences. std::vector var_with_positive_degree_; @@ -257,6 +469,265 @@ class PrecedenceRelations : public ReversibleInterface { std::vector tmp_precedences_; }; +// A relation of the form enforcement => expr \in [lhs, rhs]. +// Note that the [lhs, rhs] interval should always be within [min_activity, +// max_activity] where the activity is the value of expr. +struct Relation { + Literal enforcement; + LinearExpression2 expr; + IntegerValue lhs; + IntegerValue rhs; + + bool operator==(const Relation& other) const { + return enforcement == other.enforcement && expr == other.expr && + lhs == other.lhs && rhs == other.rhs; + } + + template + friend void AbslStringify(Sink& sink, const Relation& relation) { + absl::Format(&sink, "%s => %v in [%v, %v]", + relation.enforcement.DebugString(), relation.expr, + relation.lhs, relation.rhs); + } +}; + +// A repository of all the enforced linear constraints of size 1 or 2. +// +// TODO(user): This is not always needed, find a way to clean this once we +// don't need it. +class BinaryRelationRepository { + public: + int size() const { return relations_.size(); } + + // The returned relation is guaranteed to only have positive variables. + const Relation& relation(int index) const { return relations_[index]; } + + // Returns the indices of the relations that are enforced by the given + // literal. + absl::Span IndicesOfRelationsEnforcedBy(LiteralIndex lit) const { + if (lit >= lit_to_relations_.size()) return {}; + return lit_to_relations_[lit]; + } + + // Adds a conditional relation lit => expr \in [lhs, rhs] (one of the coeffs + // can be zero). + void Add(Literal lit, LinearExpression2 expr, IntegerValue lhs, + IntegerValue rhs); + + // Adds a partial conditional relation between two variables, with unspecified + // coefficients and bounds. + void AddPartialRelation(Literal lit, IntegerVariable a, IntegerVariable b); + + // Builds the literal to relations mapping. This should be called once all the + // relations have been added. + void Build(); + + // Assuming level-zero bounds + any (var >= value) in the input map, + // fills "output" with a "propagated" set of bounds assuming lit is true (by + // using the relations enforced by lit, as well as the non-enforced ones). + // Note that we will only fill bounds > level-zero ones in output. + // + // Returns false if the new bounds are infeasible at level zero. + // + // Important: by default this does not call output->clear() so we can take + // the max with already inferred bounds. + bool PropagateLocalBounds( + const IntegerTrail& integer_trail, + const RootLevelLinear2Bounds& root_level_bounds, Literal lit, + const absl::flat_hash_map& input, + absl::flat_hash_map* output) const; + + private: + bool is_built_ = false; + int num_enforced_relations_ = 0; + std::vector relations_; + CompactVectorVector lit_to_relations_; +}; + +// Class that keeps the best upper bound for a*x + b*y by using all the linear3 +// relations of the form a*x + b*y + c*z <= d. +class Linear2BoundsFromLinear3 { + public: + explicit Linear2BoundsFromLinear3(Model* model); + ~Linear2BoundsFromLinear3(); + + // If the given upper bound evaluate better than the current one we have, this + // will replace it and returns true, otherwise it returns false. + bool AddAffineUpperBound(LinearExpression2Index lin2_index, + IntegerValue lin_expr_gcd, + AffineExpression affine_ub); + + bool AddAffineUpperBound(LinearExpression2 expr, AffineExpression affine_ub) { + expr.SimpleCanonicalization(); + if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) return false; + const IntegerValue gcd = expr.DivideByGcd(); + return AddAffineUpperBound(lin2_indices_->AddOrGet(expr), gcd, affine_ub); + } + + // Most users should just use Linear2Bounds::UpperBound() instead. + // + // Returns the upper bound only if there is some relations coming from a + // linear3. Otherwise always returns kMaxIntegerValue. + // `expr` must be canonicalized and gcd-reduced. + IntegerValue GetUpperBoundFromLinear3( + LinearExpression2Index lin2_index) const; + + // Most users should use Linear2Bounds::AddReasonForUpperBoundLowerThan() + // instead. + // + // Adds the reason for GetUpperBoundFromLinear3() to be <= ub. + // `expr` must be canonicalized and gcd-reduced. + void AddReasonForUpperBoundLowerThan( + LinearExpression2Index lin2_index, IntegerValue ub, + std::vector* literal_reason, + std::vector* integer_reason) const; + + private: + IntegerTrail* integer_trail_; + Trail* trail_; + Linear2Watcher* linear2_watcher_; + GenericLiteralWatcher* watcher_; + SharedStatistics* shared_stats_; + RootLevelLinear2Bounds* root_level_bounds_; + Linear2Indices* lin2_indices_; + + int64_t num_affine_updates_ = 0; + + // This stores linear2 <= AffineExpression / divisor. + // + // Note(user): This is a "cheap way" to not have to deal with backtracking, If + // we have many possible AffineExpression that bounds a LinearExpression2, we + // keep the best one during "search dive" but on backtrack we might have a + // sub-optimal relation. + util_intops::StrongVector> + best_affine_ub_; +}; + +// TODO(user): Merge with BinaryRelationRepository. Note that this one provides +// different indexing though, so it could be kept separate. +// TODO(user): Use LinearExpression2 instead of pairs of AffineExpression for +// consistency with other classes. +class ReifiedLinear2Bounds { + public: + explicit ReifiedLinear2Bounds(Model* model); + + // Return the status of a <= b; + RelationStatus GetLevelZeroPrecedenceStatus(AffineExpression a, + AffineExpression b) const; + + // Register the fact that l <=> ( a <= b ). + // These are considered equivalence relation. + void AddReifiedPrecedenceIfNonTrivial(Literal l, AffineExpression a, + AffineExpression b); + + // Returns kNoLiteralIndex if we don't have a literal <=> ( a <= b ), or + // returns that literal if we have one. Note that we will return the + // true/false literal if the status is known at level zero. + LiteralIndex GetReifiedPrecedence(AffineExpression a, AffineExpression b); + + private: + IntegerEncoder* integer_encoder_; + RootLevelLinear2Bounds* best_root_level_bounds_; + + // This stores relations l <=> (linear2 <= rhs). + absl::flat_hash_map, Literal> + relation_to_lit_; + + // This is used to detect relations that become fixed at level zero and + // "upgrade" them to non-enforced relations. Because we only do that when + // we fix variable, a linear scan shouldn't be too bad and is relatively + // compact memory wise. + absl::flat_hash_set variable_appearing_in_reified_relations_; + std::vector> + all_reified_relations_; +}; + +// Simple wrapper around the different repositories for bounds of linear2. +// This should provide the best bounds. +class Linear2Bounds { + public: + explicit Linear2Bounds(Model* model) + : integer_trail_(model->GetOrCreate()), + root_level_bounds_(model->GetOrCreate()), + enforced_bounds_(model->GetOrCreate()), + linear3_bounds_(model->GetOrCreate()), + lin2_indices_(model->GetOrCreate()) {} + + // Returns the best known upper-bound of the given LinearExpression2 at the + // current decision level. If its explanation is needed, it can be queried + // with the second function. + IntegerValue UpperBound(LinearExpression2 expr) const; + IntegerValue UpperBound(LinearExpression2Index lin2_index) const; + + void AddReasonForUpperBoundLowerThan( + LinearExpression2 expr, IntegerValue ub, + std::vector* literal_reason, + std::vector* integer_reason) const; + + // Like UpperBound() but do not consider the bounds coming from + // the individual variable bounds. This is faster. + IntegerValue NonTrivialUpperBound(LinearExpression2Index lin2_index) const; + + private: + IntegerTrail* integer_trail_; + RootLevelLinear2Bounds* root_level_bounds_; + EnforcedLinear2Bounds* enforced_bounds_; + Linear2BoundsFromLinear3* linear3_bounds_; + Linear2Indices* lin2_indices_; +}; + +// Detects if at least one of a subset of linear of size 2 or 1, touching the +// same variable, must be true. When this is the case we add a new propagator to +// propagate that fact. +// +// TODO(user): Shall we do that on the main thread before the workers are +// spawned? note that the probing version need the model to be loaded though. +class GreaterThanAtLeastOneOfDetector { + public: + explicit GreaterThanAtLeastOneOfDetector(Model* model) + : repository_(*model->GetOrCreate()) {} + + // Advanced usage. To be called once all the constraints have been added to + // the model. This will detect GreaterThanAtLeastOneOfConstraint(). + // Returns the number of added constraint. + // + // TODO(user): This can be quite slow, add some kind of deterministic limit + // so that we can use it all the time. + int AddGreaterThanAtLeastOneOfConstraints(Model* model, + bool auto_detect_clauses = false); + + private: + // Given an existing clause, sees if it can be used to add "greater than at + // least one of" type of constraints. Returns the number of such constraint + // added. + int AddGreaterThanAtLeastOneOfConstraintsFromClause( + absl::Span clause, Model* model); + + // Another approach for AddGreaterThanAtLeastOneOfConstraints(), this one + // might be a bit slow as it relies on the propagation engine to detect + // clauses between incoming arcs presence literals. + // Returns the number of added constraints. + int AddGreaterThanAtLeastOneOfConstraintsWithClauseAutoDetection( + Model* model); + + // Once we identified a clause and relevant indices, this build the + // constraint. Returns true if we actually add it. + bool AddRelationFromIndices(IntegerVariable var, + absl::Span clause, + absl::Span indices, Model* model); + + BinaryRelationRepository& repository_; +}; + +// ============================================================================= +// Old precedences propagator. +// +// This is superseded by the new LinearPropagator and should only be used if the +// option 'new_linear_propagation' is false. We still keep it around to +// benchmark and test the new code vs this one. +// ============================================================================= + // This class implement a propagator on simple inequalities between integer // variables of the form (i1 + offset <= i2). The offset can be constant or // given by the value of a third integer variable. Offsets can also be negative. @@ -277,7 +748,7 @@ class PrecedencesPropagator : public SatPropagator, PropagatorInterface { public: explicit PrecedencesPropagator(Model* model) : SatPropagator("PrecedencesPropagator"), - relations_(model->GetOrCreate()), + relations_(model->GetOrCreate()), trail_(model->GetOrCreate()), integer_trail_(model->GetOrCreate()), shared_stats_(model->Mutable()), @@ -405,7 +876,7 @@ class PrecedencesPropagator : public SatPropagator, PropagatorInterface { // External class needed to get the IntegerVariable lower bounds and Enqueue // new ones. - PrecedenceRelations* relations_; + EnforcedLinear2Bounds* relations_; Trail* trail_; IntegerTrail* integer_trail_; SharedStatistics* shared_stats_ = nullptr; @@ -471,260 +942,16 @@ class PrecedencesPropagator : public SatPropagator, PropagatorInterface { int64_t num_enforcement_pushes_ = 0; }; -// Similar to AffineExpression, but with a zero constant. -// If coeff is zero, then this is always zero and var is ignored. -struct LinearTerm { - LinearTerm() = default; - LinearTerm(IntegerVariable v, IntegerValue c) : var(v), coeff(c) {} - - void MakeCoeffPositive() { - if (coeff < 0) { - coeff = -coeff; - var = NegationOf(var); - } - } - - bool operator==(const LinearTerm& other) const { - return var == other.var && coeff == other.coeff; - } - - IntegerVariable var = kNoIntegerVariable; - IntegerValue coeff = IntegerValue(0); -}; - -// A relation of the form enforcement => a + b \in [lhs, rhs]. -// Note that the [lhs, rhs] interval should always be within [min_activity, -// max_activity] where the activity is the value of a + b. -struct Relation { - Literal enforcement; - LinearTerm a; - LinearTerm b; - IntegerValue lhs; - IntegerValue rhs; - - bool operator==(const Relation& other) const { - return enforcement == other.enforcement && a == other.a && b == other.b && - lhs == other.lhs && rhs == other.rhs; - } -}; - -// A repository of all the enforced linear constraints of size 1 or 2, and of -// all the non-enforced linear constraints of size 2. -// -// TODO(user): This is not always needed, find a way to clean this once we -// don't need it. -class BinaryRelationRepository { - public: - int size() const { return relations_.size(); } - - // The returned relation is guaranteed to only have positive variables. - const Relation& relation(int index) const { return relations_[index]; } - - // Returns the indices of the relations that are enforced by the given - // literal. - absl::Span IndicesOfRelationsEnforcedBy(LiteralIndex lit) const { - if (lit >= lit_to_relations_.size()) return {}; - return lit_to_relations_[lit]; - } - - // Returns the indices of the non-enforced relations that contain the given - // (positive) variable. - absl::Span IndicesOfRelationsContaining( - IntegerVariable var) const { - if (var >= var_to_relations_.size()) return {}; - return var_to_relations_[var]; - } - - // Returns the indices of the non-enforced relations that contain the given - // (positive) variables. - absl::Span IndicesOfRelationsBetween(IntegerVariable var1, - IntegerVariable var2) const { - if (var1 > var2) std::swap(var1, var2); - const std::pair key(var1, var2); - const auto it = var_pair_to_relations_.find(key); - if (it == var_pair_to_relations_.end()) return {}; - return it->second; - } - - // Adds a conditional relation lit => a + b \in [lhs, rhs] (one of the terms - // can be zero), or an always true binary relation a + b \in [lhs, rhs] (both - // terms must be non-zero). - void Add(Literal lit, LinearTerm a, LinearTerm b, IntegerValue lhs, - IntegerValue rhs); - - // Adds a partial conditional relation between two variables, with unspecified - // coefficients and bounds. - void AddPartialRelation(Literal lit, IntegerVariable a, IntegerVariable b); - - // Builds the literal to relations mapping. This should be called once all the - // relations have been added. - void Build(); - - // Assuming level-zero bounds + any (var >= value) in the input map, - // fills "output" with a "propagated" set of bounds assuming lit is true (by - // using the relations enforced by lit, as well as the non-enforced ones). - // Note that we will only fill bounds > level-zero ones in output. - // - // Returns false if the new bounds are infeasible at level zero. - // - // Important: by default this does not call output->clear() so we can take - // the max with already inferred bounds. - bool PropagateLocalBounds( - const IntegerTrail& integer_trail, Literal lit, - const absl::flat_hash_map& input, - absl::flat_hash_map* output) const; - - private: - bool is_built_ = false; - int num_enforced_relations_ = 0; - std::vector relations_; - CompactVectorVector lit_to_relations_; - CompactVectorVector var_to_relations_; - absl::flat_hash_map, - std::vector> - var_pair_to_relations_; -}; - -// TODO(user): Merge with BinaryRelationRepository. Note that this one provides -// different indexing though, so it could be kept separate. The -// LinearExpression2 data structure is also slightly more efficient. -class BinaryRelationsMaps { - public: - explicit BinaryRelationsMaps(Model* model); - ~BinaryRelationsMaps(); - - // This mainly wraps BestBinaryRelationBounds, but in addition it checks the - // current LevelZero variable bounds to detect trivially true or false - // relation. - void AddRelationBounds(LinearExpression2 expr, IntegerValue lb, - IntegerValue ub); - RelationStatus GetLevelZeroStatus(LinearExpression2 expr, IntegerValue lb, - IntegerValue ub) const; - - // Return the status of a <= b; - RelationStatus GetLevelZeroPrecedenceStatus(AffineExpression a, - AffineExpression b) const; - - // Register the fact that l <=> ( a <= b ). - // These are considered equivalence relation. - void AddReifiedPrecedenceIfNonTrivial(Literal l, AffineExpression a, - AffineExpression b); - - // Returns kNoLiteralIndex if we don't have a literal <=> ( a <= b ), or - // returns that literal if we have one. Note that we will return the - // true/false literal if the status is known at level zero. - LiteralIndex GetReifiedPrecedence(AffineExpression a, AffineExpression b); - - // If the given upper bound evaluate better than the current one we have, this - // will replace it and returns true, otherwise it returns false. - // - // Note that we never store trivial upper bound (using the current variable - // domain). - bool AddAffineUpperBound(LinearExpression2 expr, AffineExpression affine_ub); - - // Returns the best known upper-bound of the given LinearExpression2 at the - // current decision level. If its explanation is needed, it can be queried - // with the second function. - IntegerValue UpperBound(LinearExpression2 expr) const; - void AddReasonForUpperBoundLowerThan( - LinearExpression2 expr, IntegerValue ub, - std::vector* literal_reason, - std::vector* integer_reason) const; - - // Warning, the order will not be deterministic. - std::vector GetAllExpressionsWithAffineBounds() const; - - int NumExpressionsWithAffineBounds() const { return best_affine_ub_.size(); } - - void WatchAllLinearExpressions2(int id) { propagator_ids_.insert(id); } - - private: - void NotifyWatchingPropagators() const; - - // Return the pair (a - b) <= rhs. - std::pair FromDifference( - const AffineExpression& a, const AffineExpression& b) const; - - IntegerValue GetImpliedUpperBound(const LinearExpression2& expr) const; - std::pair GetImpliedLevelZeroBounds( - const LinearExpression2& expr) const; - - IntegerTrail* integer_trail_; - IntegerEncoder* integer_encoder_; - GenericLiteralWatcher* watcher_; - SharedStatistics* shared_stats_; - BestBinaryRelationBounds best_root_level_bounds_; - - int64_t num_updates_ = 0; - int64_t num_affine_updates_ = 0; - - // This stores relations l <=> (linear2 <= rhs). - absl::flat_hash_map, Literal> - relation_to_lit_; - - // This is used to detect relations that become fixed at level zero and - // "upgrade" them to non-enforced relations. Because we only do that when - // we fix variable, a linear scan shouldn't be too bad and is relatively - // compact memory wise. - absl::flat_hash_set variable_appearing_in_reified_relations_; - std::vector> - all_reified_relations_; - - // This stores linear2 <= AffineExpression / divisor. - // - // Note(user): This is a "cheap way" to not have to deal with backtracking, If - // we have many possible AffineExpression that bounds a LinearExpression2, we - // keep the best one during "search dive" but on backtrack we might have a - // sub-optimal relation. - absl::flat_hash_map> - best_affine_ub_; - - absl::btree_set propagator_ids_; -}; - -// Detects if at least one of a subset of linear of size 2 or 1, touching the -// same variable, must be true. When this is the case we add a new propagator to -// propagate that fact. -// -// TODO(user): Shall we do that on the main thread before the workers are -// spawned? note that the probing version need the model to be loaded though. -class GreaterThanAtLeastOneOfDetector { - public: - explicit GreaterThanAtLeastOneOfDetector(Model* model) - : repository_(*model->GetOrCreate()) {} - - // Advanced usage. To be called once all the constraints have been added to - // the model. This will detect GreaterThanAtLeastOneOfConstraint(). - // Returns the number of added constraint. - // - // TODO(user): This can be quite slow, add some kind of deterministic limit - // so that we can use it all the time. - int AddGreaterThanAtLeastOneOfConstraints(Model* model, - bool auto_detect_clauses = false); - - private: - // Given an existing clause, sees if it can be used to add "greater than at - // least one of" type of constraints. Returns the number of such constraint - // added. - int AddGreaterThanAtLeastOneOfConstraintsFromClause( - absl::Span clause, Model* model); - - // Another approach for AddGreaterThanAtLeastOneOfConstraints(), this one - // might be a bit slow as it relies on the propagation engine to detect - // clauses between incoming arcs presence literals. - // Returns the number of added constraints. - int AddGreaterThanAtLeastOneOfConstraintsWithClauseAutoDetection( - Model* model); - - // Once we identified a clause and relevant indices, this build the - // constraint. Returns true if we actually add it. - bool AddRelationFromIndices(IntegerVariable var, - absl::Span clause, - absl::Span indices, Model* model); - - BinaryRelationRepository& repository_; -}; +// This can be in a hot-loop, so we want to inline it if possible. +inline IntegerValue Linear2Bounds::NonTrivialUpperBound( + LinearExpression2Index lin2_index) const { + CHECK_NE(lin2_index, kNoLinearExpression2Index); + IntegerValue ub = kMaxIntegerValue; + ub = std::min(ub, root_level_bounds_->GetUpperBoundNoTrail(lin2_index)); + ub = std::min(ub, enforced_bounds_->GetUpperBoundFromEnforced(lin2_index)); + ub = std::min(ub, linear3_bounds_->GetUpperBoundFromLinear3(lin2_index)); + return ub; +} // ============================================================================= // Implementation of the small API functions below. @@ -768,43 +995,6 @@ inline void PrecedencesPropagator::AddPrecedenceWithAllOptions( // Model based functions. // ============================================================================= -// a <= b. -inline std::function LowerOrEqual(IntegerVariable a, - IntegerVariable b) { - return [=](Model* model) { - return model->GetOrCreate()->AddPrecedence(a, b); - }; -} - -// a + offset <= b. -inline std::function LowerOrEqualWithOffset(IntegerVariable a, - IntegerVariable b, - int64_t offset) { - return [=](Model* model) { - LinearExpression2 expr(a, b, 1, -1); - model->GetOrCreate()->AddUpperBound( - expr, IntegerValue(-offset)); - model->GetOrCreate()->AddPrecedenceWithOffset( - a, b, IntegerValue(offset)); - }; -} - -// a + offset <= b. (when a and b are of the form 1 * var + offset). -inline std::function AffineCoeffOneLowerOrEqualWithOffset( - AffineExpression a, AffineExpression b, int64_t offset) { - CHECK_NE(a.var, kNoIntegerVariable); - CHECK_EQ(a.coeff, 1); - CHECK_NE(b.var, kNoIntegerVariable); - CHECK_EQ(b.coeff, 1); - return [=](Model* model) { - LinearExpression2 expr(a.var, b.var, 1, -1); - model->GetOrCreate()->AddUpperBound( - expr, -a.constant + b.constant - offset); - model->GetOrCreate()->AddPrecedenceWithOffset( - a.var, b.var, a.constant - b.constant + offset); - }; -} - // l => (a + b <= ub). inline void AddConditionalSum2LowerOrEqual( absl::Span enforcement_literals, IntegerVariable a, @@ -812,8 +1002,8 @@ inline void AddConditionalSum2LowerOrEqual( // TODO(user): Refactor to be sure we do not miss any level zero relations. if (enforcement_literals.empty()) { LinearExpression2 expr(a, b, 1, 1); - model->GetOrCreate()->AddUpperBound(expr, - IntegerValue(ub)); + model->GetOrCreate()->AddUpperBound( + expr, IntegerValue(ub)); } PrecedencesPropagator* p = model->GetOrCreate(); @@ -832,34 +1022,21 @@ inline void AddConditionalSum3LowerOrEqual( enforcement_literals); } -// a >= b. -inline std::function GreaterOrEqual(IntegerVariable a, - IntegerVariable b) { - return [=](Model* model) { - return model->GetOrCreate()->AddPrecedence(b, a); - }; -} - // a == b. +// +// ABSL_DEPRECATED("Use linear constraint API instead") inline std::function Equality(IntegerVariable a, IntegerVariable b) { return [=](Model* model) { - model->Add(LowerOrEqual(a, b)); - model->Add(LowerOrEqual(b, a)); - }; -} - -// a + offset == b. -inline std::function EqualityWithOffset(IntegerVariable a, - IntegerVariable b, - int64_t offset) { - return [=](Model* model) { - model->Add(LowerOrEqualWithOffset(a, b, offset)); - model->Add(LowerOrEqualWithOffset(b, a, -offset)); + auto* precedences = model->GetOrCreate(); + precedences->AddPrecedence(a, b); + precedences->AddPrecedence(b, a); }; } // is_le => (a + offset <= b). +// +// ABSL_DEPRECATED("Use linear constraint API instead") inline std::function ConditionalLowerOrEqualWithOffset( IntegerVariable a, IntegerVariable b, int64_t offset, Literal is_le) { return [=](Model* model) { @@ -868,6 +1045,60 @@ inline std::function ConditionalLowerOrEqualWithOffset( }; } +inline LinearExpression2Index Linear2Indices::GetIndex( + LinearExpression2 expr) const { + if (expr.coeffs[0] == 0 || expr.coeffs[1] == 0) { + return kNoLinearExpression2Index; + } + DCHECK(expr.IsCanonicalized()); + DCHECK_EQ(expr.DivideByGcd(), 1); + const bool negated = expr.NegateForCanonicalization(); + auto it = expr_to_index_.find(expr); + if (it == expr_to_index_.end()) return kNoLinearExpression2Index; + + const LinearExpression2Index positive_index(2 * it->second); + if (negated) { + return NegationOf(positive_index); + } else { + return positive_index; + } +} + +inline LinearExpression2 Linear2Indices::GetExpression( + LinearExpression2Index index) const { + DCHECK_NE(index, kNoLinearExpression2Index); + const int lookup_index = index.value() / 2; + DCHECK_LT(lookup_index, exprs_.size()); + if (Linear2IsPositive(index)) { + return exprs_[lookup_index]; + } else { + LinearExpression2 result = exprs_[lookup_index]; + result.Negate(); + return result; + } +} + +inline absl::Span +Linear2Indices::GetAllLinear2ContainingVariable(IntegerVariable var) const { + const IntegerVariable positive_var = PositiveVariable(var); + auto it = var_to_bounds_.find(positive_var); + if (it == var_to_bounds_.end()) return {}; + return it->second; +} + +inline absl::Span +Linear2Indices::GetAllLinear2ContainingVariables(IntegerVariable var1, + IntegerVariable var2) const { + IntegerVariable positive_var1 = PositiveVariable(var1); + IntegerVariable positive_var2 = PositiveVariable(var2); + if (positive_var1 > positive_var2) { + std::swap(positive_var1, positive_var2); + } + auto it = var_pair_to_bounds_.find({positive_var1, positive_var2}); + if (it == var_pair_to_bounds_.end()) return {}; + return it->second; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/precedences_test.cc b/ortools/sat/precedences_test.cc index 781781d5f2..715be1b237 100644 --- a/ortools/sat/precedences_test.cc +++ b/ortools/sat/precedences_test.cc @@ -14,10 +14,12 @@ #include "ortools/sat/precedences.h" #include +#include #include #include #include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" #include "gtest/gtest.h" #include "ortools/base/gmock.h" #include "ortools/base/parse_test_proto.h" @@ -38,6 +40,7 @@ namespace { using ::google::protobuf::contrib::parse_proto::ParseTestProto; using ::testing::ElementsAre; +using ::testing::FieldsAre; using ::testing::IsEmpty; using ::testing::UnorderedElementsAre; @@ -59,123 +62,135 @@ std::vector AddVariables(IntegerTrail* integer_trail) { return vars; } -TEST(PrecedenceRelationsTest, BasicAPI) { +TEST(EnforcedLinear2BoundsTest, BasicAPI) { Model model; IntegerTrail* integer_trail = model.GetOrCreate(); + auto* root_bounds = model.GetOrCreate(); + auto* precedence_builder = + model.GetOrCreate(); const std::vector vars = AddVariables(integer_trail); // Note that odd indices are for the negation. IntegerVariable a(0), b(2), c(4), d(6); - PrecedenceRelations precedences(&model); - precedences.AddUpperBound(LinearExpression2::Difference(a, b), -10); - precedences.AddUpperBound(LinearExpression2::Difference(d, c), -7); - precedences.AddUpperBound(LinearExpression2::Difference(b, d), -5); + root_bounds->AddUpperBound(LinearExpression2::Difference(a, b), -10); + root_bounds->AddUpperBound(LinearExpression2::Difference(d, c), -7); + root_bounds->AddUpperBound(LinearExpression2::Difference(b, d), -5); - precedences.Build(); + precedence_builder->Build(); EXPECT_EQ( - precedences.LevelZeroUpperBound(LinearExpression2::Difference(a, b)), + root_bounds->LevelZeroUpperBound(LinearExpression2::Difference(a, b)), -10); - EXPECT_EQ(precedences.LevelZeroUpperBound( + EXPECT_EQ(root_bounds->LevelZeroUpperBound( LinearExpression2::Difference(NegationOf(b), NegationOf(a))), -10); EXPECT_EQ( - precedences.LevelZeroUpperBound(LinearExpression2::Difference(a, c)), + root_bounds->LevelZeroUpperBound(LinearExpression2::Difference(a, c)), -22); - EXPECT_EQ(precedences.LevelZeroUpperBound( + EXPECT_EQ(root_bounds->LevelZeroUpperBound( LinearExpression2::Difference(NegationOf(c), NegationOf(a))), -22); EXPECT_EQ( - precedences.LevelZeroUpperBound(LinearExpression2::Difference(a, d)), + root_bounds->LevelZeroUpperBound(LinearExpression2::Difference(a, d)), -15); - EXPECT_EQ(precedences.LevelZeroUpperBound( + EXPECT_EQ(root_bounds->LevelZeroUpperBound( LinearExpression2::Difference(NegationOf(d), NegationOf(a))), -15); EXPECT_EQ( - precedences.LevelZeroUpperBound(LinearExpression2::Difference(d, a)), - kMaxIntegerValue); + root_bounds->LevelZeroUpperBound(LinearExpression2::Difference(d, a)), + 100); // Once built, we can update the offsets. // Note however that this would not propagate through the precedence graphs. - precedences.AddUpperBound(LinearExpression2::Difference(a, b), -15); + root_bounds->AddUpperBound(LinearExpression2::Difference(a, b), -15); EXPECT_EQ( - precedences.LevelZeroUpperBound(LinearExpression2::Difference(a, b)), + root_bounds->LevelZeroUpperBound(LinearExpression2::Difference(a, b)), -15); - EXPECT_EQ(precedences.LevelZeroUpperBound( + EXPECT_EQ(root_bounds->LevelZeroUpperBound( LinearExpression2::Difference(NegationOf(b), NegationOf(a))), -15); } -TEST(PrecedenceRelationsTest, CornerCase1) { +TEST(EnforcedLinear2BoundsTest, CornerCase1) { Model model; IntegerTrail* integer_trail = model.GetOrCreate(); + auto* root_bounds = model.GetOrCreate(); + auto* precedence_builder = + model.GetOrCreate(); const std::vector vars = AddVariables(integer_trail); // Note that odd indices are for the negation. IntegerVariable a(0), b(2), c(4), d(6); - PrecedenceRelations precedences(&model); - precedences.AddUpperBound(LinearExpression2::Difference(a, b), -10); - precedences.AddUpperBound(LinearExpression2::Difference(b, c), -7); - precedences.AddUpperBound(LinearExpression2::Difference(b, d), -5); - precedences.AddUpperBound(LinearExpression2::Difference(NegationOf(b), a), - -5); + root_bounds->AddUpperBound(LinearExpression2::Difference(a, b), -10); + root_bounds->AddUpperBound(LinearExpression2::Difference(b, c), -7); + root_bounds->AddUpperBound(LinearExpression2::Difference(b, d), -5); + root_bounds->AddUpperBound(LinearExpression2::Difference(NegationOf(b), a), + -5); - precedences.Build(); - EXPECT_EQ(precedences.LevelZeroUpperBound( + precedence_builder->Build(); + EXPECT_EQ(root_bounds->LevelZeroUpperBound( LinearExpression2::Difference(NegationOf(b), a)), -5); - EXPECT_EQ(precedences.LevelZeroUpperBound( + EXPECT_EQ(root_bounds->LevelZeroUpperBound( LinearExpression2::Difference(NegationOf(b), c)), -22); - EXPECT_EQ(precedences.LevelZeroUpperBound( + EXPECT_EQ(root_bounds->LevelZeroUpperBound( LinearExpression2::Difference(NegationOf(b), d)), -20); } -TEST(PrecedenceRelationsTest, CornerCase2) { +TEST(EnforcedLinear2BoundsTest, CornerCase2) { Model model; IntegerTrail* integer_trail = model.GetOrCreate(); + auto* root_bounds = model.GetOrCreate(); + auto* precedence_builder = + model.GetOrCreate(); const std::vector vars = AddVariables(integer_trail); // Note that odd indices are for the negation. IntegerVariable a(0), b(2), c(4), d(6); - PrecedenceRelations precedences(&model); - precedences.AddUpperBound(LinearExpression2::Difference(NegationOf(a), a), - -10); - precedences.AddUpperBound(LinearExpression2::Difference(a, b), -7); - precedences.AddUpperBound(LinearExpression2::Difference(a, c), -5); - precedences.AddUpperBound(LinearExpression2::Difference(a, d), -2); - EXPECT_EQ(precedences.LevelZeroUpperBound( + root_bounds->AddUpperBound(LinearExpression2::Difference(NegationOf(a), a), + -10); + root_bounds->AddUpperBound(LinearExpression2::Difference(a, b), -7); + root_bounds->AddUpperBound(LinearExpression2::Difference(a, c), -5); + root_bounds->AddUpperBound(LinearExpression2::Difference(a, d), -2); + EXPECT_EQ(root_bounds->LevelZeroUpperBound( LinearExpression2::Difference(NegationOf(b), NegationOf(a))), -7); - precedences.Build(); + precedence_builder->Build(); } -TEST(PrecedenceRelationsTest, CoefficientGreaterThanOne) { +TEST(EnforcedLinear2BoundsTest, CoefficientGreaterThanOne) { Model model; IntegerTrail* integer_trail = model.GetOrCreate(); + auto* root_bounds = model.GetOrCreate(); + auto* precedence_builder = + model.GetOrCreate(); const std::vector vars = AddVariables(integer_trail); // Note that odd indices are for the negation. IntegerVariable a(0), b(2), c(4); - PrecedenceRelations precedences(&model); - precedences.AddUpperBound(LinearExpression2(a, b, 3, -4), 7); - precedences.AddUpperBound(LinearExpression2(a, c, 2, -3), -5); - precedences.AddUpperBound(LinearExpression2(a, b, 6, -8), 5); - EXPECT_EQ(precedences.LevelZeroUpperBound(LinearExpression2(a, b, 9, -12)), + EnforcedLinear2Bounds precedences(&model); + root_bounds->AddUpperBound(LinearExpression2(a, b, 3, -4), 7); + root_bounds->AddUpperBound(LinearExpression2(a, c, 2, -3), -5); + root_bounds->AddUpperBound(LinearExpression2(a, b, 6, -8), 5); + EXPECT_EQ(root_bounds->LevelZeroUpperBound(LinearExpression2(a, b, 9, -12)), 6); - precedences.Build(); + precedence_builder->Build(); } -TEST(PrecedenceRelationsTest, ConditionalRelations) { +TEST(EnforcedLinear2BoundsTest, ConditionalRelations) { Model model; auto* sat_solver = model.GetOrCreate(); + auto* lin2_bounds = model.GetOrCreate(); auto* integer_trail = model.GetOrCreate(); + auto* precedences = model.GetOrCreate(); + auto* lin2_indices = model.GetOrCreate(); const std::vector vars = AddVariables(integer_trail); const Literal l(model.Add(NewBooleanVariable()), true); @@ -183,31 +198,29 @@ TEST(PrecedenceRelationsTest, ConditionalRelations) { // Note that odd indices are for the negation. IntegerVariable a(0), b(2); - PrecedenceRelations precedences(&model); - precedences.PushConditionalRelation({l}, LinearExpression2(a, b, 1, 1), 15); - precedences.PushConditionalRelation({l}, LinearExpression2(a, b, 1, 1), 20); + precedences->PushConditionalRelation({l}, LinearExpression2(a, b, 1, 1), 15); + precedences->PushConditionalRelation({l}, LinearExpression2(a, b, 1, 1), 20); + LinearExpression2 expr_a_plus_b = + LinearExpression2::Difference(a, NegationOf(b)); + expr_a_plus_b.SimpleCanonicalization(); // We only keep the best one. - EXPECT_EQ( - precedences.UpperBound(LinearExpression2::Difference(a, NegationOf(b))), - 15); + EXPECT_EQ(lin2_bounds->UpperBound(expr_a_plus_b), 15); std::vector literal_reason; std::vector integer_reason; - precedences.AddReasonForUpperBoundLowerThan( - LinearExpression2::Difference(a, NegationOf(b)), 15, &literal_reason, + precedences->AddReasonForUpperBoundLowerThan( + lin2_indices->AddOrGet(expr_a_plus_b), 15, &literal_reason, &integer_reason); EXPECT_THAT(literal_reason, ElementsAre(l.Negated())); // Backtrack works. EXPECT_TRUE(sat_solver->ResetToLevelZero()); - EXPECT_EQ( - precedences.UpperBound(LinearExpression2::Difference(a, NegationOf(b))), - kMaxIntegerValue); + EXPECT_EQ(lin2_bounds->UpperBound(expr_a_plus_b), 200); literal_reason.clear(); integer_reason.clear(); - precedences.AddReasonForUpperBoundLowerThan( - LinearExpression2::Difference(a, NegationOf(b)), kMaxIntegerValue, - &literal_reason, &integer_reason); + precedences->AddReasonForUpperBoundLowerThan( + lin2_indices->AddOrGet(expr_a_plus_b), kMaxIntegerValue, &literal_reason, + &integer_reason); EXPECT_THAT(literal_reason, IsEmpty()); } @@ -435,8 +448,9 @@ TEST(PrecedencesPropagatorTest, ZeroWeightCycleOnDiscreteDomain) { NewIntegerVariable(Domain::FromValues({3, 6, 9, 14, 16, 18, 20, 35}))); // Add the fact that a == b with two inequalities. - model.Add(LowerOrEqual(a, b)); - model.Add(LowerOrEqual(b, a)); + auto* precedences = model.GetOrCreate(); + precedences->AddPrecedence(a, b); + precedences->AddPrecedence(b, a); // After propagation, we should detect that the only common values fall in // [16, 20]. @@ -455,7 +469,8 @@ TEST(PrecedencesPropagatorTest, ConditionalPrecedencesOnFixedLiteral) { // To trigger the old bug, we need to add some precedences. IntegerVariable x = model.Add(NewIntegerVariable(0, 100)); IntegerVariable y = model.Add(NewIntegerVariable(50, 100)); - model.Add(LowerOrEqual(x, y)); + auto* precedences = model.GetOrCreate(); + precedences->AddPrecedence(x, y); // We then add a Boolean variable and fix it. // This will trigger a propagation. @@ -472,33 +487,34 @@ TEST(PrecedencesPropagatorTest, ConditionalPrecedencesOnFixedLiteral) { #undef EXPECT_BOUNDS_EQ -TEST(PrecedenceRelationsTest, CollectPrecedences) { +TEST(EnforcedLinear2BoundsTest, CollectPrecedences) { Model model; auto* integer_trail = model.GetOrCreate(); - auto* relations = model.GetOrCreate(); + auto* relations = model.GetOrCreate(); + auto* root_bounds = model.GetOrCreate(); std::vector vars = AddVariables(integer_trail); - relations->AddUpperBound(LinearExpression2::Difference(vars[0], vars[2]), - IntegerValue(-1)); - relations->AddUpperBound(LinearExpression2::Difference(vars[0], vars[5]), - IntegerValue(-1)); - relations->AddUpperBound(LinearExpression2::Difference(vars[1], vars[2]), - IntegerValue(-1)); - relations->AddUpperBound(LinearExpression2::Difference(vars[2], vars[4]), - IntegerValue(-1)); - relations->AddUpperBound(LinearExpression2::Difference(vars[3], vars[4]), - IntegerValue(-1)); - relations->AddUpperBound(LinearExpression2::Difference(vars[4], vars[5]), - IntegerValue(-1)); + root_bounds->AddUpperBound(LinearExpression2::Difference(vars[0], vars[2]), + IntegerValue(-1)); + root_bounds->AddUpperBound(LinearExpression2::Difference(vars[0], vars[5]), + IntegerValue(-1)); + root_bounds->AddUpperBound(LinearExpression2::Difference(vars[1], vars[2]), + IntegerValue(-1)); + root_bounds->AddUpperBound(LinearExpression2::Difference(vars[2], vars[4]), + IntegerValue(-1)); + root_bounds->AddUpperBound(LinearExpression2::Difference(vars[3], vars[4]), + IntegerValue(-1)); + root_bounds->AddUpperBound(LinearExpression2::Difference(vars[4], vars[5]), + IntegerValue(-1)); - std::vector p; + std::vector p; relations->CollectPrecedences({vars[0], vars[2], vars[3]}, &p); // Note that we do not return precedences with just one variable. std::vector indices; std::vector variables; for (const auto precedence : p) { - indices.push_back(precedence.index); + indices.push_back(precedence.var_index); variables.push_back(precedence.var); } EXPECT_EQ(indices, (std::vector{1, 2})); @@ -511,46 +527,76 @@ TEST(PrecedenceRelationsTest, CollectPrecedences) { TEST(BinaryRelationRepositoryTest, Build) { Model model; - const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); - const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); - const IntegerVariable z = model.Add(NewIntegerVariable(0, 10)); + const IntegerVariable x = model.Add(NewIntegerVariable(-100, 100)); + const IntegerVariable y = model.Add(NewIntegerVariable(-100, 100)); + const IntegerVariable z = model.Add(NewIntegerVariable(-100, 100)); const Literal lit_a = Literal(model.Add(NewBooleanVariable()), true); const Literal lit_b = Literal(model.Add(NewBooleanVariable()), true); BinaryRelationRepository repository; - repository.Add(lit_a, {NegationOf(x), 1}, {y, 1}, 2, 8); - repository.Add(Literal(kNoLiteralIndex), {x, 2}, {y, -2}, 0, 10); - repository.Add(lit_a, {x, -3}, {NegationOf(y), 2}, 1, 15); - repository.Add(lit_b, {x, -3}, {kNoIntegerVariable, 0}, 3, 5); - repository.Add(Literal(kNoLiteralIndex), {x, 3}, {y, -1}, 5, 15); - repository.Add(Literal(kNoLiteralIndex), {x, 1}, {z, -1}, 0, 10); + RootLevelLinear2Bounds* root_level_bounds = + model.GetOrCreate(); + repository.Add(lit_a, LinearExpression2(NegationOf(x), y, 1, 1), 2, 8); + root_level_bounds->Add(LinearExpression2(x, y, 2, -2), 0, 10); + repository.Add(lit_a, LinearExpression2(x, NegationOf(y), -3, 2), 1, 15); + repository.Add(lit_b, LinearExpression2(x, kNoIntegerVariable, -3, 0), 3, 5); + root_level_bounds->Add(LinearExpression2(x, y, 3, -1), 5, 15); + root_level_bounds->Add(LinearExpression2::Difference(x, z), 0, 10); repository.AddPartialRelation(lit_b, x, z); repository.Build(); - EXPECT_EQ(repository.size(), 7); - EXPECT_EQ(repository.relation(0), (Relation{lit_a, {x, -1}, {y, 1}, 2, 8})); - EXPECT_EQ(repository.relation(1), - (Relation{Literal(kNoLiteralIndex), {x, 2}, {y, -2}, 0, 10})); - EXPECT_EQ(repository.relation(2), (Relation{lit_a, {x, -3}, {y, -2}, 1, 15})); - EXPECT_EQ(repository.relation(3), - (Relation{lit_b, {x, -3}, {kNoIntegerVariable, 0}, 3, 5})); - EXPECT_EQ(repository.relation(6), (Relation{lit_b, {x, 1}, {z, 1}, 0, 0})); - EXPECT_THAT(repository.IndicesOfRelationsEnforcedBy(lit_a), - UnorderedElementsAre(0, 2)); - EXPECT_THAT(repository.IndicesOfRelationsEnforcedBy(lit_b), - UnorderedElementsAre(3, 6)); - EXPECT_THAT(repository.IndicesOfRelationsContaining(x), - UnorderedElementsAre(1, 4, 5)); - EXPECT_THAT(repository.IndicesOfRelationsContaining(y), - UnorderedElementsAre(1, 4)); - EXPECT_THAT(repository.IndicesOfRelationsContaining(z), - UnorderedElementsAre(5)); - EXPECT_THAT(repository.IndicesOfRelationsBetween(x, y), - UnorderedElementsAre(1, 4)); - EXPECT_THAT(repository.IndicesOfRelationsBetween(y, x), - UnorderedElementsAre(1, 4)); - EXPECT_THAT(repository.IndicesOfRelationsBetween(x, z), - UnorderedElementsAre(5)); - EXPECT_THAT(repository.IndicesOfRelationsBetween(z, y), IsEmpty()); + auto get_rel = [&](absl::Span indexes) { + std::vector result; + for (int i : indexes) { + result.push_back(repository.relation(i)); + } + return result; + }; + std::vector all(repository.size()); + std::iota(all.begin(), all.end(), 0); + EXPECT_THAT( + get_rel(all), + UnorderedElementsAre( + Relation{lit_a, LinearExpression2(x, y, -1, 1), 2, 8}, + Relation{lit_a, LinearExpression2(x, y, -3, -2), 1, 15}, + Relation{lit_b, LinearExpression2(kNoIntegerVariable, x, 0, -3), 3, + 5}, + Relation{lit_b, LinearExpression2(x, z, 1, 1), 0, 0})); + EXPECT_THAT(get_rel(repository.IndicesOfRelationsEnforcedBy(lit_a)), + UnorderedElementsAre( + Relation{lit_a, LinearExpression2(x, y, -1, 1), 2, 8}, + Relation{lit_a, LinearExpression2(x, y, -3, -2), 1, 15})); + EXPECT_THAT( + get_rel(repository.IndicesOfRelationsEnforcedBy(lit_b)), + UnorderedElementsAre( + Relation{lit_b, LinearExpression2(kNoIntegerVariable, x, 0, -3), 3, + 5}, + Relation{lit_b, LinearExpression2(x, z, 1, 1), 0, 0})); + EXPECT_THAT(root_level_bounds->GetAllBoundsContainingVariable(x), + UnorderedElementsAre( + FieldsAre(LinearExpression2(x, NegationOf(y), 1, 1), 0, 5), + + FieldsAre(LinearExpression2(x, NegationOf(y), 3, 1), 5, 15), + FieldsAre(LinearExpression2(x, NegationOf(z), 1, 1), 0, 10))); + EXPECT_THAT( + root_level_bounds->GetAllBoundsContainingVariable(y), + UnorderedElementsAre(FieldsAre(LinearExpression2(y, x, -1, 1), 0, 5), + FieldsAre(LinearExpression2(y, x, -1, 3), 5, 15))); + EXPECT_THAT( + root_level_bounds->GetAllBoundsContainingVariable(z), + UnorderedElementsAre(FieldsAre(LinearExpression2(z, x, -1, 1), 0, 10))); + EXPECT_THAT( + root_level_bounds->GetAllBoundsContainingVariables(x, y), + UnorderedElementsAre(FieldsAre(LinearExpression2(x, y, 1, -1), 0, 5), + FieldsAre(LinearExpression2(x, y, 3, -1), 5, 15))); + EXPECT_THAT( + root_level_bounds->GetAllBoundsContainingVariables(y, x), + UnorderedElementsAre(FieldsAre(LinearExpression2(y, x, -1, 1), 0, 5), + FieldsAre(LinearExpression2(y, x, -1, 3), 5, 15))); + EXPECT_THAT( + root_level_bounds->GetAllBoundsContainingVariables(x, z), + UnorderedElementsAre(FieldsAre(LinearExpression2(x, z, 1, -1), 0, 10))); + EXPECT_THAT(root_level_bounds->GetAllBoundsContainingVariables(z, y), + IsEmpty()); } std::vector GetRelations(Model& model) { @@ -559,12 +605,11 @@ std::vector GetRelations(Model& model) { std::vector relations; for (int i = 0; i < repository.size(); ++i) { Relation r = repository.relation(i); - if (r.a.coeff < 0) { + if (r.expr.coeffs[0] < 0) { r = Relation({r.enforcement, - {r.a.var, -r.a.coeff}, - {r.b.var, -r.b.coeff}, - -r.rhs, - -r.lhs}); + LinearExpression2(r.expr.vars[0], r.expr.vars[1], + -r.expr.coeffs[0], -r.expr.coeffs[1]), + -r.rhs, -r.lhs}); } relations.push_back(r); } @@ -622,22 +667,16 @@ TEST(BinaryRelationRepositoryTest, LoadCpModelAddUnaryAndBinaryRelations) { LoadCpModel(model_proto, &model); const CpModelMapping& mapping = *model.GetOrCreate(); - EXPECT_THAT(GetRelations(model), - UnorderedElementsAre(Relation{mapping.Literal(0), - {mapping.Integer(2), 1}, - {mapping.Integer(3), -1}, - 0, - 10}, - Relation{mapping.Literal(1), - {mapping.Integer(2), 1}, - {kNoIntegerVariable, 0}, - 5, - 10}, - Relation{Literal(kNoLiteralIndex), - {mapping.Integer(2), 3}, - {mapping.Integer(3), -2}, - -10, - 10})); + EXPECT_THAT( + GetRelations(model), + UnorderedElementsAre(Relation{mapping.Literal(0), + LinearExpression2::Difference( + mapping.Integer(2), mapping.Integer(3)), + 0, 10}, + Relation{mapping.Literal(1), + LinearExpression2(kNoIntegerVariable, + mapping.Integer(2), 0, 1), + 5, 10})); } TEST(BinaryRelationRepositoryTest, @@ -672,8 +711,10 @@ TEST(BinaryRelationRepositoryTest, // - b => x - 10.a in [10, 90] EXPECT_THAT(GetRelations(model), UnorderedElementsAre( - Relation{mapping.Literal(0), {x, 1}, {b, -10}, 10, 90}, - Relation{mapping.Literal(1), {x, 1}, {a, -10}, 10, 90})); + Relation{mapping.Literal(0), LinearExpression2(b, x, 10, -1), + -90, -10}, + Relation{mapping.Literal(1), LinearExpression2(a, x, 10, -1), + -90, -10})); } TEST(BinaryRelationRepositoryTest, @@ -706,10 +747,12 @@ TEST(BinaryRelationRepositoryTest, // Two binary relations enforced by only one literal should be added: // - a => x + 10.b in [10, 90] // - b => x + 10.a in [10, 90] - EXPECT_THAT(GetRelations(model), - UnorderedElementsAre( - Relation{mapping.Literal(0), {x, 1}, {b, 10}, 10, 90}, - Relation{mapping.Literal(1), {x, 1}, {a, 10}, 10, 90})); + EXPECT_THAT( + GetRelations(model), + UnorderedElementsAre( + Relation{mapping.Literal(0), LinearExpression2(b, x, 10, 1), 10, 90}, + Relation{mapping.Literal(1), LinearExpression2(a, x, 10, 1), 10, + 90})); } TEST(BinaryRelationRepositoryTest, @@ -745,8 +788,9 @@ TEST(BinaryRelationRepositoryTest, EXPECT_THAT( GetRelations(model), UnorderedElementsAre( - Relation{mapping.Literal(0), {x, 1}, {b, 10}, 20, 100}, - Relation{mapping.Literal(1).Negated(), {x, 1}, {a, -10}, 10, 90})); + Relation{mapping.Literal(0), LinearExpression2(b, x, 10, 1), 20, 100}, + Relation{mapping.Literal(1).Negated(), + LinearExpression2(a, x, 10, -1), -90, -10})); } TEST(BinaryRelationRepositoryTest, @@ -782,8 +826,9 @@ TEST(BinaryRelationRepositoryTest, EXPECT_THAT( GetRelations(model), UnorderedElementsAre( - Relation{mapping.Literal(0), {x, 1}, {b, -10}, 0, 80}, - Relation{mapping.Literal(1).Negated(), {x, 1}, {a, 10}, 10, 90})); + Relation{mapping.Literal(0), LinearExpression2(b, x, 10, -1), -80, 0}, + Relation{mapping.Literal(1).Negated(), LinearExpression2(a, x, 10, 1), + 10, 90})); } TEST(BinaryRelationRepositoryTest, PropagateLocalBounds_EnforcedRelation) { @@ -792,14 +837,17 @@ TEST(BinaryRelationRepositoryTest, PropagateLocalBounds_EnforcedRelation) { const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); const Literal lit_a = Literal(model.Add(NewBooleanVariable()), true); BinaryRelationRepository repository; - repository.Add(lit_a, {x, -1}, {y, 1}, 2, 10); // lit_a => y => x + 2 + RootLevelLinear2Bounds* root_level_bounds = + model.GetOrCreate(); + repository.Add(lit_a, LinearExpression2::Difference(y, x), 2, + 10); // lit_a => y => x + 2 repository.Build(); IntegerTrail* integer_trail = model.GetOrCreate(); absl::flat_hash_map input = {{x, 3}}; absl::flat_hash_map output; - const bool result = - repository.PropagateLocalBounds(*integer_trail, lit_a, input, &output); + const bool result = repository.PropagateLocalBounds( + *integer_trail, *root_level_bounds, lit_a, input, &output); EXPECT_TRUE(result); EXPECT_THAT(output, UnorderedElementsAre(std::make_pair(NegationOf(x), -8), @@ -808,43 +856,50 @@ TEST(BinaryRelationRepositoryTest, PropagateLocalBounds_EnforcedRelation) { TEST(BinaryRelationRepositoryTest, PropagateLocalBounds_UnenforcedRelation) { Model model; - const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); - const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); + RootLevelLinear2Bounds* root_level_bounds = + model.GetOrCreate(); + const IntegerVariable x = model.Add(NewIntegerVariable(-100, 100)); + const IntegerVariable y = model.Add(NewIntegerVariable(-100, 100)); const Literal lit_a = Literal(model.Add(NewBooleanVariable()), true); - const Literal kNoLiteral = Literal(kNoLiteralIndex); BinaryRelationRepository repository; - repository.Add(lit_a, {x, -1}, {y, 1}, -5, 10); // lit_a => y => x - 5 - repository.Add(kNoLiteral, {x, -1}, {y, 1}, 2, 10); // y => x + 2 + repository.Add(lit_a, LinearExpression2(x, y, -1, 1), -5, + 10); // lit_a => y => x - 5 + root_level_bounds->Add(LinearExpression2(x, y, -1, 1), 2, + 10); // y => x + 2 repository.Build(); IntegerTrail* integer_trail = model.GetOrCreate(); absl::flat_hash_map input = {{x, 3}}; absl::flat_hash_map output; - const bool result = - repository.PropagateLocalBounds(*integer_trail, lit_a, input, &output); + const bool result = repository.PropagateLocalBounds( + *integer_trail, *root_level_bounds, lit_a, input, &output); EXPECT_TRUE(result); - EXPECT_THAT(output, UnorderedElementsAre(std::make_pair(NegationOf(x), -8), + EXPECT_THAT(output, UnorderedElementsAre(std::make_pair(NegationOf(x), -98), std::make_pair(y, 5))); } TEST(BinaryRelationRepositoryTest, PropagateLocalBounds_EnforcedBoundSmallerThanLevelZeroBound) { Model model; + RootLevelLinear2Bounds* root_level_bounds = + model.GetOrCreate(); const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); const Literal lit_a = Literal(model.Add(NewBooleanVariable()), true); const Literal lit_b = Literal(model.Add(NewBooleanVariable()), true); BinaryRelationRepository repository; - repository.Add(lit_a, {x, -1}, {y, 1}, -5, 10); // lit_a => y => x - 5 - repository.Add(lit_b, {x, -1}, {y, 1}, 2, 10); // lit_b => y => x + 2 + repository.Add(lit_a, LinearExpression2::Difference(y, x), -5, + 10); // lit_a => y => x - 5 + repository.Add(lit_b, LinearExpression2::Difference(y, x), 2, + 10); // lit_b => y => x + 2 repository.Build(); IntegerTrail* integer_trail = model.GetOrCreate(); absl::flat_hash_map input = {{x, 3}}; absl::flat_hash_map output; - const bool result = - repository.PropagateLocalBounds(*integer_trail, lit_a, input, &output); + const bool result = repository.PropagateLocalBounds( + *integer_trail, *root_level_bounds, lit_a, input, &output); EXPECT_TRUE(result); EXPECT_THAT(output, IsEmpty()); @@ -857,14 +912,17 @@ TEST(BinaryRelationRepositoryTest, const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); const Literal lit_a = Literal(model.Add(NewBooleanVariable()), true); BinaryRelationRepository repository; - repository.Add(lit_a, {x, -1}, {y, 1}, 2, 10); // lit_a => y => x + 2 + RootLevelLinear2Bounds* root_level_bounds = + model.GetOrCreate(); + repository.Add(lit_a, LinearExpression2::Difference(y, x), 2, + 10); // lit_a => y => x + 2 repository.Build(); IntegerTrail* integer_trail = model.GetOrCreate(); absl::flat_hash_map input = {{x, 3}}; absl::flat_hash_map output = {{y, 8}}; - const bool result = - repository.PropagateLocalBounds(*integer_trail, lit_a, input, &output); + const bool result = repository.PropagateLocalBounds( + *integer_trail, *root_level_bounds, lit_a, input, &output); EXPECT_TRUE(result); EXPECT_THAT(output, UnorderedElementsAre(std::make_pair(NegationOf(x), -8), @@ -877,14 +935,17 @@ TEST(BinaryRelationRepositoryTest, PropagateLocalBounds_Infeasible) { const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); const Literal lit_a = Literal(model.Add(NewBooleanVariable()), true); BinaryRelationRepository repository; - repository.Add(lit_a, {x, -1}, {y, 1}, 8, 10); // lit_a => y => x + 8 + RootLevelLinear2Bounds* root_level_bounds = + model.GetOrCreate(); + repository.Add(lit_a, LinearExpression2::Difference(y, x), 8, + 10); // lit_a => y => x + 8 repository.Build(); IntegerTrail* integer_trail = model.GetOrCreate(); absl::flat_hash_map input = {{x, 3}}; absl::flat_hash_map output; - const bool result = - repository.PropagateLocalBounds(*integer_trail, lit_a, input, &output); + const bool result = repository.PropagateLocalBounds( + *integer_trail, *root_level_bounds, lit_a, input, &output); EXPECT_FALSE(result); EXPECT_THAT(output, UnorderedElementsAre(std::make_pair(NegationOf(x), -2), @@ -903,9 +964,12 @@ TEST(GreaterThanAtLeastOneOfDetectorTest, AddGreaterThanAtLeastOneOf) { model.Add(ClauseConstraint({lit_a, lit_b, lit_c})); auto* repository = model.GetOrCreate(); - repository->Add(lit_a, {a, -1}, {d, 1}, 2, 1000); // d >= a + 2 - repository->Add(lit_b, {b, -1}, {d, 1}, -1, 1000); // d >= b -1 - repository->Add(lit_c, {c, -1}, {d, 1}, 0, 1000); // d >= c + repository->Add(lit_a, LinearExpression2::Difference(d, a), 2, + 1000); // d >= a + 2 + repository->Add(lit_b, LinearExpression2::Difference(d, b), -1, + 1000); // d >= b -1 + repository->Add(lit_c, LinearExpression2::Difference(d, c), 0, + 1000); // d >= c repository->Build(); auto* detector = model.GetOrCreate(); @@ -931,9 +995,11 @@ TEST(GreaterThanAtLeastOneOfDetectorTest, model.Add(ClauseConstraint({lit_a, lit_b, lit_c})); auto* repository = model.GetOrCreate(); - repository->Add(lit_a, {a, -1}, {d, 1}, 2, 1000); // d >= a + 2 - repository->Add(lit_b, {b, -1}, {d, 1}, -1, 1000); // d >= b -1 - repository->Add(lit_c, {c, -1}, {d, 1}, 0, 1000); // d >= c + repository->Add(lit_a, LinearExpression2(a, d, -1, 1), 2, + 1000); // d >= a + 2 + repository->Add(lit_b, LinearExpression2(b, d, -1, 1), -1, + 1000); // d >= b -1 + repository->Add(lit_c, LinearExpression2(c, d, -1, 1), 0, 1000); // d >= c repository->Build(); auto* detector = model.GetOrCreate(); @@ -947,7 +1013,7 @@ TEST(GreaterThanAtLeastOneOfDetectorTest, EXPECT_EQ(model.Get(LowerBound(d)), std::min({2 + 2, 5 - 1, 3 + 0})); } -TEST(PrecedencesPropagatorTest, ComputeFullPrecedencesIfCycle) { +TEST(TransitivePrecedencesEvaluatorTest, ComputeFullPrecedencesIfCycle) { Model model; std::vector vars(10); for (int i = 0; i < vars.size(); ++i) { @@ -955,18 +1021,19 @@ TEST(PrecedencesPropagatorTest, ComputeFullPrecedencesIfCycle) { } // Even if the weight are compatible, we will fail here. - model.Add(LowerOrEqualWithOffset(vars[0], vars[1], 2)); - model.Add(LowerOrEqualWithOffset(vars[1], vars[2], 2)); - model.Add(LowerOrEqualWithOffset(vars[2], vars[1], -10)); - model.Add(LowerOrEqualWithOffset(vars[0], vars[2], 5)); + auto* r = model.GetOrCreate(); + r->AddUpperBound(LinearExpression2::Difference(vars[0], vars[1]), -2); + r->AddUpperBound(LinearExpression2::Difference(vars[1], vars[2]), -2); + r->AddUpperBound(LinearExpression2::Difference(vars[2], vars[1]), 10); + r->AddUpperBound(LinearExpression2::Difference(vars[0], vars[2]), -5); std::vector precedences; - model.GetOrCreate()->ComputeFullPrecedences( + model.GetOrCreate()->ComputeFullPrecedences( {vars[0], vars[1]}, &precedences); EXPECT_TRUE(precedences.empty()); } -TEST(PrecedencesPropagatorTest, BasicFiltering) { +TEST(TransitivePrecedencesEvaluatorTest, BasicTest1) { Model model; std::vector vars(10); for (int i = 0; i < vars.size(); ++i) { @@ -978,14 +1045,15 @@ TEST(PrecedencesPropagatorTest, BasicFiltering) { // 0 2 -- 4 // \ / // 3 - model.Add(LowerOrEqualWithOffset(vars[0], vars[1], 2)); - model.Add(LowerOrEqualWithOffset(vars[1], vars[2], 2)); - model.Add(LowerOrEqualWithOffset(vars[0], vars[3], 1)); - model.Add(LowerOrEqualWithOffset(vars[3], vars[2], 2)); - model.Add(LowerOrEqualWithOffset(vars[2], vars[4], 2)); + auto* r = model.GetOrCreate(); + r->AddUpperBound(LinearExpression2::Difference(vars[0], vars[1]), -2); + r->AddUpperBound(LinearExpression2::Difference(vars[1], vars[2]), -2); + r->AddUpperBound(LinearExpression2::Difference(vars[0], vars[3]), -1); + r->AddUpperBound(LinearExpression2::Difference(vars[3], vars[2]), -2); + r->AddUpperBound(LinearExpression2::Difference(vars[2], vars[4]), -2); std::vector precedences; - model.GetOrCreate()->ComputeFullPrecedences( + model.GetOrCreate()->ComputeFullPrecedences( {vars[0], vars[1], vars[3]}, &precedences); // We only output size at least 2, and "relevant" precedences. @@ -996,7 +1064,7 @@ TEST(PrecedencesPropagatorTest, BasicFiltering) { EXPECT_THAT(precedences[0].indices, ElementsAre(0, 1, 2)); } -TEST(PrecedencesPropagatorTest, BasicFiltering2) { +TEST(TransitivePrecedencesEvaluatorTest, BasicTest2) { Model model; std::vector vars(10); for (int i = 0; i < vars.size(); ++i) { @@ -1008,15 +1076,16 @@ TEST(PrecedencesPropagatorTest, BasicFiltering2) { // 0 2 -- 4 // \ / / // 3 5 - model.Add(LowerOrEqualWithOffset(vars[0], vars[1], 2)); - model.Add(LowerOrEqualWithOffset(vars[1], vars[2], 2)); - model.Add(LowerOrEqualWithOffset(vars[0], vars[3], 1)); - model.Add(LowerOrEqualWithOffset(vars[3], vars[2], 2)); - model.Add(LowerOrEqualWithOffset(vars[2], vars[4], 2)); - model.Add(LowerOrEqualWithOffset(vars[5], vars[4], 7)); + auto* r = model.GetOrCreate(); + r->AddUpperBound(LinearExpression2::Difference(vars[0], vars[1]), -2); + r->AddUpperBound(LinearExpression2::Difference(vars[1], vars[2]), -2); + r->AddUpperBound(LinearExpression2::Difference(vars[0], vars[3]), -1); + r->AddUpperBound(LinearExpression2::Difference(vars[3], vars[2]), -2); + r->AddUpperBound(LinearExpression2::Difference(vars[2], vars[4]), -2); + r->AddUpperBound(LinearExpression2::Difference(vars[5], vars[4]), -7); std::vector precedences; - model.GetOrCreate()->ComputeFullPrecedences( + model.GetOrCreate()->ComputeFullPrecedences( {vars[0], vars[1], vars[3]}, &precedences); // Same as before here. @@ -1027,7 +1096,7 @@ TEST(PrecedencesPropagatorTest, BasicFiltering2) { // But if we ask for 5, we will get two results. precedences.clear(); - model.GetOrCreate()->ComputeFullPrecedences( + model.GetOrCreate()->ComputeFullPrecedences( {vars[0], vars[1], vars[3], vars[5]}, &precedences); ASSERT_EQ(precedences.size(), 2); EXPECT_EQ(precedences[0].var, vars[2]); @@ -1043,6 +1112,7 @@ TEST(BinaryRelationMapsTest, AffineUpperBound) { const IntegerVariable x = model.Add(NewIntegerVariable(0, 10)); const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); const IntegerVariable z = model.Add(NewIntegerVariable(0, 2)); + const IntegerVariable w = model.Add(NewIntegerVariable(0, 20)); // x - y; LinearExpression2 expr; @@ -1052,35 +1122,41 @@ TEST(BinaryRelationMapsTest, AffineUpperBound) { expr.coeffs[1] = IntegerValue(-1); // Starts with trivial level zero bound. - auto* tested = model.GetOrCreate(); - EXPECT_EQ(tested->UpperBound(expr), IntegerValue(10)); + auto* bounds = model.GetOrCreate(); + auto* lin3_bounds = model.GetOrCreate(); + auto* root_bounds = model.GetOrCreate(); + EXPECT_EQ(bounds->UpperBound(expr), IntegerValue(10)); + + auto* search = model.GetOrCreate(); + search->TakeDecision( + Literal(search->GetDecisionLiteral(BooleanOrIntegerLiteral( + IntegerLiteral::LowerOrEqual(w, IntegerValue(10)))))); // Lets add a relation. - tested->AddRelationBounds(expr, IntegerValue(-5), IntegerValue(5)); - EXPECT_EQ(tested->UpperBound(expr), IntegerValue(5)); + root_bounds->Add(expr, IntegerValue(-5), IntegerValue(5)); + EXPECT_EQ(bounds->UpperBound(expr), IntegerValue(5)); // Note that we canonicalize with gcd. expr.coeffs[0] *= 3; expr.coeffs[1] *= 3; - EXPECT_EQ(tested->UpperBound(expr), IntegerValue(15)); + EXPECT_EQ(bounds->UpperBound(expr), IntegerValue(15)); // Lets add an affine upper bound to that expression <= 4 * z + 1. - EXPECT_TRUE(tested->AddAffineUpperBound( + EXPECT_TRUE(lin3_bounds->AddAffineUpperBound( expr, AffineExpression(z, IntegerValue(4), IntegerValue(1)))); - EXPECT_EQ(tested->UpperBound(expr), IntegerValue(9)); + EXPECT_EQ(bounds->UpperBound(expr), IntegerValue(9)); // Lets test the reason, first push a new bound. - auto* search = model.GetOrCreate(); search->TakeDecision( Literal(search->GetDecisionLiteral(BooleanOrIntegerLiteral( IntegerLiteral::LowerOrEqual(z, IntegerValue(1)))))); // Because of gcd, even though ub(affine) is now 5, we get 3, - EXPECT_EQ(tested->UpperBound(expr), IntegerValue(3)); + EXPECT_EQ(bounds->UpperBound(expr), IntegerValue(3)); { std::vector literal_reason; std::vector integer_reason; - tested->AddReasonForUpperBoundLowerThan(expr, IntegerValue(4), + bounds->AddReasonForUpperBoundLowerThan(expr, IntegerValue(4), &literal_reason, &integer_reason); EXPECT_THAT(literal_reason, ElementsAre()); EXPECT_THAT(integer_reason, @@ -1091,7 +1167,7 @@ TEST(BinaryRelationMapsTest, AffineUpperBound) { { std::vector literal_reason; std::vector integer_reason; - tested->AddReasonForUpperBoundLowerThan(expr, IntegerValue(9), + bounds->AddReasonForUpperBoundLowerThan(expr, IntegerValue(9), &literal_reason, &integer_reason); EXPECT_THAT(literal_reason, ElementsAre()); EXPECT_THAT(integer_reason, @@ -1101,7 +1177,7 @@ TEST(BinaryRelationMapsTest, AffineUpperBound) { // This is implied by the level zero relation x <= 5 std::vector literal_reason; std::vector integer_reason; - tested->AddReasonForUpperBoundLowerThan(expr, IntegerValue(15), + bounds->AddReasonForUpperBoundLowerThan(expr, IntegerValue(15), &literal_reason, &integer_reason); EXPECT_THAT(literal_reason, ElementsAre()); EXPECT_THAT(integer_reason, ElementsAre()); @@ -1110,7 +1186,7 @@ TEST(BinaryRelationMapsTest, AffineUpperBound) { // Note that the bound works on the canonicalized expr. expr.coeffs[0] /= 3; expr.coeffs[1] /= 3; - EXPECT_EQ(tested->UpperBound(expr), IntegerValue(1)); + EXPECT_EQ(bounds->UpperBound(expr), IntegerValue(1)); } } // namespace diff --git a/ortools/sat/primary_variables.cc b/ortools/sat/primary_variables.cc index 3071139260..c32b1b0654 100644 --- a/ortools/sat/primary_variables.cc +++ b/ortools/sat/primary_variables.cc @@ -411,6 +411,21 @@ VariableRelationships ComputeVariableRelationships(const CpModelProto& model) { -num_times_variable_appears_as_preferred_to_deduce[b], -num_times_variable_appears_as_deducible[b]); }); + + // Put in front of the queue all the variables that can readily be deduced + // using some constraint. + for (int c = 0; c < model.constraints_size(); ++c) { + ConstraintData& data = constraint_data[c]; + if (data.input_vars.size() + data.deducible_vars.size() != 1) { + continue; + } + if (!data.deducible_vars.empty()) { + vars_queue.push_front(*data.deducible_vars.begin()); + } else if (data.is_linear_inequality) { + vars_queue.push_front(*data.input_vars.begin()); + } + } + std::vector constraints_to_check; while (!vars_queue.empty()) { const int v = vars_queue.front(); diff --git a/ortools/sat/python/BUILD.bazel b/ortools/sat/python/BUILD.bazel index ff0364fcba..0def869eee 100644 --- a/ortools/sat/python/BUILD.bazel +++ b/ortools/sat/python/BUILD.bazel @@ -89,6 +89,7 @@ pybind_extension( ":linear_expr", ":linear_expr_doc", ":proto_builder_pybind11", + "//ortools/base:string_view_migration", "//ortools/sat:cp_model_cc_proto", "//ortools/sat:cp_model_utils", "//ortools/sat:sat_parameters_cc_proto", diff --git a/ortools/sat/python/CMakeLists.txt b/ortools/sat/python/CMakeLists.txt index 6e91bfdefb..c42f3e5844 100644 --- a/ortools/sat/python/CMakeLists.txt +++ b/ortools/sat/python/CMakeLists.txt @@ -11,7 +11,57 @@ # See the License for the specific language governing permissions and # limitations under the License. -pybind11_add_module(cp_model_helper_pybind11 MODULE cp_model_helper.cc) +set(WRAPPERS_NAME sat_python_wrappers) + +add_library(${WRAPPERS_NAME} OBJECT wrappers.h wrappers.cc) +set_target_properties(${WRAPPERS_NAME} PROPERTIES + POSITION_INDEPENDENT_CODE ON) +target_include_directories(${WRAPPERS_NAME} PUBLIC + ${PROJECT_SOURCE_DIR} + ${PROJECT_BINARY_DIR}) +target_link_libraries(${WRAPPERS_NAME} PUBLIC + absl::memory + absl::synchronization + absl::str_format + protobuf::libprotobuf) +add_library(${PROJECT_NAMESPACE}::${WRAPPERS_NAME} ALIAS ${WRAPPERS_NAME}) + +# gen_proto_builder_pybind11 code generator. +add_executable(gen_proto_builder_pybind11) +target_sources(gen_proto_builder_pybind11 PRIVATE "gen_proto_builder_pybind11.cc") +target_include_directories(gen_proto_builder_pybind11 PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +target_compile_features(gen_proto_builder_pybind11 PRIVATE cxx_std_17) +target_link_libraries(gen_proto_builder_pybind11 PRIVATE + absl::flags_commandlineflag + absl::flags_parse + absl::flags_usage + absl::die_if_null + absl::str_format + protobuf::libprotobuf + ${PROJECT_NAMESPACE}::ortools_proto + ${PROJECT_NAMESPACE}::${WRAPPERS_NAME}) + +include(GNUInstallDirs) +if(APPLE) + set_target_properties(gen_proto_builder_pybind11 PROPERTIES INSTALL_RPATH + "@loader_path/../${CMAKE_INSTALL_LIBDIR};@loader_path") +elseif(UNIX) + cmake_path(RELATIVE_PATH CMAKE_INSTALL_FULL_LIBDIR + BASE_DIRECTORY ${CMAKE_INSTALL_FULL_BINDIR} + OUTPUT_VARIABLE libdir_relative_path) + set_target_properties(gen_proto_builder_pybind11 PROPERTIES + INSTALL_RPATH "$ORIGIN/${libdir_relative_path}") +endif() + +install(TARGETS gen_proto_builder_pybind11) + +add_custom_command( + OUTPUT proto_builder_pybind11.h + COMMAND gen_proto_builder_pybind11 > proto_builder_pybind11.h + COMMENT "Generate C++ proto_builder_pybind11.h" + VERBATIM) + +pybind11_add_module(cp_model_helper_pybind11 MODULE cp_model_helper.cc proto_builder_pybind11.h) set_target_properties(cp_model_helper_pybind11 PROPERTIES LIBRARY_OUTPUT_NAME "cp_model_helper") @@ -26,8 +76,8 @@ elseif(UNIX) endif() target_link_libraries(cp_model_helper_pybind11 PRIVATE ${PROJECT_NAMESPACE}::ortools - pybind11_native_proto_caster - protobuf::libprotobuf) + protobuf::libprotobuf + ) target_include_directories(cp_model_helper_pybind11 PRIVATE ${protobuf_SOURCE_DIR}) add_library(${PROJECT_NAMESPACE}::cp_model_helper_pybind11 ALIAS cp_model_helper_pybind11) diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index 435be0f7ba..cd97ae8aa1 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -45,7 +45,6 @@ Other methods and functions listed are primarily used for developing OR-Tools, rather than for solving specific optimization problems. """ -import copy import threading import time from typing import ( @@ -65,19 +64,21 @@ import warnings import numpy as np import pandas as pd -from ortools.sat import cp_model_pb2 -from ortools.sat import sat_parameters_pb2 from ortools.sat.python import cp_model_helper as cmh -from ortools.sat.python import cp_model_numbers as cmn from ortools.util.python import sorted_interval_list # Import external types. -Domain = sorted_interval_list.Domain BoundedLinearExpression = cmh.BoundedLinearExpression +CpModelProto = cmh.CpModelProto +CpSolverResponse = cmh.CpSolverResponse +CpSolverStatus = cmh.CpSolverStatus +Domain = sorted_interval_list.Domain FlatFloatExpr = cmh.FlatFloatExpr FlatIntExpr = cmh.FlatIntExpr +IntVar = cmh.IntVar LinearExpr = cmh.LinearExpr NotBooleanVariable = cmh.NotBooleanVariable +SatParameters = cmh.SatParameters # The classes below allow linear expressions to be expressed naturally with the @@ -90,39 +91,52 @@ INT32_MIN = -(2**31) INT32_MAX = 2**31 - 1 # CpSolver status (exported to avoid importing cp_model_cp2). -UNKNOWN = cp_model_pb2.UNKNOWN -MODEL_INVALID = cp_model_pb2.MODEL_INVALID -FEASIBLE = cp_model_pb2.FEASIBLE -INFEASIBLE = cp_model_pb2.INFEASIBLE -OPTIMAL = cp_model_pb2.OPTIMAL +UNKNOWN = cmh.CpSolverStatus.UNKNOWN +UNKNOWN = cmh.CpSolverStatus.UNKNOWN +MODEL_INVALID = cmh.CpSolverStatus.MODEL_INVALID +FEASIBLE = cmh.CpSolverStatus.FEASIBLE +INFEASIBLE = cmh.CpSolverStatus.INFEASIBLE +OPTIMAL = cmh.CpSolverStatus.OPTIMAL # Variable selection strategy -CHOOSE_FIRST = cp_model_pb2.DecisionStrategyProto.CHOOSE_FIRST -CHOOSE_LOWEST_MIN = cp_model_pb2.DecisionStrategyProto.CHOOSE_LOWEST_MIN -CHOOSE_HIGHEST_MAX = cp_model_pb2.DecisionStrategyProto.CHOOSE_HIGHEST_MAX -CHOOSE_MIN_DOMAIN_SIZE = cp_model_pb2.DecisionStrategyProto.CHOOSE_MIN_DOMAIN_SIZE -CHOOSE_MAX_DOMAIN_SIZE = cp_model_pb2.DecisionStrategyProto.CHOOSE_MAX_DOMAIN_SIZE +CHOOSE_FIRST = cmh.DecisionStrategyProto.VariableSelectionStrategy.CHOOSE_FIRST +CHOOSE_LOWEST_MIN = ( + cmh.DecisionStrategyProto.VariableSelectionStrategy.CHOOSE_LOWEST_MIN +) +CHOOSE_HIGHEST_MAX = ( + cmh.DecisionStrategyProto.VariableSelectionStrategy.CHOOSE_HIGHEST_MAX +) +CHOOSE_MIN_DOMAIN_SIZE = ( + cmh.DecisionStrategyProto.VariableSelectionStrategy.CHOOSE_MIN_DOMAIN_SIZE +) +CHOOSE_MAX_DOMAIN_SIZE = ( + cmh.DecisionStrategyProto.VariableSelectionStrategy.CHOOSE_MAX_DOMAIN_SIZE +) # Domain reduction strategy -SELECT_MIN_VALUE = cp_model_pb2.DecisionStrategyProto.SELECT_MIN_VALUE -SELECT_MAX_VALUE = cp_model_pb2.DecisionStrategyProto.SELECT_MAX_VALUE -SELECT_LOWER_HALF = cp_model_pb2.DecisionStrategyProto.SELECT_LOWER_HALF -SELECT_UPPER_HALF = cp_model_pb2.DecisionStrategyProto.SELECT_UPPER_HALF -SELECT_MEDIAN_VALUE = cp_model_pb2.DecisionStrategyProto.SELECT_MEDIAN_VALUE -SELECT_RANDOM_HALF = cp_model_pb2.DecisionStrategyProto.SELECT_RANDOM_HALF +SELECT_MIN_VALUE = cmh.DecisionStrategyProto.DomainReductionStrategy.SELECT_MIN_VALUE +SELECT_MAX_VALUE = cmh.DecisionStrategyProto.DomainReductionStrategy.SELECT_MAX_VALUE +SELECT_LOWER_HALF = cmh.DecisionStrategyProto.DomainReductionStrategy.SELECT_LOWER_HALF +SELECT_UPPER_HALF = cmh.DecisionStrategyProto.DomainReductionStrategy.SELECT_UPPER_HALF +SELECT_MEDIAN_VALUE = ( + cmh.DecisionStrategyProto.DomainReductionStrategy.SELECT_MEDIAN_VALUE +) +SELECT_RANDOM_HALF = ( + cmh.DecisionStrategyProto.DomainReductionStrategy.SELECT_RANDOM_HALF +) # Search branching -AUTOMATIC_SEARCH = sat_parameters_pb2.SatParameters.AUTOMATIC_SEARCH -FIXED_SEARCH = sat_parameters_pb2.SatParameters.FIXED_SEARCH -PORTFOLIO_SEARCH = sat_parameters_pb2.SatParameters.PORTFOLIO_SEARCH -LP_SEARCH = sat_parameters_pb2.SatParameters.LP_SEARCH -PSEUDO_COST_SEARCH = sat_parameters_pb2.SatParameters.PSEUDO_COST_SEARCH +AUTOMATIC_SEARCH = cmh.SatParameters.SearchBranching.AUTOMATIC_SEARCH +FIXED_SEARCH = cmh.SatParameters.SearchBranching.FIXED_SEARCH +PORTFOLIO_SEARCH = cmh.SatParameters.SearchBranching.PORTFOLIO_SEARCH +LP_SEARCH = cmh.SatParameters.SearchBranching.LP_SEARCH +PSEUDO_COST_SEARCH = cmh.SatParameters.SearchBranching.PSEUDO_COST_SEARCH PORTFOLIO_WITH_QUICK_RESTART_SEARCH = ( - sat_parameters_pb2.SatParameters.PORTFOLIO_WITH_QUICK_RESTART_SEARCH + cmh.SatParameters.SearchBranching.PORTFOLIO_WITH_QUICK_RESTART_SEARCH ) -HINT_SEARCH = sat_parameters_pb2.SatParameters.HINT_SEARCH -PARTIAL_FIXED_SEARCH = sat_parameters_pb2.SatParameters.PARTIAL_FIXED_SEARCH -RANDOMIZED_SEARCH = sat_parameters_pb2.SatParameters.RANDOMIZED_SEARCH +HINT_SEARCH = cmh.SatParameters.SearchBranching.HINT_SEARCH +PARTIAL_FIXED_SEARCH = cmh.SatParameters.SearchBranching.PARTIAL_FIXED_SEARCH +RANDOMIZED_SEARCH = cmh.SatParameters.SearchBranching.RANDOMIZED_SEARCH # Type aliases IntegralT = Union[int, np.int8, np.uint8, np.int32, np.uint32, np.int64, np.uint64] @@ -170,34 +184,17 @@ ArcT = Tuple[IntegralT, IntegralT, LiteralT] _IndexOrSeries = Union[pd.Index, pd.Series] -def display_bounds(bounds: Sequence[int]) -> str: - """Displays a flattened list of intervals.""" - out = "" - for i in range(0, len(bounds), 2): - if i != 0: - out += ", " - if bounds[i] == bounds[i + 1]: - out += str(bounds[i]) - else: - out += str(bounds[i]) + ".." + str(bounds[i + 1]) - return out - - -def short_name(model: cp_model_pb2.CpModelProto, i: int) -> str: +def short_name(model: cmh.CpModelProto, i: int) -> str: """Returns a short name of an integer variable, or its negation.""" - if i < 0: - return f"not({short_name(model, -i - 1)})" - v = model.variables[i] - if v.name: - return v.name - elif len(v.domain) == 2 and v.domain[0] == v.domain[1]: - return str(v.domain[0]) + if i >= 0: + return str(IntVar(model, i)) else: - return f"[{display_bounds(v.domain)}]" + return f"not({IntVar(model, -i - 1)})" def short_expr_name( - model: cp_model_pb2.CpModelProto, e: cp_model_pb2.LinearExpressionProto + model: cmh.CpModelProto, + e: cmh.LinearExpressionProto, ) -> str: """Pretty-print LinearExpressionProto instances.""" if not e.vars: @@ -221,106 +218,41 @@ def short_expr_name( return str(e) -class IntVar(cmh.BaseIntVar): - """An integer variable. +def arg_is_boolean(x: Any) -> bool: + """Checks if the x is a boolean.""" + if isinstance(x, bool): + return True + if isinstance(x, np.bool_): + return True + return False - An IntVar is an object that can take on any integer value within defined - ranges. Variables appear in constraint like: - x + y >= 5 - AllDifferent([x, y, z]) - - Solving a model is equivalent to finding, for each variable, a single value - from the set of initial values (called the initial domain), such that the - model is feasible, or optimal if you provided an objective function. - """ - - def __init__( - self, - model: cp_model_pb2.CpModelProto, - domain: Union[int, sorted_interval_list.Domain], - is_boolean: bool, - name: Optional[str], - ) -> None: - """See CpModel.new_int_var below.""" - self.__model: cp_model_pb2.CpModelProto = model - # Python do not support multiple __init__ methods. - # This method is only called from the CpModel class. - # We hack the parameter to support the two cases: - # case 1: - # model is a CpModelProto, domain is a Domain, and name is a string. - # case 2: - # model is a CpModelProto, domain is an index (int), and name is None. - if isinstance(domain, IntegralTypes) and name is None: - cmh.BaseIntVar.__init__(self, int(domain), is_boolean) +def rebuild_from_linear_expression_proto( + proto: cmh.LinearExpressionProto, + model_proto: cmh.CpModelProto, +) -> LinearExprT: + """Recreate a LinearExpr from a LinearExpressionProto.""" + num_elements = len(proto.vars) + if num_elements == 0: + return proto.offset + elif num_elements == 1: + var = IntVar(model_proto, proto.vars[0]) + return LinearExpr.affine( + var, proto.coeffs[0], proto.offset + ) # pytype: disable=bad-return-type + else: + variables = [] + for var_index in range(len(proto.vars)): + var = IntVar(model_proto, var_index) + variables.append(var) + if proto.offset != 0: + coeffs = [] + coeffs.extend(proto.coeffs) + coeffs.append(1) + variables.append(proto.offset) + return LinearExpr.weighted_sum(variables, coeffs) else: - cmh.BaseIntVar.__init__(self, len(model.variables), is_boolean) - proto: cp_model_pb2.IntegerVariableProto = self.__model.variables.add() - proto.domain.extend( - cast(sorted_interval_list.Domain, domain).flattened_intervals() - ) - if name is not None: - proto.name = name - - def __copy__(self) -> "IntVar": - """Returns a shallowcopy of the variable.""" - return IntVar(self.__model, self.index, self.is_boolean, None) - - def __deepcopy__(self, memo: Any) -> "IntVar": - """Returns a deepcopy of the variable.""" - return IntVar( - copy.deepcopy(self.__model, memo), self.index, self.is_boolean, None - ) - - @property - def proto(self) -> cp_model_pb2.IntegerVariableProto: - """Returns the variable protobuf.""" - return self.__model.variables[self.index] - - @property - def model_proto(self) -> cp_model_pb2.CpModelProto: - """Returns the model protobuf.""" - return self.__model - - def is_equal_to(self, other: Any) -> bool: - """Returns true if self == other in the python sense.""" - if not isinstance(other, IntVar): - return False - return self.index == other.index - - def __str__(self) -> str: - if not self.proto.name: - if ( - len(self.proto.domain) == 2 - and self.proto.domain[0] == self.proto.domain[1] - ): - # Special case for constants. - return str(self.proto.domain[0]) - elif self.is_boolean: - return f"BooleanVar({self.__index})" - else: - return f"IntVar({self.__index})" - else: - return self.proto.name - - def __repr__(self) -> str: - return f"{self}({display_bounds(self.proto.domain)})" - - @property - def name(self) -> str: - if not self.proto or not self.proto.name: - return "" - return self.proto.name - - # Pre PEP8 compatibility. - # pylint: disable=invalid-name - def Name(self) -> str: - return self.name - - def Proto(self) -> cp_model_pb2.IntegerVariableProto: - return self.proto - - # pylint: enable=invalid-name + return LinearExpr.weighted_sum(variables, proto.coeffs) class Constraint: @@ -338,24 +270,22 @@ class Constraint: model.add(x + 2 * y == 5).only_enforce_if(b.negated()) """ - def __init__( - self, - cp_model: "CpModel", - ) -> None: - self.__index: int = len(cp_model.proto.constraints) + def __init__(self, cp_model: "CpModel", index: Optional[int] = None) -> None: self.__cp_model: "CpModel" = cp_model - self.__constraint: cp_model_pb2.ConstraintProto = ( + if index is None: + self.__index: int = len(cp_model.proto.constraints) cp_model.proto.constraints.add() - ) + else: + self.__index: int = index @overload - def only_enforce_if(self, boolvar: Iterable[LiteralT]) -> "Constraint": ... + def only_enforce_if(self, literals: Iterable[LiteralT]) -> "Constraint": ... @overload - def only_enforce_if(self, *boolvar: LiteralT) -> "Constraint": ... + def only_enforce_if(self, *literals: LiteralT) -> "Constraint": ... - def only_enforce_if(self, *boolvar) -> "Constraint": - """Adds an enforcement literal to the constraint. + def only_enforce_if(self, *literals) -> "Constraint": + """Adds one or more enforcement literals to the constraint. This method adds one or more literals (that is, a boolean variable or its negation) as enforcement literals. The conjunction of all these literals @@ -366,43 +296,30 @@ class Constraint: BoolOr, BoolAnd, and linear constraints all support enforcement literals. Args: - *boolvar: One or more Boolean literals. + *literals: One or more Boolean literals. Returns: self. """ - for lit in expand_generator_or_tuple(boolvar): - if (cmn.is_boolean(lit) and lit) or ( - isinstance(lit, IntegralTypes) and lit == 1 - ): - # Always true. Do nothing. - pass - elif (cmn.is_boolean(lit) and not lit) or ( - isinstance(lit, IntegralTypes) and lit == 0 - ): - self.__constraint.enforcement_literal.append( - self.__cp_model.new_constant(0).index - ) - else: - self.__constraint.enforcement_literal.append( - cast(cmh.Literal, lit).index - ) + cmh.CpSatHelper.add_enforcement_literals( + self.__index, + self.__cp_model.expand_literals_to_index_list(literals), + self.__cp_model.proto, + ) return self def with_name(self, name: str) -> "Constraint": """Sets the name of the constraint.""" if name: - self.__constraint.name = name + cmh.CpSatHelper.set_ct_name(self.__index, name, self.__cp_model.proto) else: - self.__constraint.ClearField("name") + cmh.CpSatHelper.clear_ct_name(self.__index, self.__cp_model.proto) return self @property def name(self) -> str: """Returns the name of the constraint.""" - if not self.__constraint or not self.__constraint.name: - return "" - return self.__constraint.name + return cmh.CpSatHelper.ct_name(self.__index, self.__cp_model.proto) @property def index(self) -> int: @@ -410,9 +327,15 @@ class Constraint: return self.__index @property - def proto(self) -> cp_model_pb2.ConstraintProto: + def proto(self) -> cmh.ConstraintProto: """Returns the constraint protobuf.""" - return self.__constraint + return self.__cp_model.proto.constraints[self.__index] + + def __str__(self) -> str: + return ( + f"Constraint({self.__index}," + f" {self.__cp_model.proto.constraints[self.__index]})" + ) # Pre PEP8 compatibility. # pylint: disable=invalid-name @@ -425,55 +348,12 @@ class Constraint: def Index(self) -> int: return self.index - def Proto(self) -> cp_model_pb2.ConstraintProto: + def Proto(self) -> cmh.ConstraintProto: return self.proto # pylint: enable=invalid-name -class VariableList: - """Stores all integer variables of the model.""" - - def __init__(self) -> None: - self.__var_list: list[IntVar] = [] - - def append(self, var: IntVar) -> None: - assert var.index == len(self.__var_list) - self.__var_list.append(var) - - def get(self, index: int) -> IntVar: - if index < 0 or index >= len(self.__var_list): - raise ValueError("Index out of bounds.") - return self.__var_list[index] - - def rebuild_expr( - self, - proto: cp_model_pb2.LinearExpressionProto, - ) -> LinearExprT: - """Recreate a LinearExpr from a LinearExpressionProto.""" - num_elements = len(proto.vars) - if num_elements == 0: - return proto.offset - elif num_elements == 1: - var = self.get(proto.vars[0]) - return LinearExpr.affine( - var, proto.coeffs[0], proto.offset - ) # pytype: disable=bad-return-type - else: - variables = [] - for var_index in range(len(proto.vars)): - var = self.get(var_index) - variables.append(var) - if proto.offset != 0: - coeffs = [] - coeffs.extend(proto.coeffs) - coeffs.append(1) - variables.append(proto.offset) - return LinearExpr.weighted_sum(variables, coeffs) - else: - return LinearExpr.weighted_sum(variables, proto.coeffs) - - class IntervalVar: """Represents an Interval variable. @@ -497,18 +377,16 @@ class IntervalVar: def __init__( self, - model: cp_model_pb2.CpModelProto, - var_list: VariableList, - start: Union[cp_model_pb2.LinearExpressionProto, int], - size: Optional[cp_model_pb2.LinearExpressionProto], - end: Optional[cp_model_pb2.LinearExpressionProto], + model: cmh.CpModelProto, + start: Union[cmh.LinearExpressionProto, int], + size: Optional[cmh.LinearExpressionProto], + end: Optional[cmh.LinearExpressionProto], is_present_index: Optional[int], name: Optional[str], ) -> None: - self.__model: cp_model_pb2.CpModelProto = model - self.__var_list: VariableList = var_list + self.__model: cmh.CpModelProto = model self.__index: int - self.__ct: cp_model_pb2.ConstraintProto + self.__ct: cmh.ConstraintProto # As with the IntVar::__init__ method, we hack the __init__ method to # support two use cases: # case 1: called when creating a new interval variable. @@ -530,13 +408,13 @@ class IntervalVar: self.__ct = self.__model.constraints.add() if start is None: raise TypeError("start is not defined") - self.__ct.interval.start.CopyFrom(start) + self.__ct.interval.start.copy_from(start) if size is None: raise TypeError("size is not defined") - self.__ct.interval.size.CopyFrom(size) + self.__ct.interval.size.copy_from(size) if end is None: raise TypeError("end is not defined") - self.__ct.interval.end.CopyFrom(end) + self.__ct.interval.end.copy_from(end) if is_present_index is not None: self.__ct.enforcement_literal.append(is_present_index) if name: @@ -548,12 +426,12 @@ class IntervalVar: return self.__index @property - def proto(self) -> cp_model_pb2.ConstraintProto: + def proto(self) -> cmh.ConstraintProto: """Returns the interval protobuf.""" return self.__model.constraints[self.__index] @property - def model_proto(self) -> cp_model_pb2.CpModelProto: + def model_proto(self) -> cmh.CpModelProto: """Returns the model protobuf.""" return self.__model @@ -585,13 +463,19 @@ class IntervalVar: return self.proto.name def start_expr(self) -> LinearExprT: - return self.__var_list.rebuild_expr(self.proto.interval.start) + return rebuild_from_linear_expression_proto( + self.proto.interval.start, self.__model + ) def size_expr(self) -> LinearExprT: - return self.__var_list.rebuild_expr(self.proto.interval.size) + return rebuild_from_linear_expression_proto( + self.proto.interval.size, self.__model + ) def end_expr(self) -> LinearExprT: - return self.__var_list.rebuild_expr(self.proto.interval.end) + return rebuild_from_linear_expression_proto( + self.proto.interval.end, self.__model + ) # Pre PEP8 compatibility. # pylint: disable=invalid-name @@ -601,7 +485,7 @@ class IntervalVar: def Index(self) -> int: return self.index - def Proto(self) -> cp_model_pb2.ConstraintProto: + def Proto(self) -> cmh.ConstraintProto: return self.proto StartExpr = start_expr @@ -642,14 +526,13 @@ class CpModel: Methods beginning with: - * ```New``` create integer, boolean, or interval variables. - * ```add``` create new constraints and add them to the model. + * ```new_``` create integer, boolean, or interval variables. + * ```add_``` create new constraints and add them to the model. """ def __init__(self) -> None: - self.__model: cp_model_pb2.CpModelProto = cp_model_pb2.CpModelProto() + self.__model: cmh.CpModelProto = cmh.CpModelProto() self.__constant_map: Dict[IntegralT, int] = {} - self.__var_list: VariableList = VariableList() # Naming. @property @@ -665,21 +548,6 @@ class CpModel: self.__model.name = name # Integer variable. - - def _append_int_var(self, var: IntVar) -> IntVar: - """Appends an integer variable to the list of variables.""" - self.__var_list.append(var) - return var - - def _get_int_var(self, index: int) -> IntVar: - return self.__var_list.get(index) - - def rebuild_from_linear_expression_proto( - self, - proto: cp_model_pb2.LinearExpressionProto, - ) -> LinearExpr: - return self.__var_list.rebuild_expr(proto) - def new_int_var(self, lb: IntegralT, ub: IntegralT, name: str) -> IntVar: """Create an integer variable with domain [lb, ub]. @@ -695,14 +563,10 @@ class CpModel: Returns: a variable whose domain is [lb, ub]. """ - domain_is_boolean = lb >= 0 and ub <= 1 - return self._append_int_var( - IntVar( - self.__model, - sorted_interval_list.Domain(lb, ub), - domain_is_boolean, - name, - ) + return ( + IntVar(self.__model) + .with_name(name) + .with_domain(sorted_interval_list.Domain(lb, ub)) ) def new_int_var_from_domain( @@ -721,21 +585,20 @@ class CpModel: Returns: a variable whose domain is the given domain. """ - domain_is_boolean = domain.min() >= 0 and domain.max() <= 1 - return self._append_int_var( - IntVar(self.__model, domain, domain_is_boolean, name) - ) + return IntVar(self.__model).with_name(name).with_domain(domain) def new_bool_var(self, name: str) -> IntVar: """Creates a 0-1 variable with the given name.""" - return self._append_int_var( - IntVar(self.__model, sorted_interval_list.Domain(0, 1), True, name) + return ( + IntVar(self.__model) + .with_name(name) + .with_domain(sorted_interval_list.Domain(0, 1)) ) def new_constant(self, value: IntegralT) -> IntVar: """Declares a constant integer.""" index: int = self.get_or_make_index_from_constant(value) - return self._get_int_var(index) + return IntVar(self.__model, index) def new_int_var_series( self, @@ -790,15 +653,10 @@ class CpModel: index=index, data=[ # pylint: disable=g-complex-comprehension - self._append_int_var( - IntVar( - model=self.__model, - name=f"{name}[{i}]", - domain=sorted_interval_list.Domain( - lower_bounds[i], upper_bounds[i] - ), - is_boolean=lower_bounds[i] >= 0 and upper_bounds[i] <= 1, - ) + IntVar(self.__model) + .with_name(f"{name}[{i}]") + .with_domain( + sorted_interval_list.Domain(lower_bounds[i], upper_bounds[i]) ) for i in index ], @@ -830,14 +688,9 @@ class CpModel: index=index, data=[ # pylint: disable=g-complex-comprehension - self._append_int_var( - IntVar( - model=self.__model, - name=f"{name}[{i}]", - domain=sorted_interval_list.Domain(0, 1), - is_boolean=True, - ) - ) + IntVar(self.__model) + .with_name(f"{name}[{i}]") + .with_domain(sorted_interval_list.Domain(0, 1)) for i in index ], ) @@ -889,21 +742,15 @@ class CpModel: TypeError: If the `ct` is not a `BoundedLinearExpression` or a Boolean. """ if isinstance(ct, BoundedLinearExpression): - result = Constraint(self) - model_ct = self.__model.constraints[result.index] - for var in ct.vars: - model_ct.linear.vars.append(var.index) - model_ct.linear.coeffs.extend(ct.coeffs) - model_ct.linear.domain.extend( - [ - cmn.capped_subtraction(x, ct.offset) - for x in ct.bounds.flattened_intervals() - ] + return Constraint( + self, + cmh.CpSatHelper.add_bounded_linear_expression_to_model( + ct, self.__model + ), ) - return result - if ct and cmn.is_boolean(ct): + if ct and arg_is_boolean(ct): return self.add_bool_or([True]) - if not ct and cmn.is_boolean(ct): + if not ct and arg_is_boolean(ct): return self.add_bool_or([]) # Evaluate to false. raise TypeError(f"not supported: CpModel.add({type(ct).__name__!r})") @@ -928,7 +775,7 @@ class CpModel: """ ct = Constraint(self) model_ct = self.__model.constraints[ct.index] - expanded = expand_generator_or_tuple(expressions) + expanded = expand_exprs_generator_or_tuple(expressions) model_ct.all_diff.exprs.extend( self.parse_linear_expression(x) for x in expanded ) @@ -960,14 +807,10 @@ class CpModel: expression: LinearExprT = list(expressions)[int(index)] return self.add(expression == target) - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.element.linear_index.CopyFrom(self.parse_linear_expression(index)) - model_ct.element.exprs.extend( - [self.parse_linear_expression(e) for e in expressions] + return Constraint( + self, + cmh.CpSatHelper.add_element(index, expressions, target, self.__model), ) - model_ct.element.linear_target.CopyFrom(self.parse_linear_expression(target)) - return ct def add_circuit(self, arcs: Sequence[ArcT]) -> Constraint: """Adds Circuit(arcs). @@ -1417,15 +1260,9 @@ class CpModel: def add_bool_or(self, *literals): """Adds `Or(literals) == true`: sum(literals) >= 1.""" - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.bool_or.literals.extend( - [ - self.get_or_make_boolean_index(x) - for x in expand_generator_or_tuple(literals) - ] - ) - return ct + lits = self.expand_literals_to_index_list(literals) + index: int = cmh.CpSatHelper.add_bool_or(lits, self.__model) + return Constraint(self, index) @overload def add_at_least_one(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -1443,17 +1280,11 @@ class CpModel: @overload def add_at_most_one(self, *literals: LiteralT) -> Constraint: ... - def add_at_most_one(self, *literals): + def add_at_most_one(self, *literals) -> Constraint: """Adds `AtMostOne(literals)`: `sum(literals) <= 1`.""" - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.at_most_one.literals.extend( - [ - self.get_or_make_boolean_index(x) - for x in expand_generator_or_tuple(literals) - ] - ) - return ct + lits = self.expand_literals_to_index_list(literals) + index: int = cmh.CpSatHelper.add_at_most_one(lits, self.__model) + return Constraint(self, index) @overload def add_exactly_one(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -1463,15 +1294,9 @@ class CpModel: def add_exactly_one(self, *literals): """Adds `ExactlyOne(literals)`: `sum(literals) == 1`.""" - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.exactly_one.literals.extend( - [ - self.get_or_make_boolean_index(x) - for x in expand_generator_or_tuple(literals) - ] - ) - return ct + lits = self.expand_literals_to_index_list(literals) + index: int = cmh.CpSatHelper.add_exactly_one(lits, self.__model) + return Constraint(self, index) @overload def add_bool_and(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -1481,15 +1306,9 @@ class CpModel: def add_bool_and(self, *literals): """Adds `And(literals) == true`.""" - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.bool_and.literals.extend( - [ - self.get_or_make_boolean_index(x) - for x in expand_generator_or_tuple(literals) - ] - ) - return ct + lits = self.expand_literals_to_index_list(literals) + index: int = cmh.CpSatHelper.add_bool_and(lits, self.__model) + return Constraint(self, index) @overload def add_bool_xor(self, literals: Iterable[LiteralT]) -> Constraint: ... @@ -1509,57 +1328,75 @@ class CpModel: Returns: An `Constraint` object. """ - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.bool_xor.literals.extend( - [ - self.get_or_make_boolean_index(x) - for x in expand_generator_or_tuple(literals) - ] - ) - return ct + lits = self.expand_literals_to_index_list(literals) + index: int = cmh.CpSatHelper.add_bool_xor(lits, self.__model) + return Constraint(self, index) + @overload def add_min_equality( - self, target: LinearExprT, exprs: Iterable[LinearExprT] - ) -> Constraint: - """Adds `target == Min(exprs)`.""" - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.lin_max.exprs.extend( - [self.parse_linear_expression(x, True) for x in exprs] - ) - model_ct.lin_max.target.CopyFrom(self.parse_linear_expression(target, True)) - return ct + self, target: LinearExprT, expressions: Iterable[LinearExprT] + ) -> Constraint: ... + @overload + def add_min_equality( + self, target: LinearExprT, *expressions: LinearExprT + ) -> Constraint: ... + + def add_min_equality(self, target, *expressions) -> Constraint: + """Adds `target == Min(expressions)`.""" + exprs = [e for e in expand_exprs_generator_or_tuple(expressions)] + return Constraint( + self, + cmh.CpSatHelper.add_linear_argument_constraint( + "min", + target, + exprs, + self.__model, + ), + ) + + @overload def add_max_equality( - self, target: LinearExprT, exprs: Iterable[LinearExprT] - ) -> Constraint: - """Adds `target == Max(exprs)`.""" - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.lin_max.exprs.extend([self.parse_linear_expression(x) for x in exprs]) - model_ct.lin_max.target.CopyFrom(self.parse_linear_expression(target)) - return ct + self, target: LinearExprT, expressions: Iterable[LinearExprT] + ) -> Constraint: ... + + @overload + def add_max_equality( + self, target: LinearExprT, *expressions: LinearExprT + ) -> Constraint: ... + + def add_max_equality(self, target, *expressions) -> Constraint: + """Adds `target == Max(expressions)`.""" + exprs = [e for e in expand_exprs_generator_or_tuple(expressions)] + return Constraint( + self, + cmh.CpSatHelper.add_linear_argument_constraint( + "max", + target, + exprs, + self.__model, + ), + ) def add_division_equality( self, target: LinearExprT, num: LinearExprT, denom: LinearExprT ) -> Constraint: """Adds `target == num // denom` (integer division rounded towards 0).""" - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.int_div.exprs.append(self.parse_linear_expression(num)) - model_ct.int_div.exprs.append(self.parse_linear_expression(denom)) - model_ct.int_div.target.CopyFrom(self.parse_linear_expression(target)) - return ct + return Constraint( + self, + cmh.CpSatHelper.add_linear_argument_constraint( + "div", target, [num, denom], self.__model + ), + ) def add_abs_equality(self, target: LinearExprT, expr: LinearExprT) -> Constraint: """Adds `target == Abs(expr)`.""" - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.lin_max.exprs.append(self.parse_linear_expression(expr)) - model_ct.lin_max.exprs.append(self.parse_linear_expression(expr, True)) - model_ct.lin_max.target.CopyFrom(self.parse_linear_expression(target)) - return ct + return Constraint( + self, + cmh.CpSatHelper.add_linear_argument_constraint( + "max", target, [expr, -expr], self.__model + ), + ) def add_modulo_equality( self, target: LinearExprT, expr: LinearExprT, mod: LinearExprT @@ -1583,12 +1420,12 @@ class CpModel: Returns: A `Constraint` object. """ - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.int_mod.exprs.append(self.parse_linear_expression(expr)) - model_ct.int_mod.exprs.append(self.parse_linear_expression(mod)) - model_ct.int_mod.target.CopyFrom(self.parse_linear_expression(target)) - return ct + return Constraint( + self, + cmh.CpSatHelper.add_linear_argument_constraint( + "mod", target, [expr, mod], self.__model + ), + ) def add_multiplication_equality( self, @@ -1596,16 +1433,15 @@ class CpModel: *expressions: Union[Iterable[LinearExprT], LinearExprT], ) -> Constraint: """Adds `target == expressions[0] * .. * expressions[n]`.""" - ct = Constraint(self) - model_ct = self.__model.constraints[ct.index] - model_ct.int_prod.exprs.extend( - [ - self.parse_linear_expression(expr) - for expr in expand_generator_or_tuple(expressions) - ] + return Constraint( + self, + cmh.CpSatHelper.add_linear_argument_constraint( + "prod", + target, + expand_exprs_generator_or_tuple(expressions), + self.__model, + ), ) - model_ct.int_prod.target.CopyFrom(self.parse_linear_expression(target)) - return ct # Scheduling support @@ -1646,7 +1482,6 @@ class CpModel: ) return IntervalVar( self.__model, - self.__var_list, start_expr, size_expr, end_expr, @@ -1731,7 +1566,6 @@ class CpModel: ) return IntervalVar( self.__model, - self.__var_list, start_expr, size_expr, end_expr, @@ -1833,7 +1667,6 @@ class CpModel: ) return IntervalVar( self.__model, - self.__var_list, start_expr, size_expr, end_expr, @@ -1932,7 +1765,6 @@ class CpModel: is_present_index = self.get_or_make_boolean_index(is_present) return IntervalVar( self.__model, - self.__var_list, start_expr, size_expr, end_expr, @@ -2074,32 +1906,32 @@ class CpModel: ) for d in demands: model_ct.cumulative.demands.append(self.parse_linear_expression(d)) - model_ct.cumulative.capacity.CopyFrom(self.parse_linear_expression(capacity)) + model_ct.cumulative.capacity.copy_from(self.parse_linear_expression(capacity)) return cumulative # Support for model cloning. def clone(self) -> "CpModel": """Reset the model, and creates a new one from a CpModelProto instance.""" clone = CpModel() - clone.proto.CopyFrom(self.proto) - clone.rebuild_var_and_constant_map() + clone.proto.copy_from(self.proto) + clone.rebuild_constant_map() return clone - def rebuild_var_and_constant_map(self): + def rebuild_constant_map(self): """Internal method used during model cloning.""" for i, var in enumerate(self.__model.variables): if len(var.domain) == 2 and var.domain[0] == var.domain[1]: self.__constant_map[var.domain[0]] = i - is_boolean = ( - len(var.domain) == 2 and var.domain[0] >= 0 and var.domain[1] <= 1 - ) - self.__var_list.append(IntVar(self.__model, i, is_boolean, None)) def get_bool_var_from_proto_index(self, index: int) -> IntVar: """Returns an already created Boolean variable from its index.""" - result = self._get_int_var(index) - if not result.is_boolean: + if index < 0 or index >= len(self.__model.variables): raise ValueError( + f"get_bool_var_from_proto_index: out of bound index {index}" + ) + result = IntVar(self.__model, index) + if not result.is_boolean: + raise TypeError( f"get_bool_var_from_proto_index: index {index} does not reference a" " boolean variable" ) @@ -2107,7 +1939,11 @@ class CpModel: def get_int_var_from_proto_index(self, index: int) -> IntVar: """Returns an already created integer variable from its index.""" - return self._get_int_var(index) + if index < 0 or index >= len(self.__model.variables): + raise ValueError( + f"get_int_var_from_proto_index: out of bound index {index}" + ) + return IntVar(self.__model, index) def get_interval_var_from_proto_index(self, index: int) -> IntervalVar: """Returns an already created interval variable from its index.""" @@ -2116,13 +1952,13 @@ class CpModel: f"get_interval_var_from_proto_index: out of bound index {index}" ) ct = self.__model.constraints[index] - if not ct.HasField("interval"): + if not ct.has_interval(): raise ValueError( f"get_interval_var_from_proto_index: index {index} does not" " reference an" + " interval variable" ) - return IntervalVar(self.__model, self.__var_list, index, None, None, None, None) + return IntervalVar(self.__model, index, None, None, None, None) # Helpers. @@ -2130,7 +1966,7 @@ class CpModel: return str(self.__model) @property - def proto(self) -> cp_model_pb2.CpModelProto: + def proto(self) -> cmh.CpModelProto: """Returns the underlying CpModelProto.""" return self.__model @@ -2160,9 +1996,11 @@ class CpModel: return self.get_or_make_index_from_constant(1) if arg == ~int(True): return self.get_or_make_index_from_constant(0) - arg = cmn.assert_is_zero_or_one(arg) - return self.get_or_make_index_from_constant(arg) - if cmn.is_boolean(arg): + arg_as_int: int = int(arg) + if arg_as_int < 0 or arg_as_int > 1: + raise TypeError(f"Not a boolean: {arg}") + return self.get_or_make_index_from_constant(arg_as_int) + if arg_is_boolean(arg): return self.get_or_make_index_from_constant(int(arg)) raise TypeError( "not supported:" f" model.get_or_make_boolean_index({type(arg).__name__!r})" @@ -2184,11 +2022,9 @@ class CpModel: def parse_linear_expression( self, linear_expr: LinearExprT, negate: bool = False - ) -> cp_model_pb2.LinearExpressionProto: + ) -> cmh.LinearExpressionProto: """Returns a LinearExpressionProto built from a LinearExpr instance.""" - result: cp_model_pb2.LinearExpressionProto = ( - cp_model_pb2.LinearExpressionProto() - ) + result: cmh.LinearExpressionProto = cmh.LinearExpressionProto() mult = -1 if negate else 1 if isinstance(linear_expr, IntegralTypes): result.offset = int(linear_expr) * mult @@ -2244,19 +2080,19 @@ class CpModel: self._set_objective(obj, minimize=False) def has_objective(self) -> bool: - return self.__model.HasField("objective") or self.__model.HasField( - "floating_point_objective" + return ( + self.__model.has_objective() or self.__model.has_floating_point_objective() ) def clear_objective(self): - self.__model.ClearField("objective") - self.__model.ClearField("floating_point_objective") + self.__model.clear_objective() + self.__model.clear_floating_point_objective() def add_decision_strategy( self, variables: Sequence[IntVar], - var_strategy: cp_model_pb2.DecisionStrategyProto.VariableSelectionStrategy, - domain_strategy: cp_model_pb2.DecisionStrategyProto.DomainReductionStrategy, + var_strategy: cmh.DecisionStrategyProto.VariableSelectionStrategy, + domain_strategy: cmh.DecisionStrategyProto.DomainReductionStrategy, ) -> None: """Adds a search strategy to the model. @@ -2269,9 +2105,7 @@ class CpModel: solve() will fail. """ - strategy: cp_model_pb2.DecisionStrategyProto = ( - self.__model.search_strategy.add() - ) + strategy: cmh.DecisionStrategyProto = self.__model.search_strategy.add() for v in variables: expr = strategy.exprs.add() if v.index >= 0: @@ -2308,11 +2142,11 @@ class CpModel: def remove_all_names(self) -> None: """Removes all names from the model.""" - self.__model.ClearField("name") + self.__model.clear_name() for v in self.__model.variables: - v.ClearField("name") + v.clear_name() for c in self.__model.constraints: - c.ClearField("name") + c.clear_name() @overload def add_hint(self, var: IntVar, value: int) -> None: ... @@ -2331,7 +2165,7 @@ class CpModel: def clear_hints(self): """Removes any solution hint from the model.""" - self.__model.ClearField("solution_hint") + self.__model.clear_solution_hint() def add_assumption(self, lit: LiteralT) -> None: """Adds the literal to the model as assumptions.""" @@ -2344,7 +2178,7 @@ class CpModel: def clear_assumptions(self) -> None: """Removes all assumptions from the model.""" - self.__model.ClearField("assumptions") + self.__model.clear_assumptions() # Helpers. def assert_is_boolean_variable(self, x: LiteralT) -> None: @@ -2359,6 +2193,27 @@ class CpModel: f"TypeError: {type(x).__name__!r} is not a boolean variable" ) + def expand_literals_generator_or_tuple( + self, args: Union[Tuple[LiteralT, ...], Iterable[LiteralT]] + ): + if hasattr(args, "__len__"): # Tuple + if len(args) != 1: + return args + if isinstance(args[0], (NumberTypes, cmh.Literal)): + return args + # Generator + return args[0] + + def expand_literals_to_index_list( + self, + literals: Union[Tuple[LiteralT, ...], Iterable[LiteralT]], + ) -> list[int]: + """Expands a tuple or generator of literals to a list of indices.""" + return [ + self.get_or_make_boolean_index(lit) + for lit in self.expand_literals_generator_or_tuple(literals) + ] + # Compatibility with pre PEP8 # pylint: disable=invalid-name @@ -2368,7 +2223,7 @@ class CpModel: def SetName(self, name: str) -> None: self.name = name - def Proto(self) -> cp_model_pb2.CpModelProto: + def Proto(self) -> cmh.CpModelProto: return self.proto NewIntVar = new_int_var @@ -2434,26 +2289,16 @@ class CpModel: # pylint: enable=invalid-name -@overload -def expand_generator_or_tuple( - args: Union[Tuple[LiteralT, ...], Iterable[LiteralT]], -) -> Union[Iterable[LiteralT], LiteralT]: ... - - -@overload -def expand_generator_or_tuple( - args: Union[Tuple[LinearExprT, ...], Iterable[LinearExprT]], -) -> Union[Iterable[LinearExprT], LinearExprT]: ... - - -def expand_generator_or_tuple(args): - if hasattr(args, "__len__"): # Tuple - if len(args) != 1: - return args - if isinstance(args[0], (NumberTypes, LinearExpr)): - return args +def expand_exprs_generator_or_tuple( + expressions: Union[Tuple[LinearExprT, ...], Iterable[LinearExprT]], +) -> Union[Iterable[LinearExprT], LinearExprT]: + if hasattr(expressions, "__len__"): # Tuple + if len(expressions) != 1: + return expressions + if isinstance(expressions[0], (NumberTypes, LinearExpr)): + return expressions # Generator - return args[0] + return expressions[0] class CpSolver: @@ -2469,9 +2314,7 @@ class CpSolver: def __init__(self) -> None: self.__response_wrapper: Optional[cmh.ResponseWrapper] = None - self.parameters: sat_parameters_pb2.SatParameters = ( - sat_parameters_pb2.SatParameters() - ) + self.parameters: cmh.SatParameters = cmh.SatParameters() self.log_callback: Optional[Callable[[str], None]] = None self.best_bound_callback: Optional[Callable[[float], None]] = None self.__solve_wrapper: Optional[cmh.SolveWrapper] = None @@ -2481,7 +2324,7 @@ class CpSolver: self, model: CpModel, solution_callback: Optional["CpSolverSolutionCallback"] = None, - ) -> cp_model_pb2.CpSolverStatus: + ) -> cmh.CpSolverStatus: """Solves a problem and passes each solution to the callback if not null.""" with self.__lock: self.__solve_wrapper = cmh.SolveWrapper() @@ -2628,6 +2471,21 @@ class CpSolver: """Returns the number of search branches explored by the solver.""" return self._checked_response.num_branches() + @property + def num_boolean_propagations(self) -> int: + """Returns the number of Boolean propagations done by the solver.""" + return self._checked_response.num_boolean_propagations() + + @property + def num_integer_propagations(self) -> int: + """Returns the number of integer propagations done by the solver.""" + return self._checked_response.num_integer_propagations() + + @property + def deterministic_time(self) -> float: + """Returns the deterministic time in seconds since the creation of the solver.""" + return self._checked_response.deterministic_time() + @property def wall_time(self) -> float: """Returns the wall time in seconds since the creation of the solver.""" @@ -2639,7 +2497,7 @@ class CpSolver: return self._checked_response.user_time() @property - def response_proto(self) -> cp_model_pb2.CpSolverResponse: + def response_proto(self) -> cmh.CpSolverResponse: """Returns the response object.""" return self._checked_response.response() @@ -2655,7 +2513,7 @@ class CpSolver: """Returns the name of the status returned by solve().""" if status is None: status = self._checked_response.status() - return cp_model_pb2.CpSolverStatus.Name(status) + return status.name def solution_info(self) -> str: """Returns some information on the solve process. @@ -2699,7 +2557,7 @@ class CpSolver: def ObjectiveValue(self) -> float: return self.objective_value - def ResponseProto(self) -> cp_model_pb2.CpSolverResponse: + def ResponseProto(self) -> cmh.CpSolverResponse: return self.response_proto def ResponseStats(self) -> str: @@ -2709,7 +2567,7 @@ class CpSolver: self, model: CpModel, solution_callback: Optional["CpSolverSolutionCallback"] = None, - ) -> cp_model_pb2.CpSolverStatus: + ) -> cmh.CpSolverStatus: return self.solve(model, solution_callback) def SolutionInfo(self) -> str: @@ -2738,7 +2596,7 @@ class CpSolver: def SolveWithSolutionCallback( self, model: CpModel, callback: "CpSolverSolutionCallback" - ) -> cp_model_pb2.CpSolverStatus: + ) -> cmh.CpSolverStatus: """DEPRECATED Use solve() with the callback argument.""" warnings.warn( "solve_with_solution_callback is deprecated; use solve() with" @@ -2749,7 +2607,7 @@ class CpSolver: def SearchForAllSolutions( self, model: CpModel, callback: "CpSolverSolutionCallback" - ) -> cp_model_pb2.CpSolverStatus: + ) -> cmh.CpSolverStatus: """DEPRECATED Use solve() with the right parameter. Search for all solutions of a satisfiability problem. @@ -2783,7 +2641,7 @@ class CpSolver: enumerate_all = self.parameters.enumerate_all_solutions self.parameters.enumerate_all_solutions = True - status: cp_model_pb2.CpSolverStatus = self.solve(model, callback) + status: cmh.CpSolverStatus = self.solve(model, callback) # Restore parameter. self.parameters.enumerate_all_solutions = enumerate_all @@ -2944,7 +2802,7 @@ class CpSolverSolutionCallback(cmh.SolutionCallback): return self.UserTime() @property - def response_proto(self) -> cp_model_pb2.CpSolverResponse: + def response_proto(self) -> cmh.CpSolverResponse: """Returns the response object.""" if not self.has_response(): raise RuntimeError("solve() has not been called.") diff --git a/ortools/sat/python/cp_model_helper_test.py b/ortools/sat/python/cp_model_helper_test.py index d5901787a7..71b90845f2 100644 --- a/ortools/sat/python/cp_model_helper_test.py +++ b/ortools/sat/python/cp_model_helper_test.py @@ -18,10 +18,8 @@ import sys from absl.testing import absltest -from google.protobuf import text_format -from ortools.sat import cp_model_pb2 -from ortools.sat import sat_parameters_pb2 from ortools.sat.python import cp_model_helper as cmh +from ortools.util.python import sorted_interval_list class Callback(cmh.SolutionCallback): @@ -47,19 +45,6 @@ class BestBoundCallback: self.best_bound = bb -class TestIntVar(cmh.BaseIntVar): - - def __init__(self, index: int, name: str, is_boolean: bool = False) -> None: - cmh.BaseIntVar.__init__(self, index, is_boolean) - self._name = name - - def __str__(self) -> str: - return self._name - - def __repr__(self) -> str: - return self._name - - class CpModelHelperTest(absltest.TestCase): def tearDown(self) -> None: @@ -71,8 +56,8 @@ class CpModelHelperTest(absltest.TestCase): variables { domain: [ -10, 10 ] } variables { domain: [ -5, -5, 3, 6 ] } """ - model = cp_model_pb2.CpModelProto() - self.assertTrue(text_format.Parse(model_string, model)) + model = cmh.CpModelProto() + self.assertTrue(model.parse_text_format(model_string)) d0 = cmh.CpSatHelper.variable_domain(model.variables[0]) d1 = cmh.CpSatHelper.variable_domain(model.variables[1]) @@ -112,13 +97,13 @@ class CpModelHelperTest(absltest.TestCase): coeffs: -1 scaling_factor: -1 }""" - model = cp_model_pb2.CpModelProto() - self.assertTrue(text_format.Parse(model_string, model)) + model = cmh.CpModelProto() + self.assertTrue(model.parse_text_format(model_string)) solve_wrapper = cmh.SolveWrapper() response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) - self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) + self.assertEqual(cmh.OPTIMAL, response_wrapper.status()) self.assertEqual(30.0, response_wrapper.objective_value()) def test_simple_solve_with_core(self): @@ -153,20 +138,21 @@ class CpModelHelperTest(absltest.TestCase): coeffs: -1 scaling_factor: -1 }""" - model = cp_model_pb2.CpModelProto() - self.assertTrue(text_format.Parse(model_string, model)) + model = cmh.CpModelProto() + self.assertTrue(model.parse_text_format(model_string)) - parameters = sat_parameters_pb2.SatParameters(optimize_with_core=True) + parameters = cmh.SatParameters() + parameters.optimize_with_core = True solve_wrapper = cmh.SolveWrapper() solve_wrapper.set_parameters(parameters) response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) - self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) + self.assertEqual(cmh.OPTIMAL, response_wrapper.status()) self.assertEqual(30.0, response_wrapper.objective_value()) def test_simple_solve_with_proto_api(self): - model = cp_model_pb2.CpModelProto() + model = cmh.CpModelProto() x = model.variables.add() x.domain.extend([-10, 10]) y = model.variables.add() @@ -184,7 +170,7 @@ class CpModelHelperTest(absltest.TestCase): solve_wrapper = cmh.SolveWrapper() response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) - self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) + self.assertEqual(cmh.OPTIMAL, response_wrapper.status()) self.assertEqual(30.0, response_wrapper.objective_value()) self.assertEqual(30.0, response_wrapper.best_objective_bound()) self.assertRaises(TypeError, response_wrapper.value, None) @@ -198,19 +184,19 @@ class CpModelHelperTest(absltest.TestCase): constraints { linear { vars: 0 vars: 1 coeffs: 1 coeffs: 1 domain: 6 domain: 6 } } """ - model = cp_model_pb2.CpModelProto() - self.assertTrue(text_format.Parse(model_string, model)) + model = cmh.CpModelProto() + self.assertTrue(model.parse_text_format(model_string)) solve_wrapper = cmh.SolveWrapper() callback = Callback() solve_wrapper.add_solution_callback(callback) - params = sat_parameters_pb2.SatParameters() + params = cmh.SatParameters() params.enumerate_all_solutions = True solve_wrapper.set_parameters(params) response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) self.assertEqual(5, callback.solution_count()) - self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) + self.assertEqual(cmh.OPTIMAL, response_wrapper.status()) def test_best_bound_callback(self): model_string = """ @@ -225,13 +211,13 @@ class CpModelHelperTest(absltest.TestCase): offset: 0.6 } """ - model = cp_model_pb2.CpModelProto() - self.assertTrue(text_format.Parse(model_string, model)) + model = cmh.CpModelProto() + self.assertTrue(model.parse_text_format(model_string)) solve_wrapper = cmh.SolveWrapper() best_bound_callback = BestBoundCallback() solve_wrapper.add_best_bound_callback(best_bound_callback.new_best_bound) - params = sat_parameters_pb2.SatParameters() + params = cmh.SatParameters() params.num_workers = 1 params.linearization_level = 2 params.log_search_progress = True @@ -239,7 +225,7 @@ class CpModelHelperTest(absltest.TestCase): response_wrapper = solve_wrapper.solve_and_return_response_wrapper(model) self.assertEqual(2.6, best_bound_callback.best_bound) - self.assertEqual(cp_model_pb2.OPTIMAL, response_wrapper.status()) + self.assertEqual(cmh.OPTIMAL, response_wrapper.status()) def test_model_stats(self): model_string = """ @@ -275,15 +261,16 @@ class CpModelHelperTest(absltest.TestCase): } name: 'testModelStats' """ - model = cp_model_pb2.CpModelProto() - self.assertTrue(text_format.Parse(model_string, model)) + model = cmh.CpModelProto() + self.assertTrue(model.parse_text_format(model_string)) stats = cmh.CpSatHelper.model_stats(model) self.assertTrue(stats) def test_int_lin_expr(self): - x = TestIntVar(0, "x") + model = cmh.CpModelProto() + x = cmh.IntVar(model).with_name("x") self.assertTrue(x.is_integer()) - self.assertIsInstance(x, cmh.BaseIntVar) + self.assertIsInstance(x, cmh.IntVar) self.assertIsInstance(x, cmh.LinearExpr) e1 = x + 2 self.assertTrue(e1.is_integer()) @@ -291,7 +278,7 @@ class CpModelHelperTest(absltest.TestCase): e2 = 3 + x self.assertTrue(e2.is_integer()) self.assertEqual(str(e2), "(x + 3)") - y = TestIntVar(1, "y") + y = cmh.IntVar(model).with_name("y") e3 = y * 5 self.assertTrue(e3.is_integer()) self.assertEqual(str(e3), "(5 * y)") @@ -304,7 +291,8 @@ class CpModelHelperTest(absltest.TestCase): e6 = x - 2 * y self.assertTrue(e6.is_integer()) self.assertEqual(str(e6), "(x + (-2 * y))") - z = TestIntVar(2, "z", True) + z = cmh.IntVar(model).with_name("z") + z.domain = sorted_interval_list.Domain.from_values([0, 1]) e7 = -z self.assertTrue(e7.is_integer()) self.assertEqual(str(e7), "(-z)") @@ -326,9 +314,10 @@ class CpModelHelperTest(absltest.TestCase): self.assertEqual(str(e12), "(x + (-y) + (-2 * z))") def test_float_lin_expr(self): - x = TestIntVar(0, "x") + model = cmh.CpModelProto() + x = cmh.IntVar(model).with_name("x") self.assertTrue(x.is_integer()) - self.assertIsInstance(x, TestIntVar) + self.assertIsInstance(x, cmh.IntVar) self.assertIsInstance(x, cmh.LinearExpr) e1 = x + 2.5 self.assertFalse(e1.is_integer()) @@ -336,7 +325,7 @@ class CpModelHelperTest(absltest.TestCase): e2 = 3.1 + x self.assertFalse(e2.is_integer()) self.assertEqual(str(e2), "(x + 3.1)") - y = TestIntVar(1, "y") + y = cmh.IntVar(model).with_name("y") e3 = y * 5.2 self.assertFalse(e3.is_integer()) self.assertEqual(str(e3), "(5.2 * y)") @@ -353,7 +342,7 @@ class CpModelHelperTest(absltest.TestCase): self.assertFalse(e7.is_integer()) self.assertEqual(str(e7), "(x + (-(2.4 * y)))") - z = TestIntVar(2, "z") + z = cmh.IntVar(model).with_name("z") e8 = cmh.LinearExpr.sum([x, y, z, -2]) self.assertTrue(e8.is_integer()) self.assertEqual(str(e8), "(x + y + z - 2)") @@ -371,5 +360,89 @@ class CpModelHelperTest(absltest.TestCase): self.assertEqual(str(e12), "(3.1 * (x + 2))") +class CpModelBuilderTest(absltest.TestCase): + + def test_basic(self): + model_proto = cmh.CpModelProto() + + # Singular message. + objective = model_proto.objective + + # Singular int. + self.assertEqual(objective.offset, 0) + objective.offset = 123 + self.assertEqual(objective.offset, 123) + + # Set a message. + new_obj = cmh.CpObjectiveProto() + new_obj.offset = 456 + model_proto.objective = new_obj + self.assertEqual(objective.offset, 456) + + # Large int. + objective.offset = 500000000000 + self.assertEqual(objective.offset, 500000000000) + + # Repeated message. + my_var = model_proto.variables.add() + + # Singular string. + self.assertEqual(my_var.name, "") + my_var.name = "my_var" + self.assertEqual(my_var.name, "my_var") + my_var.domain.extend([0, 1]) + domain = list(my_var.domain) + self.assertLen(domain, 2) + self.assertEqual(domain[0], 0) + self.assertEqual(domain[1], 1) + + # Repeated int. + objective.vars.append(0) + self.assertLen(objective.vars, 1) + self.assertEqual(objective.vars[0], 0) + objective.vars[0] = 42 + self.assertEqual(objective.vars[0], 42) + + # Singular enum + search_strategy = model_proto.search_strategy.add() + self.assertEqual( + search_strategy.variable_selection_strategy, + cmh.DecisionStrategyProto.CHOOSE_FIRST, + ) + search_strategy.variable_selection_strategy = ( + cmh.DecisionStrategyProto.CHOOSE_LOWEST_MIN + ) + self.assertEqual( + search_strategy.variable_selection_strategy, + cmh.DecisionStrategyProto.CHOOSE_LOWEST_MIN, + ) + + +class SatParametersBuilderTest(absltest.TestCase): + + def test_basic_api(self) -> None: + params = cmh.SatParameters() + + # Test that we can set and get an integer parameter. + params.num_workers = 10 + self.assertEqual(params.num_workers, 10) + + # Test that we can set and get an enum parameter. + self.assertEqual( + params.clause_cleanup_ordering, + cmh.SatParameters.ClauseOrdering.CLAUSE_ACTIVITY, + ) + params.clause_cleanup_ordering = cmh.SatParameters.ClauseOrdering.CLAUSE_LBD + self.assertEqual( + params.clause_cleanup_ordering, + cmh.SatParameters.ClauseOrdering.CLAUSE_LBD, + ) + + # Test that we can set and get a repeated string parameter. + params.subsolvers.append("no_lp") + self.assertLen(params.subsolvers, 1) + self.assertEqual(params.subsolvers[0], "no_lp") + + if __name__ == "__main__": absltest.main() diff --git a/ortools/sat/python/cp_model_numbers.py b/ortools/sat/python/cp_model_numbers.py deleted file mode 100644 index 26b7928df5..0000000000 --- a/ortools/sat/python/cp_model_numbers.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2010-2025 Google LLC -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""helpers methods for the cp_model module.""" - -import numbers -from typing import Any -import numpy as np - - -INT_MIN = -9223372036854775808 # hardcoded to be platform independent. -INT_MAX = 9223372036854775807 - - -def is_boolean(x: Any) -> bool: - """Checks if the x is a boolean.""" - if isinstance(x, bool): - return True - if isinstance(x, np.bool_): - return True - return False - - -def assert_is_zero_or_one(x: Any) -> int: - """Asserts that x is 0 or 1 and returns it as an int.""" - if not isinstance(x, numbers.Integral): - raise TypeError(f"Not a boolean: {x} of type {type(x)}") - x_as_int = int(x) - if x_as_int < 0 or x_as_int > 1: - raise TypeError(f"Not a boolean: {x}") - return x_as_int - - -def to_capped_int64(v: int) -> int: - """Restrict v within [INT_MIN..INT_MAX] range.""" - if v > INT_MAX: - return INT_MAX - if v < INT_MIN: - return INT_MIN - return v - - -def capped_subtraction(x: int, y: int) -> int: - """Saturated arithmetics. Returns x - y truncated to the int64_t range.""" - if y == 0: - return x - if x == y: - if x == INT_MAX or x == INT_MIN: - raise OverflowError("Integer NaN: subtracting INT_MAX or INT_MIN to itself") - return 0 - if x == INT_MAX or x == INT_MIN: - return x - if y == INT_MAX: - return INT_MIN - if y == INT_MIN: - return INT_MAX - return to_capped_int64(x - y) diff --git a/ortools/sat/python/cp_model_numbers_test.py b/ortools/sat/python/cp_model_numbers_test.py deleted file mode 100644 index 1e0b91a0fc..0000000000 --- a/ortools/sat/python/cp_model_numbers_test.py +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2010-2025 Google LLC -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys - -from absl.testing import absltest -import numpy as np - -from ortools.sat.python import cp_model_numbers as cmn - - -class CpModelNumbersTest(absltest.TestCase): - - def tearDown(self) -> None: - super().tearDown() - sys.stdout.flush() - - def test_is_boolean(self): - self.assertTrue(cmn.is_boolean(True)) - self.assertTrue(cmn.is_boolean(False)) - self.assertFalse(cmn.is_boolean(1)) - self.assertFalse(cmn.is_boolean(0)) - self.assertTrue(cmn.is_boolean(np.bool_(1))) - self.assertTrue(cmn.is_boolean(np.bool_(0))) - - def test_to_capped_int64(self): - self.assertEqual(cmn.to_capped_int64(cmn.INT_MAX), cmn.INT_MAX) - self.assertEqual(cmn.to_capped_int64(cmn.INT_MAX + 1), cmn.INT_MAX) - self.assertEqual(cmn.to_capped_int64(cmn.INT_MIN), cmn.INT_MIN) - self.assertEqual(cmn.to_capped_int64(cmn.INT_MIN - 1), cmn.INT_MIN) - self.assertEqual(cmn.to_capped_int64(15), 15) - - def test_capped_subtraction(self): - self.assertEqual(cmn.capped_subtraction(10, 5), 5) - self.assertEqual(cmn.capped_subtraction(cmn.INT_MIN, 5), cmn.INT_MIN) - self.assertEqual(cmn.capped_subtraction(cmn.INT_MIN, -5), cmn.INT_MIN) - self.assertEqual(cmn.capped_subtraction(cmn.INT_MAX, 5), cmn.INT_MAX) - self.assertEqual(cmn.capped_subtraction(cmn.INT_MAX, -5), cmn.INT_MAX) - self.assertEqual(cmn.capped_subtraction(2, cmn.INT_MIN), cmn.INT_MAX) - self.assertEqual(cmn.capped_subtraction(2, cmn.INT_MAX), cmn.INT_MIN) - self.assertRaises( - OverflowError, cmn.capped_subtraction, cmn.INT_MAX, cmn.INT_MAX - ) - self.assertRaises( - OverflowError, cmn.capped_subtraction, cmn.INT_MIN, cmn.INT_MIN - ) - self.assertRaises(TypeError, cmn.capped_subtraction, 5, "dummy") - self.assertRaises(TypeError, cmn.capped_subtraction, "dummy", 5) - - -if __name__ == "__main__": - absltest.main() diff --git a/ortools/sat/python/cp_model_test.py b/ortools/sat/python/cp_model_test.py index 9bbaee5513..c446246ef8 100644 --- a/ortools/sat/python/cp_model_test.py +++ b/ortools/sat/python/cp_model_test.py @@ -21,7 +21,6 @@ from absl.testing import absltest import numpy as np import pandas as pd -from ortools.sat import cp_model_pb2 from ortools.sat.python import cp_model from ortools.sat.python import cp_model_helper as cmh @@ -184,6 +183,14 @@ class CpModelTest(absltest.TestCase): super().tearDown() sys.stdout.flush() + def test_is_boolean(self): + self.assertTrue(cp_model.arg_is_boolean(True)) + self.assertTrue(cp_model.arg_is_boolean(False)) + self.assertFalse(cp_model.arg_is_boolean(1)) + self.assertFalse(cp_model.arg_is_boolean(0)) + self.assertTrue(cp_model.arg_is_boolean(np.bool_(1))) + self.assertTrue(cp_model.arg_is_boolean(np.bool_(0))) + def test_create_integer_variable(self) -> None: model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") @@ -230,6 +237,9 @@ class CpModelTest(absltest.TestCase): one = model.new_constant(1) self.assertEqual("1", str(one)) self.assertEqual("not(1)", str(~one)) + no_name = model.new_bool_var("") + self.assertEqual("b4", str(no_name)) + self.assertEqual("not(b4)", str(~no_name)) z = model.new_int_var(0, 2, "z") self.assertRaises(TypeError, z.negated) self.assertRaises(TypeError, z.__invert__) @@ -284,14 +294,14 @@ class CpModelTest(absltest.TestCase): self.assertRaises(TypeError, solver.float_value, None) self.assertRaises(TypeError, solver.boolean_value, None) - def test_linear_constraint(self) -> None: + def test_empty_linear_constraint(self) -> None: model = cp_model.CpModel() model.add_linear_constraint(5, 0, 10) model.add_linear_constraint(-1, 0, 10) self.assertLen(model.proto.constraints, 2) - self.assertTrue(model.proto.constraints[0].HasField("bool_and")) + self.assertTrue(model.proto.constraints[0].has_bool_and()) self.assertEmpty(model.proto.constraints[0].bool_and.literals) - self.assertTrue(model.proto.constraints[1].HasField("bool_or")) + self.assertTrue(model.proto.constraints[1].has_bool_or()) self.assertEmpty(model.proto.constraints[1].bool_or.literals) def test_linear_non_equal(self) -> None: @@ -315,6 +325,17 @@ class CpModelTest(absltest.TestCase): self.assertEqual(2, ct.linear.domain[0]) self.assertEqual(2, ct.linear.domain[1]) + def test_large_constants(self) -> None: + model = cp_model.CpModel() + x = model.new_int_var(-10, 10, "x") + ct = model.add(x * 50000000000 == 1234567890).proto + self.assertLen(ct.linear.vars, 1) + self.assertLen(ct.linear.coeffs, 1) + self.assertEqual(50000000000, ct.linear.coeffs[0]) + self.assertLen(ct.linear.domain, 2) + self.assertEqual(1234567890, ct.linear.domain[0]) + self.assertEqual(1234567890, ct.linear.domain[1]) + def testGe(self) -> None: model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") @@ -476,7 +497,7 @@ class CpModelTest(absltest.TestCase): model.add(x * 2 - 1 * y == 1) model.minimize(x * 1 - 2 * y + 3) solver = cp_model.CpSolver() - self.assertEqual("OPTIMAL", solver.status_name(solver.solve(model))) + self.assertEqual("OPTIMAL", solver.solve(model).name) self.assertEqual(5, solver.value(x)) self.assertEqual(15, solver.value(x * 3)) self.assertEqual(6, solver.value(1 + x)) @@ -488,7 +509,7 @@ class CpModelTest(absltest.TestCase): y = model.new_int_var(0, 10, "y") model.maximize(x.negated() * 3.5 + x.negated() - y + 2 * y + 1.6) solver = cp_model.CpSolver() - self.assertEqual("OPTIMAL", solver.status_name(solver.solve(model))) + self.assertEqual("OPTIMAL", solver.solve(model).name) self.assertFalse(solver.boolean_value(x)) self.assertTrue(solver.boolean_value(x.negated())) self.assertEqual(-10, solver.value(-y)) @@ -505,7 +526,7 @@ class CpModelTest(absltest.TestCase): + cp_model.LinearExpr.weighted_sum([x3, x4.negated()], [2, 4]) ) solver = cp_model.CpSolver() - self.assertEqual("OPTIMAL", solver.status_name(solver.solve(model))) + self.assertEqual("OPTIMAL", solver.solve(model).name) self.assertEqual(5, solver.value(3 + 2 * x1)) self.assertEqual(3, solver.value(x1 + x2 + x3)) self.assertEqual(1, solver.value(cp_model.LinearExpr.sum([x1, x2, x3, 0, -2]))) @@ -525,7 +546,7 @@ class CpModelTest(absltest.TestCase): model.add(2 * x - y == 1) model.maximize(x - 2 * y + 3) solver = cp_model.CpSolver() - self.assertEqual("OPTIMAL", solver.status_name(solver.solve(model))) + self.assertEqual("OPTIMAL", solver.solve(model).name) self.assertEqual(-4, solver.value(x)) self.assertEqual(-9, solver.value(y)) self.assertEqual(17, solver.objective_value) @@ -536,7 +557,7 @@ class CpModelTest(absltest.TestCase): model.add(x >= -1) model.minimize(10) solver = cp_model.CpSolver() - self.assertEqual("OPTIMAL", solver.status_name(solver.solve(model))) + self.assertEqual("OPTIMAL", solver.solve(model).name) self.assertEqual(10, solver.objective_value) def test_maximize_constant(self) -> None: @@ -545,7 +566,7 @@ class CpModelTest(absltest.TestCase): model.add(x >= -1) model.maximize(5) solver = cp_model.CpSolver() - self.assertEqual("OPTIMAL", solver.status_name(solver.solve(model))) + self.assertEqual("OPTIMAL", solver.solve(model).name) self.assertEqual(5, solver.objective_value) def test_add_true(self) -> None: @@ -554,7 +575,7 @@ class CpModelTest(absltest.TestCase): model.add(3 >= -1) model.minimize(x) solver = cp_model.CpSolver() - self.assertEqual("OPTIMAL", solver.status_name(solver.solve(model))) + self.assertEqual("OPTIMAL", solver.solve(model).name) self.assertEqual(-10, solver.value(x)) def test_add_false(self) -> None: @@ -563,7 +584,8 @@ class CpModelTest(absltest.TestCase): model.add(3 <= -1) model.minimize(x) solver = cp_model.CpSolver() - self.assertEqual("INFEASIBLE", solver.status_name(solver.solve(model))) + status: cmh.CpSolverStatus = solver.solve(model) + self.assertEqual("INFEASIBLE", status.name) def test_sum(self) -> None: model = cp_model.CpModel() @@ -838,7 +860,7 @@ class CpModelTest(absltest.TestCase): self.assertLen(model.proto.constraints[0].linear.vars, 1) self.assertEqual(x[3].index, model.proto.constraints[0].linear.vars[0]) self.assertEqual(1, model.proto.constraints[0].linear.coeffs[0]) - self.assertEqual([2, 2], model.proto.constraints[0].linear.domain) + self.assertEqual([2, 2], list(model.proto.constraints[0].linear.domain)) def test_affine_element(self) -> None: model = cp_model.CpModel() @@ -992,6 +1014,62 @@ class CpModelTest(absltest.TestCase): self.assertEqual(0, model.proto.constraints[0].lin_max.target.vars[0]) self.assertEqual(1, model.proto.constraints[0].lin_max.target.coeffs[0]) + def test_max_equality_list(self) -> None: + model = cp_model.CpModel() + x = model.new_int_var(0, 4, "x") + y = [model.new_int_var(0, 4, f"y{i}") for i in range(5)] + model.add_max_equality(x, [y[0], y[2], y[1], y[3]]) + self.assertLen(model.proto.variables, 6) + self.assertLen(model.proto.constraints[0].lin_max.exprs, 4) + self.assertEqual(0, model.proto.constraints[0].lin_max.target.vars[0]) + self.assertEqual(1, model.proto.constraints[0].lin_max.target.coeffs[0]) + + def test_max_equality_tuple(self) -> None: + model = cp_model.CpModel() + x = model.new_int_var(0, 4, "x") + y = [model.new_int_var(0, 4, f"y{i}") for i in range(5)] + model.add_max_equality(x, (y[0], y[2], y[1], y[3])) + self.assertLen(model.proto.variables, 6) + self.assertLen(model.proto.constraints[0].lin_max.exprs, 4) + self.assertEqual(0, model.proto.constraints[0].lin_max.target.vars[0]) + self.assertEqual(1, model.proto.constraints[0].lin_max.target.coeffs[0]) + + def test_max_equality_generator(self) -> None: + model = cp_model.CpModel() + x = model.new_int_var(0, 4, "x") + y = [model.new_int_var(0, 4, f"y{i}") for i in range(5)] + model.add_max_equality(x, (z for z in y)) + self.assertLen(model.proto.variables, 6) + self.assertLen(model.proto.constraints[0].lin_max.exprs, 5) + self.assertEqual(0, model.proto.constraints[0].lin_max.target.vars[0]) + self.assertEqual(1, model.proto.constraints[0].lin_max.target.coeffs[0]) + + def test_max_equality_args(self) -> None: + model = cp_model.CpModel() + x = model.new_int_var(0, 4, "x") + y = [model.new_int_var(0, 4, f"y{i}") for i in range(5)] + model.add_max_equality(x, y[2], y[4]) + self.assertLen(model.proto.variables, 6) + self.assertLen(model.proto.constraints[0].lin_max.exprs, 2) + self.assertEqual(0, model.proto.constraints[0].lin_max.target.vars[0]) + self.assertEqual(1, model.proto.constraints[0].lin_max.target.coeffs[0]) + + def test_max_equality_with_constant(self) -> None: + model = cp_model.CpModel() + x = model.new_int_var(0, 4, "x") + y = model.new_int_var(0, 4, "y") + model.add_max_equality(x, [y, 3]) + self.assertLen(model.proto.variables, 2) + self.assertLen(model.proto.constraints, 1) + lin_max = model.proto.constraints[0].lin_max + self.assertLen(lin_max.exprs, 2) + self.assertLen(lin_max.exprs[0].vars, 1) + self.assertEqual(1, lin_max.exprs[0].vars[0]) + self.assertEqual(1, lin_max.exprs[0].coeffs[0]) + self.assertEqual(0, lin_max.exprs[0].offset) + self.assertEmpty(lin_max.exprs[1].vars) + self.assertEqual(3, lin_max.exprs[1].offset) + def test_min_equality(self) -> None: model = cp_model.CpModel() x = model.new_int_var(0, 4, "x") @@ -1032,6 +1110,16 @@ class CpModelTest(absltest.TestCase): self.assertEqual(0, model.proto.constraints[0].lin_max.target.vars[0]) self.assertEqual(-1, model.proto.constraints[0].lin_max.target.coeffs[0]) + def test_min_equality_args(self) -> None: + model = cp_model.CpModel() + x = model.new_int_var(0, 4, "x") + y = [model.new_int_var(0, 4, f"y{i}") for i in range(5)] + model.add_min_equality(x, y[2], y[4]) + self.assertLen(model.proto.variables, 6) + self.assertLen(model.proto.constraints[0].lin_max.exprs, 2) + self.assertEqual(0, model.proto.constraints[0].lin_max.target.vars[0]) + self.assertEqual(-1, model.proto.constraints[0].lin_max.target.coeffs[0]) + def test_min_equality_with_constant(self) -> None: model = cp_model.CpModel() x = model.new_int_var(0, 4, "x") @@ -1149,6 +1237,16 @@ class CpModelTest(absltest.TestCase): self.assertLen(model.proto.constraints[0].int_prod.exprs, 5) self.assertEqual(0, model.proto.constraints[0].int_prod.target.vars[0]) + def test_multiplication_equality_generator(self) -> None: + model = cp_model.CpModel() + x = model.new_int_var(0, 4, "x") + y = [model.new_int_var(0, 4, f"y{i}") for i in range(5)] + model.add_multiplication_equality(x, y[2], y[3]) + self.assertLen(model.proto.variables, 6) + self.assertLen(model.proto.constraints, 1) + self.assertLen(model.proto.constraints[0].int_prod.exprs, 2) + self.assertEqual(0, model.proto.constraints[0].int_prod.target.vars[0]) + def test_implication(self) -> None: model = cp_model.CpModel() x = model.new_bool_var("x") @@ -1285,12 +1383,12 @@ class CpModelTest(absltest.TestCase): self.assertEqual(~i.size_expr(), ~y) self.assertRaises(TypeError, i.start_expr().negated) - proto = cp_model_pb2.LinearExpressionProto() + proto = cmh.LinearExpressionProto() proto.vars.append(x.index) proto.coeffs.append(1) proto.vars.append(y.index) proto.coeffs.append(2) - expr1 = model.rebuild_from_linear_expression_proto(proto) + expr1 = cp_model.rebuild_from_linear_expression_proto(proto, model.proto) canonical_expr1 = cmh.FlatIntExpr(expr1) self.assertEqual(canonical_expr1.vars[0], x) self.assertEqual(canonical_expr1.vars[1], y) @@ -1301,7 +1399,7 @@ class CpModelTest(absltest.TestCase): self.assertRaises(TypeError, canonical_expr1.vars[0].negated) proto.offset = 2 - expr2 = model.rebuild_from_linear_expression_proto(proto) + expr2 = cp_model.rebuild_from_linear_expression_proto(proto, model.proto) canonical_expr2 = cmh.FlatIntExpr(expr2) self.assertEqual(canonical_expr2.vars[0], x) self.assertEqual(canonical_expr2.vars[1], y) @@ -1474,7 +1572,7 @@ class CpModelTest(absltest.TestCase): self.assertEqual(repr(i), "i(start = x, size = 2, end = y)") b = model.new_bool_var("b") self.assertEqual(repr(b), "b(0..1)") - self.assertEqual(repr(~b), "NotBooleanVariable(index=3)") + self.assertEqual(repr(~b), "NotBooleanVariable(var_index=3)") x1 = model.new_int_var(0, 4, "x1") y1 = model.new_int_var(0, 3, "y1") j = model.new_optional_interval_var(x1, 2, y1, b, "j") @@ -1486,16 +1584,6 @@ class CpModelTest(absltest.TestCase): repr(k), "k(start = x2, size = 2, end = y2, is_present = not(b))" ) - def testDisplayBounds(self) -> None: - self.assertEqual("10..20", cp_model.display_bounds([10, 20])) - self.assertEqual("10", cp_model.display_bounds([10, 10])) - self.assertEqual("10..15, 20..30", cp_model.display_bounds([10, 15, 20, 30])) - - def test_short_name(self) -> None: - model = cp_model.CpModel() - model.proto.variables.add(domain=[5, 10]) - self.assertEqual("[5..10]", cp_model.short_name(model.proto, 0)) - def test_integer_expression_errors(self) -> None: model = cp_model.CpModel() x = model.new_int_var(0, 1, "x") @@ -1525,10 +1613,23 @@ class CpModelTest(absltest.TestCase): model = cp_model.CpModel() x = model.new_int_var(0, 1, "x") y = model.new_int_var(-10, 10, "y") + b = model.new_bool_var("b") model.add_linear_constraint(x + 2 * y, 0, 10) model.minimize(y) solver = cp_model.CpSolver() self.assertRaises(RuntimeError, solver.value, x) + self.assertRaises(RuntimeError, solver.boolean_value, b) + self.assertRaises(RuntimeError, lambda: solver.best_objective_bound) + self.assertRaises(RuntimeError, lambda: solver.deterministic_time) + self.assertRaises(RuntimeError, lambda: solver.num_boolean_propagations) + self.assertRaises(RuntimeError, lambda: solver.num_booleans) + self.assertRaises(RuntimeError, lambda: solver.num_branches) + self.assertRaises(RuntimeError, lambda: solver.num_conflicts) + self.assertRaises(RuntimeError, lambda: solver.num_integer_propagations) + self.assertRaises(RuntimeError, lambda: solver.objective_value) + self.assertRaises(RuntimeError, lambda: solver.response_proto) + self.assertRaises(RuntimeError, lambda: solver.user_time) + self.assertRaises(RuntimeError, lambda: solver.wall_time) solver.solve(model) self.assertRaises(TypeError, solver.value, "not_a_variable") self.assertRaises(TypeError, model.add_bool_or, [x, y]) @@ -1885,7 +1986,7 @@ class CpModelTest(absltest.TestCase): with self.assertRaises(ValueError): new_model.get_interval_var_from_proto_index(-1) - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): new_model.get_bool_var_from_proto_index(x.index) with self.assertRaises(ValueError): @@ -1908,8 +2009,8 @@ class CpModelTest(absltest.TestCase): deepcopy_c = copy.deepcopy(c) self.assertIsNot(deepcopy_c.model, c.model) self.assertIsNot(deepcopy_c.var, c.var) - self.assertIs(deepcopy_c.model.proto, deepcopy_c.var.model_proto) - self.assertIs( + self.assertEqual(deepcopy_c.model.proto, deepcopy_c.var.model_proto) + self.assertEqual( deepcopy_c.var, deepcopy_c.model.get_int_var_from_proto_index(x.index), ) @@ -2296,10 +2397,10 @@ TRFM""" solver.best_bound_callback = best_bound_callback.new_best_bound status = solver.Solve(model, solution_callback) if status == cp_model.OPTIMAL: - self.assertLess( - time.time(), - max(best_bound_callback.last_time, solution_callback.last_time) + 9.0, + last_activity = max( + best_bound_callback.last_time, solution_callback.last_time ) + self.assertLess(time.time(), last_activity + 15.0) def test_issue4434(self) -> None: model = cp_model.CpModel() diff --git a/ortools/sat/python/gen_proto_builder_pybind11.cc b/ortools/sat/python/gen_proto_builder_pybind11.cc new file mode 100644 index 0000000000..0857c01c89 --- /dev/null +++ b/ortools/sat/python/gen_proto_builder_pybind11.cc @@ -0,0 +1,49 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/flags/parse.h" +#include "absl/flags/usage.h" +#include "absl/log/die_if_null.h" +#include "absl/log/initialize.h" +#include "absl/strings/str_format.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/python/wrappers.h" +#include "ortools/sat/sat_parameters.pb.h" + +namespace operations_research::sat::python { + +void ParseAndGenerate() { + absl::PrintF( + R"( + +// This is a generated file, do not edit. +#if defined(IMPORT_PROTO_WRAPPER_CODE) +%s +#endif // defined(IMPORT_PROTO_WRAPPER_CODE) +)", + GeneratePybindCode({ABSL_DIE_IF_NULL(CpModelProto::descriptor()), + ABSL_DIE_IF_NULL(CpSolverResponse::descriptor()), + ABSL_DIE_IF_NULL(SatParameters::descriptor())})); +} + +} // namespace operations_research::sat::python + +int main(int argc, char* argv[]) { + // We do not use InitGoogle() to avoid linking with or-tools as this would + // create a circular dependency. + absl::InitializeLog(); + absl::SetProgramUsageMessage(argv[0]); + absl::ParseCommandLine(argc, argv); + operations_research::sat::python::ParseAndGenerate(); + return 0; +} diff --git a/ortools/sat/python/linear_expr.cc b/ortools/sat/python/linear_expr.cc index f8c2954f62..0b87598b67 100644 --- a/ortools/sat/python/linear_expr.cc +++ b/ortools/sat/python/linear_expr.cc @@ -26,6 +26,8 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/cp_model_utils.h" #include "ortools/util/fp_roundtrip_conv.h" #include "ortools/util/sorted_interval_list.h" @@ -143,8 +145,7 @@ void FloatExprVisitor::AddToProcess(std::shared_ptr expr, void FloatExprVisitor::AddConstant(double constant) { offset_ += constant; } -void FloatExprVisitor::AddVarCoeff(std::shared_ptr var, - double coeff) { +void FloatExprVisitor::AddVarCoeff(std::shared_ptr var, double coeff) { canonical_terms_[var] += coeff; } @@ -156,7 +157,7 @@ void FloatExprVisitor::ProcessAll() { } } -double FloatExprVisitor::Process(std::vector>* vars, +double FloatExprVisitor::Process(std::vector>* vars, std::vector* coeffs) { ProcessAll(); @@ -316,7 +317,7 @@ std::string FlatIntExpr::DebugString() const { return absl::StrCat( "FlatIntExpr([", absl::StrJoin(vars_, ", ", - [](std::string* out, std::shared_ptr var) { + [](std::string* out, std::shared_ptr var) { absl::StrAppend(out, var->DebugString()); }), "], [", absl::StrJoin(coeffs_, ", "), "], ", offset_, ")"); @@ -745,8 +746,7 @@ void IntExprVisitor::AddToProcess(std::shared_ptr expr, void IntExprVisitor::AddConstant(int64_t constant) { offset_ += constant; } -void IntExprVisitor::AddVarCoeff(std::shared_ptr var, - int64_t coeff) { +void IntExprVisitor::AddVarCoeff(std::shared_ptr var, int64_t coeff) { canonical_terms_[var] += coeff; } @@ -759,7 +759,7 @@ bool IntExprVisitor::ProcessAll() { return true; } -bool IntExprVisitor::Process(std::vector>* vars, +bool IntExprVisitor::Process(std::vector>* vars, std::vector* coeffs, int64_t* offset) { if (!ProcessAll()) return false; vars->clear(); @@ -789,64 +789,146 @@ bool IntExprVisitor::Evaluate(const CpSolverResponse& solution, // the same index and different models. int64_t Literal::Hash() const { return absl::HashOf(index()); } -bool BaseIntVarComparator::operator()(std::shared_ptr lhs, - std::shared_ptr rhs) const { +bool IntVarComparator::operator()(std::shared_ptr lhs, + std::shared_ptr rhs) const { return lhs->index() < rhs->index(); } -BaseIntVar::BaseIntVar(int index, bool is_boolean) - : index_(index), is_boolean_(is_boolean) {} - -std::shared_ptr BaseIntVar::negated() { - if (negated_ == nullptr) { - std::shared_ptr self = - std::static_pointer_cast(shared_from_this()); - negated_ = std::make_shared(self); +std::string IntVar::name() const { + if (model_proto_ == nullptr || index_ >= model_proto_->variables_size()) { + return ""; } - return negated_; + return model_proto_->variables(index_).name(); } -int NotBooleanVariable::index() const { - std::shared_ptr var = var_.lock(); - CHECK(var != nullptr); // Cannot happen as checked in the pybind11 code. - return -var->index() - 1; +void IntVar::SetName(const std::string& name) { + if (model_proto_ == nullptr || index_ >= model_proto_->variables_size()) { + return; + } + if (name.empty()) { + model_proto_->mutable_variables(index_)->clear_name(); + } else { + model_proto_->mutable_variables(index_)->set_name(name); + } } +Domain IntVar::domain() const { + if (model_proto_ == nullptr || index_ >= model_proto_->variables_size()) { + return Domain(); + } + return ReadDomainFromProto(model_proto_->variables(index_)); +} + +void IntVar::SetDomain(const Domain& domain) { + if (model_proto_ == nullptr || index_ >= model_proto_->variables_size()) { + return; + } + FillDomainInProto(domain, model_proto_->mutable_variables(index_)); +} + +std::shared_ptr IntVar::model_proto() const { + return model_proto_; +} + +IntegerVariableProto* IntVar::proto() const { + if (model_proto_ == nullptr || index_ >= model_proto_->variables_size()) { + return nullptr; + } + return model_proto_->mutable_variables(index_); +} + +bool IntVar::is_boolean() const { + IntegerVariableProto* var_proto = proto(); + if (var_proto == nullptr) return false; + return var_proto->domain_size() == 2 && var_proto->domain(0) >= 0 && + var_proto->domain(1) <= 1; +} + +bool IntVar::is_fixed() const { + IntegerVariableProto* var_proto = proto(); + if (var_proto == nullptr) return false; + return var_proto->domain_size() == 2 && + var_proto->domain(0) == var_proto->domain(1); +} + +std::shared_ptr IntVar::negated() const { + return std::make_shared(model_proto_, index_); +} + +namespace { +std::string VarDomainToString(IntegerVariableProto* var_proto) { + std::string domain_str; + for (int i = 0; i < var_proto->domain_size(); i += 2) { + const int64_t lb = var_proto->domain(i); + const int64_t ub = var_proto->domain(i + 1); + if (i > 0) absl::StrAppend(&domain_str, ", "); + if (lb == ub) { + absl::StrAppend(&domain_str, lb); + } else { + absl::StrAppend(&domain_str, lb, "..", ub); + } + } + return domain_str; +} + +} // namespace + +std::string IntVar::ToString() const { + std::string var_name = name(); + IntegerVariableProto* var_proto = proto(); + if (var_name.empty()) { + if (is_fixed() && var_proto != nullptr && var_proto->domain_size() >= 2) { + return absl::StrCat(var_proto->domain(0)); + } else if (is_boolean()) { + return absl::StrCat("b", index_); + } else { + return absl::StrCat("x", index_); + } + } + return var_name; +} + +std::string IntVar::DebugString() const { + std::string var_name = name(); + if (var_name.empty()) { + if (is_boolean()) { + var_name = absl::StrCat("b", index_); + } else { + var_name = absl::StrCat("x", index_); + } + } + IntegerVariableProto* var_proto = proto(); + if (var_proto == nullptr) return var_name; + return absl::StrCat(var_name, "(", VarDomainToString(var_proto), ")"); +} + +int NotBooleanVariable::index() const { return NegatedRef(var_index_); } + /** * Returns the negation of the current literal, that is the original Boolean * variable. */ -std::shared_ptr NotBooleanVariable::negated() { - std::shared_ptr var = var_.lock(); - CHECK(var != nullptr); // Cannot happen as checked in the pybind11 code. - return var; +std::shared_ptr NotBooleanVariable::negated() const { + return std::make_shared(model_proto_, var_index_); } bool NotBooleanVariable::VisitAsInt(IntExprVisitor& lin, int64_t c) { - std::shared_ptr var = var_.lock(); - CHECK(var != nullptr); // Cannot happen as checked in the pybind11 code. - lin.AddVarCoeff(var, -c); + lin.AddVarCoeff(std::make_shared(model_proto_, var_index_), -c); lin.AddConstant(c); return true; } void NotBooleanVariable::VisitAsFloat(FloatExprVisitor& lin, double c) { - std::shared_ptr var = var_.lock(); - CHECK(var != nullptr); // Cannot happen as checked in the pybind11 code. - lin.AddVarCoeff(var, -c); + lin.AddVarCoeff(std::make_shared(model_proto_, var_index_), -c); lin.AddConstant(c); } std::string NotBooleanVariable::ToString() const { - std::shared_ptr var = var_.lock(); - CHECK(var != nullptr); // Cannot happen as checked in the pybind11 code. - return absl::StrCat("not(", var->ToString(), ")"); + return absl::StrCat("not(", negated()->ToString(), ")"); } std::string NotBooleanVariable::DebugString() const { - std::shared_ptr var = var_.lock(); - CHECK(var != nullptr); // Cannot happen as checked in the pybind11 code. - return absl::StrCat("NotBooleanVariable(index=", var->index(), ")"); + return absl::StrCat("NotBooleanVariable(var_index=", var_index_, ")"); } BoundedLinearExpression::BoundedLinearExpression( @@ -868,7 +950,7 @@ BoundedLinearExpression::BoundedLinearExpression( } const Domain& BoundedLinearExpression::bounds() const { return bounds_; } -const std::vector>& BoundedLinearExpression::vars() +const std::vector>& BoundedLinearExpression::vars() const { return vars_; } @@ -956,7 +1038,7 @@ std::string BoundedLinearExpression::DebugString() const { return absl::StrCat( "BoundedLinearExpression(vars=[", absl::StrJoin(vars_, ", ", - [](std::string* out, std::shared_ptr var) { + [](std::string* out, std::shared_ptr var) { absl::StrAppend(out, var->DebugString()); }), "], coeffs=[", absl::StrJoin(coeffs_, ", "), "], offset=", offset_, diff --git a/ortools/sat/python/linear_expr.h b/ortools/sat/python/linear_expr.h index 06d973f9ea..3e74f256a5 100644 --- a/ortools/sat/python/linear_expr.h +++ b/ortools/sat/python/linear_expr.h @@ -36,7 +36,7 @@ class FloatExprVisitor; class LinearExpr; class IntExprVisitor; class LinearExpr; -class BaseIntVar; +class IntVar; class NotBooleanVariable; /** @@ -152,9 +152,9 @@ class LinearExpr : public std::enable_shared_from_this { }; /// Compare the indices of variables. -struct BaseIntVarComparator { - bool operator()(std::shared_ptr lhs, - std::shared_ptr rhs) const; +struct IntVarComparator { + bool operator()(std::shared_ptr lhs, + std::shared_ptr rhs) const; }; /// A visitor class to process a floating point linear expression. @@ -162,15 +162,15 @@ class FloatExprVisitor { public: void AddToProcess(std::shared_ptr expr, double coeff); void AddConstant(double constant); - void AddVarCoeff(std::shared_ptr var, double coeff); + void AddVarCoeff(std::shared_ptr var, double coeff); void ProcessAll(); - double Process(std::vector>* vars, + double Process(std::vector>* vars, std::vector* coeffs); double Evaluate(const CpSolverResponse& solution); private: std::vector, double>> to_process_; - absl::btree_map, double, BaseIntVarComparator> + absl::btree_map, double, IntVarComparator> canonical_terms_; double offset_ = 0; }; @@ -188,7 +188,7 @@ class FlatFloatExpr : public LinearExpr { /// expression. explicit FlatFloatExpr(std::shared_ptr expr); /// Returns the array of variables of the flattened expression. - const std::vector>& vars() const { return vars_; } + const std::vector>& vars() const { return vars_; } /// Returns the array of coefficients of the flattened expression. const std::vector& coeffs() const { return coeffs_; } /// Returns the offset of the flattened expression. @@ -202,7 +202,7 @@ class FlatFloatExpr : public LinearExpr { } private: - std::vector> vars_; + std::vector> vars_; std::vector coeffs_; double offset_ = 0; }; @@ -212,15 +212,15 @@ class IntExprVisitor { public: void AddToProcess(std::shared_ptr expr, int64_t coeff); void AddConstant(int64_t constant); - void AddVarCoeff(std::shared_ptr var, int64_t coeff); + void AddVarCoeff(std::shared_ptr var, int64_t coeff); bool ProcessAll(); - bool Process(std::vector>* vars, + bool Process(std::vector>* vars, std::vector* coeffs, int64_t* offset); bool Evaluate(const CpSolverResponse& solution, int64_t* value); private: std::vector, int64_t>> to_process_; - absl::btree_map, int64_t, BaseIntVarComparator> + absl::btree_map, int64_t, IntVarComparator> canonical_terms_; int64_t offset_ = 0; }; @@ -238,7 +238,7 @@ class FlatIntExpr : public LinearExpr { /// expression. explicit FlatIntExpr(std::shared_ptr expr); /// Returns the array of variables of the flattened expression. - const std::vector>& vars() const { return vars_; } + const std::vector>& vars() const { return vars_; } /// Returns the array of coefficients of the flattened expression. const std::vector& coeffs() const { return coeffs_; } /// Returns the offset of the flattened expression. @@ -265,7 +265,7 @@ class FlatIntExpr : public LinearExpr { std::string DebugString() const override; private: - std::vector> vars_; + std::vector> vars_; std::vector coeffs_; int64_t offset_ = 0; bool ok_ = true; @@ -479,90 +479,115 @@ class Literal : public LinearExpr { * Returns: * The negation of the current literal. */ - virtual std::shared_ptr negated() = 0; + virtual std::shared_ptr negated() const = 0; /// Returns the hash of the current literal. int64_t Hash() const; }; /** - * A class to hold a variable index. It is the base class for Integer - * variables. + * An integer variable. + * + * An IntVar is an object that can take on any integer value within defined + * ranges. Variables appear in constraint like: + * + * x + y >= 5 + * AllDifferent([x, y, z]) + * + * Solving a model is equivalent to finding, for each variable, a single value + * from the set of initial values (called the initial domain), such that the + * model is feasible, or optimal if you provided an objective function. */ -class BaseIntVar : public Literal { +class IntVar : public Literal { public: - explicit BaseIntVar(int index) : index_(index), is_boolean_(false) { + IntVar(std::shared_ptr model, int index) + : model_proto_(model), index_(index) { DCHECK_GE(index, 0); } - BaseIntVar(int index, bool is_boolean); - ~BaseIntVar() override = default; + explicit IntVar(std::shared_ptr model) + : model_proto_(model), index_(model->variables_size()) { + model->add_variables(); + } + ~IntVar() override = default; + + /// Returns the index of the variable in the model. int index() const override { return index_; } + /// Returns the name of the variable. + std::string name() const; + + /// Overwrite the name of the variable. If name is empty, this method clears + /// the name of the variable. + void SetName(const std::string& name); + + /// Returns a copy of the domain of the variable. + Domain domain() const; + + /// Overwrite the domain of the variable. + void SetDomain(const Domain& domain); + + /// Returns the model proto. + std::shared_ptr model_proto() const; + + /// Returns the proto of the variable. + IntegerVariableProto* proto() const; + + /// Returns the negation of the current variable. + std::shared_ptr negated() const override; + + /// Returns true if the variable has a Boolean domain (0 or 1). + bool is_boolean() const; + + /// Returns true if the variable is fixed. + bool is_fixed() const; + bool VisitAsInt(IntExprVisitor& lin, int64_t c) override { - std::shared_ptr var = - std::static_pointer_cast(shared_from_this()); + std::shared_ptr var = + std::static_pointer_cast(shared_from_this()); lin.AddVarCoeff(var, c); return true; } void VisitAsFloat(FloatExprVisitor& lin, double c) override { - std::shared_ptr var = - std::static_pointer_cast(shared_from_this()); + std::shared_ptr var = + std::static_pointer_cast(shared_from_this()); lin.AddVarCoeff(var, c); } - std::string ToString() const override { - if (negated_ != nullptr) { - return absl::StrCat("BooleanBaseIntVar(", index_, ")"); - } else { - return absl::StrCat("BaseIntVar(", index_, ")"); - } - } + std::string ToString() const override; - std::string DebugString() const override { - return absl::StrCat("BaseIntVar(index=", index_, - ", is_boolean=", negated_ != nullptr, ")"); - } + std::string DebugString() const override; - /// Returns the negation of the current variable. - std::shared_ptr negated() override; + bool operator<(const IntVar& other) const { return index_ < other.index_; } - /// Returns true if the variable has a Boolean domain (0 or 1). - bool is_boolean() const { return is_boolean_; } - - bool operator<(const BaseIntVar& other) const { - return index_ < other.index_; - } - - protected: + private: + std::shared_ptr model_proto_; const int index_; - const bool is_boolean_; - std::shared_ptr negated_; }; template -H AbslHashValue(H h, std::shared_ptr i) { +H AbslHashValue(H h, std::shared_ptr i) { return H::combine(std::move(h), i->index()); } /// A class to hold a negated variable index. class NotBooleanVariable : public Literal { public: - explicit NotBooleanVariable(std::shared_ptr var) : var_(var) {} + explicit NotBooleanVariable(std::shared_ptr model_proto, + int var_index) + : model_proto_(model_proto), var_index_(var_index) {} ~NotBooleanVariable() override = default; /// Returns the index of the current literal. int index() const override; - bool ok() const { return !var_.expired(); } - /** * Returns the negation of the current literal, that is the original Boolean * variable. */ - std::shared_ptr negated() override; + std::shared_ptr negated() const override; bool VisitAsInt(IntExprVisitor& lin, int64_t c) override; @@ -573,11 +598,8 @@ class NotBooleanVariable : public Literal { std::string DebugString() const override; private: - // We keep a weak ptr to the base variable to avoid a circular dependency. - // The base variable holds a shared pointer to the negated variable. - // Any call to a risky method is checked at the pybind11 level to raise a - // python exception before the call is made. - std::weak_ptr var_; + std::shared_ptr model_proto_; + const int var_index_; }; /// A class to hold a linear expression with bounds. @@ -597,7 +619,7 @@ class BoundedLinearExpression { /// Returns the bounds constraining the expression passed to the constructor. const Domain& bounds() const; /// Returns the array of variables of the flattened expression. - const std::vector>& vars() const; + const std::vector>& vars() const; /// Returns the array of coefficients of the flattened expression. const std::vector& coeffs() const; /// Returns the offset of the flattened expression. @@ -609,7 +631,7 @@ class BoundedLinearExpression { bool CastToBool(bool* result) const; private: - std::vector> vars_; + std::vector> vars_; std::vector coeffs_; int64_t offset_; const Domain bounds_; diff --git a/ortools/sat/python/linear_expr_doc.h b/ortools/sat/python/linear_expr_doc.h index d36484d457..b62752c217 100644 --- a/ortools/sat/python/linear_expr_doc.h +++ b/ortools/sat/python/linear_expr_doc.h @@ -46,55 +46,53 @@ static const char* __doc_operations_research_sat_python_AbslHashValue = R"doc()doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar = - R"doc(A class to hold a variable index. It is the base class for Integer -variables.)doc"; +static const char* __doc_operations_research_sat_python_IntVar = + R"doc(A class to hold an integer or Boolean variable)doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar_2 = - R"doc(A class to hold a variable index. It is the base class for Integer -variables.)doc"; +static const char* __doc_operations_research_sat_python_IntVar_2 = + R"doc(A class to hold an integer or Boolean variable)doc"; -static const char* __doc_operations_research_sat_python_BaseIntVarComparator = +static const char* __doc_operations_research_sat_python_IntVarComparator = R"doc(Compare the indices of variables.)doc"; static const char* - __doc_operations_research_sat_python_BaseIntVarComparator_operator_call = + __doc_operations_research_sat_python_IntVarComparator_operator_call = R"doc()doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar_BaseIntVar = +static const char* __doc_operations_research_sat_python_IntVar_IntVar = R"doc()doc"; -static const char* - __doc_operations_research_sat_python_BaseIntVar_BaseIntVar_2 = R"doc()doc"; - -static const char* __doc_operations_research_sat_python_BaseIntVar_DebugString = +static const char* __doc_operations_research_sat_python_IntVar_IntVar_2 = R"doc()doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar_ToString = +static const char* __doc_operations_research_sat_python_IntVar_DebugString = R"doc()doc"; -static const char* - __doc_operations_research_sat_python_BaseIntVar_VisitAsFloat = R"doc()doc"; - -static const char* __doc_operations_research_sat_python_BaseIntVar_VisitAsInt = +static const char* __doc_operations_research_sat_python_IntVar_ToString = R"doc()doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar_index = +static const char* __doc_operations_research_sat_python_IntVar_VisitAsFloat = R"doc()doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar_index_2 = +static const char* __doc_operations_research_sat_python_IntVar_VisitAsInt = R"doc()doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar_is_boolean = +static const char* __doc_operations_research_sat_python_IntVar_index = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntVar_index_2 = + R"doc()doc"; + +static const char* __doc_operations_research_sat_python_IntVar_is_boolean = R"doc(Returns true if the variable has a Boolean domain (0 or 1).)doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar_negated = +static const char* __doc_operations_research_sat_python_IntVar_negated = R"doc(Returns the negation of the current variable.)doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar_negated_2 = +static const char* __doc_operations_research_sat_python_IntVar_negated_2 = R"doc()doc"; -static const char* __doc_operations_research_sat_python_BaseIntVar_operator_lt = +static const char* __doc_operations_research_sat_python_IntVar_operator_lt = R"doc()doc"; static const char* diff --git a/ortools/sat/python/wrappers.cc b/ortools/sat/python/wrappers.cc new file mode 100644 index 0000000000..2fd1539199 --- /dev/null +++ b/ortools/sat/python/wrappers.cc @@ -0,0 +1,450 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/python/wrappers.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/die_if_null.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/types/span.h" +#include "google/protobuf/descriptor.h" + +namespace operations_research::sat::python { + +// A class that generates pybind11 code for a proto message. +class Generator { + public: + struct Context { + static Context TopLevel(const google::protobuf::Descriptor& msg) { + const std::string cpp_name = GetQualifiedCppName(msg); + const std::string shared_name = + absl::StrCat("std::shared_ptr<", cpp_name, ">"); + return {.cpp_name = cpp_name, .self_mutable_name = shared_name}; + } + + static Context Nested(const google::protobuf::Descriptor& msg) { + const std::string cpp_name = GetQualifiedCppName(msg); + return {.cpp_name = cpp_name, + .self_mutable_name = absl::StrCat(cpp_name, "*")}; + } + + std::string cpp_name; + std::string self_mutable_name; + }; + + explicit Generator( + absl::Span roots) + : message_stack_(roots.begin(), roots.end()) { + // DFS on root. + while (!message_stack_.empty()) { + const google::protobuf::Descriptor* const msg = message_stack_.back(); + message_stack_.pop_back(); + if (!visited_messages_.insert(msg).second) continue; + const bool is_top_level = absl::c_linear_search(roots, msg); + current_context_ = + is_top_level ? Context::TopLevel(*msg) : Context::Nested(*msg); + if (is_top_level) { + GenerateTopLevelMessageDecl(*msg); + } else { + GenerateMessageDecl(*msg); + } + GenerateMessageFields(*msg); + absl::StrAppend(&out_, ";\n"); + } + + // Now generate wrappers for enums, repeated and repeated ptr fields that + // were encountered along the way. + for (const google::protobuf::EnumDescriptor* pb_enum : enum_types_) { + GenerateEnumDecl(*pb_enum); + } + for (const google::protobuf::Descriptor* msg : repeated_ptr_types_) { + GenerateRepeatedPtrDecl(*msg); + } + for (const absl::string_view scalar_type : repeated_scalar_types_) { + GenerateRepeatedScalarDecl(scalar_type); + } + } + + std::string Result() && { return std::move(out_); } + + private: + template + static std::string GetQualifiedCppName(const DescriptorT& descriptor) { + return absl::StrReplaceAll(descriptor.full_name(), {{".", "::"}}); + } + + template + static std::string GetEscapedName(const DescriptorT& descriptor) { + return absl::StrReplaceAll(descriptor.full_name(), {{".", "_"}}); + } + + static std::string GetCppType( + const google::protobuf::FieldDescriptor::CppType cpp_type, + const google::protobuf::FieldDescriptor& field) { + switch (cpp_type) { + case google::protobuf::FieldDescriptor::CPPTYPE_INT32: + return "int32_t"; + case google::protobuf::FieldDescriptor::CPPTYPE_INT64: + return "int64_t"; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: + return "uint32_t"; + case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: + return "uint64_t"; + case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: + return "double"; + case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: + return "float"; + case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: + return "bool"; + case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: + return GetQualifiedCppName(*field.enum_type()); + case google::protobuf::FieldDescriptor::CPPTYPE_STRING: + return "std::string"; + default: + LOG(FATAL) << "Unsupported type: " << cpp_type; + } + } + + // Generates a pybind11 wrapper class declaration for a top level message. + void GenerateTopLevelMessageDecl(const google::protobuf::Descriptor& msg) { + CHECK(wrapper_id_.emplace(&msg, wrapper_id_.size()).second) + << "duplicate message: " << msg.full_name(); + absl::SubstituteAndAppend(&out_, R"( + const auto $0 = py::class_<$1, std::shared_ptr<$1>>($2, "$3"))", + GetWrapperName(&msg), current_context_.cpp_name, + GetWrapperName(msg.containing_type()), + msg.name()); + // Add constructor and utilities. + absl::SubstituteAndAppend(&out_, R"( + .def(py::init<>()) + .def("copy_from", + [](std::shared_ptr<$0> self, std::shared_ptr<$0> other) { + self->CopyFrom(*other); + }) + .def("merge_from", + [](std::shared_ptr<$0> self, std::shared_ptr<$0> other) { + self->MergeFrom(*other); + }) + .def("merge_text_format", + [](std::shared_ptr<$0> self, const std::string& text) { + return google::protobuf::TextFormat::MergeFromString(text, self.get()); + }) + .def("parse_text_format", + [](std::shared_ptr<$0> self, const std::string& text) { + return google::protobuf::TextFormat::ParseFromString(text, self.get()); + }) + .def("__copy__", + [](std::shared_ptr<$0> self) { + return self; + }) + .def("__deepcopy__", + [](std::shared_ptr<$0> self, py::dict) { + std::shared_ptr<$0> result = std::make_shared<$0>(); + result->CopyFrom(*self); + return result; + }) + .def("__str__", + [](std::shared_ptr<$0> self) { + return operations_research::ProtobufDebugString(*self); + }))", + current_context_.cpp_name); + } + + // Generates a pybind11 wrapper class declaration for a message. + void GenerateMessageDecl(const google::protobuf::Descriptor& msg) { + CHECK(wrapper_id_.emplace(&msg, wrapper_id_.size()).second) + << "duplicate message: " << msg.full_name(); + absl::SubstituteAndAppend(&out_, R"( + const auto $0 = py::class_<$1>($2, "$3"))", + GetWrapperName(&msg), current_context_.cpp_name, + GetWrapperName(msg.containing_type()), + msg.name()); + // Add constructor and utilities. + absl::SubstituteAndAppend(&out_, R"( + .def(py::init<>()) + .def("copy_from", + []($0* self, const $0& other) { self->CopyFrom(other); }) + .def("merge_from", + []($0* self, const $0& other) { self->MergeFrom(other); }) + .def("merge_text_format", + []($0* self, const std::string& text) { + return google::protobuf::TextFormat::MergeFromString(text, self); + }) + .def("parse_text_format", + []($0* self, const std::string& text) { + return google::protobuf::TextFormat::ParseFromString(text, self); + }) + .def("__copy__", + []($0 self) { + return $0(self); + }) + .def("__deepcopy__", + []($0 self, py::dict) { + return $0(self); + }) + .def("__str__", + []($0 self) { + return operations_research::ProtobufDebugString(self); + }))", + current_context_.cpp_name); + } + + // Generates a pybind11 wrapper class declaration for an enum. + void GenerateEnumDecl(const google::protobuf::EnumDescriptor& pb_enum) { + absl::SubstituteAndAppend(&out_, R"( + py::enum_<$0>($1, "$2"))", + GetQualifiedCppName(pb_enum), + GetWrapperName(pb_enum.containing_type()), + pb_enum.name()); + for (int i = 0; i < pb_enum.value_count(); ++i) { + const google::protobuf::EnumValueDescriptor& value = *pb_enum.value(i); + absl::SubstituteAndAppend(&out_, R"( + .value("$0", $1))", + value.name(), GetQualifiedCppName(value)); + } + absl::SubstituteAndAppend(&out_, R"( + .export_values();)"); + } + + // Generates a pybind11 wrapper class declaration & definitions for a repeated + // ptr. + void GenerateRepeatedPtrDecl(const google::protobuf::Descriptor& msg) { + absl::SubstituteAndAppend(&out_, R"( + py::class_>(m, "repeated_$1") + .def("add", + [](google::protobuf::RepeatedPtrField<$0>* self) { + return self->Add(); + }, + py::return_value_policy::reference, py::keep_alive<0, 1>()) + .def("append", [](google::protobuf::RepeatedPtrField<$0>* self, const $0& value) { + *self->Add() = value; + }) + .def("extend", + [](google::protobuf::RepeatedPtrField<$0>* self, const std::vector<$0>& values) { + for (const $0& value : values) { + *self->Add() = value; + } + }) + .def("__len__", &google::protobuf::RepeatedPtrField<$0>::size) + .def("__getitem__", + [](google::protobuf::RepeatedPtrField<$0>* self, int index) { + if (index >= self->size()) { + PyErr_SetString(PyExc_IndexError, "Index out of range"); + throw py::error_already_set(); + } + return self->Mutable(index); + }, + py::return_value_policy::reference, py::keep_alive<0, 1>());)", + GetQualifiedCppName(msg), msg.name()); + } + + // Generates a pybind11 wrapper class declaration & definitions for a repeated + // scalar. + void GenerateRepeatedScalarDecl(absl::string_view scalar_type) { + if (scalar_type == "std::string") { + absl::StrAppend(&out_, R"( + py::class_>(m, "repeated_scalar_std_string") + .def("append", + [](google::protobuf::RepeatedPtrField* self, std::string str) { + self->Add(std::move(str)); + }) + .def("extend", + [](google::protobuf::RepeatedPtrField* self, + const std::vector& values) { + self->Add(values.begin(), values.end()); + }) + .def("__len__", [](const google::protobuf::RepeatedPtrField& self) { + return self.size(); + }) + .def("__getitem__", + [](const google::protobuf::RepeatedPtrField& self, int index) { + if (index >= self.size()) { + PyErr_SetString(PyExc_IndexError, "Index out of range"); + throw py::error_already_set(); + } + + return self.Get(index); + }, + py::return_value_policy::copy) + .def("__setitem__", + [](google::protobuf::RepeatedPtrField* self, + int index, const std::string& value) { + self->at(index) = value; + }) + .def("__str__", [](const google::protobuf::RepeatedPtrField& self) { + return absl::StrCat("[", absl::StrJoin(self, ", "), "]"); + });)"); + } else { + absl::SubstituteAndAppend( + &out_, R"( + py::class_>(m, "repeated_scalar_$1") + .def("append", [](google::protobuf::RepeatedField<$0>* self, $0 value) { + self->Add(value); + }) + .def("extend", [](google::protobuf::RepeatedField<$0>* self, + const std::vector<$0>& values) { + self->Add(values.begin(), values.end()); + }) + .def("__len__", [](const google::protobuf::RepeatedField<$0>& self) { + return self.size(); + }) + .def("__getitem__", [](const google::protobuf::RepeatedField<$0>& self, int index) { + if (index >= self.size()) { + PyErr_SetString(PyExc_IndexError, "Index out of range"); + throw py::error_already_set(); + } + + return self.Get(index); + }) + .def("__setitem__", &google::protobuf::RepeatedField<$0>::Set) + .def("__str__", [](const google::protobuf::RepeatedField<$0>& self) { + return absl::StrCat("[", absl::StrJoin(self, ", "), "]"); + });)", + scalar_type, absl::StrReplaceAll(scalar_type, {{"::", "_"}})); + } + } + + void GenerateRepeatedField(const google::protobuf::FieldDescriptor& field) { + const google::protobuf::Descriptor* msg_type = field.message_type(); + if (msg_type != nullptr) { + // Repeated message. + absl::SubstituteAndAppend( + &out_, R"( + .def_property_readonly( + "$0", + []($1 self) { return self->mutable_$2(); }, + py::return_value_policy::reference, py::keep_alive<0, 1>()))", + field.name(), current_context_.self_mutable_name, field.name()); + // We'll need to generate the wrapping for `proto2::RepeatedPtrField<$3>`. + repeated_ptr_types_.insert(msg_type); + // We'll need to generate the wrapping for this message type. + message_stack_.push_back(ABSL_DIE_IF_NULL(field.message_type())); + } else { + // Repeated scalar field. + absl::SubstituteAndAppend(&out_, R"( + .def_property_readonly( + "$0", + []($1 self) { return self->mutable_$0(); }, + py::return_value_policy::reference, py::keep_alive<0, 1>()))", + field.name(), + current_context_.self_mutable_name); + // We'll need to generate the wrapping for `proto2::RepeatedField<$2>`. + repeated_scalar_types_.insert(GetCppType(field.cpp_type(), field)); + } + } + + void GenerateSingularField(const google::protobuf::FieldDescriptor& field) { + if (const google::protobuf::Descriptor* msg_type = field.message_type()) { + // Singular message. + absl::SubstituteAndAppend(&out_, R"( + .def_property( + "$0", + []($1 self) { return self->mutable_$0(); }, + []($1 self, $2 arg) { *self->mutable_$0() = arg; }, + py::return_value_policy::reference_internal) + .def("clear_$0", []($1 self) { self->clear_$0(); }) + .def("has_$0", []($1 self) { return self->has_$0(); }))", + field.name(), + current_context_.self_mutable_name, + GetQualifiedCppName(*msg_type)); + // We'll need to generate the wrapping for this message type. + message_stack_.push_back(ABSL_DIE_IF_NULL(field.message_type())); + } else { + if (const google::protobuf::EnumDescriptor* enum_type = + field.enum_type()) { + enum_types_.insert(enum_type); + } + // Singular scalar (int, string, ...). + absl::SubstituteAndAppend(&out_, R"( + .def_property( + "$0", + []($1 msg) { return msg->$0(); }, + []($1 msg, $2 arg) { return msg->set_$0(arg); }) + .def("clear_$0", []($1 self) { self->clear_$0(); }))", + field.name(), + current_context_.self_mutable_name, + GetCppType(field.cpp_type(), field)); + } + } + + // Generates definitions for accessing fields of a message. + void GenerateMessageFields(const google::protobuf::Descriptor& msg) { + const std::string msg_name = GetQualifiedCppName(msg); + + for (int i = 0; i < msg.field_count(); ++i) { + const google::protobuf::FieldDescriptor& field = + *ABSL_DIE_IF_NULL(msg.field(i)); + if (field.is_repeated()) { + GenerateRepeatedField(field); + } else { + GenerateSingularField(field); + } + } + } + + // Returns the wrapper name for a message (or "m" if `msg` is null). + // Dies if the scope is not found. + std::string GetWrapperName(const google::protobuf::Descriptor* msg) { + const auto it = wrapper_id_.find(msg); + CHECK(it != wrapper_id_.end()) + << "wrapper id not found: " << msg->full_name(); + if (msg == nullptr) return "m"; + return absl::StrCat("gen_", it->second); + } + + // This identifies the pybind11 wrapper variable for a `_class` declaration in + // the generated code. These are used to generate enums in the correct + // scope. + static constexpr int kNoScope = 0; + absl::flat_hash_map wrapper_id_ = { + {nullptr, kNoScope}}; + + // Output buffer. + std::string out_; + + // Our DFS stack. + std::vector message_stack_; + absl::flat_hash_set + visited_messages_; + + // A list of enum wrappers to generate. + absl::flat_hash_set + enum_types_; + // A list of repeated ptr wrappers to generate. + absl::flat_hash_set + repeated_ptr_types_; + // A list of repeated scalar wrappers to generate. + absl::flat_hash_set repeated_scalar_types_; + + // Context for the current message being generated. + Context current_context_; +}; + +std::string GeneratePybindCode( + absl::Span roots) { + return Generator(roots).Result(); +} + +} // namespace operations_research::sat::python diff --git a/ortools/sat/python/wrappers.h b/ortools/sat/python/wrappers.h new file mode 100644 index 0000000000..04aaa6e594 --- /dev/null +++ b/ortools/sat/python/wrappers.h @@ -0,0 +1,31 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_PYTHON_WRAPPERS_H_ +#define OR_TOOLS_SAT_PYTHON_WRAPPERS_H_ + +#include + +#include "absl/base/nullability.h" +#include "absl/types/span.h" +#include "google/protobuf/descriptor.h" + +namespace operations_research::sat::python { + +// Generated pybind11 code for the given proto messages. +std::string GeneratePybindCode( + absl::Span roots); + +} // namespace operations_research::sat::python + +#endif // OR_TOOLS_SAT_PYTHON_WRAPPERS_H_ diff --git a/ortools/sat/rins.cc b/ortools/sat/rins.cc index 11a6672285..fbd01e762f 100644 --- a/ortools/sat/rins.cc +++ b/ortools/sat/rins.cc @@ -204,14 +204,11 @@ ReducedDomainNeighborhood GetRinsRensNeighborhood( if (relaxation_values.empty()) return reduced_domains; // Not generated. std::bernoulli_distribution three_out_of_four(0.75); - - if (response_manager != nullptr && - response_manager->SolutionsRepository().NumSolutions() > 0 && + if (response_manager != nullptr && response_manager->HasFeasibleSolution() && three_out_of_four(random)) { // Rins. std::shared_ptr::Solution> solution = - response_manager->SolutionsRepository().GetRandomBiasedSolution( - random); + response_manager->SolutionPool().GetSolutionToImprove(random); FillRinsNeighborhood(solution->variable_values, relaxation_values, difficulty, random, reduced_domains); reduced_domains.source_info = "rins_"; diff --git a/ortools/sat/rins_test.cc b/ortools/sat/rins_test.cc index 74ff5995da..17740b6d10 100644 --- a/ortools/sat/rins_test.cc +++ b/ortools/sat/rins_test.cc @@ -16,6 +16,7 @@ #include #include +#include "absl/strings/match.h" #include "absl/types/span.h" #include "gtest/gtest.h" #include "ortools/base/parse_test_proto.h" @@ -150,19 +151,24 @@ TEST(GetRinsRensNeighborhoodTest, GetRinsRensNeighborhoodLP) { // Add a lp solution. lp_solutions.NewLPSolution({3.5, 5}); lp_solutions.Synchronize(); + // Add a solution. CpSolverResponse solution; solution.add_solution(4); solution.add_solution(5); shared_response_manager->NewSolution(solution.solution(), solution.solution_info(), &model); - shared_response_manager->MutableSolutionsRepository()->Synchronize(); + shared_response_manager->Synchronize(); - const ReducedDomainNeighborhood rins_neighborhood = GetRinsRensNeighborhood( - shared_response_manager, &lp_solutions, &incomplete_solutions, - /*difficulty=*/0.5, random); + ReducedDomainNeighborhood rins_neighborhood; + for (int i = 0; i < 100; ++i) { + rins_neighborhood = GetRinsRensNeighborhood( + shared_response_manager, &lp_solutions, &incomplete_solutions, + /*difficulty=*/0.5, random); + if (absl::StartsWith(rins_neighborhood.source_info, "rins")) break; + } - EXPECT_EQ(rins_neighborhood.reduced_domain_vars.size(), 0); + EXPECT_TRUE(rins_neighborhood.reduced_domain_vars.empty()); EXPECT_EQ(rins_neighborhood.fixed_vars.size(), 1); EXPECT_EQ(rins_neighborhood.fixed_vars[0].first, 1); EXPECT_EQ(rins_neighborhood.fixed_vars[0].second, 5); diff --git a/ortools/sat/routing_cuts.cc b/ortools/sat/routing_cuts.cc index 13f484a901..1da0ac49e1 100644 --- a/ortools/sat/routing_cuts.cc +++ b/ortools/sat/routing_cuts.cc @@ -124,6 +124,7 @@ MinOutgoingFlowHelper::MinOutgoingFlowHelper( trail_(*model->GetOrCreate()), integer_trail_(*model->GetOrCreate()), integer_encoder_(*model->GetOrCreate()), + root_level_bounds_(*model->GetOrCreate()), shared_stats_(model->GetOrCreate()), in_subset_(num_nodes, false), index_in_subset_(num_nodes, -1), @@ -629,7 +630,8 @@ int MinOutgoingFlowHelper::ComputeMinOutgoingFlow( // If this arc cannot be taken skip. tmp_lbs.clear(); if (!binary_relation_repository_.PropagateLocalBounds( - integer_trail_, lit, node_var_lower_bounds_[tail], &tmp_lbs)) { + integer_trail_, root_level_bounds_, lit, + node_var_lower_bounds_[tail], &tmp_lbs)) { continue; } @@ -755,8 +757,8 @@ int MinOutgoingFlowHelper::ComputeTightMinOutgoingFlow( // If this arc cannot be taken skip. tmp_lbs.clear(); if (!binary_relation_repository_.PropagateLocalBounds( - integer_trail_, literals_[outgoing_arc_index], path_bounds, - &tmp_lbs)) { + integer_trail_, root_level_bounds_, + literals_[outgoing_arc_index], path_bounds, &tmp_lbs)) { continue; } @@ -916,7 +918,7 @@ bool MinOutgoingFlowHelper::SubsetMightBeServedWithKRoutes( absl::flat_hash_map copy = state.lbs; return binary_relation_repository_.PropagateLocalBounds( - integer_trail_, unique_lit, copy, &state.lbs); + integer_trail_, root_level_bounds_, unique_lit, copy, &state.lbs); }; // We always start with the first node in this case. @@ -1011,7 +1013,8 @@ bool MinOutgoingFlowHelper::SubsetMightBeServedWithKRoutes( } } else { if (!binary_relation_repository_.PropagateLocalBounds( - integer_trail_, literal, from_state.lbs, &to_state.lbs)) { + integer_trail_, root_level_bounds_, literal, from_state.lbs, + &to_state.lbs)) { continue; } } @@ -1077,12 +1080,24 @@ struct LocalRelation { IntegerVariable UniqueSharedVariable(const sat::Relation& r1, const sat::Relation& r2) { - DCHECK_NE(r1.a.var, r1.b.var); - DCHECK_NE(r2.a.var, r2.b.var); - if (r1.a.var == r2.a.var && r1.b.var != r2.b.var) return r1.a.var; - if (r1.a.var == r2.b.var && r1.b.var != r2.a.var) return r1.a.var; - if (r1.b.var == r2.a.var && r1.a.var != r2.b.var) return r1.b.var; - if (r1.b.var == r2.b.var && r1.a.var != r2.a.var) return r1.b.var; + DCHECK_NE(r1.expr.vars[0], r1.expr.vars[1]); + DCHECK_NE(r2.expr.vars[0], r2.expr.vars[1]); + if (r1.expr.vars[0] == r2.expr.vars[0] && + r1.expr.vars[1] != r2.expr.vars[1]) { + return r1.expr.vars[0]; + } + if (r1.expr.vars[0] == r2.expr.vars[1] && + r1.expr.vars[1] != r2.expr.vars[0]) { + return r1.expr.vars[0]; + } + if (r1.expr.vars[1] == r2.expr.vars[0] && + r1.expr.vars[0] != r2.expr.vars[1]) { + return r1.expr.vars[1]; + } + if (r1.expr.vars[1] == r2.expr.vars[1] && + r1.expr.vars[0] != r2.expr.vars[0]) { + return r1.expr.vars[1]; + } return kNoIntegerVariable; } @@ -1254,10 +1269,11 @@ class RouteRelationsBuilder { binary_relation_repository_.IndicesOfRelationsEnforcedBy( literals_[i])) { const auto& r = binary_relation_repository_.relation(relation_index); - if (r.a.var == kNoIntegerVariable || r.b.var == kNoIntegerVariable) { + if (r.expr.vars[0] == kNoIntegerVariable || + r.expr.vars[1] == kNoIntegerVariable) { continue; } - cc_finder.AddEdge(r.a.var, r.b.var); + cc_finder.AddEdge(r.expr.vars[0], r.expr.vars[1]); } } const std::vector> connected_components = @@ -1283,10 +1299,11 @@ class RouteRelationsBuilder { binary_relation_repository_.IndicesOfRelationsEnforcedBy( literals_[i])) { const auto& r = binary_relation_repository_.relation(relation_index); - if (r.a.var == kNoIntegerVariable || r.b.var == kNoIntegerVariable) { + if (r.expr.vars[0] == kNoIntegerVariable || + r.expr.vars[1] == kNoIntegerVariable) { continue; } - const int dimension = dimension_by_var_[r.a.var]; + const int dimension = dimension_by_var_[r.expr.vars[0]]; adjacent_relation_indices_[dimension][tails_[i]].push_back( relation_index); adjacent_relation_indices_[dimension][heads_[i]].push_back( @@ -1360,24 +1377,25 @@ class RouteRelationsBuilder { binary_relation_repository_.IndicesOfRelationsEnforcedBy( literals_[arc_index])) { const auto& r = binary_relation_repository_.relation(relation_index); - if (r.a.var == kNoIntegerVariable || r.b.var == kNoIntegerVariable) { + if (r.expr.vars[0] == kNoIntegerVariable || + r.expr.vars[1] == kNoIntegerVariable) { continue; } - if (r.a.var == node_expr.var) { + if (r.expr.vars[0] == node_expr.var) { if (candidate_var != kNoIntegerVariable && - candidate_var != r.b.var) { + candidate_var != r.expr.vars[1]) { candidate_var_is_unique = false; break; } - candidate_var = r.b.var; + candidate_var = r.expr.vars[1]; } - if (r.b.var == node_expr.var) { + if (r.expr.vars[1] == node_expr.var) { if (candidate_var != kNoIntegerVariable && - candidate_var != r.a.var) { + candidate_var != r.expr.vars[0]) { candidate_var_is_unique = false; break; } - candidate_var = r.a.var; + candidate_var = r.expr.vars[0]; } } if (candidate_var != kNoIntegerVariable && candidate_var_is_unique) { @@ -1394,6 +1412,8 @@ class RouteRelationsBuilder { const auto& integer_encoder = *model->GetOrCreate(); const auto& trail = *model->GetOrCreate(); const auto& integer_trail = *model->GetOrCreate(); + const auto& root_level_bounds = + *model->GetOrCreate(); DCHECK_EQ(trail.CurrentDecisionLevel(), 0); flat_arc_dim_relations_ = std::vector( @@ -1488,21 +1508,26 @@ class RouteRelationsBuilder { // Try to match the relation variables with the node expression // variables. First swap the relation terms if needed (this does not // change the relation bounds). - if ((r.a.var != kNoIntegerVariable && r.a.var == head_expr.var) || - (r.b.var != kNoIntegerVariable && r.b.var == tail_expr.var)) { - std::swap(r.a, r.b); + if ((r.expr.vars[0] != kNoIntegerVariable && + r.expr.vars[0] == head_expr.var) || + (r.expr.vars[1] != kNoIntegerVariable && + r.expr.vars[1] == tail_expr.var)) { + std::swap(r.expr.vars[0], r.expr.vars[1]); + std::swap(r.expr.coeffs[0], r.expr.coeffs[1]); } // If the relation has only one term, try to remove the variable // in the node expression corresponding to the missing term. - if (r.a.var == kNoIntegerVariable) { + if (r.expr.vars[0] == kNoIntegerVariable) { if (!to_constant(tail_expr)) continue; - } else if (r.b.var == kNoIntegerVariable) { + } else if (r.expr.vars[1] == kNoIntegerVariable) { if (!to_constant(head_expr)) continue; } // If the relation and node expression variables do not match, we // cannot use this relation for this arc. - if (!((tail_expr.var == r.a.var && head_expr.var == r.b.var) || - (tail_expr.var == r.b.var && head_expr.var == r.a.var))) { + if (!((tail_expr.var == r.expr.vars[0] && + head_expr.var == r.expr.vars[1]) || + (tail_expr.var == r.expr.vars[1] && + head_expr.var == r.expr.vars[0]))) { continue; } ComputeArcRelation(i, dimension, tail_expr, head_expr, r, @@ -1512,13 +1537,12 @@ class RouteRelationsBuilder { // Check if we can use non-enforced relations to improve the relations. if (!tail_expr.IsEmpty() && !head_expr.IsEmpty()) { - for (const int relation_index : - binary_relation_repository_.IndicesOfRelationsBetween( + for (const auto& [expr, lb, ub] : + root_level_bounds.GetAllBoundsContainingVariables( tail_expr.var, head_expr.var)) { - ComputeArcRelation( - i, dimension, tail_expr, head_expr, - binary_relation_repository_.relation(relation_index), - integer_trail); + ComputeArcRelation(i, dimension, tail_expr, head_expr, + Relation{Literal(kNoLiteralIndex), expr, lb, ub}, + integer_trail); } } @@ -1553,20 +1577,25 @@ class RouteRelationsBuilder { const NodeExpression& tail_expr, const NodeExpression& head_expr, sat::Relation r, const IntegerTrail& integer_trail) { - DCHECK((r.a.var == tail_expr.var && r.b.var == head_expr.var) || - (r.a.var == head_expr.var && r.b.var == tail_expr.var)); - if (r.a.var != tail_expr.var) std::swap(r.a, r.b); - if (r.a.coeff == 0 || tail_expr.coeff == 0) { - LocalRelation result = ComputeArcUnaryRelation(head_expr, tail_expr, - r.b.coeff, r.lhs, r.rhs); + DCHECK( + (r.expr.vars[0] == tail_expr.var && r.expr.vars[1] == head_expr.var) || + (r.expr.vars[0] == head_expr.var && r.expr.vars[1] == tail_expr.var)); + if (r.expr.vars[0] != tail_expr.var) { + std::swap(r.expr.vars[0], r.expr.vars[1]); + std::swap(r.expr.coeffs[0], r.expr.coeffs[1]); + } + if (r.expr.coeffs[0] == 0 || tail_expr.coeff == 0) { + LocalRelation result = ComputeArcUnaryRelation( + head_expr, tail_expr, r.expr.coeffs[1], r.lhs, r.rhs); std::swap(result.tail_coeff, result.head_coeff); ProcessNewArcRelation(arc_index, dimension, result); return; } - if (r.b.coeff == 0 || head_expr.coeff == 0) { - ProcessNewArcRelation(arc_index, dimension, - ComputeArcUnaryRelation(tail_expr, head_expr, - r.a.coeff, r.lhs, r.rhs)); + if (r.expr.coeffs[1] == 0 || head_expr.coeff == 0) { + ProcessNewArcRelation( + arc_index, dimension, + ComputeArcUnaryRelation(tail_expr, head_expr, r.expr.coeffs[0], r.lhs, + r.rhs)); return; } const auto [lhs, rhs] = @@ -1680,14 +1709,16 @@ IntegerValue GetDifferenceLowerBound( // TODO(user): overflows could happen if the node expressions are // provided by the user in the model proto. auto lower_bound = [&](IntegerValue k) { - const IntegerValue y_coeff = y_expr.coeff - k * r.b.coeff; - const IntegerValue x_coeff = k * (-r.a.coeff) - x_expr.coeff; + const IntegerValue y_coeff = y_expr.coeff - k * r.expr.coeffs[1]; + const IntegerValue x_coeff = k * (-r.expr.coeffs[0]) - x_expr.coeff; return y_coeff * (y_coeff >= 0 ? y_var_bounds.first : y_var_bounds.second) + x_coeff * (x_coeff >= 0 ? x_var_bounds.first : x_var_bounds.second) + k * (k >= 0 ? r.lhs : r.rhs); }; - const IntegerValue k_x = MathUtil::FloorOfRatio(x_expr.coeff, -r.a.coeff); - const IntegerValue k_y = MathUtil::FloorOfRatio(y_expr.coeff, r.b.coeff); + const IntegerValue k_x = + MathUtil::FloorOfRatio(x_expr.coeff, -r.expr.coeffs[0]); + const IntegerValue k_y = + MathUtil::FloorOfRatio(y_expr.coeff, r.expr.coeffs[1]); IntegerValue result = lower_bound(0); result = std::max(result, lower_bound(k_x)); result = std::max(result, lower_bound(k_x + 1)); @@ -1702,14 +1733,14 @@ std::pair GetDifferenceBounds( const sat::Relation& r, const std::pair& x_var_bounds, const std::pair& y_var_bounds) { - DCHECK_EQ(x_expr.var, r.a.var); - DCHECK_EQ(y_expr.var, r.b.var); + DCHECK_EQ(x_expr.var, r.expr.vars[0]); + DCHECK_EQ(y_expr.var, r.expr.vars[1]); DCHECK_NE(x_expr.var, kNoIntegerVariable); DCHECK_NE(y_expr.var, kNoIntegerVariable); DCHECK_NE(x_expr.coeff, 0); DCHECK_NE(y_expr.coeff, 0); - DCHECK_NE(r.a.coeff, 0); - DCHECK_NE(r.b.coeff, 0); + DCHECK_NE(r.expr.coeffs[0], 0); + DCHECK_NE(r.expr.coeffs[1], 0); const IntegerValue lb = GetDifferenceLowerBound(x_expr, y_expr, r, x_var_bounds, y_var_bounds); const IntegerValue ub = -GetDifferenceLowerBound( @@ -1830,6 +1861,7 @@ BinaryRelationRepository ComputePartialBinaryRelationRepository( ToPositiveIntegerVariable(vars[0]), ToPositiveIntegerVariable(vars[1])); } + Model empty_model; repository.Build(); return repository; } @@ -1917,6 +1949,7 @@ class RoutingCutHelper { *model->GetOrCreate()), random_(model->GetOrCreate()), encoder_(model->GetOrCreate()), + root_level_bounds_(*model->GetOrCreate()), in_subset_(num_nodes, false), self_arc_literal_(num_nodes_), self_arc_lp_value_(num_nodes_), @@ -2050,6 +2083,7 @@ class RoutingCutHelper { const BinaryRelationRepository& binary_relation_repository_; ModelRandomGenerator* random_; IntegerEncoder* encoder_; + const RootLevelLinear2Bounds& root_level_bounds_; std::vector in_subset_; @@ -2755,7 +2789,8 @@ void RoutingCutHelper::GenerateCutsForInfeasiblePaths( const Literal next_literal = literals_[arc_index]; next_state.bounds = state.bounds; if (binary_relation_repository_.PropagateLocalBounds( - integer_trail_, next_literal, state.bounds, &next_state.bounds)) { + integer_trail_, root_level_bounds_, next_literal, state.bounds, + &next_state.bounds)) { // Do not explore "long" paths to keep the search time bounded. if (path_length < max_path_length) { path_nodes[next_state.last_node] = true; diff --git a/ortools/sat/routing_cuts.h b/ortools/sat/routing_cuts.h index 584e0cca19..0713f01bc4 100644 --- a/ortools/sat/routing_cuts.h +++ b/ortools/sat/routing_cuts.h @@ -545,6 +545,7 @@ class MinOutgoingFlowHelper { const Trail& trail_; const IntegerTrail& integer_trail_; const IntegerEncoder& integer_encoder_; + const RootLevelLinear2Bounds& root_level_bounds_; SharedStatistics* shared_stats_; // Temporary data used by ComputeMinOutgoingFlow(). Always contain default diff --git a/ortools/sat/routing_cuts_test.cc b/ortools/sat/routing_cuts_test.cc index 9b38802d5f..39bb0469ee 100644 --- a/ortools/sat/routing_cuts_test.cc +++ b/ortools/sat/routing_cuts_test.cc @@ -65,7 +65,7 @@ std::pair ExactDifferenceBounds( IntegerValue ub = kMinIntegerValue; for (IntegerValue x = x_bounds.first; x <= x_bounds.second; ++x) { for (IntegerValue y = y_bounds.first; y <= y_bounds.second; ++y) { - const IntegerValue r_value = x * r.a.coeff + y * r.b.coeff; + const IntegerValue r_value = x * r.expr.coeffs[0] + y * r.expr.coeffs[1]; if (r_value < r.lhs || r_value > r.rhs) continue; const IntegerValue difference = y_expr.ValueAt(y) - x_expr.ValueAt(x); lb = std::min(lb, difference); @@ -101,8 +101,7 @@ TEST(GetDifferenceBounds, RandomTest) { const NodeExpression y_expr(y, B, absl::Uniform(random, -5, 5)); const Relation r{ .enforcement = lit, - .a = LinearTerm(x, a), - .b = LinearTerm(y, b), + .expr = LinearExpression2(x, y, a, b), .lhs = lhs, .rhs = rhs, }; @@ -162,8 +161,8 @@ TEST(MinOutgoingFlowHelperTest, CapacityConstraints) { // picked up by the vehicle leaving n. const int head_load = head == 0 ? 0 : head + 10; // loads[head] - loads[tail] >= head_load - repository->Add(literal, {loads[head], 1}, {loads[tail], -1}, head_load, - 1000); + repository->Add(literal, LinearExpression2(loads[head], loads[tail], 1, -1), + head_load, 1000); } repository->Build(); // Subject under test. @@ -230,11 +229,13 @@ TEST_P(DimensionBasedMinOutgoingFlowHelperTest, BasicCapacities) { if (tail == 0 || head == 0) continue; if (pickup) { // loads[head] - loads[tail] >= demand - repository->Add(literal, {loads[head], 1}, {loads[tail], -1}, + repository->Add(literal, + LinearExpression2(loads[head], loads[tail], 1, -1), demands[use_outgoing_load ? head : tail], 1000); } else { // loads[tail] - loads[head] >= demand - repository->Add(literal, {loads[tail], 1}, {loads[head], -1}, + repository->Add(literal, + LinearExpression2(loads[tail], loads[head], 1, -1), demands[use_outgoing_load ? head : tail], 1000); } } @@ -301,11 +302,13 @@ TEST_P(DimensionBasedMinOutgoingFlowHelperTest, const int tail = tails[i]; if (pickup) { // loads[head] - loads[tail] >= demand - repository->Add(literals[i], {loads[head], 1}, {loads[tail], -1}, + repository->Add(literals[i], + LinearExpression2::Difference(loads[head], loads[tail]), demands[use_outgoing_load ? head : tail], 1000); } else { // loads[tail] - loads[head] >= demand - repository->Add(literals[i], {loads[tail], 1}, {loads[head], -1}, + repository->Add(literals[i], + LinearExpression2::Difference(loads[tail], loads[head]), demands[use_outgoing_load ? head : tail], 1000); } } @@ -357,8 +360,8 @@ TEST(MinOutgoingFlowHelperTest, NodeExpressionWithConstant) { auto* repository = model.GetOrCreate(); // Capacity constraint: (offset_load2 + offset) - load1 >= demand1 - repository->Add(literals[0], {offset_load2, 1}, {load1, -1}, demand1 - offset, - 1000); + repository->Add(literals[0], LinearExpression2(offset_load2, load1, 1, -1), + demand1 - offset, 1000); repository->Build(); std::unique_ptr route_relations_helper = RouteRelationsHelper::Create(num_nodes, tails, heads, literals, @@ -398,7 +401,8 @@ TEST(MinOutgoingFlowHelperTest, ConstantNodeExpression) { auto* repository = model.GetOrCreate(); // Capacity constraint: load2 - load1 >= demand1 - repository->Add(literals[0], {kNoIntegerVariable, 0}, {load1, -1}, + repository->Add(literals[0], + LinearExpression2(kNoIntegerVariable, load1, 0, -1), demand1 - load2, 1000); repository->Build(); std::unique_ptr route_relations_helper = @@ -451,7 +455,8 @@ TEST(MinOutgoingFlowHelperTest, NodeExpressionUsingArcLiteralAsVariable) { // Capacity constraint: load2 - load1 >= demand1. This expands to // (capacity - demand2 - demand3 * l) - load1 >= demand1, i.e., // -demand3 * l - load1 >= demand1 + demand2 - capacity - repository->Add(literals[0], {arc_2_3_var, -demand3}, {load1, -1}, + repository->Add(literals[0], + LinearExpression2(arc_2_3_var, load1, -demand3, -1), demand1 + demand2 - capacity, 1000); // Capacity constraint: load3 - load2 >= demand2. This expands to // (capacity - demand3) - (capacity - demand2 - demand3 * l) >= demand2 which, @@ -508,7 +513,8 @@ TEST(MinOutgoingFlowHelperTest, // Capacity constraint: load2 - load1 >= demand1. This expands to // (capacity - demand2 - demand3 + demand3 * l) - load1 >= demand1, i.e., // demand3 * l - load1 >= demand1 + demand2 + demand3 - capacity - repository->Add(literals[0], {arc_2_3_var, demand3}, {load1, -1}, + repository->Add(literals[0], + LinearExpression2(arc_2_3_var, load1, demand3, -1), demand1 + demand2 + demand3 - capacity, 1000); // Capacity constraint: load3 - load2 >= demand2. This expands to // (capacity - demand3) - (capacity - demand2 - demand3 + demand3 * l) >= @@ -566,7 +572,7 @@ TEST(MinOutgoingFlowHelperTest, ArcNodeExpressionsWithSharedVariable) { // Capacity constraint: load2 - load1 >= demand1. This expands to // (capacity - demand2 - demand3) - coeff * x - load1 >= demand1, i.e., // -coeff * x - load1 >= demand1 + demand2 + demand3 - capacity. - repository->Add(literals[0], {x, -coeff}, {load1, -1}, + repository->Add(literals[0], LinearExpression2(x, load1, -coeff, -1), demand1 + demand2 + demand3 - capacity, 1000); // Capacity constraint: load3 - load2 >= demand2. This expands to // (capacity - demand3) - (capacity - demand2 - demand3) >= demand2, which @@ -629,12 +635,14 @@ TEST(MinOutgoingFlowHelperTest, UnaryRelationForTwoNodeExpressions) { // constraint is enforced by arc_1_2_lit we can assume it is true, which // implies that x = 0. Hence the constraint simplifies to load1 <= capacity - // demand2 - demand1. - repository->Add(literals[0], {load1, 1}, {kNoIntegerVariable, 0}, 0, + repository->Add(literals[0], + LinearExpression2(load1, kNoIntegerVariable, 1, 0), 0, capacity - demand1 - demand2); // Capacity constraint: load3 - load2 >= demand2. This expands to // load3 - ((capacity - demand2) - demand1 * x) >= demand2, i.e. to load3 + // demand1 * x >= capacity - repository->Add(literals[1], {load3, 1}, {x, demand1}, capacity, 1000); + repository->Add(literals[1], LinearExpression2(load3, x, 1, demand1), + capacity, 1000); repository->Build(); std::unique_ptr route_relations_helper = RouteRelationsHelper::Create(num_nodes, tails, heads, literals, @@ -687,8 +695,10 @@ TEST(MinOutgoingFlowHelperTest, NodeMustBeInnerNode) { auto* repository = model.GetOrCreate(); for (int i = 0; i < num_arcs; ++i) { // loads[head] - loads[tail] >= demand[arc] - repository->Add(literals[i], {loads[heads[i]], 1}, {loads[tails[i]], -1}, - demands[i], 1000); + repository->Add( + literals[i], + LinearExpression2(loads[heads[i]], loads[tails[i]], 1, -1), + demands[i], 1000); } repository->Build(); @@ -745,8 +755,10 @@ TEST(MinOutgoingFlowHelperTest, BetterUseOfUpperBound) { auto* repository = model.GetOrCreate(); for (int i = 0; i < num_arcs; ++i) { // loads[head] - loads[tail] >= demand[arc] - repository->Add(literals[i], {loads[heads[i]], 1}, {loads[tails[i]], -1}, - demands[i], 1000); + repository->Add( + literals[i], + LinearExpression2::Difference(loads[heads[i]], loads[tails[i]]), + demands[i], 1000); } repository->Build(); const RoutingCumulExpressions cumuls = DetectDimensionsAndCumulExpressions( @@ -783,8 +795,9 @@ TEST(MinOutgoingFlowHelperTest, DimensionBasedMinOutgoingFlow_IsolatedNodes) { literals.push_back(Literal(model.Add(NewBooleanVariable()), true)); variables.push_back(model.Add(NewIntegerVariable(0, 100))); // Dummy relation, used only to associate a variable with each node. - repository->Add(literals.back(), {variables[head], 1}, {variables[0], -1}, - 1, 100); + repository->Add(literals.back(), + LinearExpression2(variables[head], variables[0], 1, -1), 1, + 100); } repository->Build(); const RoutingCumulExpressions cumuls = DetectDimensionsAndCumulExpressions( @@ -834,8 +847,8 @@ TEST(MinOutgoingFlowHelperTest, TimeWindows) { const auto& [tail, head] = arc; const int travel_time = 10 - tail; // times[head] - times[tail] >= travel_time - repository->Add(literal, {times[head], 1}, {times[tail], -1}, travel_time, - 1000); + repository->Add(literal, LinearExpression2(times[head], times[tail], 1, -1), + travel_time, 1000); } repository->Build(); // Subject under test. @@ -963,10 +976,14 @@ TEST(MinOutgoingFlowHelperTest, SubsetMightBeServedWithKRoutes) { const auto& [tail, head] = arc; // vars[head] >= vars[tail] + load[head]; - repository->Add(literal, {cumul_vars_1[head], 1}, {cumul_vars_1[tail], -1}, - load1[head], 10000); - repository->Add(literal, {cumul_vars_2[head], 1}, {cumul_vars_2[tail], -1}, - load2[head], 10000); + repository->Add( + literal, + LinearExpression2(cumul_vars_1[head], cumul_vars_1[tail], 1, -1), + load1[head], 10000); + repository->Add( + literal, + LinearExpression2(cumul_vars_2[head], cumul_vars_2[tail], 1, -1), + load2[head], 10000); } repository->Build(); @@ -1031,10 +1048,14 @@ TEST(MinOutgoingFlowHelperTest, SubsetMightBeServedWithKRoutesRandom) { const auto& [tail, head] = arc; // vars[head] >= vars[tail] + load[head]; - repository->Add(literal, {cumul_vars_1[head], 1}, {cumul_vars_1[tail], -1}, - load1[head], 10000); - repository->Add(literal, {cumul_vars_2[head], 1}, {cumul_vars_2[tail], -1}, - load2[head], 10000); + repository->Add( + literal, + LinearExpression2::Difference(cumul_vars_1[head], cumul_vars_1[tail]), + load1[head], 10000); + repository->Add( + literal, + LinearExpression2::Difference(cumul_vars_2[head], cumul_vars_2[tail]), + load2[head], 10000); } repository->Build(); @@ -1160,8 +1181,10 @@ TEST(MinOutgoingFlowHelperTest, const Literal literal = literals[arc]; // vars[head] >= vars[tail] + travel_times[arc]; - repository->Add(literal, {cumul_vars[head], 1}, {cumul_vars[tail], -1}, - travel_times[arc], 10000); + repository->Add( + literal, + LinearExpression2::Difference(cumul_vars[head], cumul_vars[tail]), + travel_times[arc], 10000); } repository->Build(); @@ -1387,14 +1410,16 @@ TEST(RouteRelationsHelperTest, Basic) { const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); const IntegerVariable z = model.Add(NewIntegerVariable(0, 10)); BinaryRelationRepository repository; - repository.Add(literals[0], {a, 1}, {b, -1}, 50, 1000); - repository.Add(literals[1], {a, 1}, {c, -1}, 70, 1000); - repository.Add(literals[2], {c, 1}, {b, -1}, 40, 1000); - repository.Add(literals[0], {NegationOf(u), -1}, {NegationOf(v), 1}, 4, 100); - repository.Add(literals[1], {u, 1}, {w, -1}, 4, 100); - repository.Add(literals[2], {w, -1}, {v, 1}, -100, -3); - repository.Add(literals[3], {x, 1}, {w, -1}, 5, 100); - repository.Add(literals[4], {z, 1}, {y, -1}, 7, 100); + repository.Add(literals[0], LinearExpression2::Difference(a, b), 50, 1000); + repository.Add(literals[1], LinearExpression2::Difference(a, c), 70, 1000); + repository.Add(literals[2], LinearExpression2::Difference(c, b), 40, 1000); + repository.Add(literals[0], + LinearExpression2(NegationOf(u), NegationOf(v), -1, 1), 4, + 100); + repository.Add(literals[1], LinearExpression2::Difference(u, w), 4, 100); + repository.Add(literals[2], LinearExpression2(w, v, -1, 1), -100, -3); + repository.Add(literals[3], LinearExpression2::Difference(x, w), 5, 100); + repository.Add(literals[4], LinearExpression2::Difference(z, y), 7, 100); repository.Build(); const RoutingCumulExpressions cumuls = DetectDimensionsAndCumulExpressions( @@ -1480,15 +1505,16 @@ TEST(RouteRelationsHelperTest, UnenforcedRelations) { const IntegerVariable c = model.Add(NewIntegerVariable(0, 100)); const IntegerVariable d = model.Add(NewIntegerVariable(0, 100)); BinaryRelationRepository repository; - repository.Add(literals[0], {b, 1}, {a, -1}, 1, 1); - repository.Add(literals[1], {c, 1}, {b, -1}, 2, 2); - repository.Add(literals[2], {d, 1}, {c, -1}, 3, 3); - repository.Add(literals[3], {a, 1}, {d, -1}, 4, 4); + RootLevelLinear2Bounds* bounds = model.GetOrCreate(); + repository.Add(literals[0], LinearExpression2::Difference(b, a), 1, 1); + repository.Add(literals[1], LinearExpression2::Difference(c, b), 2, 2); + repository.Add(literals[2], LinearExpression2::Difference(d, c), 3, 3); + repository.Add(literals[3], LinearExpression2::Difference(a, d), 4, 4); // Several unenforced relations on the diagonal arc. The one with the +/-1 // coefficients should be preferred. - repository.Add(Literal(kNoLiteralIndex), {c, 3}, {a, -2}, 1, 9); - repository.Add(Literal(kNoLiteralIndex), {c, 1}, {a, -1}, 5, 5); - repository.Add(Literal(kNoLiteralIndex), {c, 2}, {a, -3}, 3, 8); + bounds->Add(LinearExpression2(c, a, 3, -2), 1, 9); + bounds->Add(LinearExpression2(c, a, 1, -1), 5, 5); + bounds->Add(LinearExpression2(c, a, 2, -3), 3, 8); repository.Build(); const RoutingCumulExpressions cumuls = DetectDimensionsAndCumulExpressions( @@ -1529,13 +1555,13 @@ TEST(RouteRelationsHelperTest, SeveralVariablesPerNode) { const IntegerVariable y = model.Add(NewIntegerVariable(0, 10)); const IntegerVariable z = model.Add(NewIntegerVariable(0, 10)); BinaryRelationRepository repository; - repository.Add(literals[0], {b, 1}, {a, -1}, 50, 1000); - repository.Add(literals[1], {c, 1}, {b, -1}, 70, 1000); - repository.Add(literals[0], {z, 1}, {y, -1}, 5, 100); - repository.Add(literals[1], {y, 1}, {x, -1}, 7, 100); + repository.Add(literals[0], LinearExpression2::Difference(b, a), 50, 1000); + repository.Add(literals[1], LinearExpression2::Difference(c, b), 70, 1000); + repository.Add(literals[0], LinearExpression2::Difference(z, y), 5, 100); + repository.Add(literals[1], LinearExpression2::Difference(y, x), 7, 100); // Weird relation linking time and load variables, causing all the variables // to be in a single "dimension". - repository.Add(literals[0], {x, 1}, {a, -1}, 0, 100); + repository.Add(literals[0], LinearExpression2::Difference(x, a), 0, 100); repository.Build(); const RoutingCumulExpressions cumuls = DetectDimensionsAndCumulExpressions( @@ -1561,7 +1587,7 @@ TEST(RouteRelationsHelperTest, ComplexVariableRelations) { const IntegerVariable b = model.Add(NewIntegerVariable(0, 1)); BinaryRelationRepository repository; // "complex" relation with non +1/-1 coefficients. - repository.Add(literals[0], {b, 10}, {a, 1}, 0, 150); + repository.Add(literals[0], LinearExpression2(b, a, 10, 1), 0, 150); repository.Build(); const RoutingCumulExpressions cumuls = { @@ -1625,10 +1651,10 @@ TEST(RouteRelationsHelperTest, SeveralRelationsPerArc) { const IntegerVariable b = model.Add(NewIntegerVariable(0, 100)); const IntegerVariable c = model.Add(NewIntegerVariable(0, 100)); BinaryRelationRepository repository; - repository.Add(literals[0], {b, 1}, {a, -1}, 50, 1000); - repository.Add(literals[1], {c, 1}, {b, -1}, 70, 1000); + repository.Add(literals[0], LinearExpression2::Difference(b, a), 50, 1000); + repository.Add(literals[1], LinearExpression2::Difference(c, b), 70, 1000); // Add a second relation for some arc. - repository.Add(literals[1], {c, 2}, {b, -3}, 100, 200); + repository.Add(literals[1], LinearExpression2(c, b, 2, -3), 100, 200); repository.Build(); const RoutingCumulExpressions cumuls = DetectDimensionsAndCumulExpressions( @@ -1661,8 +1687,8 @@ TEST(RouteRelationsHelperTest, SeveralArcsPerLiteral) { const IntegerVariable b = model.Add(NewIntegerVariable(0, 100)); const IntegerVariable c = model.Add(NewIntegerVariable(0, 100)); BinaryRelationRepository repository; - repository.Add(literals[0], {b, 1}, {a, -1}, 50, 1000); - repository.Add(literals[0], {c, 1}, {b, -1}, 40, 1000); + repository.Add(literals[0], LinearExpression2::Difference(b, a), 50, 1000); + repository.Add(literals[0], LinearExpression2::Difference(c, b), 40, 1000); repository.Build(); const RoutingCumulExpressions cumuls = DetectDimensionsAndCumulExpressions( @@ -1703,13 +1729,13 @@ TEST(RouteRelationsHelperTest, InconsistentRelationIsSkipped) { const IntegerVariable e = model.Add(NewIntegerVariable(0, 100)); const IntegerVariable f = model.Add(NewIntegerVariable(0, 100)); BinaryRelationRepository repository; - repository.Add(literals[0], {b, 1}, {a, -1}, 0, 0); - repository.Add(literals[1], {c, 1}, {b, -1}, 1, 1); - repository.Add(literals[2], {d, 1}, {c, -1}, 2, 2); - repository.Add(literals[3], {e, 1}, {d, -1}, 3, 3); - repository.Add(literals[4], {f, 1}, {b, -1}, 4, 4); + repository.Add(literals[0], LinearExpression2::Difference(b, a), 0, 0); + repository.Add(literals[1], LinearExpression2::Difference(c, b), 1, 1); + repository.Add(literals[2], LinearExpression2::Difference(d, c), 2, 2); + repository.Add(literals[3], LinearExpression2::Difference(e, d), 3, 3); + repository.Add(literals[4], LinearExpression2::Difference(f, b), 4, 4); // Inconsistent relation for arc 5->3 (should be between f and d). - repository.Add(literals[5], {f, 2}, {b, -1}, 5, 5); + repository.Add(literals[5], LinearExpression2(f, b, 2, -1), 5, 5); repository.Build(); const RoutingCumulExpressions cumuls = DetectDimensionsAndCumulExpressions( @@ -1763,16 +1789,16 @@ TEST(RouteRelationsHelperTest, InconsistentRelationWithMultipleArcsPerLiteral) { const IntegerVariable d = model.Add(NewIntegerVariable(0, 100)); const IntegerVariable e = model.Add(NewIntegerVariable(0, 100)); BinaryRelationRepository repository; - repository.Add(literals[0], {b, 1}, {a, -1}, 0, 0); - repository.Add(literals[1], {c, 1}, {b, -1}, 1, 1); - repository.Add(literals[2], {d, 1}, {c, -1}, 2, 2); - repository.Add(literals[3], {a, 1}, {d, -1}, 3, 3); + repository.Add(literals[0], LinearExpression2::Difference(b, a), 0, 0); + repository.Add(literals[1], LinearExpression2::Difference(c, b), 1, 1); + repository.Add(literals[2], LinearExpression2::Difference(d, c), 2, 2); + repository.Add(literals[3], LinearExpression2::Difference(a, d), 3, 3); // Inconsistent relation for arc 4->1 (should be between e and b). Note that // arcs 4->1 and 4->3 are enforced by the same literal, thus both should // be true at the same time, hence the crossed bounds below. - repository.Add(literals[4], {e, 1}, {d, -1}, 4, 4); - repository.Add(literals[5], {e, 1}, {d, -1}, 5, 5); + repository.Add(literals[4], LinearExpression2::Difference(e, d), 4, 4); + repository.Add(literals[5], LinearExpression2::Difference(e, d), 5, 5); repository.Build(); const RoutingCumulExpressions cumuls = DetectDimensionsAndCumulExpressions( @@ -2406,7 +2432,8 @@ TEST(CreateCVRPCutGeneratorTest, InfeasiblePathCuts) { const int head = heads[i]; if (tail == 0 || head == 0) continue; // loads[head] >= loads[tail] + demand[tail] - repository->Add(literals[i], {loads[head], 1}, {loads[tail], -1}, + repository->Add(literals[i], + LinearExpression2(loads[head], loads[tail], 1, -1), demands[tail], 10000); } repository->Build(); diff --git a/ortools/sat/samples/assumptions_sample_sat.go b/ortools/sat/samples/assumptions_sample_sat.go index d4564dbd48..f2603241d6 100644 --- a/ortools/sat/samples/assumptions_sample_sat.go +++ b/ortools/sat/samples/assumptions_sample_sat.go @@ -19,6 +19,7 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) diff --git a/ortools/sat/samples/boolean_product_sample_sat.go b/ortools/sat/samples/boolean_product_sample_sat.go index ad93ecde04..cb0185be42 100644 --- a/ortools/sat/samples/boolean_product_sample_sat.go +++ b/ortools/sat/samples/boolean_product_sample_sat.go @@ -19,8 +19,9 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" - sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" "google.golang.org/protobuf/proto" + + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" ) func booleanProductSample() error { diff --git a/ortools/sat/samples/channeling_sample_sat.go b/ortools/sat/samples/channeling_sample_sat.go index 9ce0bfa0d4..a35c6f2337 100644 --- a/ortools/sat/samples/channeling_sample_sat.go +++ b/ortools/sat/samples/channeling_sample_sat.go @@ -19,9 +19,10 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + "google.golang.org/protobuf/proto" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" - "google.golang.org/protobuf/proto" ) func channelingSampleSat() error { diff --git a/ortools/sat/samples/earliness_tardiness_cost_sample_sat.go b/ortools/sat/samples/earliness_tardiness_cost_sample_sat.go index 651f72c849..cc7ca29fc1 100644 --- a/ortools/sat/samples/earliness_tardiness_cost_sample_sat.go +++ b/ortools/sat/samples/earliness_tardiness_cost_sample_sat.go @@ -20,9 +20,10 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + "google.golang.org/protobuf/proto" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" - "google.golang.org/protobuf/proto" ) const ( diff --git a/ortools/sat/samples/no_overlap_sample_sat.go b/ortools/sat/samples/no_overlap_sample_sat.go index 8b94881aec..e24e7775e8 100644 --- a/ortools/sat/samples/no_overlap_sample_sat.go +++ b/ortools/sat/samples/no_overlap_sample_sat.go @@ -19,6 +19,7 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) diff --git a/ortools/sat/samples/rabbits_and_pheasants_sat.go b/ortools/sat/samples/rabbits_and_pheasants_sat.go index 1a8cad267c..3dc6a51190 100644 --- a/ortools/sat/samples/rabbits_and_pheasants_sat.go +++ b/ortools/sat/samples/rabbits_and_pheasants_sat.go @@ -20,6 +20,7 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) diff --git a/ortools/sat/samples/ranking_sample_sat.go b/ortools/sat/samples/ranking_sample_sat.go index 779d477040..838d8c6e80 100644 --- a/ortools/sat/samples/ranking_sample_sat.go +++ b/ortools/sat/samples/ranking_sample_sat.go @@ -19,6 +19,7 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) diff --git a/ortools/sat/samples/search_for_all_solutions_sample_sat.go b/ortools/sat/samples/search_for_all_solutions_sample_sat.go index 31a1fbd98f..15ba8ec56d 100644 --- a/ortools/sat/samples/search_for_all_solutions_sample_sat.go +++ b/ortools/sat/samples/search_for_all_solutions_sample_sat.go @@ -20,8 +20,9 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" - sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" "google.golang.org/protobuf/proto" + + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" ) func searchForAllSolutionsSampleSat() error { diff --git a/ortools/sat/samples/simple_sat_program.go b/ortools/sat/samples/simple_sat_program.go index 47c151adcf..588ceed1ea 100644 --- a/ortools/sat/samples/simple_sat_program.go +++ b/ortools/sat/samples/simple_sat_program.go @@ -19,6 +19,7 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) diff --git a/ortools/sat/samples/solution_hinting_sample_sat.go b/ortools/sat/samples/solution_hinting_sample_sat.go index 1b7a31f4ee..8ad6434151 100644 --- a/ortools/sat/samples/solution_hinting_sample_sat.go +++ b/ortools/sat/samples/solution_hinting_sample_sat.go @@ -19,6 +19,7 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" ) diff --git a/ortools/sat/samples/solve_and_print_intermediate_solutions_sample_sat.go b/ortools/sat/samples/solve_and_print_intermediate_solutions_sample_sat.go index 573b6b2c91..46b85cb548 100644 --- a/ortools/sat/samples/solve_and_print_intermediate_solutions_sample_sat.go +++ b/ortools/sat/samples/solve_and_print_intermediate_solutions_sample_sat.go @@ -19,8 +19,9 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" - sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" "google.golang.org/protobuf/proto" + + sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" ) func solveAndPrintIntermediateSolutionsSampleSat() error { diff --git a/ortools/sat/samples/solve_with_time_limit_sample_sat.go b/ortools/sat/samples/solve_with_time_limit_sample_sat.go index a600391017..c7b89e8d51 100644 --- a/ortools/sat/samples/solve_with_time_limit_sample_sat.go +++ b/ortools/sat/samples/solve_with_time_limit_sample_sat.go @@ -19,9 +19,10 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + "google.golang.org/protobuf/proto" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" - "google.golang.org/protobuf/proto" ) func solveWithTimeLimitSampleSat() error { diff --git a/ortools/sat/samples/step_function_sample_sat.go b/ortools/sat/samples/step_function_sample_sat.go index 21b9e1f044..7fa569d7da 100644 --- a/ortools/sat/samples/step_function_sample_sat.go +++ b/ortools/sat/samples/step_function_sample_sat.go @@ -19,9 +19,10 @@ import ( log "github.com/golang/glog" "github.com/google/or-tools/ortools/sat/go/cpmodel" + "google.golang.org/protobuf/proto" + cmpb "github.com/google/or-tools/ortools/sat/proto/cpmodel" sppb "github.com/google/or-tools/ortools/sat/proto/satparameters" - "google.golang.org/protobuf/proto" ) func stepFunctionSampleSat() error { diff --git a/ortools/sat/sat_decision.cc b/ortools/sat/sat_decision.cc index deb38947e3..ba731aa626 100644 --- a/ortools/sat/sat_decision.cc +++ b/ortools/sat/sat_decision.cc @@ -14,12 +14,14 @@ #include "ortools/sat/sat_decision.h" #include +#include #include #include #include #include #include +#include "absl/algorithm/container.h" #include "absl/log/check.h" #include "absl/types/span.h" #include "ortools/base/logging.h" @@ -225,6 +227,27 @@ void SatDecisionPolicy::RandomizeCurrentPolarity() { } } +void SatDecisionPolicy::ResetActivitiesToFollowBestPartialAssignment() { + DCHECK_EQ(trail_.CurrentDecisionLevel(), 0); + CHECK(!activities_.empty()); + const double max_activity = + *absl::c_max_element(activities_) + variable_activity_increment_; + const double kDecay = 0.999; + variable_activity_increment_ = + max_activity / pow(kDecay, best_partial_assignment_.size() + 1); + var_ordering_is_initialized_ = false; + if (max_activity + variable_activity_increment_ > + parameters_.max_variable_activity_value()) { + RescaleVariableActivities(1 / parameters_.max_variable_activity_value()); + } + double weight = 1.0; + for (int i = 0; i < best_partial_assignment_.size(); ++i) { + const Literal l = best_partial_assignment_[i]; + weight *= kDecay; + activities_[l.Variable()] += weight * variable_activity_increment_; + } +} + void SatDecisionPolicy::InitializeVariableOrdering() { const int num_variables = activities_.size(); diff --git a/ortools/sat/sat_decision.h b/ortools/sat/sat_decision.h index acce6c1292..b00fd48882 100644 --- a/ortools/sat/sat_decision.h +++ b/ortools/sat/sat_decision.h @@ -107,13 +107,31 @@ class SatDecisionPolicy { } // Like SetAssignmentPreference() but it can be overridden by phase-saving. - void SetTargetPolarity(Literal l) { - var_polarity_[l.Variable()] = l.IsPositive(); + void SetTargetPolarityIfUnassigned(Literal l) { + if (trail_.Assignment().VariableIsAssigned(l.Variable())) return; + has_target_polarity_[l.Variable()] = true; + target_polarity_[l.Variable()] = var_polarity_[l.Variable()] = + l.IsPositive(); + best_partial_assignment_.push_back(l); + target_length_++; } absl::Span GetBestPartialAssignment() const { return best_partial_assignment_; } - void ClearBestPartialAssignment() { best_partial_assignment_.clear(); } + void ClearBestPartialAssignment() { + target_length_ = 0; + has_target_polarity_.assign(has_target_polarity_.size(), false); + best_partial_assignment_.clear(); + } + + // Increases activities of variables in the best partial assignment to ensure + // they are branched on first in the same order until the next conflict. + // Activities before this call are scaled to become disambiguation terms. + // Future conflicts will bump activity by the largest increase applied by this + // method. + // This acts as a soft-reset of the decision policy, useful when exploring a + // new region of the search space. + void ResetActivitiesToFollowBestPartialAssignment(); private: // Computes an initial variable ordering. diff --git a/ortools/sat/sat_decision_test.cc b/ortools/sat/sat_decision_test.cc index 104a4a566d..798178bfe9 100644 --- a/ortools/sat/sat_decision_test.cc +++ b/ortools/sat/sat_decision_test.cc @@ -95,6 +95,75 @@ TEST(SatDecisionPolicyTest, ErwaHeuristic) { EXPECT_EQ(Literal(BooleanVariable(2), true), decision->NextBranch()); } +TEST(SatDecisionPolicyTest, SetTargetPolarityInStablePhase) { + Model model; + Trail* trail = model.GetOrCreate(); + SatDecisionPolicy* decision = model.GetOrCreate(); + const int num_variables = 100; + trail->Resize(num_variables); + decision->IncreaseNumVariables(num_variables); + + for (int i = 0; i < num_variables; ++i) { + decision->SetTargetPolarityIfUnassigned(Literal(BooleanVariable(i), i % 2)); + } + + decision->SetStablePhase(true); + for (int i = 0; i < num_variables; ++i) { + const Literal literal = decision->NextBranch(); + EXPECT_EQ(literal, Literal(BooleanVariable(literal.Variable()), + literal.Variable().value() % 2)); + trail->EnqueueSearchDecision(literal); + } +} + +TEST(SatDecisionPolicyTest, SetTargetPolarity) { + Model model; + Trail* trail = model.GetOrCreate(); + SatDecisionPolicy* decision = model.GetOrCreate(); + const int num_variables = 100; + trail->Resize(num_variables); + decision->IncreaseNumVariables(num_variables); + + for (int i = 0; i < num_variables; ++i) { + decision->SetTargetPolarityIfUnassigned(Literal(BooleanVariable(i), i % 2)); + } + + decision->SetStablePhase(false); + for (int i = 0; i < num_variables; ++i) { + const Literal literal = decision->NextBranch(); + EXPECT_EQ(literal, Literal(BooleanVariable(literal.Variable()), + literal.Variable().value() % 2)); + trail->EnqueueSearchDecision(literal); + } +} + +TEST(SatDecisionPolicyTest, TestFollowBestPartialAssignment) { + Model model; + model.GetOrCreate()->set_initial_variables_activity(1e9); + Trail* trail = model.GetOrCreate(); + SatDecisionPolicy* decision = model.GetOrCreate(); + const int num_variables = 10; + trail->Resize(num_variables); + decision->IncreaseNumVariables(num_variables); + + for (int i = 0; i < num_variables; ++i) { + decision->SetTargetPolarityIfUnassigned(Literal(BooleanVariable(i), i % 2)); + } + for (int i = 0; i < num_variables - 1; ++i) { + // Bump all suffixes of the best partial assignment, so the last element has + // the highest activity. + decision->BumpVariableActivities( + decision->GetBestPartialAssignment().subspan(i)); + } + decision->ResetActivitiesToFollowBestPartialAssignment(); + + decision->SetStablePhase(false); + for (int i = 0; i < num_variables; ++i) { + const Literal literal = decision->NextBranch(); + EXPECT_EQ(literal, Literal(BooleanVariable(i), i % 2)); + trail->EnqueueSearchDecision(literal); + } +} } // namespace } // namespace sat } // namespace operations_research diff --git a/ortools/sat/sat_parameters.proto b/ortools/sat/sat_parameters.proto index fb7d23f541..f6502cdfdb 100644 --- a/ortools/sat/sat_parameters.proto +++ b/ortools/sat/sat_parameters.proto @@ -24,7 +24,7 @@ option java_multiple_files = true; // Contains the definitions for all the sat algorithm parameters and their // default values. // -// NEXT TAG: 325 +// NEXT TAG: 329 message SatParameters { // In some context, like in a portfolio of search, it makes sense to name a // given parameters set for logging purpose. @@ -703,6 +703,13 @@ message SatParameters { // Allows sharing of the bounds of modified variables at level 0. optional bool share_level_zero_bounds = 114 [default = true]; + // Allows sharing of the bounds on linear2 discovered at level 0. This is + // mainly interesting on scheduling type of problems when we branch on + // precedences. + // + // Warning: This currently non-deterministic. + optional bool share_linear2_bounds = 326 [default = false]; + // Allows sharing of new learned binary clause between workers. optional bool share_binary_clauses = 203 [default = true]; @@ -818,6 +825,11 @@ message SatParameters { // depending on the problem, turning this off may lead to a faster solution. optional bool use_precedences_in_disjunctive_constraint = 74 [default = true]; + // At root level, we might compute the transitive closure of "precedences" + // relations so that we can exploit that in scheduling problems. Setting this + // to zero disable the feature. + optional int32 transitive_precedences_work_limit = 327 [default = 1000000]; + // Create one literal for each disjunction of two pairs of tasks. This slows // down the solve time, but improves the lower bound of the objective in the // makespan case. This will be triggered if the number of intervals is less or @@ -1265,6 +1277,11 @@ message SatParameters { // SPLIT_STRATEGY_BALANCED_TREE and SPLIT_STRATEGY_DISCREPANCY. optional int32 shared_tree_balance_tolerance = 305 [default = 1]; + // How much dtime a worker will wait between proposing splits. + // This limits the contention in splitting the shared tree, and also reduces + // the number of too-easy subtrees that are generates. + optional double shared_tree_split_min_dtime = 328 [default = 0.1]; + // Whether we enumerate all solutions of a problem without objective. Note // that setting this to true automatically disable some presolve reduction // that can remove feasible solution. That is it has the same effect as @@ -1336,10 +1353,15 @@ message SatParameters { optional bool use_lns_only = 101 [default = false]; // Size of the top-n different solutions kept by the solver. - // This parameter must be > 0. - // Currently this only impact the "base" solution chosen for a LNS fragment. + // This parameter must be > 0. Currently, having this larger than one mainly + // impact the "base" solution chosen for a LNS/LS fragment. optional int32 solution_pool_size = 193 [default = 3]; + // In order to not get stuck in local optima, when this is non-zero, we try to + // also work on "older" solutions with a worse objective value so we get a + // chance to follow a different LS/LNS trajectory. + optional int32 alternative_pool_size = 325 [default = 1]; + // Turns on relaxation induced neighborhood generator. optional bool use_rins_lns = 129 [default = true]; diff --git a/ortools/sat/sat_solver.cc b/ortools/sat/sat_solver.cc index ef554832d4..c27cc650f5 100644 --- a/ortools/sat/sat_solver.cc +++ b/ortools/sat/sat_solver.cc @@ -1013,9 +1013,9 @@ SatSolver::Status SatSolver::EnqueueDecisionAndBacktrackOnConflict( bool SatSolver::EnqueueDecisionIfNotConflicting(Literal true_literal) { SCOPED_TIME_STAT(&stats_); + if (model_is_unsat_) return kUnsatTrailIndex; DCHECK(PropagationIsDone()); - if (model_is_unsat_) return kUnsatTrailIndex; const int current_level = CurrentDecisionLevel(); EnqueueNewDecision(true_literal); if (Propagate()) { diff --git a/ortools/sat/scheduling_cuts.cc b/ortools/sat/scheduling_cuts.cc index 066fc486d0..4c62279d7b 100644 --- a/ortools/sat/scheduling_cuts.cc +++ b/ortools/sat/scheduling_cuts.cc @@ -338,9 +338,9 @@ std::vector FindPossibleDemands(const EnergyEvent& event, void GenerateCumulativeEnergeticCutsWithMakespanAndFixedCapacity( absl::string_view cut_name, const util_intops::StrongVector& lp_values, - std::vector events, IntegerValue capacity, + absl::Span events, IntegerValue capacity, AffineExpression makespan, TimeLimit* time_limit, Model* model, - LinearConstraintManager* manager) { + TopNCuts& top_n_cuts) { // Checks the precondition of the code. IntegerTrail* integer_trail = model->GetOrCreate(); DCHECK(integer_trail->IsFixed(capacity)); @@ -408,7 +408,6 @@ void GenerateCumulativeEnergeticCutsWithMakespanAndFixedCapacity( const double makespan_lp = makespan.LpValue(lp_values); const double makespan_min_lp = ToDouble(makespan_min); LinearConstraintBuilder temp_builder(model); - TopNCuts top_n_cuts(5); for (int i = 0; i + 1 < num_time_points; ++i) { // Checks the time limit if the problem is too big. if (events.size() > 50 && time_limit->LimitReached()) return; @@ -510,15 +509,13 @@ void GenerateCumulativeEnergeticCutsWithMakespanAndFixedCapacity( } } } - - top_n_cuts.TransferToManager(manager); } void GenerateCumulativeEnergeticCuts( absl::string_view cut_name, const util_intops::StrongVector& lp_values, - std::vector events, const AffineExpression& capacity, - TimeLimit* time_limit, Model* model, LinearConstraintManager* manager) { + absl::Span events, const AffineExpression& capacity, + TimeLimit* time_limit, Model* model, TopNCuts& top_n_cuts) { double max_possible_energy_lp = 0.0; for (const EnergyEvent& event : events) { max_possible_energy_lp += event.linearized_energy_lp_value; @@ -549,7 +546,6 @@ void GenerateCumulativeEnergeticCuts( const int num_time_points = time_points.size(); LinearConstraintBuilder temp_builder(model); - TopNCuts top_n_cuts(5); for (int i = 0; i + 1 < num_time_points; ++i) { // Checks the time limit if the problem is too big. if (events.size() > 50 && time_limit->LimitReached()) return; @@ -602,8 +598,6 @@ void GenerateCumulativeEnergeticCuts( } } } - - top_n_cuts.TransferToManager(manager); } CutGenerator CreateCumulativeEnergyCutGenerator( @@ -664,16 +658,24 @@ CutGenerator CreateCumulativeEnergyCutGenerator( events.push_back(e); } - if (makespan.has_value() && integer_trail->IsFixed(capacity)) { - GenerateCumulativeEnergeticCutsWithMakespanAndFixedCapacity( - "CumulativeEnergyM", lp_values, events, - integer_trail->FixedValue(capacity), makespan.value(), time_limit, - model, manager); + TopNCuts top_n_cuts(5); + std::vector> disjoint_events = + SplitEventsInIndendentSets(absl::MakeSpan(events)); + // Can we pass cluster as const. It would mean sorting before. + for (const absl::Span cluster : disjoint_events) { + if (makespan.has_value() && integer_trail->IsFixed(capacity)) { + GenerateCumulativeEnergeticCutsWithMakespanAndFixedCapacity( + "CumulativeEnergyM", lp_values, cluster, + integer_trail->FixedValue(capacity), makespan.value(), time_limit, + model, top_n_cuts); - } else { - GenerateCumulativeEnergeticCuts("CumulativeEnergy", lp_values, events, - capacity, time_limit, model, manager); + } else { + GenerateCumulativeEnergeticCuts("CumulativeEnergy", lp_values, cluster, + capacity, time_limit, model, + top_n_cuts); + } } + top_n_cuts.TransferToManager(manager); return true; }; @@ -716,16 +718,22 @@ CutGenerator CreateNoOverlapEnergyCutGenerator( events.push_back(e); } - if (makespan.has_value()) { - GenerateCumulativeEnergeticCutsWithMakespanAndFixedCapacity( - "NoOverlapEnergyM", lp_values, events, - /*capacity=*/IntegerValue(1), makespan.value(), time_limit, model, - manager); - } else { - GenerateCumulativeEnergeticCuts("NoOverlapEnergy", lp_values, events, - /*capacity=*/IntegerValue(1), time_limit, - model, manager); + TopNCuts top_n_cuts(5); + std::vector> disjoint_events = + SplitEventsInIndendentSets(absl::MakeSpan(events)); + for (const absl::Span cluster : disjoint_events) { + if (makespan.has_value()) { + GenerateCumulativeEnergeticCutsWithMakespanAndFixedCapacity( + "NoOverlapEnergyM", lp_values, cluster, + /*capacity=*/IntegerValue(1), makespan.value(), time_limit, model, + top_n_cuts); + } else { + GenerateCumulativeEnergeticCuts("NoOverlapEnergy", lp_values, cluster, + /*capacity=*/IntegerValue(1), + time_limit, model, top_n_cuts); + } } + top_n_cuts.TransferToManager(manager); return true; }; return result; @@ -889,9 +897,8 @@ struct CachedIntervalData { void GenerateCutsBetweenPairOfNonOverlappingTasks( absl::string_view cut_name, bool ignore_zero_size_intervals, const util_intops::StrongVector& lp_values, - std::vector events, IntegerValue capacity_max, - Model* model, LinearConstraintManager* manager) { - TopNCuts top_n_cuts(5); + absl::Span events, IntegerValue capacity_max, + Model* model, TopNCuts& top_n_cuts) { const int num_events = events.size(); if (num_events <= 1) return; @@ -984,8 +991,6 @@ void GenerateCutsBetweenPairOfNonOverlappingTasks( } } } - - top_n_cuts.TransferToManager(manager); } CutGenerator CreateCumulativePrecedenceCutGenerator( @@ -1014,9 +1019,16 @@ CutGenerator CreateCumulativePrecedenceCutGenerator( } const IntegerValue capacity_max = integer_trail->UpperBound(capacity); - GenerateCutsBetweenPairOfNonOverlappingTasks( - "Cumulative", /* ignore_zero_size_intervals= */ true, - manager->LpValues(), std::move(events), capacity_max, model, manager); + + TopNCuts top_n_cuts(5); + std::vector> disjoint_events = + SplitEventsInIndendentSets(absl::MakeSpan(events)); + for (const absl::Span cluster : disjoint_events) { + GenerateCutsBetweenPairOfNonOverlappingTasks( + "Cumulative", /* ignore_zero_size_intervals= */ true, + manager->LpValues(), cluster, capacity_max, model, top_n_cuts); + } + top_n_cuts.TransferToManager(manager); return true; }; return result; @@ -1042,10 +1054,15 @@ CutGenerator CreateNoOverlapPrecedenceCutGenerator( events.push_back(event); } - GenerateCutsBetweenPairOfNonOverlappingTasks( - "NoOverlap", /* ignore_zero_size_intervals= */ false, - manager->LpValues(), std::move(events), IntegerValue(1), model, - manager); + TopNCuts top_n_cuts(5); + std::vector> disjoint_events = + SplitEventsInIndendentSets(absl::MakeSpan(events)); + for (const absl::Span cluster : disjoint_events) { + GenerateCutsBetweenPairOfNonOverlappingTasks( + "NoOverlap", /* ignore_zero_size_intervals= */ false, + manager->LpValues(), cluster, IntegerValue(1), model, top_n_cuts); + } + top_n_cuts.TransferToManager(manager); return true; }; @@ -1104,21 +1121,32 @@ void CtExhaustiveHelper::Init( const absl::Span events, Model* model) { max_task_index_ = 0; if (events.empty()) return; + // We compute the max_task_index_ from the events early to avoid sorting // the events if there are too many of them. for (const auto& event : events) { max_task_index_ = std::max(max_task_index_, event.task_index); } + BuildPredecessors(events, model); + VLOG(2) << "num_tasks:" << max_task_index_ + 1 + << " num_precedences:" << predecessors_.num_entries() + << " predecessors size:" << predecessors_.size(); +} + +void CtExhaustiveHelper::BuildPredecessors( + const absl::Span events, Model* model) { + predecessors_.clear(); if (events.size() > 100) return; - BinaryRelationsMaps* binary_relations = - model->GetOrCreate(); + ReifiedLinear2Bounds* binary_relations = + model->GetOrCreate(); std::vector sorted_events(events.begin(), events.end()); std::sort(sorted_events.begin(), sorted_events.end(), [](const CompletionTimeEvent& a, const CompletionTimeEvent& b) { return a.task_index < b.task_index; }); + predecessors_.reserve(max_task_index_ + 1); for (const auto& e1 : sorted_events) { for (const auto& e2 : sorted_events) { @@ -1130,9 +1158,6 @@ void CtExhaustiveHelper::Init( } } } - VLOG(2) << "num_tasks:" << max_task_index_ + 1 - << " num_precedences:" << predecessors_.num_entries() - << " predecessors size:" << predecessors_.size(); } bool CtExhaustiveHelper::PermutationIsCompatibleWithPrecedences( @@ -1391,11 +1416,10 @@ CompletionTimeExplorationStatus ComputeMinSumOfWeightedEndMins( // - detect disjoint tasks (no need to crossover to the second part) // - better caching of explored states ABSL_MUST_USE_RESULT bool GenerateShortCompletionTimeCutsWithExactBound( - absl::string_view cut_name, std::vector events, - IntegerValue capacity_max, CtExhaustiveHelper& helper, Model* model, - LinearConstraintManager* manager) { - TopNCuts top_n_cuts(5); - + absl::string_view cut_name, + const util_intops::StrongVector& lp_values, + absl::Span events, IntegerValue capacity_max, + CtExhaustiveHelper& helper, Model* model, TopNCuts& top_n_cuts) { // Sort by start min to bucketize by start_min. std::sort( events.begin(), events.end(), @@ -1488,7 +1512,7 @@ ABSL_MUST_USE_RESULT bool GenerateShortCompletionTimeCutsWithExactBound( std::string full_name(cut_name); if (cut_use_precedences) full_name.append("_prec"); if (is_lifted) full_name.append("_lifted"); - top_n_cuts.AddCut(cut.Build(), full_name, manager->LpValues()); + top_n_cuts.AddCut(cut.Build(), full_name, lp_values); } // Weighted cuts. @@ -1506,11 +1530,10 @@ ABSL_MUST_USE_RESULT bool GenerateShortCompletionTimeCutsWithExactBound( if (is_lifted) full_name.append("_lifted"); if (cut_use_precedences) full_name.append("_prec"); full_name.append("_weighted"); - top_n_cuts.AddCut(cut.Build(), full_name, manager->LpValues()); + top_n_cuts.AddCut(cut.Build(), full_name, lp_values); } } } - top_n_cuts.TransferToManager(manager); return true; } @@ -1627,9 +1650,10 @@ void AddEventDemandsToCapacitySubsetSum( // - second loop, we add tasks that must contribute after this start time // ordered by increasing end time in the LP relaxation. void GenerateCompletionTimeCutsWithEnergy( - absl::string_view cut_name, std::vector events, - IntegerValue capacity_max, Model* model, LinearConstraintManager* manager) { - TopNCuts top_n_cuts(5); + absl::string_view cut_name, + const util_intops::StrongVector& lp_values, + absl::Span events, IntegerValue capacity_max, + Model* model, TopNCuts& top_n_cuts) { const VariablesAssignment& assignment = model->GetOrCreate()->Assignment(); std::vector tmp_possible_demands; @@ -1780,10 +1804,9 @@ void GenerateCompletionTimeCutsWithEnergy( if (add_energy_to_name) full_name.append("_energy"); if (is_lifted) full_name.append("_lifted"); if (best_uses_subset_sum) full_name.append("_subsetsum"); - top_n_cuts.AddCut(cut.Build(), full_name, manager->LpValues()); + top_n_cuts.AddCut(cut.Build(), full_name, lp_values); } } - top_n_cuts.TransferToManager(manager); } CutGenerator CreateNoOverlapCompletionTimeCutGenerator( @@ -1817,15 +1840,21 @@ CutGenerator CreateNoOverlapCompletionTimeCutGenerator( CtExhaustiveHelper helper; helper.Init(events, model); - if (!GenerateShortCompletionTimeCutsWithExactBound( - "NoOverlapCompletionTimeExhaustive", events, - /*capacity_max=*/IntegerValue(1), helper, model, manager)) { - return false; - } + TopNCuts top_n_cuts(5); + std::vector> disjoint_events = + SplitEventsInIndendentSets(absl::MakeSpan(events)); + for (const absl::Span cluster : disjoint_events) { + if (!GenerateShortCompletionTimeCutsWithExactBound( + "NoOverlapCompletionTimeExhaustive", lp_values, cluster, + /*capacity_max=*/IntegerValue(1), helper, model, top_n_cuts)) { + return false; + } - GenerateCompletionTimeCutsWithEnergy( - "NoOverlapCompletionTimeQueyrane", std::move(events), - /*capacity_max=*/IntegerValue(1), model, manager); + GenerateCompletionTimeCutsWithEnergy( + "NoOverlapCompletionTimeQueyrane", lp_values, cluster, + /*capacity_max=*/IntegerValue(1), model, top_n_cuts); + } + top_n_cuts.TransferToManager(manager); return true; }; if (!generate_cuts(/*time_is_forward=*/true)) return false; @@ -1881,15 +1910,21 @@ CutGenerator CreateCumulativeCompletionTimeCutGenerator( helper.Init(events, model); const IntegerValue capacity_max = integer_trail->UpperBound(capacity); - if (!GenerateShortCompletionTimeCutsWithExactBound( - "CumulativeCompletionTimeExhaustive", events, capacity_max, - helper, model, manager)) { - return false; - } + TopNCuts top_n_cuts(5); + std::vector> disjoint_events = + SplitEventsInIndendentSets(absl::MakeSpan(events)); + for (const absl::Span cluster : disjoint_events) { + if (!GenerateShortCompletionTimeCutsWithExactBound( + "CumulativeCompletionTimeExhaustive", lp_values, cluster, + capacity_max, helper, model, top_n_cuts)) { + return false; + } - GenerateCompletionTimeCutsWithEnergy("CumulativeCompletionTimeQueyrane", - std::move(events), capacity_max, - model, manager); + GenerateCompletionTimeCutsWithEnergy("CumulativeCompletionTimeQueyrane", + lp_values, cluster, capacity_max, + model, top_n_cuts); + } + top_n_cuts.TransferToManager(manager); return true; }; diff --git a/ortools/sat/scheduling_cuts.h b/ortools/sat/scheduling_cuts.h index 920f5a23e6..8b493eefa3 100644 --- a/ortools/sat/scheduling_cuts.h +++ b/ortools/sat/scheduling_cuts.h @@ -174,6 +174,9 @@ class CtExhaustiveHelper { absl::Span permutation); private: + void BuildPredecessors(absl::Span events, + Model* model); + CompactVectorVector predecessors_; int max_task_index_ = 0; std::vector visited_; @@ -215,6 +218,37 @@ CompletionTimeExplorationStatus ComputeMinSumOfWeightedEndMins( double& min_sum_of_weighted_ends, bool& cut_use_precedences, int& exploration_credit); +// Split the list of events in connected components. Two intervals are connected +// if they overlap. It expects the events to have the start_min and end_max +// fields. Note that events are semi-open intervals [start_min, end_max). This +// will filter out components of size one. +template +std::vector> SplitEventsInIndendentSets(absl::Span events) { + if (events.empty()) return {}; + + std::sort(events.begin(), events.end(), [](const E& a, const E& b) { + return std::tie(a.start_min, a.end_max) < std::tie(b.start_min, b.end_max); + }); + const int size = events.size(); + std::vector> result; + IntegerValue max_end_max = events[0].end_max; + int start = 0; + for (int i = 1; i < size; ++i) { + const E& event = events[i]; + if (event.start_min >= max_end_max) { + if (i - start > 1) { + result.push_back(absl::MakeSpan(events.data() + start, i - start)); + } + start = i; + } + max_end_max = std::max(max_end_max, event.end_max); + } + if (size - start > 1) { + result.push_back(absl::MakeSpan(events.data() + start, size - start)); + } + return result; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/scheduling_cuts_test.cc b/ortools/sat/scheduling_cuts_test.cc index 543bafd019..5a51c9b535 100644 --- a/ortools/sat/scheduling_cuts_test.cc +++ b/ortools/sat/scheduling_cuts_test.cc @@ -15,8 +15,8 @@ #include +#include #include -#include #include #include "absl/base/log_severity.h" @@ -587,7 +587,7 @@ double ExactMakespan(absl::Span sizes, std::vector& demands, } builder.Minimize(obj); const CpSolverResponse response = - SolveWithParameters(builder.Build(), "num_search_workers:8"); + SolveWithParameters(builder.Build(), "num_workers:8"); EXPECT_EQ(response.status(), CpSolverStatus::OPTIMAL); return response.objective_value(); } @@ -657,6 +657,35 @@ TEST(ComputeMinSumOfEndMinsTest, RandomCases) { } } +struct SimpleEvent { + IntegerValue start_min; + IntegerValue end_max; + bool operator==(const SimpleEvent& other) const { + return start_min == other.start_min && end_max == other.end_max; + } +}; + +SimpleEvent ConvexHull(absl::Span events) { + SimpleEvent result = events[0]; + for (int i = 1; i < events.size(); ++i) { + result.start_min = std::min(result.start_min, events[i].start_min); + result.end_max = std::max(result.end_max, events[i].end_max); + } + return result; +} + +TEST(SplitEventsInIndendentSetsTest, BasicTest) { + std::vector events = {{0, 10}, {2, 12}, {3, 5}, + {15, 20}, {12, 21}, {30, 35}}; + const std::vector> sets = + SplitEventsInIndendentSets(absl::MakeSpan(events)); + EXPECT_EQ(sets.size(), 2); + EXPECT_EQ(sets[0].size(), 3); + EXPECT_EQ(ConvexHull(sets[0]), SimpleEvent({0, 12})); + EXPECT_EQ(sets[1].size(), 2); + EXPECT_EQ(ConvexHull(sets[1]), SimpleEvent({12, 21})); +} + } // namespace } // namespace sat } // namespace operations_research diff --git a/ortools/sat/scheduling_helpers.cc b/ortools/sat/scheduling_helpers.cc index 69d0bad256..9d25b17a7f 100644 --- a/ortools/sat/scheduling_helpers.cc +++ b/ortools/sat/scheduling_helpers.cc @@ -48,7 +48,8 @@ SchedulingConstraintHelper::SchedulingConstraintHelper( assignment_(sat_solver_->Assignment()), integer_trail_(model->GetOrCreate()), watcher_(model->GetOrCreate()), - precedence_relations_(model->GetOrCreate()), + linear2_bounds_(model->GetOrCreate()), + root_level_lin2_bounds_(model->GetOrCreate()), starts_(std::move(starts)), ends_(std::move(ends)), sizes_(std::move(sizes)), @@ -86,7 +87,8 @@ SchedulingConstraintHelper::SchedulingConstraintHelper(int num_tasks, sat_solver_(model->GetOrCreate()), assignment_(sat_solver_->Assignment()), integer_trail_(model->GetOrCreate()), - precedence_relations_(model->GetOrCreate()), + linear2_bounds_(model->GetOrCreate()), + root_level_lin2_bounds_(model->GetOrCreate()), capacity_(num_tasks), cached_size_min_(new IntegerValue[capacity_]), cached_start_min_(new IntegerValue[capacity_]), @@ -340,27 +342,16 @@ bool SchedulingConstraintHelper::SynchronizeAndSetTimeDirection( return true; } -// TODO(user): be more precise when we know a and b are in disjunction. -// we really just need start_b > start_a, or even >= if duration is non-zero. IntegerValue SchedulingConstraintHelper::GetCurrentMinDistanceBetweenTasks( - int a, int b, bool add_reason_if_after) { + int a, int b) { const AffineExpression before = ends_[a]; const AffineExpression after = starts_[b]; - LinearExpression2 expr(before.var, after.var, before.coeff, -after.coeff); - - // We take the min of the level zero (end_a - start_b) and the one coming from - // a conditional precedence at true. - const IntegerValue conditional_ub = precedence_relations_->UpperBound(expr); - const IntegerValue level_zero_ub = integer_trail_->LevelZeroUpperBound(expr); - const IntegerValue expr_ub = std::min(conditional_ub, level_zero_ub); - + const LinearExpression2 expr(before.var, after.var, before.coeff, + -after.coeff); + const IntegerValue expr_ub = linear2_bounds_->UpperBound(expr); const IntegerValue needed_offset = before.constant - after.constant; const IntegerValue ub_of_end_minus_start = expr_ub + needed_offset; const IntegerValue distance = -ub_of_end_minus_start; - if (add_reason_if_after && distance >= 0 && level_zero_ub > conditional_ub) { - precedence_relations_->AddReasonForUpperBoundLowerThan( - expr, conditional_ub, MutableLiteralReason(), MutableIntegerReason()); - } return distance; } @@ -368,41 +359,31 @@ IntegerValue SchedulingConstraintHelper::GetCurrentMinDistanceBetweenTasks( // associated to task a before task b. However we only call this for task that // are in detectable precedence, which means the normal precedence or linear // propagator should have already propagated that Boolean too. -bool SchedulingConstraintHelper::PropagatePrecedence(int a, int b) { +bool SchedulingConstraintHelper::NotifyLevelZeroPrecedence(int a, int b) { CHECK(IsPresent(a)); CHECK(IsPresent(b)); CHECK_EQ(sat_solver_->CurrentDecisionLevel(), 0); - const AffineExpression before = ends_[a]; - const AffineExpression after = starts_[b]; - if (after.coeff != 1) return true; - if (before.coeff != 1) return true; - if (after.var == kNoIntegerVariable) return true; - if (before.var == kNoIntegerVariable) return true; - if (before.var == after.var) { - if (before.constant <= after.constant) { - return true; - } else { + // Convert ends_[a] <= starts[b] to linear2 <= rhs and canonicalize. + const auto [expr, rhs] = EncodeDifferenceLowerThan(ends_[a], starts_[b], 0); + + // Trivial case. + if (expr.coeffs[0] == 0 && expr.coeffs[1] == 0) { + if (rhs < 0) { sat_solver_->NotifyThatModelIsUnsat(); return false; } + return true; } - const IntegerValue offset = before.constant - after.constant; - const LinearExpression2 expr = - LinearExpression2::Difference(before.var, after.var); - if (precedence_relations_->AddUpperBound(expr, -offset)) { + + if (root_level_lin2_bounds_->AddUpperBound(expr, rhs)) { VLOG(2) << "new relation " << TaskDebugString(a) << " <= " << TaskDebugString(b); - if (before.var == NegationOf(after.var)) { - AddWeightedSumLowerOrEqual({}, {before.var}, {int64_t{2}}, - -offset.value(), model_); - } else { - // TODO(user): Adding new constraint during propagation might not be the - // best idea as it can create some complication. - AddWeightedSumLowerOrEqual({}, {before.var, after.var}, - {int64_t{1}, int64_t{-1}}, -offset.value(), - model_); - } + // TODO(user): Adding new constraint during propagation might not be the + // best idea as it can create some complication. + AddWeightedSumLowerOrEqual({}, {expr.vars[0], expr.vars[1]}, + {expr.coeffs[0].value(), expr.coeffs[1].value()}, + rhs.value(), model_); if (sat_solver_->ModelIsUnsat()) return false; } return true; @@ -496,12 +477,27 @@ SchedulingConstraintHelper::GetEnergyProfile() { return energy_profile_; } -// Produces a relaxed reason for StartMax(before) < EndMin(after). -void SchedulingConstraintHelper::AddReasonForBeingBefore(int before, - int after) { +void SchedulingConstraintHelper::AddReasonForBeingBeforeAssumingNoOverlap( + int before, int after) { AddOtherReason(before); AddOtherReason(after); + // Prefer the linear2 explanation as it is more likely this comes from + // level zero or a single enforcement literal. + // We need Start(after) >= End(before) - SizeMin(before). + // we rewrite as "End(before) - Start(after) <= SizeMin(before). + const auto [expr, ub] = + EncodeDifferenceLowerThan(ends_[before], starts_[after], SizeMin(before)); + if (linear2_bounds_->UpperBound(expr) <= ub) { + AddSizeMinReason(before); + linear2_bounds_->AddReasonForUpperBoundLowerThan(expr, ub, &literal_reason_, + &integer_reason_); + return; + } + + // We will explain StartMax(before) < EndMin(after); + DCHECK_LT(StartMax(before), EndMin(after)); + // The reason will be a linear expression greater than a value. Note that all // coeff must be positive, and we will use the variable lower bound. std::vector vars; diff --git a/ortools/sat/scheduling_helpers.h b/ortools/sat/scheduling_helpers.h index def922023e..2d1daa3876 100644 --- a/ortools/sat/scheduling_helpers.h +++ b/ortools/sat/scheduling_helpers.h @@ -205,18 +205,22 @@ class SchedulingConstraintHelper : public PropagatorInterface { bool IsPresent(LiteralIndex lit) const; bool IsAbsent(LiteralIndex lit) const; - // Return a value so that End(a) + dist <= Start(b). - // Returns kMinInterValue if we don't have any such relation. - IntegerValue GetCurrentMinDistanceBetweenTasks( - int a, int b, bool add_reason_if_after = false); + // Returns a value so that End(a) + dist <= Start(b). + // + // TODO(user): we use this to optimize some reason, but ideally we only want + // to use linear2 bounds here, not bounds coming from trivial bounds. Make + // sure we have the best possible reason. + IntegerValue GetCurrentMinDistanceBetweenTasks(int a, int b); - // We detected a precedence between two tasks. - // If we are at level zero, we might want to add the constraint. - // If we are at positive level, we might want to propagate the associated - // precedence literal if it exists. - bool PropagatePrecedence(int a, int b); + // We detected a precedence between two tasks at level zero. + // This register a new constraint and notify the linear2 root level bounds + // repository. Returns false on conflict. + // + // TODO(user): We could also call this at positive decision level, but it is a + // bit harder to exploit as we will also need to store the reasons. + bool NotifyLevelZeroPrecedence(int a, int b); - // Return the minimum overlap of interval i with the time window [start..end]. + // Return the minimum overlap of task t with the time window [start..end]. // // Note: this is different from the mandatory part of an interval. IntegerValue GetMinOverlap(int t, IntegerValue start, IntegerValue end) const; @@ -273,9 +277,22 @@ class SchedulingConstraintHelper : public PropagatorInterface { void AddEnergyAfterReason(int t, IntegerValue energy_min, IntegerValue time); void AddEnergyMinInIntervalReason(int t, IntegerValue min, IntegerValue max); - // Adds the reason why task "before" must be before task "after". - // That is StartMax(before) < EndMin(after). - void AddReasonForBeingBefore(int before, int after); + // Adds the reason why the task "before" must be before task "after", in + // the sense that "after" can only start at the same time or later than the + // task "before" ends. + // + // Important: this assumes that the two task cannot overlap. So we can have + // a more relaxed reason than Start(after) >= Ends(before). + // + // There are actually many possibilities to explain such relation: + // - StartMax(before) < EndMin(after). + // - We have a linear2: Start(after) >= End(before) - SizeMin(before); + // - etc... + // We try to pick the best one. + // + // TODO(user): Refine the heuritic. Also consider other reason for the + // complex cases where Start() and End() do not use the same integer variable. + void AddReasonForBeingBeforeAssumingNoOverlap(int before, int after); // It is also possible to directly manipulates the underlying reason vectors // that will be used when pushing something. @@ -397,7 +414,8 @@ class SchedulingConstraintHelper : public PropagatorInterface { const VariablesAssignment& assignment_; IntegerTrail* integer_trail_; GenericLiteralWatcher* watcher_; - PrecedenceRelations* precedence_relations_; + Linear2Bounds* linear2_bounds_; + RootLevelLinear2Bounds* root_level_lin2_bounds_; // The current direction of time, true for forward, false for backward. bool current_time_direction_ = true; diff --git a/ortools/sat/shaving_solver.cc b/ortools/sat/shaving_solver.cc index cb35ef5677..488ce0228e 100644 --- a/ortools/sat/shaving_solver.cc +++ b/ortools/sat/shaving_solver.cc @@ -633,9 +633,9 @@ bool VariablesShavingSolver::ResetAndSolveModel(int64_t task_id, State* state, // Use the current best solution as hint. { - auto sols = shared_->response->SolutionsRepository().GetBestNSolutions(1); - if (!sols.empty()) { - const std::vector& solution = sols[0]->variable_values; + auto sol = shared_->response->SolutionPool().BestSolutions().GetSolution(0); + if (sol != nullptr) { + const std::vector& solution = sol->variable_values; auto* hint = shaving_proto->mutable_solution_hint(); hint->clear_vars(); hint->clear_values(); diff --git a/ortools/sat/solution_crush.cc b/ortools/sat/solution_crush.cc index aa2f2a955f..cad64f1252 100644 --- a/ortools/sat/solution_crush.cc +++ b/ortools/sat/solution_crush.cc @@ -281,6 +281,40 @@ void SolutionCrush::MaybeUpdateVarWithSymmetriesToValue( DCHECK_EQ(GetVarValue(var), value); } +void SolutionCrush::MaybeSwapOrbitopeColumns( + absl::Span> orbitope, int row, int pivot_col, + bool value) { + if (!solution_is_loaded_) return; + int col = -1; + for (int c = 0; c < orbitope[row].size(); ++c) { + if (GetLiteralValue(orbitope[row][c]) == value) { + if (col != -1) { + VLOG(2) << "Multiple literals in row with given value"; + return; + } + col = c; + } + } + if (col < pivot_col) { + // Nothing to do. + return; + } + // Swap the value of the literals in column `col` with the value of the ones + // in column `pivot_col`, if they all have a value. + for (int i = 0; i < orbitope.size(); ++i) { + if (!HasValue(PositiveRef(orbitope[i][col]))) return; + if (!HasValue(PositiveRef(orbitope[i][pivot_col]))) return; + } + for (int i = 0; i < orbitope.size(); ++i) { + const int src_lit = orbitope[i][col]; + const int dst_lit = orbitope[i][pivot_col]; + const bool src_value = GetLiteralValue(src_lit); + const bool dst_value = GetLiteralValue(dst_lit); + SetLiteralValue(src_lit, dst_value); + SetLiteralValue(dst_lit, src_value); + } +} + void SolutionCrush::UpdateRefsWithDominance( int ref, int64_t min_value, int64_t max_value, absl::Span> dominating_refs) { diff --git a/ortools/sat/solution_crush.h b/ortools/sat/solution_crush.h index 34c10861cf..4412b34323 100644 --- a/ortools/sat/solution_crush.h +++ b/ortools/sat/solution_crush.h @@ -174,6 +174,13 @@ class SolutionCrush { int var, bool value, absl::Span> generators); + // If at most one literal in `orbitope[row]` is equal to `value`, and if this + // literal is in a column 'col' > `pivot_col`, swaps the value of all the + // literals in columns 'col' and `pivot_col` (if they all have a value). + // Otherwise does nothing. + void MaybeSwapOrbitopeColumns(absl::Span> orbitope, + int row, int pivot_col, bool value); + // Sets the value of the i-th variable in `vars` so that the given constraint // "dotproduct(coeffs, vars values) = rhs" is satisfied, if all the other // variables have a value. i is equal to `var_index` if set. Otherwise it is diff --git a/ortools/sat/synchronization.cc b/ortools/sat/synchronization.cc index a4272ca360..6b4344cc08 100644 --- a/ortools/sat/synchronization.cc +++ b/ortools/sat/synchronization.cc @@ -30,9 +30,6 @@ #include #include -#include "absl/hash/hash.h" -#include "absl/log/log.h" -#include "absl/time/time.h" #include "ortools/base/logging.h" #include "ortools/base/timer.h" #if !defined(__PORTABLE_PLATFORM__) @@ -40,11 +37,17 @@ #include "ortools/base/options.h" #endif // __PORTABLE_PLATFORM__ #include "absl/algorithm/container.h" +#include "absl/base/thread_annotations.h" #include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/flags/flag.h" +#include "absl/hash/hash.h" #include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/numeric/int128.h" +#include "absl/random/bit_gen_ref.h" +#include "absl/random/distributions.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" @@ -74,6 +77,144 @@ ABSL_FLAG(bool, cp_model_dump_tightened_models, false, namespace operations_research { namespace sat { +std::shared_ptr::Solution> +SharedSolutionPool::Add(SharedSolutionRepository::Solution solution) { + // Only add to the alternative path if it has the correct source id. + if (alternative_path_.num_solutions_to_keep() > 0 && + solution.source_id == alternative_path_.source_id()) { + alternative_path_.Add(solution); + if (solution.rank < best_solutions_.GetBestRank()) { + VLOG(2) << "ALTERNATIVE WIN !"; + } + } + + // For now we only return a solution if it was stored in best_solutions_. + return best_solutions_.Add(std::move(solution)); +} + +void SharedSolutionPool::Synchronize(absl::BitGenRef random) { + // Update the "seeds" for the aternative path. + if (alternative_path_.num_solutions_to_keep() > 0) { + absl::MutexLock mutex_lock(&mutex_); + + auto process_solution = + [this](const SharedSolutionRepository::Solution& solution) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + if (solution.variable_values.empty()) return; + if (solution.rank < min_rank_ || solution.rank > max_rank_) { + // Recompute buckets. + min_rank_ = std::min(min_rank_, solution.rank); + max_rank_ = std::max(max_rank_, solution.rank); + + // We want to store around 100 MB max. + int num_solutions = std::max( + 10, 100'000'000 / solution.variable_values.size()); + const int64_t range = max_rank_ - min_rank_ + 1; + if (num_solutions > range) { + num_solutions = range; + } + + // But if the number of variables is low, we do not want + // to use a lot of space/time just iterating over num_solutions. + // + // TODO(user): Rework the algo to be in + // O(num_different_solutions) rather than initializing the + // maximum amount right away. + num_solutions = std::min(num_solutions, 1'000); + + // Resize and recompute rank_. + // + // seeds_[i] should contains solution in [ranks_[i], + // rank_[i+1]). rank_[0] is always min_rank_. As long as we have + // room, we should have exactly one bucket per rank. + ranks_.resize(num_solutions); + seeds_.resize(num_solutions); + + int64_t offset = (max_rank_ - min_rank_ + 1) / num_solutions; + CHECK_GT(offset, 0); + for (int i = 0; i < num_solutions; ++i) { + ranks_[i] = min_rank_ + + static_cast(absl::int128(i) * + absl::int128(range) / + absl::int128(num_solutions)); + } + + // Move existing solutions to their new bucket. + int to_index = seeds_.size() - 1; + for (int i = seeds_.size(); --i >= 0;) { + if (seeds_[i] == nullptr) continue; + while (to_index >= 0 && ranks_[to_index] > seeds_[i]->rank) { + --to_index; + } + seeds_[to_index] = std::move(seeds_[i]); + } + } + + // rank[limit] is the first > solution.rank. + const int limit = std::upper_bound(ranks_.begin(), ranks_.end(), + solution.rank) - + ranks_.begin(); + CHECK_GT(limit, 0); + seeds_[limit - 1] = + std::make_shared::Solution>( + solution); + }; + + // All solution go through best_solutions_.Add(), so we only need + // to process these here. + best_solutions_.Synchronize(process_solution); + } else { + best_solutions_.Synchronize(); + } + alternative_path_.Synchronize(); + + // If we try to improve the alternate path without success, reset it + // from a random path_seeds_. + // + // TODO(user): find a way to generate random solution and update the seeds + // with them. Shall we do that in a continuous way or only when needed? + if (alternative_path_.num_solutions_to_keep() > 0) { + // Restart the alternative path ? + const int threshold = std::max( + 100, static_cast(std::sqrt(best_solutions_.num_queried()))); + if (alternative_path_.NumRecentlyNonImproving() > threshold) { + VLOG(2) << "Done. num_non_improving: " + << alternative_path_.NumRecentlyNonImproving() + << " achieved: " << alternative_path_.GetBestRank() << " / " + << best_solutions_.GetBestRank(); + alternative_path_.ClearSolutionsAndIncreaseSourceId(); + } + + // If we restarted, or we are at the beginning, pick a seed for the path. + if (alternative_path_.NumSolutions() == 0) { + absl::MutexLock mutex_lock(&mutex_); + + // Pick random bucket with bias. If the bucket is empty, we will scan + // "worse" bucket until we find a solution. We never pick bucket 0. + if (seeds_.size() > 1) { + // Note that LogUniform() is always inclusive. + // TODO(user): Shall we bias even more? + int index = 1 + absl::LogUniform(random, 0, seeds_.size() - 2); + for (; index < seeds_.size(); ++index) { + if (seeds_[index] != nullptr) { + alternative_path_.Add(*seeds_[index]); + alternative_path_.Synchronize(); + VLOG(2) << "RESTART bucket=" << index << "/" << seeds_.size() + << " rank=" << alternative_path_.GetSolution(0)->rank + << " from_optimal=" + << alternative_path_.GetSolution(0)->rank - min_rank_; + break; + } + } + + // The last bucket should never be empty. + CHECK(seeds_.back() != nullptr); + CHECK_LT(index, seeds_.size()); + } + } + } +} + void SharedLPSolutionRepository::NewLPSolution( std::vector lp_solution) { if (lp_solution.empty()) return; @@ -119,7 +260,8 @@ SharedResponseManager::SharedResponseManager(Model* model) : parameters_(*model->GetOrCreate()), wall_timer_(*model->GetOrCreate()), shared_time_limit_(model->GetOrCreate()), - solutions_(parameters_.solution_pool_size(), "feasible solutions"), + random_(model->GetOrCreate()), + solution_pool_(parameters_), logger_(model->GetOrCreate()) { bounds_logging_id_ = logger_->GetNewThrottledId(); } @@ -138,9 +280,9 @@ std::string ProgressMessage(absl::string_view event_or_solution_count, obj_next, solution_info); } -std::string SatProgressMessage(const std::string& event_or_solution_count, +std::string SatProgressMessage(absl::string_view event_or_solution_count, double time_in_seconds, - const std::string& solution_info) { + absl::string_view solution_info) { return absl::StrFormat("#%-5s %6.2fs %s", event_or_solution_count, time_in_seconds, solution_info); } @@ -397,13 +539,15 @@ IntegerValue SharedResponseManager::GetInnerObjectiveUpperBound() { } void SharedResponseManager::Synchronize() { + solution_pool_.Synchronize(*random_); + absl::MutexLock mutex_lock(&mutex_); synchronized_inner_objective_lower_bound_ = IntegerValue(inner_objective_lower_bound_); synchronized_inner_objective_upper_bound_ = IntegerValue(inner_objective_upper_bound_); synchronized_best_status_ = best_status_; - if (solutions_.NumSolutions() > 0) { + if (solution_pool_.BestSolutions().NumSolutions() > 0) { first_solution_solvers_should_stop_ = true; } logger_->FlushPendingThrottledLogs(); @@ -502,7 +646,7 @@ void SharedResponseManager::UnregisterBestBoundCallback(int callback_id) { CpSolverResponse SharedResponseManager::GetResponseInternal( absl::Span variable_values, - const std::string& solution_info) { + absl::string_view solution_info) { CpSolverResponse result; result.set_status(best_status_); if (!unsat_cores_.empty()) { @@ -551,19 +695,19 @@ CpSolverResponse SharedResponseManager::GetResponseInternal( CpSolverResponse SharedResponseManager::GetResponse() { absl::MutexLock mutex_lock(&mutex_); CpSolverResponse result; - if (solutions_.NumSolutions() == 0) { + if (solution_pool_.BestSolutions().NumSolutions() == 0) { result = GetResponseInternal({}, ""); } else { std::shared_ptr::Solution> - solution = solutions_.GetSolution(0); + solution = solution_pool_.BestSolutions().GetSolution(0); result = GetResponseInternal(solution->variable_values, solution->info); } // If this is true, we postsolve and copy all of our solutions. if (parameters_.fill_additional_solutions_in_response()) { std::vector temp; - for (int i = 0; i < solutions_.NumSolutions(); ++i) { - std::shared_ptr::Solution> - solution = solutions_.GetSolution(i); + const int size = solution_pool_.BestSolutions().NumSolutions(); + for (int i = 0; i < size; ++i) { + const auto solution = solution_pool_.BestSolutions().GetSolution(i); temp = solution->variable_values; for (int i = solution_postprocessors_.size(); --i >= 0;) { solution_postprocessors_[i](&temp); @@ -622,8 +766,8 @@ void SharedResponseManager::FillObjectiveValuesInResponse( std::shared_ptr::Solution> SharedResponseManager::NewSolution(absl::Span solution_values, - const std::string& solution_info, - Model* model) { + absl::string_view solution_info, + Model* model, int source_id) { absl::MutexLock mutex_lock(&mutex_); std::shared_ptr::Solution> ret; @@ -634,7 +778,8 @@ SharedResponseManager::NewSolution(absl::Span solution_values, solution.variable_values.assign(solution_values.begin(), solution_values.end()); solution.info = solution_info; - ret = solutions_.Add(solution); + solution.source_id = source_id; + ret = solution_pool_.Add(solution); } else { const int64_t objective_value = ComputeInnerObjective(*objective_or_null_, solution_values); @@ -645,7 +790,8 @@ SharedResponseManager::NewSolution(absl::Span solution_values, solution_values.end()); solution.rank = objective_value; solution.info = solution_info; - ret = solutions_.Add(solution); + solution.source_id = source_id; + ret = solution_pool_.Add(solution); // Ignore any non-strictly improving solution. if (objective_value > inner_objective_upper_bound_) return ret; @@ -666,7 +812,7 @@ SharedResponseManager::NewSolution(absl::Span solution_values, // In single thread, no one is synchronizing the solution manager, so we // should do it from here. if (always_synchronize_) { - solutions_.Synchronize(); + solution_pool_.Synchronize(*random_); first_solution_solvers_should_stop_ = true; } @@ -703,7 +849,7 @@ SharedResponseManager::NewSolution(absl::Span solution_values, } if (logger_->LoggingIsEnabled()) { - std::string solution_message = solution_info; + std::string solution_message(solution_info); if (tmp_postsolved_response.num_booleans() > 0) { absl::StrAppend(&solution_message, " (fixed_bools=", tmp_postsolved_response.num_fixed_booleans(), "/", @@ -1240,14 +1386,27 @@ int UniqueClauseStream::NumLiteralsOfSize(int size) const { SharedClausesManager::SharedClausesManager(bool always_synchronize) : always_synchronize_(always_synchronize) {} -int SharedClausesManager::RegisterNewId(bool may_terminate_early) { +int SharedClausesManager::RegisterNewId(absl::string_view worker_name, + bool may_terminate_early) { absl::MutexLock mutex_lock(&mutex_); num_full_workers_ += may_terminate_early ? 0 : 1; const int id = id_to_last_processed_binary_clause_.size(); id_to_last_processed_binary_clause_.resize(id + 1, 0); id_to_last_returned_batch_.resize(id + 1, -1); id_to_last_finished_batch_.resize(id + 1, -1); - id_to_clauses_exported_.resize(id + 1, 0); + id_to_num_exported_.resize(id + 1, 0); + id_to_worker_name_.resize(id + 1); + id_to_worker_name_[id] = worker_name; + return id; +} + +int SharedLinear2Bounds::RegisterNewId(std::string worker_name) { + absl::MutexLock mutex_lock(&mutex_); + const int id = id_to_worker_name_.size(); + + id_to_stats_.resize(id + 1); + id_to_worker_name_.resize(id + 1); + id_to_worker_name_[id] = worker_name; return id; } @@ -1255,12 +1414,6 @@ bool SharedClausesManager::ShouldReadBatch(int reader_id, int writer_id) { return reader_id != writer_id; } -void SharedClausesManager::SetWorkerNameForId(int id, - absl::string_view worker_name) { - absl::MutexLock mutex_lock(&mutex_); - id_to_worker_name_[id] = worker_name; -} - void SharedClausesManager::AddBinaryClause(int id, int lit1, int lit2) { if (lit2 < lit1) std::swap(lit1, lit2); const auto p = std::make_pair(lit1, lit2); @@ -1270,7 +1423,7 @@ void SharedClausesManager::AddBinaryClause(int id, int lit1, int lit2) { if (inserted) { added_binary_clauses_.push_back(p); if (always_synchronize_) ++last_visible_binary_clause_; - id_to_clauses_exported_[id]++; + id_to_num_exported_[id]++; // Small optim. If the worker is already up to date with clauses to import, // we can mark this new clause as already seen. @@ -1283,7 +1436,7 @@ void SharedClausesManager::AddBinaryClause(int id, int lit1, int lit2) { void SharedClausesManager::AddBatch(int id, CompactVectorVector batch) { absl::MutexLock mutex_lock(&mutex_); - id_to_clauses_exported_[id] += batch.size(); + id_to_num_exported_[id] += batch.size(); pending_batches_.push_back(std::move(batch)); } @@ -1317,16 +1470,44 @@ void SharedClausesManager::GetUnseenBinaryClauses( void SharedClausesManager::LogStatistics(SolverLogger* logger) { absl::MutexLock mutex_lock(&mutex_); - absl::btree_map name_to_clauses; - for (int id = 0; id < id_to_clauses_exported_.size(); ++id) { - if (id_to_clauses_exported_[id] == 0) continue; - name_to_clauses[id_to_worker_name_[id]] = id_to_clauses_exported_[id]; + absl::btree_map name_to_table_line; + for (int id = 0; id < id_to_num_exported_.size(); ++id) { + if (id_to_num_exported_[id] == 0) continue; + name_to_table_line[id_to_worker_name_[id]] = id_to_num_exported_[id]; } - if (!name_to_clauses.empty()) { + if (!name_to_table_line.empty()) { std::vector> table; table.push_back({"Clauses shared", "Num"}); - for (const auto& entry : name_to_clauses) { - table.push_back({FormatName(entry.first), FormatCounter(entry.second)}); + for (const auto& [name, count] : name_to_table_line) { + table.push_back({FormatName(name), FormatCounter(count)}); + } + SOLVER_LOG(logger, FormatTable(table)); + } +} + +// TODO(user): Add some library to simplify this "transposition". Ideally we +// could merge small table with few columns. I am thinking list (row_name, +// col_name, count) + function that create table? +void SharedLinear2Bounds::LogStatistics(SolverLogger* logger) { + absl::MutexLock mutex_lock(&mutex_); + absl::btree_map name_to_table_line; + for (int id = 0; id < id_to_stats_.size(); ++id) { + const Stats stats = id_to_stats_[id]; + if (!stats.empty()) { + name_to_table_line[id_to_worker_name_[id]] = stats; + } + } + for (int import_id = 0; import_id < import_id_to_index_.size(); ++import_id) { + name_to_table_line[import_id_to_name_[import_id]].num_imported = + import_id_to_num_imported_[import_id]; + } + if (!name_to_table_line.empty()) { + std::vector> table; + table.push_back({"Linear2 shared", "New", "Updated", "Imported"}); + for (const auto& [name, stats] : name_to_table_line) { + table.push_back({FormatName(name), FormatCounter(stats.num_new), + FormatCounter(stats.num_update), + FormatCounter(stats.num_imported)}); } SOLVER_LOG(logger, FormatTable(table)); } @@ -1376,6 +1557,69 @@ void SharedClausesManager::Synchronize() { } } +void SharedLinear2Bounds::Add(int id, Key expr, IntegerValue lb, + IntegerValue ub) { + DCHECK(expr.IsCanonicalized()) << expr; + + absl::MutexLock mutex_lock(&mutex_); + auto [it, inserted] = shared_bounds_.insert({expr, {lb, ub}}); + if (inserted) { + // It is new. + id_to_stats_[id].num_new++; + newly_updated_keys_.push_back(expr); + } else { + // Update the individual bounds if the new ones are better. + auto& bounds = it->second; + const bool update_lb = lb > bounds.first; + if (update_lb) bounds.first = lb; + const bool update_ub = ub < bounds.second; + if (update_ub) bounds.second = ub; + if (update_lb || update_ub) { + id_to_stats_[id].num_update++; + newly_updated_keys_.push_back(expr); + } + } +} + +int SharedLinear2Bounds::RegisterNewImportId(std::string name) { + absl::MutexLock mutex_lock(&mutex_); + const int import_id = import_id_to_index_.size(); + import_id_to_name_.push_back(name); + import_id_to_index_.push_back(0); + import_id_to_num_imported_.push_back(0); + return import_id; +} + +std::vector< + std::pair>> +SharedLinear2Bounds::NewlyUpdatedBounds(int import_id) { + std::vector>> result; + + absl::MutexLock mutex_lock(&mutex_); + MaybeCompressNewlyUpdateKeys(); + const int size = newly_updated_keys_.size(); + for (int i = import_id_to_index_[import_id]; i < size; ++i) { + const auto& key = newly_updated_keys_[i]; + result.push_back({key, shared_bounds_[key]}); + } + import_id_to_index_[import_id] = size; + return result; +} + +void SharedLinear2Bounds::MaybeCompressNewlyUpdateKeys() { + int min_index = 0; + for (const int index : import_id_to_index_) { + min_index = std::min(index, min_index); + } + if (min_index == 0) return; + + newly_updated_keys_.erase(newly_updated_keys_.begin(), + newly_updated_keys_.begin() + min_index); + for (int& index_ref : import_id_to_index_) { + index_ref -= min_index; + } +} + void SharedStatistics::AddStats( absl::Span> stats) { absl::MutexLock mutex_lock(&mutex_); diff --git a/ortools/sat/synchronization.h b/ortools/sat/synchronization.h index c6eadff080..c6babb7d5f 100644 --- a/ortools/sat/synchronization.h +++ b/ortools/sat/synchronization.h @@ -61,8 +61,11 @@ template class SharedSolutionRepository { public: explicit SharedSolutionRepository(int num_solutions_to_keep, - absl::string_view name = "") - : name_(name), num_solutions_to_keep_(num_solutions_to_keep) {} + absl::string_view name = "", + int source_id = -1) + : name_(name), + num_solutions_to_keep_(num_solutions_to_keep), + source_id_(source_id) {} // The solution format used by this class. struct Solution { @@ -84,6 +87,8 @@ class SharedSolutionRepository { // Should be private: only SharedSolutionRepository should modify this. mutable int num_selected = 0; + int source_id; // Internal information. + bool operator==(const Solution& other) const { return rank == other.rank && variable_values == other.variable_values; } @@ -100,10 +105,11 @@ class SharedSolutionRepository { int NumSolutions() const; // Returns the solution #i where i must be smaller than NumSolutions(). + // Returns nullptr if i is out of range. std::shared_ptr GetSolution(int index) const; - // Returns the rank of the best known solution. - // You shouldn't call this if NumSolutions() is zero. + // Returns the rank of the best known solution. If there is no solution, this + // will return std::numeric_limits::max(). int64_t GetBestRank() const; std::vector> GetBestNSolutions(int n) const; @@ -131,7 +137,9 @@ class SharedSolutionRepository { // set of added solutions is the same. // // Works in O(num_solutions_to_keep_). - void Synchronize(); + // + // If f() is provided, it will be called on all new solutions. + void Synchronize(std::function f = nullptr); std::vector TableLineStats() const { absl::MutexLock mutex_lock(&mutex_); @@ -139,20 +147,52 @@ class SharedSolutionRepository { FormatCounter(num_queried_), FormatCounter(num_synchronization_)}; } + int64_t NumRecentlyNonImproving() const { + absl::MutexLock mutex_lock(&mutex_); + return num_non_improving_; + } + + void ClearSolutionsAndIncreaseSourceId() { + absl::MutexLock mutex_lock(&mutex_); + new_solutions_.clear(); + solutions_.clear(); + ++source_id_; + } + + int source_id() const { + absl::MutexLock mutex_lock(&mutex_); + return source_id_; + } + + int num_queried() const { + absl::MutexLock mutex_lock(&mutex_); + return num_queried_; + } + + int num_solutions_to_keep() const { return num_solutions_to_keep_; } + protected: const std::string name_; const int num_solutions_to_keep_; mutable absl::Mutex mutex_; + int source_id_ ABSL_GUARDED_BY(mutex_); int64_t num_added_ ABSL_GUARDED_BY(mutex_) = 0; mutable int64_t num_queried_ ABSL_GUARDED_BY(mutex_) = 0; int64_t num_synchronization_ ABSL_GUARDED_BY(mutex_) = 0; + mutable int64_t num_queried_at_last_sync_ ABSL_GUARDED_BY(mutex_) = 0; + mutable int64_t num_non_improving_ ABSL_GUARDED_BY(mutex_) = 0; + // Our two solutions pools, the current one and the new one that will be // merged into the current one on each Synchronize() calls. mutable std::vector tmp_indices_ ABSL_GUARDED_BY(mutex_); std::vector> solutions_ ABSL_GUARDED_BY(mutex_); std::vector> new_solutions_ ABSL_GUARDED_BY(mutex_); + + // For computing orthogonality. + std::vector ABSL_GUARDED_BY(mutex_) distances_; + std::vector ABSL_GUARDED_BY(mutex_) buffer_; }; // Solutions coming from the LP. @@ -165,6 +205,74 @@ class SharedLPSolutionRepository : public SharedSolutionRepository { void NewLPSolution(std::vector lp_solution); }; +// This stores all the feasible solutions the solver know about. +// Moreover, for meta-heuristics, we keep them in different buckets. +class SharedSolutionPool { + public: + explicit SharedSolutionPool(const SatParameters& parameters_) + : best_solutions_(parameters_.solution_pool_size(), "best_solutions"), + alternative_path_(parameters_.alternative_pool_size(), + "alternative_path", /*source_id=*/0) {} + + const SharedSolutionRepository& BestSolutions() const { + return best_solutions_; + } + + // Note that the given random generator is likely local to the thread calling + // this. + std::shared_ptr::Solution> + GetSolutionToImprove(absl::BitGenRef random) const { + // If we seems to have trouble making progress, work on the alternative + // path too. + if (alternative_path_.num_solutions_to_keep() > 0 && + best_solutions_.NumRecentlyNonImproving() > 100 && + absl::Bernoulli(random, 0.5) && alternative_path_.NumSolutions() > 0) { + // Tricky: We might clear the alternative_path_ between NumSolutions() + // and this call. + auto result = alternative_path_.GetRandomBiasedSolution(random); + if (result != nullptr) return result; + } + + if (best_solutions_.NumSolutions() > 0) { + return best_solutions_.GetRandomBiasedSolution(random); + } + return nullptr; + } + + std::shared_ptr::Solution> Add( + SharedSolutionRepository::Solution solution); + + void Synchronize(absl::BitGenRef random); + + void AddTableStats(std::vector>* table) const { + table->push_back(best_solutions_.TableLineStats()); + table->push_back(alternative_path_.TableLineStats()); + } + + private: + // Currently we only have two "pools" of solutions. + SharedSolutionRepository best_solutions_; + SharedSolutionRepository alternative_path_; + + // We also keep a list of possible "path seeds" in n buckets defined according + // to the objective value of the solution. These are updated on Synchronize(). + // Bucket i will only contain the last seen solution in the internal objective + // range [ranks_[i], ranks_[i + 1]). + // + // ranks_[0] should always be min_rank_, and seeds_[0] should be one of the + // best known solution. We usually never select seeds_[0] but keep it around + // for later in case new best solutions are found. + absl::Mutex mutex_; + int64_t max_rank_ ABSL_GUARDED_BY(mutex_) = + std::numeric_limits::min(); + int64_t min_rank_ ABSL_GUARDED_BY(mutex_) = + std::numeric_limits::max(); + std::vector ranks_; + std::vector< + std::shared_ptr::Solution>> + ABSL_GUARDED_BY(mutex_) seeds_; +}; + // Set of best solution from the feasibility jump workers. // // We store (solution, num_violated_constraints), so we have a list of solutions @@ -316,6 +424,13 @@ class SharedResponseManager { void Synchronize(); IntegerValue GetInnerObjectiveLowerBound(); IntegerValue GetInnerObjectiveUpperBound(); + IntegerValue GetBestSolutionObjective() { + if (solution_pool_.BestSolutions().NumSolutions() > 0) { + return solution_pool_.BestSolutions().GetBestRank(); + } else { + return GetInnerObjectiveUpperBound(); + } + } // Returns the current best solution inner objective value or kInt64Max if // there is no solution. @@ -361,7 +476,8 @@ class SharedResponseManager { // stored in the repository. std::shared_ptr::Solution> NewSolution(absl::Span solution_values, - const std::string& solution_info, Model* model = nullptr); + absl::string_view solution_info, Model* model = nullptr, + int source_id = -1); // Changes the solution to reflect the fact that the "improving" problem is // infeasible. This means that if we have a solution, we have proven @@ -380,14 +496,13 @@ class SharedResponseManager { // OPTIMAL and consider the problem solved. bool ProblemIsSolved() const; + bool HasFeasibleSolution() const { + return solution_pool_.BestSolutions().NumSolutions() > 0; + } + // Returns the underlying solution repository where we keep a set of best // solutions. - const SharedSolutionRepository& SolutionsRepository() const { - return solutions_; - } - SharedSolutionRepository* MutableSolutionsRepository() { - return &solutions_; - } + const SharedSolutionPool& SolutionPool() const { return solution_pool_; } // Debug only. Set dump prefix for solutions written to file. void set_dump_prefix(absl::string_view dump_prefix) { @@ -433,11 +548,12 @@ class SharedResponseManager { // Generates a response for callbacks and GetResponse(). CpSolverResponse GetResponseInternal( absl::Span variable_values, - const std::string& solution_info) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + absl::string_view solution_info) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); const SatParameters& parameters_; const WallTimer& wall_timer_; ModelSharedTimeLimit* shared_time_limit_; + ModelRandomGenerator* random_; CpObjectiveProto const* objective_or_null_ = nullptr; mutable absl::Mutex mutex_; @@ -450,7 +566,7 @@ class SharedResponseManager { CpSolverStatus synchronized_best_status_ ABSL_GUARDED_BY(mutex_) = CpSolverStatus::UNKNOWN; std::vector unsat_cores_ ABSL_GUARDED_BY(mutex_); - SharedSolutionRepository solutions_; // Thread-safe. + SharedSolutionPool solution_pool_; // Thread-safe. int num_solutions_ ABSL_GUARDED_BY(mutex_) = 0; int64_t inner_objective_lower_bound_ ABSL_GUARDED_BY(mutex_) = @@ -732,8 +848,7 @@ class SharedClausesManager { std::vector>* new_clauses); // Ids are used to identify which worker is exporting/importing clauses. - int RegisterNewId(bool may_terminate_early); - void SetWorkerNameForId(int id, absl::string_view worker_name); + int RegisterNewId(absl::string_view worker_name, bool may_terminate_early); // Search statistics. void LogStatistics(SolverLogger* logger); @@ -777,8 +892,106 @@ class SharedClausesManager { const bool always_synchronize_ = true; // Stats: - std::vector id_to_clauses_exported_; - absl::flat_hash_map id_to_worker_name_; + std::vector id_to_num_exported_ ABSL_GUARDED_BY(mutex_); + std::vector id_to_num_updated_ ABSL_GUARDED_BY(mutex_); + std::vector id_to_worker_name_ ABSL_GUARDED_BY(mutex_); +}; + +// A class that allows to exchange root level bounds on linear2. +// +// TODO(user): Add Synchronize() support and only publish new bounds when this +// is called. +class SharedLinear2Bounds { + public: + int RegisterNewId(std::string worker_name); + void LogStatistics(SolverLogger* logger); + + // This should only contain canonicalized expression. + // See the code for IsCanonicalized() for the definition. + struct Key { + int vars[2]; + IntegerValue coeffs[2]; + + bool IsCanonicalized() { + return vars[0] >= 0 && vars[1] >= 0 && vars[0] < vars[1] && + std::gcd(coeffs[0].value(), coeffs[1].value()) == 1; + } + + bool operator==(const Key& o) const { + return vars[0] == o.vars[0] && vars[1] == o.vars[1] && + coeffs[0] == o.coeffs[0] && coeffs[1] == o.coeffs[1]; + } + + template + friend H AbslHashValue(H h, const Key& k) { + return H::combine(std::move(h), k.vars[0], k.vars[1], k.coeffs[0], + k.coeffs[1]); + } + + template + friend void AbslStringify(Sink& sink, const Key& k) { + absl::Format(&sink, "%d X%d + %d X%d", k.coeffs[0].value(), k.vars[0], + k.coeffs[1].value(), k.vars[1]); + } + }; + + // Exports new bounds on the given expr (should be canonicalized). + void Add(int id, Key expr, IntegerValue lb, IntegerValue ub); + + // This is called less often, and maybe not every-worker that exports want to + // export, so we use a separate id space. Because we rely on hash map to + // check if a bound is new, it is not such a big deal that a worker re-read + // once the bounds it exported. + int RegisterNewImportId(std::string name); + + // Returns the linear2 and their bounds. + // We only return changes since the last call with the same id. + std::vector>> + NewlyUpdatedBounds(int import_id); + + // This is not filled by NewlyUpdatedBounds() because we want to track the + // bounds that were not already known by the worker at the time of the import, + // and we don't have this information here. + void NotifyNumImported(int import_id, int num) { + absl::MutexLock mutex_lock(&mutex_); + import_id_to_num_imported_[import_id] += num; + } + + private: + void MaybeCompressNewlyUpdateKeys() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + absl::Mutex mutex_; + + // The best known bounds for each key. + absl::flat_hash_map> shared_bounds_ + ABSL_GUARDED_BY(mutex_); + + // Ever growing list of updated position in shared_bounds_. + // Note that we do reduce it in MaybeCompressNewlyUpdateKeys(), but that + // requires all registered workers to have at least imported some bounds. + // + // TODO(user): use indirect addressing so that newly_updated_keys_ can just + // deal with indices, and it is a bit tighter memory wise? We also avoid + // hash-lookups on NewlyUpdatedBounds(). But since this is only called at + // level zero on new bounds, I don't think we care. + std::vector newly_updated_keys_; + + // For import. + std::vector import_id_to_name_ ABSL_GUARDED_BY(mutex_); + std::vector import_id_to_index_ ABSL_GUARDED_BY(mutex_); + std::vector import_id_to_num_imported_ ABSL_GUARDED_BY(mutex_); + + // Just for reporting at the end of the solve. + struct Stats { + int64_t num_new = 0; + int64_t num_update = 0; + int64_t num_imported = 0; // Copy of import_id_to_num_imported_. + bool empty() const { + return num_new == 0 && num_update == 0 && num_imported == 0; + } + }; + std::vector id_to_stats_ ABSL_GUARDED_BY(mutex_); + std::vector id_to_worker_name_ ABSL_GUARDED_BY(mutex_); }; // Simple class to add statistics by name and print them at the end. @@ -807,6 +1020,7 @@ template std::shared_ptr::Solution> SharedSolutionRepository::GetSolution(int i) const { absl::MutexLock mutex_lock(&mutex_); + if (i >= solutions_.size()) return nullptr; ++num_queried_; return solutions_[i]; } @@ -814,7 +1028,7 @@ SharedSolutionRepository::GetSolution(int i) const { template int64_t SharedSolutionRepository::GetBestRank() const { absl::MutexLock mutex_lock(&mutex_); - CHECK_GT(solutions_.size(), 0); + if (solutions_.empty()) return std::numeric_limits::max(); return solutions_[0]->rank; } @@ -823,11 +1037,12 @@ std::vector::Solution>> SharedSolutionRepository::GetBestNSolutions(int n) const { absl::MutexLock mutex_lock(&mutex_); - // Sorted and unique. - DCHECK(absl::c_is_sorted( - solutions_, - [](const std::shared_ptr& a, - const std::shared_ptr& b) { return *a < *b; })); + // Sorted by rank and unique. + DCHECK(absl::c_is_sorted(solutions_, + [](const std::shared_ptr& a, + const std::shared_ptr& b) { + return a->rank < b->rank; + })); DCHECK(absl::c_adjacent_find(solutions_, [](const std::shared_ptr& a, const std::shared_ptr& b) { @@ -855,34 +1070,41 @@ std::shared_ptr::Solution> SharedSolutionRepository::GetRandomBiasedSolution( absl::BitGenRef random) const { absl::MutexLock mutex_lock(&mutex_); + if (solutions_.empty()) return nullptr; ++num_queried_; - const int64_t best_rank = solutions_[0]->rank; + int index = 0; - // As long as we have solution with the best objective that haven't been - // explored too much, we select one uniformly. Otherwise, we select a solution - // from the pool uniformly. - // - // Note(user): Because of the increase of num_selected, this is dependent on - // the order of call. It should be fine for "determinism" because we do - // generate the task of a batch always in the same order. - const int kExplorationThreshold = 100; + if (solutions_.size() > 1) { + const int64_t best_rank = solutions_[0]->rank; - // Select all the best solution with a low enough selection count. - tmp_indices_.clear(); - for (int i = 0; i < solutions_.size(); ++i) { - std::shared_ptr solution = solutions_[i]; - if (solution->rank == best_rank && - solution->num_selected <= kExplorationThreshold) { - tmp_indices_.push_back(i); + // As long as we have solution with the best objective that haven't been + // explored too much, we select one uniformly. Otherwise, we select a + // solution from the pool uniformly. + // + // Note(user): Because of the increase of num_selected, this is dependent on + // the order of call. It should be fine for "determinism" because we do + // generate the task of a batch always in the same order. + const int kExplorationThreshold = 100; + + // Select all the best solution with a low enough selection count. + tmp_indices_.clear(); + for (int i = 0; i < solutions_.size(); ++i) { + std::shared_ptr solution = solutions_[i]; + if (solution->rank == best_rank && + solution->num_selected <= kExplorationThreshold) { + tmp_indices_.push_back(i); + } + } + + if (tmp_indices_.empty()) { + index = absl::Uniform(random, 0, solutions_.size()); + } else { + index = tmp_indices_[absl::Uniform(random, 0, tmp_indices_.size())]; } } - int index = 0; - if (tmp_indices_.empty()) { - index = absl::Uniform(random, 0, solutions_.size()); - } else { - index = tmp_indices_[absl::Uniform(random, 0, tmp_indices_.size())]; - } + CHECK_GE(index, 0); + CHECK_LT(index, solutions_.size()); solutions_[index]->num_selected++; return solutions_[index]; } @@ -896,38 +1118,147 @@ SharedSolutionRepository::Add(Solution solution) { { absl::MutexLock mutex_lock(&mutex_); ++num_added_; + solution_ptr->source_id = source_id_; new_solutions_.push_back(solution_ptr); } return solution_ptr; } template -void SharedSolutionRepository::Synchronize() { +void SharedSolutionRepository::Synchronize( + std::function f) { absl::MutexLock mutex_lock(&mutex_); - if (new_solutions_.empty()) return; + if (new_solutions_.empty()) { + const int64_t diff = num_queried_ - num_queried_at_last_sync_; + num_non_improving_ += diff; + num_queried_at_last_sync_ = num_queried_; + return; + } + + if (f != nullptr) { + gtl::STLStableSortAndRemoveDuplicates( + &new_solutions_, + [](const std::shared_ptr& a, + const std::shared_ptr& b) { return *a < *b; }); + for (const auto& ptr : new_solutions_) { + f(*ptr); + } + } + + const int64_t old_best_rank = solutions_.empty() + ? std::numeric_limits::max() + : solutions_[0]->rank; solutions_.insert(solutions_.end(), new_solutions_.begin(), new_solutions_.end()); new_solutions_.clear(); // We use a stable sort to keep the num_selected count for the already - // existing solutions. - // - // TODO(user): Introduce a notion of orthogonality to diversify the pool? + // existing solutions (in case of duplicates). gtl::STLStableSortAndRemoveDuplicates( &solutions_, [](const std::shared_ptr& a, const std::shared_ptr& b) { return *a < *b; }); + const int64_t new_best_rank = solutions_[0]->rank; + + // If we have more than num_solutions_to_keep_ solutions with the best rank, + // select them via orthogonality. + if (solutions_.size() > num_solutions_to_keep_ && + num_solutions_to_keep_ > 1) { + int num_best = 1; + while (num_best < solutions_.size() && + solutions_[num_best]->rank == new_best_rank) { + ++num_best; + } + + if (num_best > num_solutions_to_keep_ && num_solutions_to_keep_ < 10) { + // We should only be here if a new solution (not in our current set) was + // found. It could be one we saw before but forgot about. We put one + // first. + for (auto& solution : solutions_) { + if (solution->num_selected == 0) { + // TODO(user): randomize amongst new solution? + std::swap(solutions_[0], solution); + break; + } + } + + // We are going to be in O(n^2 * solution_size), so keep n <= 10. + solutions_.resize(std::min(10, num_best)); + + // Fill the pairwise distances. + const int n = solutions_.size(); + distances_.resize(n * n); + const int size = solutions_[0]->variable_values.size(); + for (int i = 0; i < n; ++i) { + for (int j = i + 1; j < n; ++j) { + int64_t dist = 0; + for (int k = 0; k < size; ++k) { + if (solutions_[i]->variable_values[k] != + solutions_[j]->variable_values[k]) { + ++dist; + } + } + distances_[i * n + j] = distances_[j * n + i] = dist; + } + } + + // In order to not get stuck on a subset that always maximize the sum of + // orthogonality, we pick the first element (which should be a new one + // thanks to the swap above), and we maximize the sum of orthogonality + // with the rest. + // + // This way, as we find new solution, the set changes slowly. + const std::vector selected = + FindMostDiverseSubset(num_solutions_to_keep_, n, distances_, buffer_, + /*always_pick_mask = */ 1); + + DCHECK(std::is_sorted(selected.begin(), selected.end())); + int new_size = 0; + for (const int s : selected) { + solutions_[new_size++] = std::move(solutions_[s]); + } + solutions_.resize(new_size); + + if (VLOG_IS_ON(3)) { + int min_count = std::numeric_limits::max(); + int max_count = 0; + for (const auto& s : solutions_) { + CHECK(s != nullptr); + min_count = std::min(s->num_selected, min_count); + max_count = std::max(s->num_selected, max_count); + } + int64_t score = 0; + for (const int i : selected) { + for (const int j : selected) { + if (i > j) score += distances_[i * n + j]; + } + } + LOG(INFO) << name_ << " rank=" << new_best_rank + << " num=" << num_solutions_to_keep_ << "/" << num_best + << " orthogonality=" << score << " count=[" << min_count + << ", " << max_count << "]"; + } + } + } + if (solutions_.size() > num_solutions_to_keep_) { solutions_.resize(num_solutions_to_keep_); } - + CHECK(!solutions_.empty()); if (!solutions_.empty()) { - VLOG(2) << "Solution pool update:" << " num_solutions=" << solutions_.size() + VLOG(4) << "Solution pool update:" << " num_solutions=" << solutions_.size() << " min_rank=" << solutions_[0]->rank << " max_rank=" << solutions_.back()->rank; } num_synchronization_++; + if (new_best_rank < old_best_rank) { + num_non_improving_ = 0; + } else { + const int64_t diff = num_queried_ - num_queried_at_last_sync_; + num_non_improving_ += diff; + } + num_queried_at_last_sync_ = num_queried_; } } // namespace sat diff --git a/ortools/sat/synchronization_test.cc b/ortools/sat/synchronization_test.cc index 00dd4a2550..1ab19d6cbc 100644 --- a/ortools/sat/synchronization_test.cc +++ b/ortools/sat/synchronization_test.cc @@ -834,8 +834,8 @@ TEST(SharedResponseManagerTest, Callback) { TEST(SharedClausesManagerTest, SyncApi) { SharedClausesManager manager(/*always_synchronize=*/true); - EXPECT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false)); - EXPECT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false)); + EXPECT_EQ(0, manager.RegisterNewId("", /*may_terminate_early=*/false)); + EXPECT_EQ(1, manager.RegisterNewId("", /*may_terminate_early=*/false)); manager.AddBinaryClause(/*id=*/0, 1, 2); std::vector> new_clauses; @@ -922,8 +922,8 @@ TEST(UniqueClauseStreamTest, DropsClauses) { TEST(SharedClausesManagerTest, NonSyncApi) { SharedClausesManager manager(/*always_synchronize=*/false); - EXPECT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false)); - EXPECT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false)); + EXPECT_EQ(0, manager.RegisterNewId("", /*may_terminate_early=*/false)); + EXPECT_EQ(1, manager.RegisterNewId("", /*may_terminate_early=*/false)); manager.AddBinaryClause(/*id=*/0, 1, 2); std::vector> new_clauses; @@ -971,8 +971,8 @@ TEST(SharedClausesManagerTest, NonSyncApi) { TEST(SharedClausesManagerTest, ShareGlueClauses) { SharedClausesManager manager(/*always_synchronize=*/true); - ASSERT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false)); - ASSERT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false)); + ASSERT_EQ(0, manager.RegisterNewId("", /*may_terminate_early=*/false)); + ASSERT_EQ(1, manager.RegisterNewId("", /*may_terminate_early=*/false)); UniqueClauseStream stream0; UniqueClauseStream stream1; // Add a bunch of clauses that will be skipped batch. @@ -999,8 +999,8 @@ TEST(SharedClausesManagerTest, ShareGlueClauses) { TEST(SharedClausesManagerTest, LbdThresholdIncrease) { SharedClausesManager manager(/*always_synchronize=*/true); - ASSERT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false)); - ASSERT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false)); + ASSERT_EQ(0, manager.RegisterNewId("", /*may_terminate_early=*/false)); + ASSERT_EQ(1, manager.RegisterNewId("", /*may_terminate_early=*/false)); UniqueClauseStream stream0; UniqueClauseStream stream1; const int kExpectedClauses = UniqueClauseStream::kMaxLiteralsPerBatch / 5; @@ -1027,8 +1027,8 @@ TEST(SharedClausesManagerTest, LbdThresholdIncrease) { TEST(SharedClausesManagerTest, LbdThresholdDecrease) { SharedClausesManager manager(/*always_synchronize=*/true); - ASSERT_EQ(0, manager.RegisterNewId(/*may_terminate_early=*/false)); - ASSERT_EQ(1, manager.RegisterNewId(/*may_terminate_early=*/false)); + ASSERT_EQ(0, manager.RegisterNewId("", /*may_terminate_early=*/false)); + ASSERT_EQ(1, manager.RegisterNewId("", /*may_terminate_early=*/false)); UniqueClauseStream stream0; UniqueClauseStream stream1; diff --git a/ortools/sat/util.cc b/ortools/sat/util.cc index 9d047a4f6b..dd90fbbcc1 100644 --- a/ortools/sat/util.cc +++ b/ortools/sat/util.cc @@ -25,6 +25,7 @@ #include "absl/algorithm/container.h" #include "absl/container/btree_set.h" #include "absl/log/check.h" +#include "absl/numeric/bits.h" #include "absl/numeric/int128.h" #include "absl/random/bit_gen_ref.h" #include "absl/random/distributions.h" @@ -1008,5 +1009,48 @@ int64_t MaxBoundedSubsetSumExact::MaxSubsetSum( return result; } +std::vector FindMostDiverseSubset(int k, int n, + absl::Span distances, + std::vector& buffer, + int always_pick_mask) { + CHECK_LE(n, 20); + const int limit = 1 << n; + buffer.assign(limit, 0); + int best_mask; + int best_value = -1; + for (unsigned int mask = 1; mask < limit; ++mask) { + const int hamming_weight = absl::popcount(mask); + + // TODO(user): Increase mask by more than one ? but counting to 1k is fast + // anyway. + if (hamming_weight > k) continue; + int low_bit = -1; + int64_t sum = 0; + for (int i = 0; i < n; ++i) { + if ((mask >> i) & 1) { + if (low_bit == -1) { + low_bit = i; + } else { + sum += distances[low_bit * n + i]; + } + } + } + buffer[mask] = buffer[mask ^ (1 << low_bit)] + sum; + if (hamming_weight == k && buffer[mask] > best_value) { + if ((mask & always_pick_mask) != always_pick_mask) continue; + best_value = buffer[mask]; + best_mask = mask; + } + } + std::vector result; + result.reserve(k); + for (int i = 0; i < n; ++i) { + if ((best_mask >> i) & 1) { + result.push_back(i); + } + } + return result; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/util.h b/ortools/sat/util.h index 88a5b927d3..3ec89ce7a9 100644 --- a/ortools/sat/util.h +++ b/ortools/sat/util.h @@ -391,6 +391,21 @@ int MoveOneUnprocessedLiteralLast( const absl::btree_set& processed, int relevant_prefix_size, std::vector* literals); +// Selects k out of n such that the sum of pairwise distances is maximal. +// distances[i * n + j] = distances[j * n + j] = distances between i and j. +// +// This shall only be called with small n, we CHECK_LE(n, 20). +// Complexity is in O(2 ^ n + n_choose_k * n). +// Memory is in O(2 ^ n). +// +// In case of tie, this will choose deterministically, so one can randomize the +// order first to get a random subset. The returned subset will always be +// sorted. +std::vector FindMostDiverseSubset(int k, int n, + absl::Span distances, + std::vector& buffer, + int always_pick_mask = 0); + // Simple DP to compute the maximum reachable value of a "subset sum" under // a given bound (inclusive). Note that we abort as soon as the computation // become too important. @@ -1005,7 +1020,6 @@ inline void CompactVectorVector::ResetFromTranspose( // // Note 2: adding an arc during an iteration is not supported and the behavior // is undefined. - class DagTopologicalSortIterator { public: DagTopologicalSortIterator() = default; diff --git a/ortools/sat/util_test.cc b/ortools/sat/util_test.cc index 1b2be2db49..9aaef75c60 100644 --- a/ortools/sat/util_test.cc +++ b/ortools/sat/util_test.cc @@ -29,6 +29,7 @@ #include "absl/container/btree_set.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" +#include "absl/numeric/bits.h" #include "absl/numeric/int128.h" #include "absl/random/random.h" #include "absl/strings/str_join.h" @@ -1160,6 +1161,94 @@ TEST(DagTopologicalSortIteratorTest, RandomTest) { } } +TEST(FindMostDiverseSubsetTest, Random) { + const int k = 4; + const int n = 10; + absl::BitGen random; + std::vector distances(n * n); + std::vector buffer; + for (int i = 0; i < n; ++i) { + for (int j = i + 1; j < n; ++j) { + distances[i * n + j] = distances[j * n + i] = + absl::Uniform(random, 0, 1000); + } + } + + const std::vector result = + FindMostDiverseSubset(k, n, distances, buffer); + CHECK(std::is_sorted(result.begin(), result.end())); + int64_t result_value = 0; + for (const int i : result) { + for (const int j : result) { + if (i < j) result_value += distances[i * n + j]; + } + } + + int64_t best_seen = 0; + std::vector subset; + const int limit = 1 << n; + for (unsigned int mask = 0; mask < limit; ++mask) { + if (absl::popcount(mask) != k) continue; + subset.clear(); + for (int i = 0; i < n; ++i) { + if ((mask >> i) & 1) subset.push_back(i); + } + int64_t value = 0; + for (const int i : subset) { + for (const int j : subset) { + if (i < j) value += distances[i * n + j]; + } + } + ASSERT_LE(value, result_value); + best_seen = std::max(best_seen, value); + } + EXPECT_EQ(best_seen, result_value); +} + +TEST(FindMostDiverseSubsetTest, RandomButAlwaysPickZero) { + const int k = 5; + const int n = 10; + absl::BitGen random; + std::vector distances(n * n); + std::vector buffer; + for (int i = 0; i < n; ++i) { + for (int j = i + 1; j < n; ++j) { + distances[i * n + j] = distances[j * n + i] = + absl::Uniform(random, 0, 1000); + } + } + + const std::vector result = + FindMostDiverseSubset(k, n, distances, buffer, /*always_pick_mask=*/1); + CHECK(std::is_sorted(result.begin(), result.end())); + int64_t result_value = 0; + for (const int i : result) { + for (const int j : result) { + if (i < j) result_value += distances[i * n + j]; + } + } + + int64_t best_seen = 0; + std::vector subset; + const int limit = 1 << n; + for (unsigned int mask = 1; mask < limit; mask += 2) { // bit 1 always set. + if (absl::popcount(mask) != k) continue; + subset.clear(); + for (int i = 0; i < n; ++i) { + if ((mask >> i) & 1) subset.push_back(i); + } + int64_t value = 0; + for (const int i : subset) { + for (const int j : subset) { + if (i < j) value += distances[i * n + j]; + } + } + ASSERT_LE(value, result_value); + best_seen = std::max(best_seen, value); + } + EXPECT_EQ(best_seen, result_value); +} + } // namespace } // namespace sat } // namespace operations_research diff --git a/ortools/sat/work_assignment.cc b/ortools/sat/work_assignment.cc index 30fc977c0c..11bb2324bb 100644 --- a/ortools/sat/work_assignment.cc +++ b/ortools/sat/work_assignment.cc @@ -50,11 +50,9 @@ namespace operations_research::sat { namespace { - -// We restart the shared tree 10 times after 2 restarts per worker. After that -// we restart when the tree reaches the maximum allowable number of nodes, but -// still at most once per 2 restarts per worker. -const int kSyncsPerWorkerPerRestart = 2; +// We restart the shared tree 10 times after (on average) 2 tree assignments per +// worker. +const int kAssignmentsPerWorkerPerRestart = 2; const int kNumInitialRestarts = 10; // If you build a tree by expanding the nodes with minimal depth+discrepancy, @@ -149,6 +147,7 @@ ProtoTrail::ProtoTrail() { target_phase_.reserve(kMaxPhaseSize); } void ProtoTrail::PushLevel(const ProtoLiteral& decision, IntegerValue objective_lb, int node_id) { CHECK_GT(node_id, 0); + assigned_at_level_[decision] = decision_indexes_.size(); decision_indexes_.push_back(literals_.size()); literals_.push_back(decision); node_ids_.push_back(node_id); @@ -165,14 +164,14 @@ void ProtoTrail::SetLevelImplied(int level) { DCHECK_LE(level, implications_.size()); SetObjectiveLb(level - 1, ObjectiveLb(level)); const ProtoLiteral decision = Decision(level); - implication_level_[decision] = level - 1; + assigned_at_level_[decision] = level - 1; // We don't store implications for level 0, so only move implications up to // the parent if we are removing level 2 or greater. if (level >= 2) { MutableImplications(level - 1).push_back(decision); } for (const ProtoLiteral& implication : Implications(level)) { - implication_level_[implication] = level - 1; + assigned_at_level_[implication] = level - 1; if (level >= 2) { MutableImplications(level - 1).push_back(implication); } @@ -190,7 +189,7 @@ void ProtoTrail::Clear() { level_to_objective_lbs_.clear(); node_ids_.clear(); target_phase_.clear(); - implication_level_.clear(); + assigned_at_level_.clear(); implications_.clear(); } @@ -232,7 +231,6 @@ SharedTreeManager::SharedTreeManager(Model* model) {.literal = ProtoLiteral(), .objective_lb = shared_response_manager_->GetInnerObjectiveLowerBound(), .trail_info = std::make_unique()}); - unassigned_leaves_.reserve(num_workers_); unassigned_leaves_.push_back(&nodes_.back()); } @@ -278,7 +276,10 @@ bool SharedTreeManager::SyncTree(ProtoTrail& path) { return false; } // Restart after processing updates - we might learn a new objective bound. - if (++num_syncs_since_restart_ / num_workers_ > kSyncsPerWorkerPerRestart && + // Do initial restarts once the tree has been split a reasonable number of + // times. + if (num_leaves_assigned_since_restart_ > + kAssignmentsPerWorkerPerRestart * num_workers_ && num_restarts_ < kNumInitialRestarts) { RestartLockHeld(); path.Clear(); @@ -370,11 +371,10 @@ void SharedTreeManager::ReplaceTree(ProtoTrail& path) { } path.Clear(); while (!unassigned_leaves_.empty()) { - const int i = num_leaves_assigned_++ % unassigned_leaves_.size(); - std::swap(unassigned_leaves_[i], unassigned_leaves_.back()); - Node* leaf = unassigned_leaves_.back(); - unassigned_leaves_.pop_back(); + Node* leaf = unassigned_leaves_.front(); + unassigned_leaves_.pop_front(); if (!leaf->closed && leaf->children[0] == nullptr) { + num_leaves_assigned_since_restart_ += 1; AssignLeaf(path, leaf); path.SetTargetPhase(GetTrailInfo(leaf)->phase); return; @@ -470,8 +470,7 @@ void SharedTreeManager::ProcessNodeChanges() { } if (num_newly_closed > 0) { shared_response_manager_->LogMessageWithThrottling( - "Tree", absl::StrCat("nodes:", nodes_.size(), "/", max_nodes_, - " closed:", num_closed_nodes_, + "Tree", absl::StrCat("closed:", num_closed_nodes_, "/", nodes_.size(), " unassigned:", unassigned_leaves_.size(), " restarts:", num_restarts_)); } @@ -581,7 +580,7 @@ void SharedTreeManager::RestartLockHeld() { num_workers_ * params_.shared_tree_open_leaves_per_worker() - 1; num_closed_nodes_ = 0; num_restarts_ += 1; - num_syncs_since_restart_ = 0; + num_leaves_assigned_since_restart_ = 0; } std::string SharedTreeManager::ShortStatus() const { @@ -728,8 +727,9 @@ bool SharedTreeWorker::NextDecision(LiteralIndex* decision_index) { const auto& decision_policy = heuristics_->decision_policies[heuristics_->policy_index]; const int next_level = sat_solver_->CurrentDecisionLevel() + 1; - new_split_available_ = next_level == assigned_tree_.MaxLevel() + 1; - + if (next_level == assigned_tree_.MaxLevel() + 1) { + new_split_available_ = true; + } CHECK_EQ(assigned_tree_literals_.size(), assigned_tree_.MaxLevel()); if (next_level <= assigned_tree_.MaxLevel()) { VLOG(2) << "Following shared trail depth=" << next_level << " " @@ -746,7 +746,8 @@ bool SharedTreeWorker::NextDecision(LiteralIndex* decision_index) { void SharedTreeWorker::MaybeProposeSplit() { if (!new_split_available_ || - sat_solver_->CurrentDecisionLevel() != assigned_tree_.MaxLevel() + 1) { + sat_solver_->CurrentDecisionLevel() < assigned_tree_.MaxLevel() + 1 || + time_limit_->GetElapsedDeterministicTime() < next_split_dtime_) { return; } new_split_available_ = false; @@ -754,6 +755,8 @@ void SharedTreeWorker::MaybeProposeSplit() { sat_solver_->Decisions()[assigned_tree_.MaxLevel()].literal; const std::optional encoded = EncodeDecision(split_decision); if (encoded.has_value()) { + next_split_dtime_ = time_limit_->GetElapsedDeterministicTime() + + parameters_->shared_tree_split_min_dtime(); CHECK_EQ(assigned_tree_literals_.size(), assigned_tree_.MaxLevel()); manager_->ProposeSplit(assigned_tree_, *encoded); if (assigned_tree_.MaxLevel() > assigned_tree_literals_.size()) { @@ -778,6 +781,7 @@ bool SharedTreeWorker::ShouldReplaceSubtree() { } bool SharedTreeWorker::SyncWithSharedTree() { + DCHECK_EQ(trail_->CurrentDecisionLevel(), 0); manager_->SyncTree(assigned_tree_); if (ShouldReplaceSubtree()) { ++num_trees_; @@ -793,6 +797,8 @@ bool SharedTreeWorker::SyncWithSharedTree() { !decision_policy_->GetBestPartialAssignment().empty()) { assigned_tree_.ClearTargetPhase(); for (Literal lit : decision_policy_->GetBestPartialAssignment()) { + // Skip saving the phase for anything assigned at the root. + if (trail_->Assignment().LiteralIsAssigned(lit)) continue; // Only set the phase for booleans to avoid creating literals on other // workers. auto encoded = ProtoLiteral::EncodeLiteral(lit, mapping_); @@ -809,8 +815,9 @@ bool SharedTreeWorker::SyncWithSharedTree() { << assigned_tree_.TargetPhase().size(); decision_policy_->ClearBestPartialAssignment(); for (const ProtoLiteral& lit : assigned_tree_.TargetPhase()) { - decision_policy_->SetTargetPolarity(DecodeDecision(lit)); + decision_policy_->SetTargetPolarityIfUnassigned(DecodeDecision(lit)); } + decision_policy_->ResetActivitiesToFollowBestPartialAssignment(); } } // If we commit to this subtree, keep it for at least 1s of dtime. diff --git a/ortools/sat/work_assignment.h b/ortools/sat/work_assignment.h index 1626af4fee..d2f7463f8b 100644 --- a/ortools/sat/work_assignment.h +++ b/ortools/sat/work_assignment.h @@ -135,10 +135,10 @@ class ProtoTrail { // the decision. absl::Span Implications(int level) const; void AddImplication(int level, ProtoLiteral implication) { - auto it = implication_level_.find(implication); - if (it != implication_level_.end() && it->second <= level) return; + auto it = assigned_at_level_.find(implication); + if (it != assigned_at_level_.end() && it->second <= level) return; MutableImplications(level).push_back(implication); - implication_level_[implication] = level; + assigned_at_level_[implication] = level; } IntegerValue ObjectiveLb(int level) const { @@ -153,7 +153,7 @@ class ProtoTrail { // Appends a literal to the target phase, returns false if the phase is full. bool AddPhase(const ProtoLiteral& lit) { if (target_phase_.size() >= kMaxPhaseSize) return false; - if (!implication_level_.contains(lit)) { + if (!IsAssigned(lit)) { target_phase_.push_back(lit); } return true; @@ -164,6 +164,10 @@ class ProtoTrail { if (!AddPhase(lit)) break; } } + bool IsAssigned(const ProtoLiteral& lit) const { + return assigned_at_level_.contains(lit) || + assigned_at_level_.contains(lit.Negated()); + } private: // 256 ProtoLiterals take up 4KiB @@ -179,7 +183,7 @@ class ProtoTrail { // Extra implications that can be propagated at each level but were never // branches in the shared tree. std::vector> implications_; - absl::flat_hash_map implication_level_; + absl::flat_hash_map assigned_at_level_; // The index in the literals_/node_ids_ vectors for the start of each level. std::vector decision_indexes_; @@ -277,7 +281,7 @@ class SharedTreeManager { // Stores the nodes in the search tree. std::deque nodes_ ABSL_GUARDED_BY(mu_); - std::vector unassigned_leaves_ ABSL_GUARDED_BY(mu_); + std::deque unassigned_leaves_ ABSL_GUARDED_BY(mu_); // How many splits we should generate now to keep the desired number of // leaves. @@ -287,7 +291,7 @@ class SharedTreeManager { // communication overhead. If we exceed this, workers become portfolio // workers when no unassigned leaves are available. const int max_nodes_; - int num_leaves_assigned_ ABSL_GUARDED_BY(mu_) = 0; + int num_leaves_assigned_since_restart_ ABSL_GUARDED_BY(mu_) = 0; // Temporary vectors used to maintain the state of the tree when nodes are // closed and/or children are updated. @@ -295,7 +299,6 @@ class SharedTreeManager { std::vector to_update_ ABSL_GUARDED_BY(mu_); int64_t num_restarts_ ABSL_GUARDED_BY(mu_) = 0; - int64_t num_syncs_since_restart_ ABSL_GUARDED_BY(mu_) = 0; int num_closed_nodes_ ABSL_GUARDED_BY(mu_) = 0; }; @@ -355,6 +358,7 @@ class SharedTreeWorker { ProtoTrail assigned_tree_; std::vector assigned_tree_literals_; std::vector> assigned_tree_implications_; + double next_split_dtime_ = 0; // True if the last decision may split the assigned tree and has not yet been // proposed to the SharedTreeManager. diff --git a/ortools/sat/work_assignment_test.cc b/ortools/sat/work_assignment_test.cc index f7d9b3ea55..90a8c287d4 100644 --- a/ortools/sat/work_assignment_test.cc +++ b/ortools/sat/work_assignment_test.cc @@ -39,6 +39,9 @@ TEST(ProtoTrailTest, PushLevel) { EXPECT_EQ(p.MaxLevel(), 1); EXPECT_EQ(p.Decision(1), ProtoLiteral(0, 0)); EXPECT_EQ(p.ObjectiveLb(1), 0); + EXPECT_TRUE(p.IsAssigned(ProtoLiteral(0, 0))); + EXPECT_TRUE(p.IsAssigned(ProtoLiteral(0, 0).Negated())); + EXPECT_FALSE(p.IsAssigned(ProtoLiteral(1, 0))); } TEST(ProtoTrailTest, AddImplications) { @@ -57,6 +60,12 @@ TEST(ProtoTrailTest, AddImplications) { EXPECT_THAT(p.Implications(2), testing::UnorderedElementsAre( ProtoLiteral(5, 0), ProtoLiteral(2, 0), ProtoLiteral(6, 0))); + EXPECT_TRUE(p.IsAssigned(ProtoLiteral(0, 0))); + EXPECT_TRUE(p.IsAssigned(ProtoLiteral(1, 0))); + EXPECT_TRUE(p.IsAssigned(ProtoLiteral(2, 0))); + EXPECT_TRUE(p.IsAssigned(ProtoLiteral(3, 0))); + EXPECT_TRUE(p.IsAssigned(ProtoLiteral(5, 0))); + EXPECT_TRUE(p.IsAssigned(ProtoLiteral(6, 0))); } TEST(ProtoTrailTest, SetLevel1Implied) { @@ -567,6 +576,9 @@ TEST(SharedTreeManagerTest, TrailSharing) { shared_tree_manager->ReplaceTree(trail1); shared_tree_manager->ReplaceTree(trail2); + EXPECT_EQ(shared_tree_manager->NumNodes(), 3); + EXPECT_EQ(trail1.MaxLevel(), 1); + EXPECT_EQ(trail2.MaxLevel(), 1); EXPECT_EQ(trail2.Implications(1).size(), 1); EXPECT_EQ(trail2.TargetPhase().size(), 1); EXPECT_TRUE(trail1.Implications(1).empty());