diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index 7d30fd7af6..78866cd780 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -1231,6 +1231,7 @@ cc_library( ":pb_constraint", ":sat_base", ":sat_parameters_cc_proto", + ":synchronization", ":util", "//ortools/base", "//ortools/base:strong_vector", @@ -2601,16 +2602,10 @@ cc_library( ":model", ":sat_base", ":util", - "//ortools/base", "//ortools/base:stl_util", "//ortools/base:strong_vector", - "//ortools/util:saturated_arithmetic", - "//ortools/util:sorted_interval_list", "//ortools/util:strong_integers", - "//ortools/util:time_limit", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -3160,6 +3155,7 @@ cc_library( ":util", "//ortools/util:saturated_arithmetic", "//ortools/util:strong_integers", + "//ortools/util:time_limit", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index c093f9cae4..bfee9b5110 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -5882,14 +5882,14 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { indexed_intervals.push_back({x, IntegerValue(context_->StartMin(y)), IntegerValue(context_->EndMax(y))}); } - std::vector> no_overlaps; - ConstructOverlappingSets(/*already_sorted=*/false, &indexed_intervals, - &no_overlaps); - for (const std::vector& no_overlap : no_overlaps) { + CompactVectorVector no_overlaps; + absl::c_sort(indexed_intervals, IndexedInterval::ComparatorByStart()); + ConstructOverlappingSets(absl::MakeSpan(indexed_intervals), &no_overlaps); + for (int i = 0; i < no_overlaps.size(); ++i) { ConstraintProto* new_ct = context_->working_model->add_constraints(); // Unfortunately, the Assign() method does not work in or-tools as the // protobuf int32_t type is not the int type. - for (const int i : no_overlap) { + for (const int i : no_overlaps[i]) { new_ct->mutable_no_overlap()->add_intervals(i); } } diff --git a/ortools/sat/cp_model_search.cc b/ortools/sat/cp_model_search.cc index 21eb32fc98..7c1e05eb12 100644 --- a/ortools/sat/cp_model_search.cc +++ b/ortools/sat/cp_model_search.cc @@ -704,6 +704,13 @@ absl::flat_hash_map GetNamedParameters( new_params.set_optimize_with_lb_tree_search(false); new_params.set_optimize_with_max_hs(false); + // Given that each workers work on a different part of the subtree, it might + // not be a good idea to try to work on a global shared solution. + // + // TODO(user): Experiments more here, in particular we could follow it if + // it falls into the current subtree. + new_params.set_polarity_exploit_ls_hints(false); + strategies["shared_tree"] = new_params; } diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index c341b6fa77..1c3856e6ac 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -682,12 +682,13 @@ void LogFinalStatistics(SharedClasses* shared) { shared->logger->FlushPendingThrottledLogs(/*ignore_rates=*/true); SOLVER_LOG(shared->logger, ""); - shared->stat_tables.Display(shared->logger); + shared->stat_tables->Display(shared->logger); shared->response->DisplayImprovementStatistics(); std::vector> table; table.push_back({"Solution repositories", "Added", "Queried", "Synchro"}); table.push_back(shared->response->SolutionsRepository().TableLineStats()); + table.push_back(shared->ls_hints->TableLineStats()); if (shared->lp_solutions != nullptr) { table.push_back(shared->lp_solutions->TableLineStats()); } @@ -914,35 +915,9 @@ class FullProblemSolver : public SubSolver { shared_->response->first_solution_solvers_should_stop()); } - if (shared->response != nullptr) { - local_model_.Register(shared->response); - } - - if (shared->lp_solutions != nullptr) { - local_model_.Register( - shared->lp_solutions.get()); - } - - if (shared->incomplete_solutions != nullptr) { - local_model_.Register( - shared->incomplete_solutions.get()); - } - - if (shared->bounds != nullptr) { - local_model_.Register(shared->bounds.get()); - } - - if (shared->clauses != nullptr) { - local_model_.Register(shared->clauses.get()); - } - - if (local_parameters.use_shared_tree_search()) { - local_model_.Register(shared->shared_tree_manager); - } - // TODO(user): For now we do not count LNS statistics. We could easily // by registering the SharedStatistics class with LNS local model. - local_model_.Register(shared_->stats); + shared_->RegisterSharedClassesInLocalModel(&local_model_); // Setup the local logger, in multi-thread log_search_progress should be // false by default, but we might turn it on for debugging. It is on by @@ -956,10 +931,10 @@ class FullProblemSolver : public SubSolver { CpSolverResponse response; shared_->response->FillSolveStatsInResponse(&local_model_, &response); shared_->response->AppendResponseToBeMerged(response); - shared_->stat_tables.AddTimingStat(*this); - shared_->stat_tables.AddLpStat(name(), &local_model_); - shared_->stat_tables.AddSearchStat(name(), &local_model_); - shared_->stat_tables.AddClausesStat(name(), &local_model_); + shared_->stat_tables->AddTimingStat(*this); + shared_->stat_tables->AddLpStat(name(), &local_model_); + shared_->stat_tables->AddSearchStat(name(), &local_model_); + shared_->stat_tables->AddClausesStat(name(), &local_model_); } bool IsDone() override { @@ -1104,30 +1079,11 @@ class FeasibilityPumpSolver : public SubSolver { *(local_model_->GetOrCreate()) = local_parameters; shared_->time_limit->UpdateLocalLimit( local_model_->GetOrCreate()); - - if (shared->response != nullptr) { - local_model_->Register(shared->response); - } - - if (shared->lp_solutions != nullptr) { - local_model_->Register( - shared->lp_solutions.get()); - } - - if (shared->incomplete_solutions != nullptr) { - local_model_->Register( - shared->incomplete_solutions.get()); - } - - // Level zero variable bounds sharing. - if (shared_->bounds != nullptr) { - RegisterVariableBoundsLevelZeroImport( - shared_->model_proto, shared_->bounds.get(), local_model_.get()); - } + shared_->RegisterSharedClassesInLocalModel(local_model_.get()); } ~FeasibilityPumpSolver() override { - shared_->stat_tables.AddTimingStat(*this); + shared_->stat_tables->AddTimingStat(*this); } bool IsDone() override { return shared_->SearchIsDone(); } @@ -1216,8 +1172,8 @@ class LnsSolver : public SubSolver { shared_(shared) {} ~LnsSolver() override { - shared_->stat_tables.AddTimingStat(*this); - shared_->stat_tables.AddLnsStat( + shared_->stat_tables->AddTimingStat(*this); + shared_->stat_tables->AddLnsStat( name(), /*num_fully_solved_calls=*/generator_->num_fully_solved_calls(), /*num_calls=*/generator_->num_calls(), @@ -1654,6 +1610,7 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) { "synchronization_agent", [shared]() { shared->response->Synchronize(); shared->response->MutableSolutionsRepository()->Synchronize(); + shared->ls_hints->Synchronize(); if (shared->bounds != nullptr) { shared->bounds->Synchronize(); } @@ -1946,7 +1903,7 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) { if (num_ls_default > 0) { std::shared_ptr states = std::make_shared( - ls_name, params, &shared->stat_tables); + ls_name, params, shared->stat_tables); for (int i = 0; i < num_ls_default; ++i) { SatParameters local_params = params; local_params.set_random_seed( @@ -1956,14 +1913,15 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) { std::make_unique( ls_name, SubSolver::INCOMPLETE, get_linear_model(), local_params, states, shared->time_limit, shared->response, - shared->bounds.get(), shared->stats, &shared->stat_tables)); + shared->bounds.get(), shared->ls_hints, shared->stats, + shared->stat_tables)); } } if (num_ls_lin > 0) { std::shared_ptr lin_states = std::make_shared(lin_ls_name, params, - &shared->stat_tables); + shared->stat_tables); for (int i = 0; i < num_ls_lin; ++i) { SatParameters local_params = params; local_params.set_random_seed( @@ -1973,7 +1931,8 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) { std::make_unique( lin_ls_name, SubSolver::INCOMPLETE, get_linear_model(), local_params, lin_states, shared->time_limit, shared->response, - shared->bounds.get(), shared->stats, &shared->stat_tables)); + shared->bounds.get(), shared->ls_hints, shared->stats, + shared->stat_tables)); } } } @@ -2011,13 +1970,13 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) { if (local_params.feasibility_jump_linearization_level() == 0) { if (fj_states == nullptr) { fj_states = std::make_shared( - local_params.name(), params, &shared->stat_tables); + local_params.name(), params, shared->stat_tables); } states = fj_states; } else { if (fj_lin_states == nullptr) { fj_lin_states = std::make_shared( - local_params.name(), params, &shared->stat_tables); + local_params.name(), params, shared->stat_tables); } states = fj_lin_states; } @@ -2026,8 +1985,8 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) { std::make_unique( local_params.name(), SubSolver::FIRST_SOLUTION, get_linear_model(), local_params, states, shared->time_limit, - shared->response, shared->bounds.get(), shared->stats, - &shared->stat_tables)); + shared->response, shared->bounds.get(), shared->ls_hints, + shared->stats, shared->stat_tables)); } else { first_solution_full_subsolvers.push_back( std::make_unique( diff --git a/ortools/sat/cp_model_solver_helpers.cc b/ortools/sat/cp_model_solver_helpers.cc index 4e0c72359e..8784bda39d 100644 --- a/ortools/sat/cp_model_solver_helpers.cc +++ b/ortools/sat/cp_model_solver_helpers.cc @@ -1971,8 +1971,10 @@ SharedClasses::SharedClasses(const CpModelProto* proto, Model* global_model) time_limit(global_model->GetOrCreate()), logger(global_model->GetOrCreate()), stats(global_model->GetOrCreate()), + stat_tables(global_model->GetOrCreate()), response(global_model->GetOrCreate()), - shared_tree_manager(global_model->GetOrCreate()) { + shared_tree_manager(global_model->GetOrCreate()), + ls_hints(global_model->GetOrCreate()) { const SatParameters& params = *global_model->GetOrCreate(); if (params.share_level_zero_bounds()) { @@ -2007,6 +2009,31 @@ SharedClasses::SharedClasses(const CpModelProto* proto, Model* global_model) } } +void SharedClasses::RegisterSharedClassesInLocalModel(Model* local_model) { + // Note that we do not register the logger which is not a shared class. + local_model->Register(response); + local_model->Register(ls_hints); + local_model->Register(shared_tree_manager); + local_model->Register(stats); + local_model->Register(stat_tables); + + // TODO(user): Use parameters and not the presence/absence of these class + // to decide when to use them. + if (lp_solutions != nullptr) { + local_model->Register(lp_solutions.get()); + } + if (incomplete_solutions != nullptr) { + local_model->Register( + incomplete_solutions.get()); + } + if (bounds != nullptr) { + local_model->Register(bounds.get()); + } + if (clauses != nullptr) { + local_model->Register(clauses.get()); + } +} + bool SharedClasses::SearchIsDone() { if (response->ProblemIsSolved()) { // This is for cases where the time limit is checked more often. diff --git a/ortools/sat/cp_model_solver_helpers.h b/ortools/sat/cp_model_solver_helpers.h index d3fe19b03c..ff4403b2fb 100644 --- a/ortools/sat/cp_model_solver_helpers.h +++ b/ortools/sat/cp_model_solver_helpers.h @@ -56,8 +56,10 @@ struct SharedClasses { ModelSharedTimeLimit* const time_limit; SolverLogger* const logger; SharedStatistics* const stats; + SharedStatTables* const stat_tables; SharedResponseManager* const response; SharedTreeManager* const shared_tree_manager; + SharedLsSolutionRepository* const ls_hints; // These can be nullptr depending on the options. std::unique_ptr bounds; @@ -65,8 +67,9 @@ struct SharedClasses { std::unique_ptr incomplete_solutions; std::unique_ptr clauses; - // For displaying summary at the end. - SharedStatTables stat_tables; + // call local_model->Register() on most of the class here, this allow to + // more easily depends on one of the shared class deep within the solver. + void RegisterSharedClassesInLocalModel(Model* local_model); bool SearchIsDone(); }; diff --git a/ortools/sat/diffn.cc b/ortools/sat/diffn.cc index ea125c893a..5c0fb2456a 100644 --- a/ortools/sat/diffn.cc +++ b/ortools/sat/diffn.cc @@ -579,12 +579,12 @@ IntegerValue FindCanonicalValue(IntegerValue lb, IntegerValue ub) { } void SplitDisjointBoxes(const SchedulingConstraintHelper& x, - absl::Span boxes, - std::vector>* result) { + absl::Span boxes, + std::vector>* result) { result->clear(); - std::sort(boxes.begin(), boxes.end(), [&x](int a, int b) { + DCHECK(std::is_sorted(boxes.begin(), boxes.end(), [&x](int a, int b) { return x.ShiftedStartMin(a) < x.ShiftedStartMin(b); - }); + })); int current_start = 0; std::size_t current_length = 1; IntegerValue current_max_end = x.EndMax(boxes[0]); @@ -670,6 +670,7 @@ NonOverlappingRectanglesDisjunctivePropagator:: global_y_(*y), x_(x->NumTasks(), model), watcher_(model->GetOrCreate()), + time_limit_(model->GetOrCreate()), overload_checker_(&x_), forward_detectable_precedences_(true, &x_), backward_detectable_precedences_(false, &x_), @@ -700,7 +701,7 @@ void NonOverlappingRectanglesDisjunctivePropagator::Register( bool NonOverlappingRectanglesDisjunctivePropagator:: FindBoxesThatMustOverlapAHorizontalLineAndPropagate( - bool fast_propagation, const SchedulingConstraintHelper& x, + bool fast_propagation, SchedulingConstraintHelper* x, SchedulingConstraintHelper* y) { // Note that since we only push bounds on x, we cache the value for y just // once. @@ -713,13 +714,13 @@ bool NonOverlappingRectanglesDisjunctivePropagator:: for (int i = temp.size(); --i >= 0;) { const int box = temp[i].task_index; // Ignore absent boxes. - if (x.IsAbsent(box) || y->IsAbsent(box)) continue; + if (x->IsAbsent(box) || y->IsAbsent(box)) continue; // Ignore boxes where the relevant presence literal is only on the y // dimension, or if both intervals are optionals with different literals. - if (x.IsPresent(box) && !y->IsPresent(box)) continue; - if (!x.IsPresent(box) && !y->IsPresent(box) && - x.PresenceLiteral(box) != y->PresenceLiteral(box)) { + if (x->IsPresent(box) && !y->IsPresent(box)) continue; + if (!x->IsPresent(box) && !y->IsPresent(box) && + x->PresenceLiteral(box) != y->PresenceLiteral(box)) { continue; } @@ -732,16 +733,60 @@ bool NonOverlappingRectanglesDisjunctivePropagator:: // Less than 2 boxes, no propagation. if (indexed_boxes_.size() < 2) return true; - ConstructOverlappingSets(/*already_sorted=*/true, &indexed_boxes_, - &events_overlapping_boxes_); + + // In ConstructOverlappingSets() we will always sort the output by + // x.ShiftedStartMin(t). We want to speed that up so we cache the order here. + if (!x->SynchronizeAndSetTimeDirection(x->CurrentTimeIsForward())) { + return false; + } + + // Optim: Abort if all rectangle can be fixed to their mandatory y + minimium + // x position without any overlap. Technically we might still propagate the x + // end in this setting, but the current code will just abort below in + // SplitDisjointBoxes() anyway. + // + // This is guaranteed to be O(N log N) whereas the algo below is O(N ^ 2). + if (indexed_boxes_.size() > 100) { + rectangles_.clear(); + rectangles_.reserve(indexed_boxes_.size()); + for (const auto [box, y_mandatory_start, y_mandatory_end] : + indexed_boxes_) { + // Note that we invert the x/y position here in order to be already sorted + // for FindOneIntersectionIfPresent() + rectangles_.push_back( + {/*x_min=*/y_mandatory_start, /*x_max=*/y_mandatory_end, + /*y_min=*/x->StartMin(box), /*y_max=*/x->EndMin(box)}); + } + const auto opt_pair = FindOneIntersectionIfPresent(rectangles_); + { + const size_t n = rectangles_.size(); + time_limit_->AdvanceDeterministicTime( + static_cast(n) * static_cast(absl::bit_width(n)) * + 1e-8); + } + if (opt_pair == std::nullopt) { + return true; + } + } + + order_.assign(x->NumTasks(), 0); + { + int i = 0; + for (const auto [t, _lit, _time] : x->TaskByIncreasingShiftedStartMin()) { + order_[t] = i++; + } + } + ConstructOverlappingSets(absl::MakeSpan(indexed_boxes_), + &events_overlapping_boxes_, order_); // Split lists of boxes into disjoint set of boxes (w.r.t. overlap). boxes_to_propagate_.clear(); reduced_overlapping_boxes_.clear(); + int work_done = indexed_boxes_.size(); for (int i = 0; i < events_overlapping_boxes_.size(); ++i) { - SplitDisjointBoxes(x, absl::MakeSpan(events_overlapping_boxes_[i]), - &disjoint_boxes_); - for (absl::Span sub_boxes : disjoint_boxes_) { + work_done += events_overlapping_boxes_[i].size(); + SplitDisjointBoxes(*x, events_overlapping_boxes_[i], &disjoint_boxes_); + for (const absl::Span sub_boxes : disjoint_boxes_) { // Boxes are sorted in a stable manner in the Split method. // Note that we do not use reduced_overlapping_boxes_ directly so that // the order of iteration is deterministic. @@ -750,6 +795,9 @@ bool NonOverlappingRectanglesDisjunctivePropagator:: } } + // TODO(user): This is a poor dtime, but we want it not to be zero here. + time_limit_->AdvanceDeterministicTime(static_cast(work_done) * 1e-8); + // And finally propagate. // // TODO(user): Sorting of boxes seems influential on the performance. Test. @@ -759,7 +807,7 @@ bool NonOverlappingRectanglesDisjunctivePropagator:: if (!fast_propagation && boxes.size() <= 2) continue; x_.ClearOtherHelper(); - if (!x_.ResetFromSubset(x, boxes)) return false; + if (!x_.ResetFromSubset(*x, boxes)) return false; // Collect the common overlapping coordinates of all boxes. IntegerValue lb(std::numeric_limits::min()); @@ -815,11 +863,11 @@ bool NonOverlappingRectanglesDisjunctivePropagator::Propagate() { // done by the fast mode. const bool fast_propagation = watcher_->GetCurrentId() == fast_id_; RETURN_IF_FALSE(FindBoxesThatMustOverlapAHorizontalLineAndPropagate( - fast_propagation, global_x_, &global_y_)); + fast_propagation, &global_x_, &global_y_)); // We can actually swap dimensions to propagate vertically. RETURN_IF_FALSE(FindBoxesThatMustOverlapAHorizontalLineAndPropagate( - fast_propagation, global_y_, &global_x_)); + fast_propagation, &global_y_, &global_x_)); return true; } diff --git a/ortools/sat/diffn.h b/ortools/sat/diffn.h index 7a1c174549..38938f5154 100644 --- a/ortools/sat/diffn.h +++ b/ortools/sat/diffn.h @@ -31,6 +31,7 @@ #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/synchronization.h" #include "ortools/sat/util.h" +#include "ortools/util/time_limit.h" namespace operations_research { namespace sat { @@ -111,7 +112,7 @@ class NonOverlappingRectanglesDisjunctivePropagator private: bool PropagateOnXWhenOnlyTwoBoxes(); bool FindBoxesThatMustOverlapAHorizontalLineAndPropagate( - bool fast_propagation, const SchedulingConstraintHelper& x, + bool fast_propagation, SchedulingConstraintHelper* x, SchedulingConstraintHelper* y); SchedulingConstraintHelper& global_x_; @@ -119,14 +120,18 @@ class NonOverlappingRectanglesDisjunctivePropagator SchedulingConstraintHelper x_; GenericLiteralWatcher* watcher_; + TimeLimit* time_limit_; int fast_id_; // Propagator id of the "fast" version. + // Temporary data. std::vector indexed_boxes_; - std::vector> events_overlapping_boxes_; + std::vector rectangles_; + std::vector order_; + CompactVectorVector events_overlapping_boxes_; - absl::flat_hash_set> reduced_overlapping_boxes_; - std::vector> boxes_to_propagate_; - std::vector> disjoint_boxes_; + absl::flat_hash_set> reduced_overlapping_boxes_; + std::vector> boxes_to_propagate_; + std::vector> disjoint_boxes_; std::vector non_zero_area_boxes_; DisjunctiveOverloadChecker overload_checker_; diff --git a/ortools/sat/diffn_util.cc b/ortools/sat/diffn_util.cc index d2c7725186..944cb648af 100644 --- a/ortools/sat/diffn_util.cc +++ b/ortools/sat/diffn_util.cc @@ -107,41 +107,33 @@ absl::InlinedVector Rectangle::RegionDifference( return result; } +// TODO(user): Switch to a faster O(n log n) algo. CompactVectorVector GetOverlappingRectangleComponents( absl::Span rectangles, absl::Span active_rectangles) { if (active_rectangles.empty()) return {}; - std::vector rectangles_to_process; - std::vector rectangles_index; - rectangles_to_process.reserve(active_rectangles.size()); - rectangles_index.reserve(active_rectangles.size()); - for (const int r : active_rectangles) { - rectangles_to_process.push_back(rectangles[r]); - rectangles_index.push_back(r); - } + std::vector active_rectangles_copy(active_rectangles.begin(), + active_rectangles.end()); + const int size = active_rectangles_copy.size(); + absl::Span indices = absl::MakeSpan(active_rectangles_copy); - std::vector> intersections = - FindPartialRectangleIntersectionsAlsoEmpty(rectangles_to_process); - const int num_intersections = intersections.size(); - intersections.reserve(num_intersections * 2 + 1); - for (int i = 0; i < num_intersections; ++i) { - intersections.push_back({intersections[i].second, intersections[i].first}); - } - - CompactVectorVector view; - view.ResetFromPairs(intersections, /*minimum_num_nodes=*/rectangles.size()); - CompactVectorVector components; - FindStronglyConnectedComponents(static_cast(rectangles.size()), view, - &components); CompactVectorVector result; - for (int i = 0; i < components.size(); ++i) { - absl::Span component = components[i]; - if (component.size() == 1) continue; - result.Add({}); - for (const int r : component) { - result.AppendToLastVector(rectangles_index[r]); + for (int start = 0; start < size;) { + // Find the component of active_rectangles[start]. + int end = start + 1; + for (int i = start; i < end; i++) { + const Rectangle rect = rectangles[indices[i]]; + for (int j = end; j < size; ++j) { + if (!rect.IsDisjoint(rectangles[indices[j]])) { + std::swap(indices[end++], indices[j]); + } + } } + if (end > start + 1) { + result.Add(indices.subspan(start, end - start)); + } + start = end; } return result; } @@ -435,52 +427,76 @@ absl::Span FilterBoxesThatAreTooLarge( return boxes.subspan(0, new_size); } -void ConstructOverlappingSets(bool already_sorted, - std::vector* intervals, - std::vector>* result) { +void ConstructOverlappingSets(absl::Span intervals, + CompactVectorVector* result, + absl::Span order) { result->clear(); - if (already_sorted) { - DCHECK(std::is_sorted(intervals->begin(), intervals->end(), - IndexedInterval::ComparatorByStart())); - } else { - std::sort(intervals->begin(), intervals->end(), - IndexedInterval::ComparatorByStart()); - } + DCHECK(std::is_sorted(intervals.begin(), intervals.end(), + IndexedInterval::ComparatorByStart())); IntegerValue min_end_in_set = kMaxIntegerValue; - intervals->push_back({-1, kMaxIntegerValue, kMaxIntegerValue}); // Sentinel. - const int size = intervals->size(); // We do a line sweep. The "current" subset crossing the "line" at // (time, time + 1) will be in (*intervals)[start_index, end_index) at the end // of the loop block. int start_index = 0; + const int size = intervals.size(); for (int end_index = 0; end_index < size;) { - const IntegerValue time = (*intervals)[end_index].start; + const IntegerValue time = intervals[end_index].start; // First, if there is some deletion, we will push the "old" set to the // result before updating it. Otherwise, we will have a superset later, so // we just continue for now. if (min_end_in_set <= time) { - result->push_back({}); - min_end_in_set = kMaxIntegerValue; - for (int i = start_index; i < end_index; ++i) { - result->back().push_back((*intervals)[i].index); - if ((*intervals)[i].end <= time) { - std::swap((*intervals)[start_index++], (*intervals)[i]); - } else { - min_end_in_set = std::min(min_end_in_set, (*intervals)[i].end); + // Push the current set to result first if its size is > 1. + if (start_index + 1 < end_index) { + result->Add({}); + for (int i = start_index; i < end_index; ++i) { + result->AppendToLastVector(intervals[i].index); } } - // Do not output subset of size one. - if (result->back().size() == 1) result->pop_back(); + // Update the set. Note that we keep the order. + min_end_in_set = kMaxIntegerValue; + int new_start = end_index; + for (int i = end_index; --i >= start_index;) { + if (intervals[i].end > time) { + min_end_in_set = std::min(min_end_in_set, intervals[i].end); + intervals[--new_start] = intervals[i]; + } + } + start_index = new_start; } // Add all the new intervals starting exactly at "time". - do { - min_end_in_set = std::min(min_end_in_set, (*intervals)[end_index].end); + // Note that we always add at least one here. + const int old_end = end_index; + while (end_index < size && intervals[end_index].start == time) { + min_end_in_set = std::min(min_end_in_set, intervals[end_index].end); ++end_index; - } while (end_index < size && (*intervals)[end_index].start == time); + } + + // If order is not empty, make sure we maintain the order. + // TODO(user): we could only do that when we push a new set. + if (!order.empty() && end_index > old_end) { + std::sort(intervals.data() + old_end, intervals.data() + end_index, + [order](const IndexedInterval& a, const IndexedInterval& b) { + return order[a.index] < order[b.index]; + }); + std::inplace_merge( + intervals.data() + start_index, intervals.data() + old_end, + intervals.data() + end_index, + [order](const IndexedInterval& a, const IndexedInterval& b) { + return order[a.index] < order[b.index]; + }); + } + } + + // Push final set. + if (start_index + 1 < size) { + result->Add({}); + for (int i = start_index; i < size; ++i) { + result->AppendToLastVector(intervals[i].index); + } } } @@ -1928,5 +1944,82 @@ std::vector> FindPartialRectangleIntersectionsAlsoEmpty( return result; } +absl::optional> FindOneIntersectionIfPresent( + absl::Span rectangles) { + DCHECK( + absl::c_is_sorted(rectangles, [](const Rectangle& a, const Rectangle& b) { + return a.x_min < b.x_min; + })); + + // Current y-coordinate intervals that are intersecting the sweep line. + // Note that the interval_set only contains disjoint intervals. + struct Interval { + int index; + IntegerValue y_min; + IntegerValue y_max; + + // IMPORTANT: For correctness, we need later insert to be first! + bool operator<(const Interval& other) const { + if (y_min == other.y_min) return index > other.index; + return y_min < other.y_min; + } + + std::string to_string() const { + return absl::StrCat("[", y_min.value(), ",", y_max.value(), "](", index, + ")"); + } + }; + + // TODO(user): Use fixed binary tree instead, it should be faster. + // We just need insert/erase/previous/next API. + std::set interval_set; + + for (int i = 0; i < rectangles.size(); ++i) { + const IntegerValue x = rectangles[i].x_min; + + // Try to add the y part of this rectangle to the set, if there is an + // intersection, lazily remove it if its x_max is already passed, otherwise + // report the intersection. + const Interval to_insert = {i, rectangles[i].y_min, rectangles[i].y_max}; + auto [it, inserted] = interval_set.insert(to_insert); + DCHECK(inserted); + + // Note that the intersection is either before 'it', or just after it. + if (it != interval_set.begin()) { + auto it_before = it; + --it_before; + + // Lazy erase stale entry. + if (rectangles[it_before->index].x_max <= x) { + interval_set.erase(it_before); + } else { + DCHECK_LE(it_before->y_min, to_insert.y_min); + if (it_before->y_max > to_insert.y_min) { + // Intersection. + return {{it_before->index, i}}; + } + } + } + ++it; + while (it != interval_set.end()) { + // Lazy erase stale entry. + if (rectangles[it->index].x_max <= x) { + auto to_erase = it++; + interval_set.erase(to_erase); + continue; + } + + DCHECK_LE(to_insert.y_min, it->y_min); + if (to_insert.y_max > it->y_min) { + // Intersection. + return {{it->index, i}}; + } + break; + } + } + + return {}; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/diffn_util.h b/ortools/sat/diffn_util.h index 0c8044740f..98200f6216 100644 --- a/ortools/sat/diffn_util.h +++ b/ortools/sat/diffn_util.h @@ -228,16 +228,23 @@ struct IndexedInterval { } }; -// Given n fixed intervals, returns the subsets of intervals that overlap during -// at least one time unit. Note that we only return "maximal" subset and filter -// subset strictly included in another. +// Given n fixed intervals that must be sorted by +// IndexedInterval::ComparatorByStart(), returns the subsets of intervals that +// overlap during at least one time unit. Note that we only return "maximal" +// subset and filter subset strictly included in another. +// +// IMPORTANT: The span of intervals will not be usable after this function! this +// could be changed if needed with an extra copy. // // All Intervals must have a positive size. // // The algo is in O(n log n) + O(result_size) which is usually O(n^2). -void ConstructOverlappingSets(bool already_sorted, - std::vector* intervals, - std::vector>* result); +// +// If the last argument is not empty, we will sort the interval in the result +// according to the given order, i.e. i will be before j if order[i] < order[j]. +void ConstructOverlappingSets(absl::Span intervals, + CompactVectorVector* result, + absl::Span order = {}); // Given n intervals, returns the set of connected components (using the overlap // relation between 2 intervals). Components are sorted by their start, and @@ -702,6 +709,16 @@ std::vector> FindPartialRectangleIntersections( std::vector> FindPartialRectangleIntersectionsAlsoEmpty( absl::Span rectangles); +// This function is faster that the FindPartialRectangleIntersections() if one +// only want to know if there is at least one intersection. It is in O(N log N). +// +// IMPORTANT: this assumes rectangles are already sorted by their x_min. +// +// If a pair {i, j} is returned, we will have i < j, and no intersection in +// the subset of rectanges in [0, j). +absl::optional> FindOneIntersectionIfPresent( + absl::Span rectangles); + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/diffn_util_test.cc b/ortools/sat/diffn_util_test.cc index 2d76136199..98ea25f57e 100644 --- a/ortools/sat/diffn_util_test.cc +++ b/ortools/sat/diffn_util_test.cc @@ -193,8 +193,6 @@ TEST(FilterBoxesThatAreTooLargeTest, BasicTest) { } TEST(ConstructOverlappingSetsTest, BasicTest) { - std::vector> result{{3}}; // To be sure we clear. - // --------------------0 // --------1 --------2 // ------------3 @@ -204,25 +202,30 @@ TEST(ConstructOverlappingSetsTest, BasicTest) { {2, IntegerValue(6), IntegerValue(10)}, {3, IntegerValue(2), IntegerValue(8)}, {4, IntegerValue(3), IntegerValue(6)}}; + absl::c_sort(intervals, IndexedInterval::ComparatorByStart()); // Note that the order is deterministic, but not sorted. - ConstructOverlappingSets(/*already_sorted=*/false, &intervals, &result); - EXPECT_THAT(result, ElementsAre(UnorderedElementsAre(0, 1, 3, 4), - UnorderedElementsAre(3, 0, 2))); + CompactVectorVector result; + result.Add({0, 1, 2}); // To be sure we clear. + ConstructOverlappingSets(absl::MakeSpan(intervals), &result); + EXPECT_THAT(result.AsVectorOfSpan(), + ElementsAre(UnorderedElementsAre(0, 1, 3, 4), + UnorderedElementsAre(3, 0, 2))); } TEST(ConstructOverlappingSetsTest, OneSet) { - std::vector> result{{3}}; // To be sure we clear. - std::vector intervals{ {0, IntegerValue(0), IntegerValue(10)}, {1, IntegerValue(1), IntegerValue(10)}, {2, IntegerValue(2), IntegerValue(10)}, {3, IntegerValue(3), IntegerValue(10)}, {4, IntegerValue(4), IntegerValue(10)}}; + absl::c_sort(intervals, IndexedInterval::ComparatorByStart()); - ConstructOverlappingSets(/*already_sorted=*/false, &intervals, &result); - EXPECT_THAT(result, ElementsAre(ElementsAre(0, 1, 2, 3, 4))); + CompactVectorVector result; + result.Add({0, 1, 2}); // To be sure we clear. + ConstructOverlappingSets(absl::MakeSpan(intervals), &result); + EXPECT_THAT(result.AsVectorOfSpan(), ElementsAre(ElementsAre(0, 1, 2, 3, 4))); } TEST(GetOverlappingIntervalComponentsTest, BasicTest) { @@ -1033,6 +1036,13 @@ TEST(FindPartialIntersections, Random) { for (int k = 0; k < num_runs; k++) { std::vector rectangles = GenerateNonConflictingRectanglesWithPacking({100, 100}, 60, random); + + // We also test FindOneIntersectionIfPresent(). + absl::c_sort(rectangles, [](const Rectangle& a, const Rectangle& b) { + return a.x_min < b.x_min; + }); + EXPECT_EQ(FindOneIntersectionIfPresent(rectangles), std::nullopt); + const int num_to_grow = absl::Uniform(random, 0, 20); for (int i = 0; i < num_to_grow; ++i) { Rectangle& rec = @@ -1049,6 +1059,7 @@ TEST(FindPartialIntersections, Random) { for (const auto& [i, j] : result) { EXPECT_FALSE(rectangles[i].IsDisjoint(rectangles[j])); } + EXPECT_TRUE(GraphsDefineSameConnectedComponents(naive_result, result)) << RenderRectGraph(std::nullopt, rectangles, result); EXPECT_FALSE(HasCycles(result)) @@ -1056,6 +1067,19 @@ TEST(FindPartialIntersections, Random) { if (k == 0) { LOG(INFO) << RenderRectGraph(std::nullopt, rectangles, result); } + + // We also test FindOneIntersectionIfPresent(). + absl::c_sort(rectangles, [](const Rectangle& a, const Rectangle& b) { + return a.x_min < b.x_min; + }); + if (naive_result.empty()) { + EXPECT_EQ(FindOneIntersectionIfPresent(rectangles), std::nullopt); + } else { + auto opt_pair = FindOneIntersectionIfPresent(rectangles); + EXPECT_NE(opt_pair, std::nullopt); + EXPECT_FALSE( + rectangles[opt_pair->first].IsDisjoint(rectangles[opt_pair->second])); + } } } diff --git a/ortools/sat/feasibility_jump.cc b/ortools/sat/feasibility_jump.cc index 1083b3e6e0..ca440ba2b3 100644 --- a/ortools/sat/feasibility_jump.cc +++ b/ortools/sat/feasibility_jump.cc @@ -389,6 +389,11 @@ std::function FeasibilityJumpSolver::GenerateTask(int64_t /*task_id*/) { state_->solution = solution.variable_values; ++state_->num_solutions_imported; } else { + if (!first_time) { + // Register this solution before we reset the search. + const int num_violations = evaluator_->ViolatedConstraints().size(); + shared_hints_->AddSolution(state_->solution, num_violations); + } ResetCurrentSolution(/*use_hint=*/first_time, state_->options.use_objective, state_->options.perturbation_probability); diff --git a/ortools/sat/feasibility_jump.h b/ortools/sat/feasibility_jump.h index 921f98591d..610397f4d9 100644 --- a/ortools/sat/feasibility_jump.h +++ b/ortools/sat/feasibility_jump.h @@ -473,6 +473,7 @@ class FeasibilityJumpSolver : public SubSolver { ModelSharedTimeLimit* shared_time_limit, SharedResponseManager* shared_response, SharedBoundsManager* shared_bounds, + SharedLsSolutionRepository* shared_hints, SharedStatistics* shared_stats, SharedStatTables* stat_tables) : SubSolver(name, type), @@ -481,6 +482,7 @@ class FeasibilityJumpSolver : public SubSolver { states_(std::move(ls_states)), shared_time_limit_(shared_time_limit), shared_response_(shared_response), + shared_hints_(shared_hints), stat_tables_(stat_tables), random_(params_), var_domains_(shared_bounds) {} @@ -589,6 +591,7 @@ class FeasibilityJumpSolver : public SubSolver { std::shared_ptr states_; ModelSharedTimeLimit* shared_time_limit_; SharedResponseManager* shared_response_; + SharedLsSolutionRepository* shared_hints_; SharedStatTables* stat_tables_; ModelRandomGenerator random_; diff --git a/ortools/sat/integer_search.cc b/ortools/sat/integer_search.cc index 5c92481b35..a7566c191a 100644 --- a/ortools/sat/integer_search.cc +++ b/ortools/sat/integer_search.cc @@ -167,15 +167,17 @@ IntegerLiteral SplitUsingBestSolutionValueInRepository( // not executed often, but otherwise it is done for each search decision, // which seems expensive. Improve. std::function FirstUnassignedVarAtItsMinHeuristic( - const std::vector& vars, Model* model) { + absl::Span vars, Model* model) { auto* integer_trail = model->GetOrCreate(); - return [/*copy*/ vars, integer_trail]() { - for (const IntegerVariable var : vars) { - const IntegerLiteral decision = AtMinValue(var, integer_trail); - if (decision.IsValid()) return BooleanOrIntegerLiteral(decision); - } - return BooleanOrIntegerLiteral(); - }; + return + [/*copy*/ vars = std::vector(vars.begin(), vars.end()), + integer_trail]() { + for (const IntegerVariable var : vars) { + const IntegerLiteral decision = AtMinValue(var, integer_trail); + if (decision.IsValid()) return BooleanOrIntegerLiteral(decision); + } + return BooleanOrIntegerLiteral(); + }; } std::function MostFractionalHeuristic(Model* model) { diff --git a/ortools/sat/integer_search.h b/ortools/sat/integer_search.h index 5e68b982d6..a360df320d 100644 --- a/ortools/sat/integer_search.h +++ b/ortools/sat/integer_search.h @@ -172,7 +172,7 @@ IntegerLiteral SplitDomainUsingBestSolutionValue(IntegerVariable var, // // Note that this function will create the associated literal if needed. std::function FirstUnassignedVarAtItsMinHeuristic( - const std::vector& vars, Model* model); + absl::Span vars, Model* model); // Choose the variable with most fractional LP value. std::function MostFractionalHeuristic(Model* model); diff --git a/ortools/sat/linear_relaxation.cc b/ortools/sat/linear_relaxation.cc index 1a9d9270bf..427374e6b9 100644 --- a/ortools/sat/linear_relaxation.cc +++ b/ortools/sat/linear_relaxation.cc @@ -628,10 +628,10 @@ void AddRoutesCutGenerator(const ConstraintProto& ct, Model* m, // // These property ensures that all other intervals ends before the start of // the makespan interval. -std::optional DetectMakespan( - const std::vector& intervals, - const std::vector& demands, - const AffineExpression& capacity, Model* model) { +std::optional DetectMakespan(absl::Span intervals, + absl::Span demands, + const AffineExpression& capacity, + Model* model) { IntegerTrail* integer_trail = model->GetOrCreate(); IntervalsRepository* repository = model->GetOrCreate(); diff --git a/ortools/sat/sat_decision.cc b/ortools/sat/sat_decision.cc index 2691265350..46e9b4ed6f 100644 --- a/ortools/sat/sat_decision.cc +++ b/ortools/sat/sat_decision.cc @@ -37,7 +37,8 @@ namespace sat { SatDecisionPolicy::SatDecisionPolicy(Model* model) : parameters_(*(model->GetOrCreate())), trail_(*model->GetOrCreate()), - random_(model->GetOrCreate()) {} + random_(model->GetOrCreate()), + ls_hints_(model->GetOrCreate()) {} void SatDecisionPolicy::IncreaseNumVariables(int num_variables) { const int old_num_variables = activities_.size(); @@ -133,6 +134,7 @@ void SatDecisionPolicy::RephaseIfNeeded() { FlipCurrentPolarity(); break; case 7: + if (UseLsSolutionAsInitialPolarity()) break; UseLongestAssignmentAsInitialPolarity(); break; } @@ -188,6 +190,25 @@ void SatDecisionPolicy::UseLongestAssignmentAsInitialPolarity() { best_partial_assignment_.clear(); } +bool SatDecisionPolicy::UseLsSolutionAsInitialPolarity() { + if (!parameters_.polarity_exploit_ls_hints()) return false; + + if (ls_hints_->NumSolutions() == 0) return false; + + // This is in term of proto variable. + // TODO(user): use cp_model_mapping. But this is not needed to experiment + // on pure sat problems. + std::vector solution = + ls_hints_->GetRandomBiasedSolution(*random_).variable_values; + if (solution.size() != var_polarity_.size()) return false; + + for (int i = 0; i < solution.size(); ++i) { + var_polarity_[BooleanVariable(i)] = solution[i] == 1; + } + + return false; +} + void SatDecisionPolicy::FlipCurrentPolarity() { const int num_variables = var_polarity_.size(); for (BooleanVariable var; var < num_variables; ++var) { diff --git a/ortools/sat/sat_decision.h b/ortools/sat/sat_decision.h index 46c94fb1c5..dc5491748b 100644 --- a/ortools/sat/sat_decision.h +++ b/ortools/sat/sat_decision.h @@ -24,6 +24,7 @@ #include "ortools/sat/pb_constraint.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/synchronization.h" #include "ortools/sat/util.h" #include "ortools/util/bitset.h" #include "ortools/util/integer_pq.h" @@ -133,6 +134,9 @@ class SatDecisionPolicy { void FlipCurrentPolarity(); void RandomizeCurrentPolarity(); + // This one returns false if there is no such solution to use. + bool UseLsSolutionAsInitialPolarity(); + // Adds the given variable to var_ordering_ or updates its priority if it is // already present. void PqInsertOrUpdate(BooleanVariable var); @@ -142,6 +146,12 @@ class SatDecisionPolicy { const Trail& trail_; ModelRandomGenerator* random_; + // TODO(user): This is in term of proto indices. Ideally we would need + // CpModelMapping to map that to Booleans but this currently lead to cyclic + // dependencies. For now we just assume one to one correspondence for the + // first entries. This will only work on pure Boolean problems. + SharedLsSolutionRepository* ls_hints_; + // Variable ordering (priority will be adjusted dynamically). queue_elements_ // holds the elements used by var_ordering_ (it uses pointers). // diff --git a/ortools/sat/sat_parameters.proto b/ortools/sat/sat_parameters.proto index 78d5a7cf3f..d8185ac4aa 100644 --- a/ortools/sat/sat_parameters.proto +++ b/ortools/sat/sat_parameters.proto @@ -23,7 +23,7 @@ option java_multiple_files = true; // Contains the definitions for all the sat algorithm parameters and their // default values. // -// NEXT TAG: 309 +// NEXT TAG: 310 message SatParameters { // In some context, like in a portfolio of search, it makes sense to name a // given parameters set for logging purpose. @@ -73,6 +73,10 @@ message SatParameters { // 2 * x the second time, etc... optional int32 polarity_rephase_increment = 168 [default = 1000]; + // If true and we have first solution LS workers, tries in some phase to + // follow a LS solutions that violates has litle constraints as possible. + optional bool polarity_exploit_ls_hints = 309 [default = false]; + // The proportion of polarity chosen at random. Note that this take // precedence over the phase saving heuristic. This is different from // initial_polarity:POLARITY_RANDOM because it will select a new random diff --git a/ortools/sat/sat_solver.cc b/ortools/sat/sat_solver.cc index aba426d0e8..9a8355ac51 100644 --- a/ortools/sat/sat_solver.cc +++ b/ortools/sat/sat_solver.cc @@ -1608,7 +1608,7 @@ std::vector SatSolver::GetDecisionsFixing( return unsat_assumptions; } -void SatSolver::BumpReasonActivities(const std::vector& literals) { +void SatSolver::BumpReasonActivities(absl::Span literals) { SCOPED_TIME_STAT(&stats_); for (const Literal literal : literals) { const BooleanVariable var = literal.Variable(); diff --git a/ortools/sat/sat_solver.h b/ortools/sat/sat_solver.h index 372cc122fb..074050519e 100644 --- a/ortools/sat/sat_solver.h +++ b/ortools/sat/sat_solver.h @@ -722,7 +722,7 @@ class SatSolver { // Activity management for clauses. This work the same way at the ones for // variables, but with different parameters. - void BumpReasonActivities(const std::vector& literals); + void BumpReasonActivities(absl::Span literals); void BumpClauseActivity(SatClause* clause); void RescaleClauseActivities(double scaling_factor); void UpdateClauseActivityIncrement(); diff --git a/ortools/sat/shaving_solver.cc b/ortools/sat/shaving_solver.cc index a7ae9ea6bb..30ca5613d6 100644 --- a/ortools/sat/shaving_solver.cc +++ b/ortools/sat/shaving_solver.cc @@ -54,7 +54,7 @@ ObjectiveShavingSolver::ObjectiveShavingSolver( local_proto_(shared->model_proto) {} ObjectiveShavingSolver::~ObjectiveShavingSolver() { - shared_->stat_tables.AddTimingStat(*this); + shared_->stat_tables->AddTimingStat(*this); } bool ObjectiveShavingSolver::TaskIsAvailable() { diff --git a/ortools/sat/subsolver.cc b/ortools/sat/subsolver.cc index d879774dd7..4497bb0bb3 100644 --- a/ortools/sat/subsolver.cc +++ b/ortools/sat/subsolver.cc @@ -69,7 +69,7 @@ int NextSubsolverToSchedule(std::vector>& subsolvers, } void ClearSubsolversThatAreDone( - const std::vector& num_in_flight_per_subsolvers, + absl::Span num_in_flight_per_subsolvers, std::vector>& subsolvers) { for (int i = 0; i < subsolvers.size(); ++i) { if (subsolvers[i] == nullptr) continue; diff --git a/ortools/sat/synchronization.cc b/ortools/sat/synchronization.cc index 5dd3363788..e116822809 100644 --- a/ortools/sat/synchronization.cc +++ b/ortools/sat/synchronization.cc @@ -878,9 +878,9 @@ SharedBoundsManager::SharedBoundsManager(const CpModelProto& model_proto) } void SharedBoundsManager::ReportPotentialNewBounds( - const std::string& worker_name, const std::vector& variables, - const std::vector& new_lower_bounds, - const std::vector& new_upper_bounds) { + const std::string& worker_name, absl::Span variables, + absl::Span new_lower_bounds, + absl::Span new_upper_bounds) { CHECK_EQ(variables.size(), new_lower_bounds.size()); CHECK_EQ(variables.size(), new_upper_bounds.size()); int num_improvements = 0; diff --git a/ortools/sat/synchronization.h b/ortools/sat/synchronization.h index 044802e1fe..ce731572cc 100644 --- a/ortools/sat/synchronization.h +++ b/ortools/sat/synchronization.h @@ -115,7 +115,7 @@ class SharedSolutionRepository { // right away. One must call Synchronize for this to happen. In order to be // deterministic, this will keep all solutions until Synchronize() is called, // so we need to be careful not to generate too many solutions at once. - void Add(const Solution& solution); + void Add(Solution solution); // Updates the current pool of solution with the one recently added. Note that // we use a stable ordering of solutions, so the final pool will be @@ -147,6 +147,7 @@ class SharedSolutionRepository { std::vector new_solutions_ ABSL_GUARDED_BY(mutex_); }; +// Solutions coming from the LP. class SharedLPSolutionRepository : public SharedSolutionRepository { public: explicit SharedLPSolutionRepository(int num_solutions_to_keep) @@ -156,6 +157,28 @@ class SharedLPSolutionRepository : public SharedSolutionRepository { void NewLPSolution(std::vector lp_solution); }; +// Set of best solution from the feasibility jump workers. +// +// We store (solution, num_violated_constraints), so we have a list of solutions +// that violate as little constraints as possible. This can be used to set the +// phase during SAT search. +// +// TODO(user): We could also use it after first solution to orient a SAT search +// towards better solutions. But then it is a bit trickier to rank solutions +// compared to the old ones. +class SharedLsSolutionRepository : public SharedSolutionRepository { + public: + SharedLsSolutionRepository() + : SharedSolutionRepository(10, "fj solution hints") {} + + void AddSolution(std::vector solution, int num_violations) { + SharedSolutionRepository::Solution sol; + sol.rank = num_violations; + sol.variable_values = std::move(solution); + Add(sol); + } +}; + // Set of partly filled solutions. They are meant to be finished by some lns // worker. // @@ -495,9 +518,9 @@ class SharedBoundsManager { // manager. The manager will compare these bounds changes against its // global state, and incorporate the improving ones. void ReportPotentialNewBounds(const std::string& worker_name, - const std::vector& variables, - const std::vector& new_lower_bounds, - const std::vector& new_upper_bounds); + absl::Span variables, + absl::Span new_lower_bounds, + absl::Span new_upper_bounds); // If we solved a small independent component of the full problem, then we can // in most situation fix the solution on this subspace. @@ -850,11 +873,11 @@ SharedSolutionRepository::GetRandomBiasedSolution( } template -void SharedSolutionRepository::Add(const Solution& solution) { +void SharedSolutionRepository::Add(Solution solution) { if (num_solutions_to_keep_ <= 0) return; absl::MutexLock mutex_lock(&mutex_); ++num_added_; - new_solutions_.push_back(solution); + new_solutions_.push_back(std::move(solution)); } template diff --git a/ortools/sat/var_domination.cc b/ortools/sat/var_domination.cc index 1e28d39dbc..cc2cb4b5cd 100644 --- a/ortools/sat/var_domination.cc +++ b/ortools/sat/var_domination.cc @@ -876,6 +876,14 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { processed[PositiveRef(enf)] = true; processed[positive_ref] = true; context->UpdateRuleStats("dual: affine relation"); + // The new affine relation added below can break the hint if hint(enf) + // is 0. In this case the only constraint blocking `ref` from + // decreasing [`ct` = enf => (var = implied)] does not apply. We can + // thus set the hint of `positive_ref` to `bound` to preserve the hint + // feasibility. + if (context->LiteralSolutionHintIs(enf, false)) { + context->UpdateRefSolutionHint(positive_ref, bound); + } if (RefIsPositive(enf)) { // positive_ref = enf * implied + (1 - enf) * bound. if (!context->StoreAffineRelation( @@ -883,7 +891,8 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { return false; } } else { - // positive_ref = (1 - enf) * implied + enf * bound. + // enf_var = PositiveRef(enf). + // positive_ref = (1 - enf_var) * implied + enf_var * bound. if (!context->StoreAffineRelation(positive_ref, PositiveRef(enf), bound - implied.FixedValue(), implied.FixedValue())) { @@ -958,10 +967,25 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { // remove the constraint. if (rhs.IsFixed()) { if (encoding_lit == NegatedRef(ref)) continue; + // Extending `ct` = "not(ref) => encoding_lit" to an equality can + // break the hint only if hint(ref) = hint(encoding_lit) = 1. But + // in this case `ct` is actually not blocking ref from decreasing. + // We can thus set its hint to 0 to preserve the hint feasibility. + if (context->LiteralSolutionHintIs(encoding_lit, true)) { + context->UpdateLiteralSolutionHint(ref, false); + } context->StoreBooleanEqualityRelation(encoding_lit, NegatedRef(ref)); } else { if (encoding_lit == ref) continue; + // Extending `ct` = "not(ref) => not(encoding_lit)" to an equality + // can break the hint only if hint(encoding_lit) = 0 and hint(ref) + // = 1. But in this case `ct` is actually not blocking ref from + // decreasing. We can thus set its hint to 0 to preserve the hint + // feasibility. + if (context->LiteralSolutionHintIs(encoding_lit, false)) { + context->UpdateLiteralSolutionHint(ref, false); + } context->StoreBooleanEqualityRelation(encoding_lit, ref); } context->working_model->mutable_constraints(ct_index)->Clear(); @@ -1045,6 +1069,18 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { ++num_bool_in_near_duplicate_ct; processed[PositiveRef(ref)] = true; processed[PositiveRef(other_ref)] = true; + // If the two hints are different, and since both refs have an + // equivalent blocking constraint, then the constraint is actually + // not blocking the ref at 1 from decreasing. Hence we can set its + // hint to false to preserve the hint feasibility despite the new + // Boolean equality constraint. + if (context->VarHasSolutionHint(PositiveRef(ref)) && + context->VarHasSolutionHint(PositiveRef(other_ref)) && + context->LiteralSolutionHint(ref) != + context->LiteralSolutionHint(other_ref)) { + context->UpdateLiteralSolutionHint(ref, false); + context->UpdateLiteralSolutionHint(other_ref, false); + } context->StoreBooleanEqualityRelation(ref, other_ref); // We can delete one of the constraint since they are duplicate @@ -1427,28 +1463,26 @@ void MaybeUpdateLiteralHintFromDominance(PresolveContext& context, int lit, // Decrements the solution hint of `ref` by the minimum amount necessary to be // in `domain`, and increments the solution hint of one or more -// `dominating_variables` by the same total amount. Does nothing if a hint is -// missing or if it is not possible to increment the hint of the dominating -// variables by the amount subtracted from the hint of the dominated variable. +// `dominating_variables` by the same total amount (or less if it is not +// possible to exactly match this amount). // -// The lower bound of `domain` must be the lower bound of `ref`'s current domain -// in `context`. +// `domain` must be an interval with the same lower bound as `ref`'s current +// domain D in `context`, and whose upper bound must be in D. void MaybeUpdateRefHintFromDominance( PresolveContext& context, int ref, const Domain& domain, const absl::Span dominating_variables) { const std::optional ref_hint = context.GetRefSolutionHint(ref); if (!ref_hint.has_value()) return; - // The quantity to subtract from the solution hint of `ref`. + // The quantity to subtract from the solution hint of `ref`. If the closest + // value of *ref_hint in `domain` is not *ref_hint then it is either the lower + // or upper bound of `domain`, which by hypothesis are in `ref`'s current + // domain D. Hence, in any case, this closest value is in D. const int64_t ref_hint_delta = *ref_hint - domain.ClosestValue(*ref_hint); // If it is 0 there is nothing to do. It might be negative if the solution // hint is not initially feasible (in which case we can't fix it). if (ref_hint_delta <= 0) return; - // First step: check that the hint of the dominating variable(s) can be - // incremented by ref_hint_delta (possibly spread over multiple variables), - // and store the new hint values in `new_ref_hint_value_pairs`. - std::vector> new_ref_hint_value_pairs; - new_ref_hint_value_pairs.push_back({ref, *ref_hint - ref_hint_delta}); + context.UpdateRefSolutionHint(ref, *ref_hint - ref_hint_delta); int64_t remaining_delta = ref_hint_delta; for (const IntegerVariable ivar : dominating_variables) { const int dominating_ref = VarDomination::IntegerVariableToRef(ivar); @@ -1461,17 +1495,10 @@ void MaybeUpdateRefHintFromDominance( *dominating_ref_hint; // This might happen if the solution hint is not initially feasible. if (delta < 0) continue; - new_ref_hint_value_pairs.push_back( - {dominating_ref, *dominating_ref_hint + delta}); + context.UpdateRefSolutionHint(dominating_ref, *dominating_ref_hint + delta); remaining_delta -= delta; if (remaining_delta == 0) break; } - if (remaining_delta != 0) return; - - // Second step: actually update the hints. - for (const auto& [ref, hint] : new_ref_hint_value_pairs) { - context.UpdateRefSolutionHint(ref, hint); - } } bool ProcessAtMostOne( @@ -1894,17 +1921,10 @@ bool ExploitDominanceRelations(const VarDomination& var_domination, context->UpdateRuleStats( "domination: dual strenghtening using dominance"); const Domain reduced_domain = Domain(context->MinOf(ref), lb); - if (dominating_vars.empty()) { - if (!context->IntersectDomainWithAndUpdateHint(ref, - reduced_domain)) { - return false; - } - } else { - MaybeUpdateRefHintFromDominance(*context, ref, reduced_domain, - dominating_vars); - if (!context->IntersectDomainWith(ref, reduced_domain)) { - return false; - } + MaybeUpdateRefHintFromDominance(*context, ref, reduced_domain, + dominating_vars); + if (!context->IntersectDomainWith(ref, reduced_domain)) { + return false; } // The rest of the loop only care about Booleans. diff --git a/ortools/sat/work_assignment.cc b/ortools/sat/work_assignment.cc index e7d5d79d89..dea3480f71 100644 --- a/ortools/sat/work_assignment.cc +++ b/ortools/sat/work_assignment.cc @@ -554,8 +554,9 @@ void SharedTreeManager::AssignLeaf(ProtoTrail& path, Node* leaf) { if (leaf->implied) { path.SetLevelImplied(path.MaxLevel()); } - if (params_.shared_tree_worker_enable_trail_sharing()) { - for (const auto& [var, lb] : GetTrailInfo(leaf)->implications) { + if (params_.shared_tree_worker_enable_trail_sharing() && + leaf->trail_info != nullptr) { + for (const auto& [var, lb] : leaf->trail_info->implications) { path.AddImplication(path.MaxLevel(), ProtoLiteral(var, lb)); } } @@ -766,7 +767,9 @@ bool SharedTreeWorker::ShouldReplaceSubtree() { // If we have no assignment, try to get one. if (assigned_tree_.MaxLevel() == 0) return true; if (restart_policy_->NumRestarts() < - parameters_->shared_tree_worker_min_restarts_per_subtree()) { + parameters_->shared_tree_worker_min_restarts_per_subtree() || + time_limit_->GetElapsedDeterministicTime() < + earliest_replacement_dtime_) { return false; } return assigned_tree_lbds_.WindowAverage() < @@ -783,7 +786,9 @@ bool SharedTreeWorker::SyncWithSharedTree() { << " target: " << assigned_tree_lbds_.WindowAverage() << " lbd: " << restart_policy_->LbdAverageSinceReset(); if (parameters_->shared_tree_worker_enable_phase_sharing() && - assigned_tree_.MaxLevel() > 0 && + // Only save the phase if we've done a non-trivial amount of work on + // this subtree. + FinishedMinRestarts() && !decision_policy_->GetBestPartialAssignment().empty()) { assigned_tree_.ClearTargetPhase(); for (Literal lit : decision_policy_->GetBestPartialAssignment()) { @@ -797,6 +802,7 @@ bool SharedTreeWorker::SyncWithSharedTree() { manager_->ReplaceTree(assigned_tree_); assigned_tree_lbds_.Add(restart_policy_->LbdAverageSinceReset()); restart_policy_->Reset(); + earliest_replacement_dtime_ = 0; if (parameters_->shared_tree_worker_enable_phase_sharing()) { VLOG(2) << "Importing phase of length: " << assigned_tree_.TargetPhase().size(); @@ -806,6 +812,14 @@ bool SharedTreeWorker::SyncWithSharedTree() { } } } + // If we commit to this subtree, keep it for at least 1s of dtime. + // This allows us to replace obviously bad subtrees quickly, and not replace + // too frequently overall. + if (FinishedMinRestarts() && earliest_replacement_dtime_ >= + time_limit_->GetElapsedDeterministicTime()) { + earliest_replacement_dtime_ = + time_limit_->GetElapsedDeterministicTime() + 1; + } VLOG(2) << "Assigned level: " << assigned_tree_.MaxLevel() << " " << parameters_->name(); assigned_tree_literals_.clear(); @@ -825,7 +839,7 @@ bool SharedTreeWorker::SyncWithSharedTree() { SatSolver::Status SharedTreeWorker::Search( const std::function& feasible_solution_observer) { // Inside GetAssociatedLiteral if a literal becomes fixed at level 0 during - // Search,the code checks it is at level 0 when decoding the literal, but + // Search, the code CHECKs it is at level 0 when decoding the literal, but // the fixed literals are cached, so we can create them now to avoid a // crash. sat_solver_->Backtrack(0); diff --git a/ortools/sat/work_assignment.h b/ortools/sat/work_assignment.h index 162b2ec648..dab12d7165 100644 --- a/ortools/sat/work_assignment.h +++ b/ortools/sat/work_assignment.h @@ -318,6 +318,11 @@ class SharedTreeWorker { bool NextDecision(LiteralIndex* decision_index); void MaybeProposeSplit(); bool ShouldReplaceSubtree(); + bool FinishedMinRestarts() const { + return assigned_tree_.MaxLevel() > 0 && + restart_policy_->NumRestarts() >= + parameters_->shared_tree_worker_min_restarts_per_subtree(); + } // Add any implications to the clause database for the current level. // Return true if any new information was added. @@ -362,6 +367,7 @@ class SharedTreeWorker { // If a tree has worse LBD than the average over the last few trees we replace // the tree. RunningAverage assigned_tree_lbds_; + double earliest_replacement_dtime_ = 0; // Stores the trail index of the last implication added to assigned_tree_. int reversible_trail_index_ = 0;