diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index bc40dc06ef..491976236a 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -250,6 +250,7 @@ cc_library( ":simplification", ":subsolver", ":synchronization", + ":work_assignment", "//ortools/base", "//ortools/base:file", "//ortools/base:stl_util", @@ -1123,6 +1124,7 @@ cc_library( ":model", ":sat_parameters_cc_proto", ":synchronization", + ":util", "//ortools/glop:revised_simplex", "//ortools/util:logging", "@com_google_absl//absl/container:btree", @@ -1618,6 +1620,32 @@ cc_library( ], ) +cc_library( + name = "work_assignment", + srcs = ["work_assignment.cc"], + hdrs = ["work_assignment.h"], + deps = [ + ":cp_model_mapping", + ":cp_model_utils", + ":integer", + ":integer_search", + ":model", + ":sat_base", + ":sat_parameters_cc_proto", + ":sat_solver", + ":synchronization", + ":util", + "//ortools/util:time_limit", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/random", + "@com_google_absl//absl/random:distributions", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + ], +) + cc_binary( name = "sat_runner", srcs = [ diff --git a/ortools/sat/cp_model_search.cc b/ortools/sat/cp_model_search.cc index ee733b42c0..0c6a53b847 100644 --- a/ortools/sat/cp_model_search.cc +++ b/ortools/sat/cp_model_search.cc @@ -184,7 +184,8 @@ const std::function ConstructSearchStrategyInternal( // may be the case if we do a fixed_search. // To store equivalent variables in randomized search. - std::vector active_refs; + TopN top_n_vars( + parameters.search_randomization_tolerance()); int t_index = 0; // Index in strategy.transformations(). for (int i = 0; i < strategy.variables().size(); ++i) { @@ -243,29 +244,18 @@ const std::function ConstructSearchStrategyInternal( !parameters.randomize_search()) { break; } else if (parameters.randomize_search()) { - if (value <= - candidate_value + parameters.search_randomization_tolerance()) { - active_refs.push_back({ref, value}); - } + // We keep the top N of 'minimal' values, thus the negation. + top_n_vars.Add(ref, -value); } } if (candidate_value == std::numeric_limits::max()) continue; if (parameters.randomize_search()) { - CHECK(!active_refs.empty()); - const IntegerValue threshold( - candidate_value + parameters.search_randomization_tolerance()); - auto is_above_tolerance = [threshold](const VarValue& entry) { - return entry.value > threshold; - }; - // Remove all values above tolerance. - active_refs.erase(std::remove_if(active_refs.begin(), active_refs.end(), - is_above_tolerance), - active_refs.end()); - const int winner = absl::Uniform(*random, 0, active_refs.size()); - candidate = active_refs[winner].ref; + const auto& elements = top_n_vars.UnorderedElements(); + candidate = elements[absl::Uniform(*random, 0, elements.size())]; } + // TODO(user): Randomize value selection. DecisionStrategyProto::DomainReductionStrategy selection = strategy.domain_reduction_strategy(); if (!RefIsPositive(candidate)) { @@ -615,6 +605,8 @@ std::vector GetDiverseSetOfParameters( // like if there is no lp, or everything is already linearized at level 1. std::vector names; + const int num_workers_to_generate = + base_params.num_workers() - base_params.num_shared_tree_workers(); // We use the default if empty. if (base_params.subsolvers().empty()) { names.push_back("default_lp"); @@ -634,14 +626,14 @@ std::vector GetDiverseSetOfParameters( // Do not add objective_lb_search if core is active and num_workers <= 16. if (cp_model.has_objective() && (cp_model.objective().vars().size() == 1 || // core is not active - base_params.num_workers() > 16)) { + num_workers_to_generate > 16)) { names.push_back("objective_lb_search"); } names.push_back("probing"); - if (base_params.num_workers() >= 20) { + if (num_workers_to_generate >= 20) { names.push_back("probing_max_lp"); } - if (base_params.num_workers() >= 24) { + if (num_workers_to_generate >= 24) { names.push_back("objective_lb_search_max_lp"); } #if !defined(__PORTABLE_PLATFORM__) && defined(USE_SCIP) @@ -747,8 +739,9 @@ std::vector GetDiverseSetOfParameters( if (cp_model.has_objective() && !cp_model.objective().vars().empty()) { // If there is an objective, the extra workers will use LNS. // Make sure we have at least min_num_lns_workers() of them. - const int target = std::max( - 1, base_params.num_workers() - base_params.min_num_lns_workers()); + const int target = + std::max(base_params.num_shared_tree_workers() > 0 ? 0 : 1, + num_workers_to_generate - base_params.min_num_lns_workers()); if (!base_params.interleave_search() && result.size() > target) { result.resize(target); } @@ -758,7 +751,7 @@ std::vector GetDiverseSetOfParameters( const bool need_extra_workers = !base_params.interleave_search() && (base_params.use_rins_lns() || base_params.use_feasibility_pump()); - int target = base_params.num_workers(); + int target = num_workers_to_generate; if (need_extra_workers && target > 4) { if (target <= 8) { target -= 1; @@ -784,7 +777,16 @@ std::vector GetFirstSolutionParams( int num_random_qr = 0; while (result.size() < num_params_to_generate) { SatParameters new_params = base_params; - const int base_seed = base_params.random_seed(); + + // Set up randomization. + new_params.set_randomize_search(true); + new_params.set_random_seed( + ValidSumSeed(base_params.random_seed(), result.size())); + new_params.set_search_randomization_tolerance(10); + const double sat_randomization_ratio = 0.02; + new_params.set_random_branches_ratio(sat_randomization_ratio); + new_params.set_random_polarity_ratio(sat_randomization_ratio); + if (num_random <= num_random_qr) { // Random search. // Alternate between automatic search and fixed search (if defined). // @@ -795,17 +797,11 @@ std::vector GetFirstSolutionParams( } else { new_params.set_search_branching(SatParameters::FIXED_SEARCH); } - new_params.set_randomize_search(true); - new_params.set_search_randomization_tolerance(num_random + 1); - new_params.set_random_seed(ValidSumSeed(base_seed, 2 * num_random + 1)); new_params.set_name(absl::StrCat("random_", num_random)); num_random++; } else { // Random quick restart. new_params.set_search_branching( SatParameters::PORTFOLIO_WITH_QUICK_RESTART_SEARCH); - new_params.set_randomize_search(true); - new_params.set_search_randomization_tolerance(num_random_qr + 1); - new_params.set_random_seed(ValidSumSeed(base_seed, 2 * num_random_qr)); new_params.set_name(absl::StrCat("random_quick_restart_", num_random_qr)); num_random_qr++; } @@ -814,5 +810,31 @@ std::vector GetFirstSolutionParams( return result; } +std::vector GetWorkSharingParams( + const SatParameters& base_params, const CpModelProto& cp_model, + int num_params_to_generate) { + std::vector result; + // TODO(user): We could support assumptions, it's just not implemented. + if (!cp_model.assumptions().empty()) return result; + if (num_params_to_generate <= 0) return result; + int num_workers = 0; + while (result.size() < num_params_to_generate) { + // TODO(user): Make the base parameters configurable. + SatParameters new_params = base_params; + std::string name = "shared_"; + const int base_seed = base_params.random_seed(); + new_params.set_random_seed(ValidSumSeed(base_seed, 2 * num_workers + 1)); + new_params.set_search_branching(SatParameters::AUTOMATIC_SEARCH); + new_params.set_use_shared_tree_search(true); + new_params.set_linearization_level(0); + std::string lp_tags[] = {"no", "default", "max"}; + absl::StrAppend(&name, lp_tags[new_params.linearization_level()], "_lp_", + num_workers); + new_params.set_name(name); + num_workers++; + result.push_back(new_params); + } + return result; +} } // namespace sat } // namespace operations_research diff --git a/ortools/sat/cp_model_search.h b/ortools/sat/cp_model_search.h index a8ff5c3725..4fc38a47fe 100644 --- a/ortools/sat/cp_model_search.h +++ b/ortools/sat/cp_model_search.h @@ -106,6 +106,12 @@ std::vector GetFirstSolutionParams( const SatParameters& base_params, const CpModelProto& cp_model, int num_params_to_generate); +// Returns a vector of num_params_to_generate set of parameters to specify +// solvers that cooperatively explore a search tree. +std::vector GetWorkSharingParams( + const SatParameters& base_params, const CpModelProto& cp_model, + int num_params_to_generate); + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index ccaa7e6adb..7dc9f0b7a0 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -91,6 +91,7 @@ #include "ortools/sat/subsolver.h" #include "ortools/sat/synchronization.h" #include "ortools/sat/util.h" +#include "ortools/sat/work_assignment.h" #include "ortools/util/logging.h" #include "ortools/util/random_engine.h" #if !defined(__PORTABLE_PLATFORM__) @@ -1800,8 +1801,13 @@ void SolveLoadedCpModel(const CpModelProto& model_proto, Model* model) { } } else if (!model_proto.has_objective()) { while (true) { - status = ResetAndSolveIntegerProblem( - mapping.Literals(model_proto.assumptions()), model); + if (parameters.use_shared_tree_search()) { + auto* subtree_worker = model->GetOrCreate(); + status = subtree_worker->Search(solution_observer); + } else { + status = ResetAndSolveIntegerProblem( + mapping.Literals(model_proto.assumptions()), model); + } if (status != SatSolver::Status::FEASIBLE) break; solution_observer(); if (!parameters.enumerate_all_solutions()) break; @@ -1847,6 +1853,9 @@ void SolveLoadedCpModel(const CpModelProto& model_proto, Model* model) { } else { status = model->Mutable()->Optimize(); } + } else if (parameters.use_shared_tree_search()) { + auto* subtree_worker = model->GetOrCreate(); + status = subtree_worker->Search(solution_observer); } else { // TODO(user): This parameter breaks the splitting in chunk of a Solve(). // It should probably be moved into another SubSolver altogether. @@ -2336,18 +2345,37 @@ CpSolverResponse SolvePureSatModel(const CpModelProto& model_proto, #if !defined(__PORTABLE_PLATFORM__) -// Small wrapper to simplify the constructions of the two SubSolver below. +// Small wrapper containing all the shared classes between our subsolver +// threads. Note that all these classes can also be retrieved with something +// like global_model->GetOrCreate() but it is not thread-safe to do so. +// +// All the classes here should be thread-safe, or at least safe in the way they +// are accessed. For instance the model_proto will be kept constant for the +// whole duration of the solve. struct SharedClasses { - CpModelProto const* model_proto; - WallTimer* wall_timer; - ModelSharedTimeLimit* time_limit; - SharedBoundsManager* bounds; - SharedResponseManager* response; - SharedRelaxationSolutionRepository* relaxation_solutions; - SharedLPSolutionRepository* lp_solutions; - SharedIncompleteSolutionManager* incomplete_solutions; - SharedClausesManager* clauses; - Model* global_model; + SharedClasses(const CpModelProto* proto, Model* global_model) + : model_proto(proto), + wall_timer(global_model->GetOrCreate()), + time_limit(global_model->GetOrCreate()), + logger(global_model->GetOrCreate()), + stats(global_model->GetOrCreate()), + response(global_model->GetOrCreate()), + shared_tree_manager(global_model->GetOrCreate()) {} + + // These are never nullptr. + const CpModelProto* const model_proto; + WallTimer* const wall_timer; + ModelSharedTimeLimit* const time_limit; + SolverLogger* const logger; + SharedStatistics* const stats; + SharedResponseManager* const response; + SharedTreeManager* const shared_tree_manager; + + // These can be nullptr depending on the options. + std::unique_ptr bounds; + std::unique_ptr lp_solutions; + std::unique_ptr incomplete_solutions; + std::unique_ptr clauses; bool SearchIsDone() { if (response->ProblemIsSolved()) return true; @@ -2381,32 +2409,31 @@ class FullProblemSolver : public SubSolver { local_model_->Register(shared->response); } - if (shared->relaxation_solutions != nullptr) { - local_model_->Register( - shared->relaxation_solutions); - } - if (shared->lp_solutions != nullptr) { - local_model_->Register(shared->lp_solutions); + local_model_->Register( + shared->lp_solutions.get()); } if (shared->incomplete_solutions != nullptr) { local_model_->Register( - shared->incomplete_solutions); + shared->incomplete_solutions.get()); } if (shared->bounds != nullptr) { - local_model_->Register(shared->bounds); + local_model_->Register(shared->bounds.get()); } if (shared->clauses != nullptr) { - local_model_->Register(shared->clauses); + local_model_->Register(shared->clauses.get()); + } + + if (local_parameters.use_shared_tree_search()) { + local_model_->Register(shared->shared_tree_manager); } // TODO(user): For now we do not count LNS statistics. We could easily // by registering the SharedStatistics class with LNS local model. - local_model_->Register( - shared->global_model->GetOrCreate()); + local_model_->Register(shared_->stats); } ~FullProblemSolver() override { @@ -2442,9 +2469,9 @@ class FullProblemSolver : public SubSolver { // at the same time. if (shared_->bounds != nullptr) { RegisterVariableBoundsLevelZeroExport( - *shared_->model_proto, shared_->bounds, local_model_.get()); + *shared_->model_proto, shared_->bounds.get(), local_model_.get()); RegisterVariableBoundsLevelZeroImport( - *shared_->model_proto, shared_->bounds, local_model_.get()); + *shared_->model_proto, shared_->bounds.get(), local_model_.get()); } // Note that this is done after the loading, so we will never export @@ -2454,9 +2481,9 @@ class FullProblemSolver : public SubSolver { const int id = shared_->clauses->RegisterNewId(); shared_->clauses->SetWorkerNameForId(id, local_model_->Name()); - RegisterClausesLevelZeroImport(id, shared_->clauses, + RegisterClausesLevelZeroImport(id, shared_->clauses.get(), local_model_.get()); - RegisterClausesExport(id, shared_->clauses, local_model_.get()); + RegisterClausesExport(id, shared_->clauses.get(), local_model_.get()); } if (local_model_->GetOrCreate()->repair_hint()) { @@ -2605,24 +2632,20 @@ class FeasibilityPumpSolver : public SubSolver { local_model_->Register(shared->response); } - if (shared->relaxation_solutions != nullptr) { - local_model_->Register( - shared->relaxation_solutions); - } - if (shared->lp_solutions != nullptr) { - local_model_->Register(shared->lp_solutions); + local_model_->Register( + shared->lp_solutions.get()); } if (shared->incomplete_solutions != nullptr) { local_model_->Register( - shared->incomplete_solutions); + shared->incomplete_solutions.get()); } // Level zero variable bounds sharing. if (shared_->bounds != nullptr) { RegisterVariableBoundsLevelZeroImport( - *shared_->model_proto, shared_->bounds, local_model_.get()); + *shared_->model_proto, shared_->bounds.get(), local_model_.get()); } } @@ -2995,7 +3018,6 @@ class LnsSolver : public SubSolver { generator_->AddSolveData(data); if (VLOG_IS_ON(1) && display_lns_info) { - auto* logger = shared_->global_model->GetOrCreate(); std::string s = absl::StrCat(" LNS ", name(), ":"); if (new_solution) { const double base_obj = ScaleObjectiveValue( @@ -3016,8 +3038,8 @@ class LnsSolver : public SubSolver { neighborhood.variables_that_can_be_fixed_to_local_optimum.size(), "]"); } - SOLVER_LOG(logger, s, " [d:", data.difficulty, ", id:", task_id, - ", dtime:", data.deterministic_time, "/", + SOLVER_LOG(shared_->logger, s, " [d:", data.difficulty, + ", id:", task_id, ", dtime:", data.deterministic_time, "/", data.deterministic_limit, ", status:", ProtoEnumToString(data.status), ", #calls:", generator_->num_calls(), @@ -3049,58 +3071,39 @@ void SolveCpModelParallel(const CpModelProto& model_proto, << "Enumerating all solutions in parallel is not supported."; if (global_model->GetOrCreate()->LimitReached()) return; - std::unique_ptr shared_bounds_manager; + SharedClasses shared(&model_proto, global_model); + if (params.share_level_zero_bounds()) { - shared_bounds_manager = std::make_unique(model_proto); - shared_bounds_manager->LoadDebugSolution( + shared.bounds = std::make_unique(model_proto); + shared.bounds->LoadDebugSolution( global_model->GetOrCreate()->DebugSolution()); } - std::unique_ptr - shared_relaxation_solutions; - - auto shared_lp_solutions = std::make_unique( + shared.lp_solutions = std::make_unique( /*num_solutions_to_keep=*/10); - global_model->Register(shared_lp_solutions.get()); + global_model->Register(shared.lp_solutions.get()); // We currently only use the feasiblity pump if it is enabled and some other // parameters are not on. - std::unique_ptr shared_incomplete_solutions; const bool use_feasibility_pump = params.use_feasibility_pump() && params.linearization_level() > 0 && !params.use_lns_only() && !params.interleave_search(); if (use_feasibility_pump) { - shared_incomplete_solutions = + shared.incomplete_solutions = std::make_unique(); global_model->Register( - shared_incomplete_solutions.get()); + shared.incomplete_solutions.get()); } // Set up synchronization mode in parallel. const bool always_synchronize = !params.interleave_search() || params.num_workers() <= 1; + shared.response->SetSynchronizationMode(always_synchronize); - std::unique_ptr shared_clauses; if (params.share_binary_clauses()) { - shared_clauses = std::make_unique(always_synchronize); + shared.clauses = std::make_unique(always_synchronize); } - SharedResponseManager* shared_response_manager = - global_model->GetOrCreate(); - shared_response_manager->SetSynchronizationMode(always_synchronize); - - SharedClasses shared; - shared.model_proto = &model_proto; - shared.wall_timer = global_model->GetOrCreate(); - shared.time_limit = global_model->GetOrCreate(); - shared.bounds = shared_bounds_manager.get(); - shared.response = shared_response_manager; - shared.relaxation_solutions = shared_relaxation_solutions.get(); - shared.lp_solutions = shared_lp_solutions.get(); - shared.incomplete_solutions = shared_incomplete_solutions.get(); - shared.clauses = shared_clauses.get(); - shared.global_model = global_model; - // The list of all the SubSolver that will be used in this parallel search. std::vector> subsolvers; std::vector> incomplete_subsolvers; @@ -3113,9 +3116,6 @@ void SolveCpModelParallel(const CpModelProto& model_proto, if (shared.bounds != nullptr) { shared.bounds->Synchronize(); } - if (shared.relaxation_solutions != nullptr) { - shared.relaxation_solutions->Synchronize(); - } if (shared.lp_solutions != nullptr) { shared.lp_solutions->Synchronize(); } @@ -3142,6 +3142,13 @@ void SolveCpModelParallel(const CpModelProto& model_proto, "first_solution", local_params, /*split_in_chunks=*/false, &shared)); } else { + for (const SatParameters& local_params : GetWorkSharingParams( + params, model_proto, params.num_shared_tree_workers())) { + subsolvers.push_back(std::make_unique( + local_params.name(), local_params, + /*split_in_chunks=*/params.interleave_search(), &shared)); + num_full_problem_solvers++; + } for (const SatParameters& local_params : GetDiverseSetOfParameters(params, model_proto)) { // TODO(user): This is currently not supported here. @@ -3163,7 +3170,7 @@ void SolveCpModelParallel(const CpModelProto& model_proto, // Add the NeighborhoodGeneratorHelper as a special subsolver so that its // Synchronize() is called before any LNS neighborhood solvers. auto unique_helper = std::make_unique( - &model_proto, ¶ms, shared.response, shared.bounds); + &model_proto, ¶ms, shared.response, shared.bounds.get()); NeighborhoodGeneratorHelper* helper = unique_helper.get(); subsolvers.push_back(std::move(unique_helper)); @@ -3181,16 +3188,16 @@ void SolveCpModelParallel(const CpModelProto& model_proto, // RINS. incomplete_subsolvers.push_back(std::make_unique( std::make_unique( - helper, shared.response, shared.relaxation_solutions, - shared.lp_solutions, /*incomplete_solutions=*/nullptr, + helper, shared.response, nullptr, shared.lp_solutions.get(), + /*incomplete_solutions=*/nullptr, absl::StrCat("rins_lns_", local_params.name())), local_params, helper, &shared)); // RENS. incomplete_subsolvers.push_back(std::make_unique( std::make_unique( - helper, /*response_manager=*/nullptr, shared.relaxation_solutions, - shared.lp_solutions, shared.incomplete_solutions, + helper, /*response_manager=*/nullptr, nullptr, + shared.lp_solutions.get(), shared.incomplete_solutions.get(), absl::StrCat("rens_lns_", local_params.name())), local_params, helper, &shared)); } diff --git a/ortools/sat/integer.cc b/ortools/sat/integer.cc index 318ca6d9b6..add259b512 100644 --- a/ortools/sat/integer.cc +++ b/ortools/sat/integer.cc @@ -872,6 +872,8 @@ bool IntegerTrail::UpdateInitialDomain(IntegerVariable var, Domain domain) { if (old_domain == domain) return true; if (domain.IsEmpty()) return false; + const bool lb_changed = domain.Min() > old_domain.Min(); + const bool ub_changed = domain.Max() < old_domain.Max(); (*domains_)[index] = domain; // Update directly the level zero bounds. @@ -886,6 +888,12 @@ bool IntegerTrail::UpdateInitialDomain(IntegerVariable var, Domain domain) { vars_[NegationOf(var)].current_bound = -domain.Max(); integer_trail_[NegationOf(var).value()].bound = -domain.Max(); + // Do not forget to update the watchers. + for (SparseBitset* bitset : watchers_) { + if (lb_changed) bitset->Set(var); + if (ub_changed) bitset->Set(NegationOf(var)); + } + // Update the encoding. return encoder_->UpdateEncodingOnInitialDomainChange(var, domain); } diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index ad6a92be64..f9b02eb0ac 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -579,8 +579,8 @@ class IntegerEncoder { // Gets the literal always set to true, make it if it does not exist. Literal GetTrueLiteral() { - DCHECK_EQ(0, sat_solver_->CurrentDecisionLevel()); if (literal_index_true_ == kNoLiteralIndex) { + DCHECK_EQ(0, sat_solver_->CurrentDecisionLevel()); const Literal literal_true = Literal(sat_solver_->NewBooleanVariable(), true); literal_index_true_ = literal_true.Index(); diff --git a/ortools/sat/linear_constraint_manager.h b/ortools/sat/linear_constraint_manager.h index 9c7bc5ef41..fc1d1c32ff 100644 --- a/ortools/sat/linear_constraint_manager.h +++ b/ortools/sat/linear_constraint_manager.h @@ -31,6 +31,7 @@ #include "ortools/sat/linear_constraint.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/util.h" #include "ortools/util/logging.h" #include "ortools/util/strong_integers.h" #include "ortools/util/time_limit.h" @@ -219,7 +220,7 @@ class LinearConstraintManager { // Sparse representation of the objective coeffs indexed by positive variables // indices. Important: We cannot use a dense representation here in the corner - // case where we have many indepedent LPs. Alternatively, we could share a + // case where we have many independent LPs. Alternatively, we could share a // dense vector between all LinearConstraintManager. double sum_of_squared_objective_coeffs_ = 0.0; absl::flat_hash_map objective_map_; @@ -240,59 +241,6 @@ class LinearConstraintManager { int32_t num_deletable_constraints_ = 0; }; -// Keep the top n elements from a stream of elements. -// -// TODO(user): We could use gtl::TopN when/if it gets open sourced. Note that -// we might be slighlty faster here since we use an indirection and don't move -// the Element class around as much. -template -class TopN { - public: - explicit TopN(int n) : n_(n) {} - - void Clear() { - heap_.clear(); - elements_.clear(); - } - - void Add(Element e, double score) { - if (heap_.size() < n_) { - const int index = elements_.size(); - heap_.push_back({index, score}); - elements_.push_back(std::move(e)); - if (heap_.size() == n_) { - // TODO(user): We could delay that on the n + 1 push. - std::make_heap(heap_.begin(), heap_.end()); - } - } else { - if (score <= heap_.front().score) return; - const int index_to_replace = heap_.front().index; - elements_[index_to_replace] = std::move(e); - - // If needed, we could be faster here with an update operation. - std::pop_heap(heap_.begin(), heap_.end()); - heap_.back() = {index_to_replace, score}; - std::push_heap(heap_.begin(), heap_.end()); - } - } - - const std::vector& UnorderedElements() const { return elements_; } - - private: - const int n_; - - // We keep a heap of the n lowest score. - struct HeapElement { - int index; // in elements_; - double score; - const double operator<(const HeapElement& other) const { - return score > other.score; - } - }; - std::vector heap_; - std::vector elements_; -}; - // Before adding cuts to the global pool, it is a classical thing to only keep // the top n of a given type during one generation round. This is there to help // doing that. @@ -318,7 +266,7 @@ class TopNCuts { std::string name; LinearConstraint cut; }; - TopN cuts_; + TopN cuts_; }; } // namespace sat diff --git a/ortools/sat/parameters_validation.cc b/ortools/sat/parameters_validation.cc index fcb9f0589d..7ae762be5a 100644 --- a/ortools/sat/parameters_validation.cc +++ b/ortools/sat/parameters_validation.cc @@ -147,6 +147,7 @@ std::string ValidateParameters(const SatParameters& params) { "probing_max_lp", "probing_no_lp", "probing", + "pseudo_costs", "quick_restart_max_lp", "quick_restart_no_lp", "quick_restart", diff --git a/ortools/sat/sat_parameters.proto b/ortools/sat/sat_parameters.proto index c0e3c63d5f..f91c840107 100644 --- a/ortools/sat/sat_parameters.proto +++ b/ortools/sat/sat_parameters.proto @@ -23,7 +23,7 @@ option csharp_namespace = "Google.OrTools.Sat"; // Contains the definitions for all the sat algorithm parameters and their // default values. // -// NEXT TAG: 235 +// NEXT TAG: 237 message SatParameters { // In some context, like in a portfolio of search, it makes sense to name a // given parameters set for logging purpose. @@ -1063,6 +1063,16 @@ message SatParameters { // http://aaai.org/ocs/index.php/AAAI/AAAI17/paper/view/14489 optional bool optimize_with_max_hs = 85 [default = false]; + // Enables experimental workstealing-like shared tree search. + // If non-zero, start this many complete worker threads to explore a shared + // search tree. These workers communicate objective bounds and simple decision + // nogoods relating to the shared prefix of the tree, and will avoid exploring + // the same subtrees as one another. + optional int32 num_shared_tree_workers = 235 [default = 0]; + + // Set on shared subtree workers. Users should not set this directly. + optional bool use_shared_tree_search = 236 [default = false]; + // Whether we enumerate all solutions of a problem without objective. Note // that setting this to true automatically disable some presolve reduction // that can remove feasible solution. That is it has the same effect as diff --git a/ortools/sat/subsolver.cc b/ortools/sat/subsolver.cc index 21c0ef96cf..e2f25a0231 100644 --- a/ortools/sat/subsolver.cc +++ b/ortools/sat/subsolver.cc @@ -22,6 +22,7 @@ #include "absl/flags/flag.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/blocking_counter.h" #include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "absl/time/time.h" @@ -104,6 +105,8 @@ void DeterministicLoop( std::vector num_generated_tasks(subsolvers.size(), 0); std::vector> to_run; to_run.reserve(batch_size); + ThreadPool pool("DeterministicLoop", num_threads); + pool.StartWorkers(); while (true) { SynchronizeAll(subsolvers); @@ -118,31 +121,39 @@ void DeterministicLoop( } if (to_run.empty()) break; - // TODO(user): We could reuse the same ThreadPool as long as we wait for all - // the task in a batch to finish before scheduling new ones. Not sure how - // to easily do that, so for now we just recreate the pool for each to_run. - ThreadPool pool("DeterministicLoop", num_threads); - pool.StartWorkers(); + // Schedule each task. + absl::BlockingCounter blocking_counter(static_cast(to_run.size())); for (auto& f : to_run) { - pool.Schedule(std::move(f)); + pool.Schedule([f = std::move(f), &blocking_counter]() { + f(); + blocking_counter.DecrementCount(); + }); } to_run.clear(); + + // Wait for all tasks of this batch to be done before scheduling another + // batch. + blocking_counter.Wait(); } } void NonDeterministicLoop( const std::vector>& subsolvers, - int num_threads) { + const int num_threads) { CHECK_GT(num_threads, 0); if (num_threads == 1) { return SequentialLoop(subsolvers); } - // The mutex will protect these two fields. This is used to only keep - // num_threads task in-flight and detect when the search is done. + // The mutex guards num_in_flight. This is used to detect when the search is + // done. absl::Mutex mutex; - absl::CondVar thread_available_condition; - int num_scheduled_and_not_done = 0; + int num_in_flight = 0; // Guarded by `mutex`. + // Predicate to be used with absl::Condition to detect that num_in_flight < + // num_threads. Must only be called while locking `mutex`. + const auto num_in_flight_lt_num_threads = [&num_in_flight, num_threads]() { + return num_in_flight < num_threads; + }; ThreadPool pool("NonDeterministicLoop", num_threads); pool.StartWorkers(); @@ -153,18 +164,16 @@ void NonDeterministicLoop( int64_t task_id = 0; std::vector num_generated_tasks(subsolvers.size(), 0); while (true) { + // Set to true if no task is pending right now. bool all_done = false; { - absl::MutexLock mutex_lock(&mutex); + // Wait if num_in_flight == num_threads. + const absl::MutexLock mutex_lock( + &mutex, absl::Condition(&num_in_flight_lt_num_threads)); // The stopping condition is that we do not have anything else to generate // once all the task are done and synchronized. - if (num_scheduled_and_not_done == 0) all_done = true; - - // Wait if num_scheduled_and_not_done == num_threads. - if (num_scheduled_and_not_done == num_threads) { - thread_available_condition.Wait(&mutex); - } + if (num_in_flight == 0) all_done = true; } SynchronizeAll(subsolvers); @@ -184,20 +193,16 @@ void NonDeterministicLoop( num_generated_tasks[best]++; { absl::MutexLock mutex_lock(&mutex); - num_scheduled_and_not_done++; + num_in_flight++; } std::function task = subsolvers[best]->GenerateTask(task_id++); const std::string name = subsolvers[best]->name(); - pool.Schedule([task, num_threads, name, &mutex, &num_scheduled_and_not_done, - &thread_available_condition]() { + pool.Schedule([task = std::move(task), name, &mutex, &num_in_flight]() { task(); - absl::MutexLock mutex_lock(&mutex); + const absl::MutexLock mutex_lock(&mutex); VLOG(1) << name << " done."; - num_scheduled_and_not_done--; - if (num_scheduled_and_not_done == num_threads - 1) { - thread_available_condition.SignalAll(); - } + num_in_flight--; }); } } diff --git a/ortools/sat/util.h b/ortools/sat/util.h index 336ba7a79c..cc9d7f66f3 100644 --- a/ortools/sat/util.h +++ b/ortools/sat/util.h @@ -380,6 +380,61 @@ std::vector>> FullyCompressTuples( absl::Span domain_sizes, std::vector>* tuples); +// Keep the top n elements from a stream of elements. +// +// TODO(user): We could use gtl::TopN when/if it gets open sourced. Note that +// we might be slighlty faster here since we use an indirection and don't move +// the Element class around as much. +template +class TopN { + public: + explicit TopN(int n) : n_(n) {} + + void Clear() { + heap_.clear(); + elements_.clear(); + } + + void Add(Element e, Score score) { + if (heap_.size() < n_) { + const int index = elements_.size(); + heap_.push_back({index, score}); + elements_.push_back(std::move(e)); + if (heap_.size() == n_) { + // TODO(user): We could delay that on the n + 1 push. + std::make_heap(heap_.begin(), heap_.end()); + } + } else { + if (score <= heap_.front().score) return; + const int index_to_replace = heap_.front().index; + elements_[index_to_replace] = std::move(e); + + // If needed, we could be faster here with an update operation. + std::pop_heap(heap_.begin(), heap_.end()); + heap_.back() = {index_to_replace, score}; + std::push_heap(heap_.begin(), heap_.end()); + } + } + + bool empty() const { return elements_.empty(); } + + const std::vector& UnorderedElements() const { return elements_; } + + private: + const int n_; + + // We keep a heap of the n highest score. + struct HeapElement { + int index; // in elements_; + Score score; + bool operator<(const HeapElement& other) const { + return score > other.score; + } + }; + std::vector heap_; + std::vector elements_; +}; + // ============================================================================ // Implementation. // ============================================================================ diff --git a/ortools/sat/work_assignment.cc b/ortools/sat/work_assignment.cc new file mode 100644 index 0000000000..341b524ca5 --- /dev/null +++ b/ortools/sat/work_assignment.cc @@ -0,0 +1,595 @@ +// Copyright 2010-2022 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/work_assignment.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/random/distributions.h" +#include "absl/random/random.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "ortools/base/logging.h" +#include "ortools/sat/cp_model_mapping.h" +#include "ortools/sat/cp_model_utils.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_search.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/sat/synchronization.h" +#include "ortools/util/time_limit.h" + +namespace operations_research::sat { + +Literal ProtoLiteral::Decode(CpModelMapping* mapping, + IntegerEncoder* encoder) const { + DCHECK_LT(proto_var_, mapping->NumProtoVariables()); + if (mapping->IsBoolean(proto_var_)) { + return mapping->Literal(proto_var_); + } + return encoder->GetOrCreateAssociatedLiteral(DecodeInteger(mapping)); +} + +IntegerLiteral ProtoLiteral::DecodeInteger(CpModelMapping* mapping) const { + const int positive_var = PositiveRef(proto_var_); + if (!mapping->IsInteger(positive_var)) { + return IntegerLiteral(); + } + if (proto_var_ < 0) { + return IntegerLiteral::LowerOrEqual(mapping->Integer(positive_var), -lb_); + } + return IntegerLiteral::GreaterOrEqual(mapping->Integer(positive_var), lb_); +} + +std::optional ProtoLiteral::EncodeInteger( + IntegerLiteral literal, CpModelMapping* mapping) { + IntegerVariable positive_var = PositiveVariable(literal.var); + const int model_var = + mapping->GetProtoVariableFromIntegerVariable(positive_var); + if (model_var == -1) { + return std::nullopt; + } + ProtoLiteral result{ + literal.var == positive_var ? model_var : NegatedRef(model_var), + literal.bound}; + DCHECK_EQ(result.DecodeInteger(mapping), literal); + DCHECK_EQ(result.Negated().DecodeInteger(mapping), literal.Negated()); + return result; +} +std::optional ProtoLiteral::Encode(Literal literal, + CpModelMapping* mapping, + IntegerEncoder* encoder) { + if (literal.Index() == kNoLiteralIndex) { + return std::nullopt; + } + int model_var = + mapping->GetProtoVariableFromBooleanVariable(literal.Variable()); + if (model_var != -1) { + CHECK(mapping->IsBoolean(model_var)); + ProtoLiteral result{ + literal.IsPositive() ? model_var : NegatedRef(model_var), + literal.IsPositive() ? 1 : 0}; + DCHECK_EQ(result.Decode(mapping, encoder), literal); + DCHECK_EQ(result.Negated().Decode(mapping, encoder), literal.Negated()); + return result; + } + for (auto int_lit : encoder->GetIntegerLiterals(literal)) { + auto result = EncodeInteger(int_lit, mapping); + if (result.has_value()) { + DCHECK_EQ(result->DecodeInteger(mapping), int_lit); + DCHECK_EQ(result->Negated().DecodeInteger(mapping), int_lit.Negated()); + return result; + } + } + return std::nullopt; +} + +void ProtoTrail::PushLevel(const ProtoLiteral& decision, + IntegerValue objective_lb, int node_id) { + CHECK_GT(node_id, 0); + decision_indexes_.push_back(literals_.size()); + literals_.push_back(decision); + node_ids_.push_back(node_id); + if (!level_to_objective_lbs_.empty()) { + objective_lb = std::max(level_to_objective_lbs_.back(), objective_lb); + } + level_to_objective_lbs_.push_back(objective_lb); +} + +void ProtoTrail::SetLevelImplied(int level) { + DCHECK_GE(level, 1); + DCHECK_LE(level, decision_indexes_.size()); + SetObjectiveLb(level - 1, ObjectiveLb(level)); + decision_indexes_.erase(decision_indexes_.begin() + level - 1); + level_to_objective_lbs_.erase(level_to_objective_lbs_.begin() + level - 1); +} + +void ProtoTrail::Clear() { + decision_indexes_.clear(); + literals_.clear(); + level_to_objective_lbs_.clear(); + node_ids_.clear(); +} + +void ProtoTrail::SetObjectiveLb(int level, IntegerValue objective_lb) { + if (level == 0) return; + level_to_objective_lbs_[level - 1] = + std::max(objective_lb, level_to_objective_lbs_[level - 1]); +} + +absl::Span ProtoTrail::NodeIds(int level) const { + DCHECK_LE(level, decision_indexes_.size()); + int start = level == 0 ? 0 : decision_indexes_[level - 1]; + int end = level == decision_indexes_.size() ? node_ids_.size() + : decision_indexes_[level]; + return absl::MakeSpan(node_ids_.data() + start, end - start); +} + +absl::Span ProtoTrail::Implications(int level) const { + DCHECK_LE(level, decision_indexes_.size()); + int start = level == 0 ? 0 : decision_indexes_[level - 1] + 1; + int end = level == decision_indexes_.size() ? node_ids_.size() + : decision_indexes_[level]; + return absl::MakeSpan(literals_.data() + start, end - start); +} + +SharedTreeManager::SharedTreeManager(Model* model) + : num_workers_( + model->GetOrCreate()->num_shared_tree_workers()), + shared_response_manager_(model->GetOrCreate()), + num_splits_wanted_(num_workers_ - 1), + max_nodes_(128 * num_workers_) { + // Create the root node with a fake literal. + nodes_.push_back({.literal = ProtoLiteral()}); + unassigned_leaves_.reserve(num_workers_); + unassigned_leaves_.push_back(&nodes_.back()); + to_close_.reserve(max_nodes_); + to_update_.reserve(max_nodes_); +} + +int SharedTreeManager::SplitsToGeneratePerWorker() const { + absl::MutexLock mutex_lock(&mu_); + return std::min((num_splits_wanted_ + num_workers_ - 1) / num_workers_, + max_nodes_ - static_cast(nodes_.size())); +} + +bool SharedTreeManager::SyncTree(ProtoTrail& path) { + absl::MutexLock mutex_lock(&mu_); + std::vector> nodes = GetAssignedNodes(path); + if (nodes.back().first->closed) { + path.Clear(); + return false; + } + // We don't rely on these being empty, but we expect them to be. + DCHECK(to_close_.empty()); + DCHECK(to_update_.empty()); + int prev_level = -1; + for (const auto& [node, level] : nodes) { + if (level == prev_level) { + to_close_.push_back(GetSibling(node)); + } + if (level > 0 && node->objective_lb < path.ObjectiveLb(level)) { + node->objective_lb = path.ObjectiveLb(level); + to_update_.push_back(node->parent); + } + prev_level = level; + } + ProcessNodeChanges(); + return true; +} + +void SharedTreeManager::ProposeSplit(ProtoTrail& path, ProtoLiteral decision) { + absl::MutexLock mutex_lock(&mu_); + std::vector> nodes = GetAssignedNodes(path); + if (nodes.back().first->children[0] != nullptr) { + LOG_IF(WARNING, nodes.size() > 1) + << "Cannot resplit previously split node @ " << nodes.back().second + << "/" << nodes.size(); + return; + } + if (nodes_.size() >= max_nodes_) { + VLOG(1) << "Too many nodes to accept split"; + return; + } + if (num_splits_wanted_ <= 0) { + VLOG(1) << "Enough splits for now"; + return; + } + if (path.MaxLevel() > log2(max_nodes_)) { + VLOG(1) << "Tree too unbalanced to accept split"; + return; + } + VLOG_EVERY_N(1, 10) << unassigned_leaves_.size() << " unassigned leaves, " + << nodes_.size() << " subtrees, " << num_splits_wanted_ + << " splits wanted"; + Split(nodes, decision); + auto [new_leaf, level] = nodes.back(); + path.PushLevel(new_leaf->literal, new_leaf->objective_lb, new_leaf->id); +} + +void SharedTreeManager::ReplaceTree(ProtoTrail& path) { + absl::MutexLock mutex_lock(&mu_); + std::vector> nodes = GetAssignedNodes(path); + if (nodes.back().first->children[0] == nullptr && + !nodes.back().first->closed && nodes.size() > 1) { + VLOG(1) << "Returning leaf to be replaced"; + unassigned_leaves_.push_back(nodes.back().first); + } + path.Clear(); + while (!unassigned_leaves_.empty()) { + const int i = num_leaves_assigned_++ % unassigned_leaves_.size(); + std::swap(unassigned_leaves_[i], unassigned_leaves_.back()); + Node* leaf = unassigned_leaves_.back(); + unassigned_leaves_.pop_back(); + if (!leaf->closed && leaf->children[0] == nullptr) { + AssignLeaf(path, leaf); + return; + } + } + VLOG(1) << "Assigning root because no unassigned leaves are available"; + // TODO(user): Investigate assigning a random leaf so workers can still + // improve shared tree bounds. +} + +SharedTreeManager::Node* SharedTreeManager::GetSibling(Node* node) { + if (node == nullptr || node->parent == nullptr) return nullptr; + if (node->parent->children[0] != node) { + return node->parent->children[0]; + } + return node->parent->children[1]; +} + +void SharedTreeManager::Split(std::vector>& nodes, + ProtoLiteral lit) { + const auto [parent, level] = nodes.back(); + DCHECK_EQ(parent->children[0], nullptr); + DCHECK_EQ(parent->children[1], nullptr); + parent->children[0] = MakeSubtree(parent, lit); + parent->children[1] = MakeSubtree(parent, lit.Negated()); + nodes.push_back(std::make_pair(parent->children[0], level + 1)); + unassigned_leaves_.push_back(parent->children[1]); + --num_splits_wanted_; +} + +SharedTreeManager::Node* SharedTreeManager::MakeSubtree(Node* parent, + ProtoLiteral literal) { + nodes_.push_back(Node{.literal = literal, + .objective_lb = parent->objective_lb, + .parent = parent, + .id = static_cast(nodes_.size())}); + return &nodes_.back(); +} + +void SharedTreeManager::ProcessNodeChanges() { + while (!to_close_.empty() || !to_update_.empty()) { + while (!to_close_.empty()) { + Node* node = to_close_.back(); + CHECK_NE(node, nullptr); + to_close_.pop_back(); + if (node->closed) continue; + node->closed = true; + // If we are closing a leaf, try to maintain the same number of leaves; + num_splits_wanted_ += (node->children[0] == nullptr); + for (Node* child : node->children) { + if (child == nullptr || child->closed) continue; + to_close_.push_back(child); + } + if (node->parent != nullptr) { + to_update_.push_back(node->parent); + GetSibling(node)->implied = true; + } else { + shared_response_manager_->NotifyThatImprovingProblemIsInfeasible( + "shared_tree_manager"); + } + } + if (to_update_.empty()) break; + Node* node = to_update_.back(); + to_update_.pop_back(); + while (node != nullptr && !node->closed && node->children[0] != nullptr) { + bool has_open_child = false; + IntegerValue child_bound = kMaxIntegerValue; + for (const Node* child : node->children) { + if (child->closed) continue; + has_open_child = true; + child_bound = std::min(child->objective_lb, child_bound); + } + if (!has_open_child) { + to_close_.push_back(node); + } else if (child_bound > node->objective_lb) { + node->objective_lb = child_bound; + if (node->parent == nullptr) { + shared_response_manager_->UpdateInnerObjectiveBounds( + "shared_tree_manager", node->objective_lb, kMaxIntegerValue); + node->objective_lb = + shared_response_manager_->GetInnerObjectiveLowerBound(); + } + } else { + break; + } + node = node->parent; + } + } +} + +std::vector> +SharedTreeManager::GetAssignedNodes(const ProtoTrail& path) { + std::vector> nodes({std::make_pair(&nodes_[0], 0)}); + for (int i = 0; i <= path.MaxLevel(); ++i) { + for (int id : path.NodeIds(i)) { + if (id != -1) { + DCHECK_EQ(nodes.back().first, nodes_[id].parent); + nodes.push_back(std::make_pair(&nodes_[id], i)); + } + } + } + return nodes; +} + +void SharedTreeManager::CloseTree(ProtoTrail& path, int level) { + absl::MutexLock mutex_lock(&mu_); + Node* node = &nodes_[path.NodeIds(level).front()]; + VLOG(1) << "Closing subtree at level " << level; + DCHECK(to_close_.empty()); + to_close_.push_back(node); + ProcessNodeChanges(); + path.Clear(); +} + +void SharedTreeManager::AssignLeaf(ProtoTrail& path, Node* leaf) { + if (leaf == &nodes_[0]) { + path.Clear(); + return; + } + AssignLeaf(path, leaf->parent); + path.PushLevel(leaf->literal, leaf->objective_lb, leaf->id); + if (leaf->implied) { + path.SetLevelImplied(path.MaxLevel()); + } +} + +SharedTreeWorker::SharedTreeWorker(Model* model) + : parameters_(model->GetOrCreate()), + shared_response_(model->GetOrCreate()), + time_limit_(model->GetOrCreate()), + manager_(model->GetOrCreate()), + mapping_(model->GetOrCreate()), + sat_solver_(model->GetOrCreate()), + trail_(model->GetOrCreate()), + integer_trail_(model->GetOrCreate()), + encoder_(model->GetOrCreate()), + objective_(model->Get()), + random_(model->GetOrCreate()), + helper_(model->GetOrCreate()), + heuristics_(model->GetOrCreate()) {} + +const std::vector& SharedTreeWorker::DecisionReason(int level) { + reason_.clear(); + for (int i = 1; i <= level; ++i) { + reason_.push_back(DecodeDecision(assigned_tree_.Decision(i)).Negated()); + } + return reason_; +} + +bool SharedTreeWorker::AddImplications( + absl::Span implied_literals) { + const int level = sat_solver_->CurrentDecisionLevel(); + // Level 0 implications are unit clauses and are synced elsewhere. + if (level == 0) return false; + if (level > assigned_tree_.MaxLevel()) { + return false; + } + bool added_clause = false; + for (const ProtoLiteral& impl : implied_literals) { + Literal lit(DecodeDecision(impl)); + if (sat_solver_->Assignment().LiteralIsFalse(lit)) { + VLOG(1) << "Closing subtree via impl at " << level + 1 + << " assigned=" << assigned_tree_.MaxLevel(); + integer_trail_->ReportConflict(DecisionReason(level), {}); + manager_->CloseTree(assigned_tree_, level); + return true; + } + if (!sat_solver_->Assignment().LiteralIsTrue(lit)) { + added_clause = true; + integer_trail_->EnqueueLiteral(lit, DecisionReason(level), {}); + VLOG(1) << "Learned shared clause"; + } + } + if (objective_ != nullptr) { + const IntegerValue obj_lb = + integer_trail_->LowerBound(objective_->objective_var); + const IntegerValue obj_ub = + integer_trail_->UpperBound(objective_->objective_var); + if (obj_ub < assigned_tree_.ObjectiveLb(level)) { + integer_trail_->ReportConflict(DecisionReason(level), {}); + manager_->CloseTree(assigned_tree_, level); + return true; + } + if (obj_lb < assigned_tree_.ObjectiveLb(level)) { + integer_trail_->EnqueueLiteral( + encoder_->GetOrCreateAssociatedLiteral(IntegerLiteral::GreaterOrEqual( + objective_->objective_var, assigned_tree_.ObjectiveLb(level))), + DecisionReason(level), {}); + VLOG(1) << "Learned shared objective clause"; + return true; + } else { + assigned_tree_.SetObjectiveLb(level, obj_lb); + } + } + return added_clause; +} + +bool SharedTreeWorker::SyncWithLocalTrail() { + const int level = sat_solver_->CurrentDecisionLevel(); + if (level > assigned_tree_.MaxLevel()) { + return true; + } + const int initial_trail_index = trail_->Index(); + bool added_clause = AddImplications(assigned_tree_.Implications(level)); + while (level + 1 <= assigned_tree_.MaxLevel()) { + const ProtoLiteral& shared_lit = assigned_tree_.Decision(level + 1); + Literal decision(DecodeDecision(shared_lit)); + if (sat_solver_->Assignment().LiteralIsTrue(decision)) { + AddImplications(assigned_tree_.Implications(level + 1)); + added_clause = true; + assigned_tree_.SetLevelImplied(level + 1); + continue; + } else if (sat_solver_->Assignment().LiteralIsFalse(Literal(decision))) { + VLOG(1) << "Closing subtree at " << level + 1 + << " assigned=" << assigned_tree_.MaxLevel(); + manager_->CloseTree(assigned_tree_, level + 1); + sat_solver_->Backtrack(0); + return false; + } + break; + } + return !added_clause && initial_trail_index == trail_->Index(); +} + +LiteralIndex SharedTreeWorker::NextDecision() { + const auto& decision_policy = + heuristics_->decision_policies[heuristics_->policy_index]; + const int next_level = sat_solver_->CurrentDecisionLevel() + 1; + if (next_level == assigned_tree_.MaxLevel() + 1 && splits_wanted_ > 0) { + VLOG(1) << "Try split! " << parameters_->name(); + Literal decision(helper_->GetDecision(decision_policy)); + std::optional shared_lit = EncodeDecision(decision); + if (shared_lit.has_value() && !sat_solver_->Assignment().LiteralIsAssigned( + Literal(DecodeDecision(*shared_lit)))) { + manager_->ProposeSplit(assigned_tree_, *shared_lit); + --splits_wanted_; + } + return decision.Index(); + } else if (next_level <= assigned_tree_.MaxLevel()) { + VLOG(1) << "Following shared trail depth=" << next_level << " " + << parameters_->name(); + const ProtoLiteral shared_lit = assigned_tree_.Decision(next_level); + Literal decision(DecodeDecision(shared_lit)); + CHECK(!sat_solver_->Assignment().LiteralIsFalse(decision)) + << " at depth " << next_level << " " << parameters_->name(); + CHECK(!sat_solver_->Assignment().LiteralIsTrue(decision)); + return decision.Index(); + } + if (objective_ == nullptr) return helper_->GetDecision(decision_policy); + // If the current node is close to the global lower bound, maybe try to + // improve it. + const IntegerValue root_obj_lb = + integer_trail_->LevelZeroLowerBound(objective_->objective_var); + const IntegerValue root_obj_ub = + integer_trail_->LevelZeroUpperBound(objective_->objective_var); + const IntegerValue obj_split = + root_obj_lb + absl::LogUniform( + *random_, 0, (root_obj_ub - root_obj_lb).value()); + const double kObjectiveSplitProbability = 0.5; + return helper_->GetDecision([&]() -> BooleanOrIntegerLiteral { + IntegerValue obj_lb = integer_trail_->LowerBound(objective_->objective_var); + IntegerValue obj_ub = integer_trail_->UpperBound(objective_->objective_var); + if (obj_lb > obj_split || obj_ub <= obj_split || + next_level > assigned_tree_.MaxLevel() + 1 || + absl::Bernoulli(*random_, 1 - kObjectiveSplitProbability)) { + return decision_policy(); + } + return BooleanOrIntegerLiteral( + IntegerLiteral::LowerOrEqual(objective_->objective_var, obj_split)); + }); +} + +SatSolver::Status SharedTreeWorker::Search( + const std::function& feasible_solution_observer) { + // Inside GetAssociatedLiteral if a literal becomes fixed at level 0 during + // Search,the code checks it is at level 0 when decoding the literal, but + // the fixed literals are cached, so we can create them now to avoid a + // crash. + sat_solver_->Backtrack(0); + encoder_->GetTrueLiteral(); + encoder_->GetFalseLiteral(); + std::vector clause; + while (!time_limit_->LimitReached() && !shared_response_->ProblemIsSolved()) { + if (!sat_solver_->FinishPropagation()) { + return sat_solver_->UnsatStatus(); + } + const int level = sat_solver_->CurrentDecisionLevel(); + if (level == 0) { + splits_wanted_ = manager_->SplitsToGeneratePerWorker(); + VLOG(1) << "Splits wanted: " << splits_wanted_ << " " + << parameters_->name(); + manager_->SyncTree(assigned_tree_); + // If we have no assignment, try to get one. + // We also want to ensure unassigned nodes have their lower bounds bumped + // periodically, so workers need to occasionally replace open trees. + // TODO(user): Ideally we should use some metric to replace a + // subtree when the worker is doing badly. + if (assigned_tree_.MaxLevel() == 0 || absl::Bernoulli(*random_, 1e-2)) { + manager_->ReplaceTree(assigned_tree_); + } + VLOG(1) << "Assigned level: " << assigned_tree_.MaxLevel() << " " + << parameters_->name(); + } + if (heuristics_->restart_policies[heuristics_->policy_index]()) { + heuristics_->policy_index = (heuristics_->policy_index + 1) % + heuristics_->decision_policies.size(); + sat_solver_->Backtrack(0); + continue; + } + if (!helper_->BeforeTakingDecision()) { + return sat_solver_->UnsatStatus(); + } + if (time_limit_->LimitReached() || shared_response_->ProblemIsSolved()) { + break; + } + if (!SyncWithLocalTrail()) continue; + Literal decision(NextDecision()); + if (time_limit_->LimitReached()) return SatSolver::LIMIT_REACHED; + if (decision.Index() == kNoLiteralIndex) { + feasible_solution_observer(); + if (objective_ == nullptr) return SatSolver::FEASIBLE; + const IntegerValue objective = + integer_trail_->LowerBound(objective_->objective_var); + sat_solver_->Backtrack(0); + if (!integer_trail_->Enqueue( + IntegerLiteral::LowerOrEqual(objective_->objective_var, + objective - 1), + {}, {})) { + return SatSolver::INFEASIBLE; + } + + continue; + } + DCHECK(!sat_solver_->Assignment().LiteralIsFalse(decision)); + DCHECK(!sat_solver_->Assignment().LiteralIsTrue(decision)); + if (!helper_->TakeDecision(decision)) { + return sat_solver_->UnsatStatus(); + } + } + + return SatSolver::LIMIT_REACHED; +} + +Literal SharedTreeWorker::DecodeDecision(ProtoLiteral lit) { + return lit.Decode(mapping_, encoder_); +} + +std::optional SharedTreeWorker::EncodeDecision(Literal decision) { + return ProtoLiteral::Encode(decision, mapping_, encoder_); +} + +} // namespace operations_research::sat diff --git a/ortools/sat/work_assignment.h b/ortools/sat/work_assignment.h new file mode 100644 index 0000000000..7ffb3d41c7 --- /dev/null +++ b/ortools/sat/work_assignment.h @@ -0,0 +1,249 @@ +// Copyright 2010-2022 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_WORK_ASSIGNMENT_H_ +#define OR_TOOLS_SAT_WORK_ASSIGNMENT_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "ortools/sat/cp_model_mapping.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/integer_search.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/synchronization.h" +#include "ortools/sat/util.h" +#include "ortools/util/time_limit.h" + +namespace operations_research::sat { + +class ProtoLiteral { + public: + ProtoLiteral() = default; + ProtoLiteral(int var, IntegerValue lb) : proto_var_(var), lb_(lb) {} + ProtoLiteral Negated() const { + return ProtoLiteral(NegatedRef(proto_var_), -lb_ + 1); + } + bool operator==(const ProtoLiteral& other) const { + return proto_var_ == other.proto_var_ && lb_ == other.lb_; + } + bool operator!=(const ProtoLiteral& other) const { return !(*this == other); } + Literal Decode(CpModelMapping*, IntegerEncoder*) const; + static std::optional Encode(Literal, CpModelMapping*, + IntegerEncoder*); + + private: + IntegerLiteral DecodeInteger(CpModelMapping*) const; + static std::optional EncodeInteger(IntegerLiteral, + CpModelMapping*); + + int proto_var_ = std::numeric_limits::max(); + IntegerValue lb_ = kMaxIntegerValue; +}; + +// ProtoTrail acts as an intermediate datastructure that can be synced +// with the shared tree and the local Trail as appropriate. +// It's intended that you sync a ProtoTrail with the tree on restart or when +// a subtree is closed, and with the local trail after propagation. +// Specifically it stores objective lower bounds, and literals that have been +// branched on and later proven to be implied by the prior decisions (i.e. they +// can be enqueued at this level). +// TODO(user): It'd be good to store an earlier level at which +// implications may be propagated. +class ProtoTrail { + public: + // Adds a new assigned level to the trail. + void PushLevel(const ProtoLiteral& decision, IntegerValue objective_lb, + int node_id); + + // Asserts that the decision at `level` is implied by earlier decisions. + void SetLevelImplied(int level); + + // Clear the trail, removing all levels. + void Clear(); + + // Set a lower bound on the objective at level. + void SetObjectiveLb(int level, IntegerValue objective_lb); + + // Returns the maximum decision level stored in the trail. + int MaxLevel() const { return decision_indexes_.size(); } + + // Returns the decision assigned at `level`. + ProtoLiteral Decision(int level) const { + CHECK_GE(level, 1); + CHECK_LE(level, decision_indexes_.size()); + return literals_[decision_indexes_[level - 1]]; + } + + // Returns the node ids for decisions and implications at `level`. + absl::Span NodeIds(int level) const; + + // Returns literals which may be propagated at `level`, this does not include + // the decision. + absl::Span Implications(int level) const; + + IntegerValue ObjectiveLb(int level) const { + CHECK_GE(level, 1); + return level_to_objective_lbs_[level - 1]; + } + + absl::Span Literals() const { return literals_; } + + private: + // Parallel vectors encoding the literals and node ids on the trail. + std::vector literals_; + std::vector node_ids_; + + // The index in the literals_/node_ids_ vectors for the start of each level. + std::vector decision_indexes_; + + // The objective lower bound of each level. + std::vector level_to_objective_lbs_; +}; + +// Experimental thread-safe class for managing work assignments between workers. +// This API is intended to investigate Graeme Gange & Peter Stuckey's proposal +// "Scalable Parallelization of Learning Solvers". +// The core idea of this implementation is that workers can maintain several +// ProtoTrails, and periodically replace the "worst" one. +// With 1 assignment per worker, this leads to something very similar to +// Embarassingly Parallel Search. +class SharedTreeManager { + public: + explicit SharedTreeManager(Model* model); + SharedTreeManager(const SharedTreeManager&) = delete; + + int NumWorkers() const { return num_workers_; } + + // Returns the number of splits each worker should propose this restart. + int SplitsToGeneratePerWorker() const; + + // Syncs the state of path with the internal search tree. + // Clears `path` and returns false if the assigned subtree is closed. + bool SyncTree(ProtoTrail& path) ABSL_LOCKS_EXCLUDED(mu_); + + // Assigns a path prefix that the worker should explore. + void ReplaceTree(ProtoTrail& path); + + // Asserts that the subtree in path up to `level` contains no improving + // solutions. + void CloseTree(ProtoTrail& path, int level); + + // Called by workers in order to split the shared tree. + // `path` may or may not be extended by one level, branching on `decision`. + void ProposeSplit(ProtoTrail& path, ProtoLiteral decision); + + private: + struct Node { + ProtoLiteral literal; + IntegerValue objective_lb = kMinIntegerValue; + Node* parent = nullptr; + std::array children = {nullptr, nullptr}; + // A node's id is its index in `nodes_` + int id; + bool closed = false; + bool implied = false; + }; + Node* GetSibling(Node* node) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + void Split(std::vector>& nodes, ProtoLiteral lit) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + Node* MakeSubtree(Node* parent, ProtoLiteral literal) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + void ProcessNodeChanges() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + std::vector> GetAssignedNodes(const ProtoTrail& path) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + void AssignLeaf(ProtoTrail& path, Node* leaf) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + mutable absl::Mutex mu_; + const int num_workers_; + SharedResponseManager* const shared_response_manager_; + + // Stores the nodes in the search tree. + std::deque nodes_ ABSL_GUARDED_BY(mu_); + std::vector unassigned_leaves_ ABSL_GUARDED_BY(mu_); + + // How many splits we should generate now to keep the desired number of + // leaves. + int num_splits_wanted_; + + // We limit the total nodes generated to cap the RAM usage and communication + // overhead. If we exceed this, workers return to being portfolio workers. + int max_nodes_; + int num_leaves_assigned_ ABSL_GUARDED_BY(mu_) = 0; + + // Temporary vectors used to maintain the state of the tree when nodes are + // closed and/or children are updated. + std::vector to_close_ ABSL_GUARDED_BY(mu_); + std::vector to_update_ ABSL_GUARDED_BY(mu_); +}; + +class SharedTreeWorker { + public: + explicit SharedTreeWorker(Model* model); + SharedTreeWorker(const SharedTreeWorker&) = delete; + SharedTreeWorker& operator=(const SharedTreeWorker&) = delete; + + SatSolver::Status Search( + const std::function& feasible_solution_observer); + + private: + // Syncs the assigned tree with the local trail, ensuring that any new + // implicatons are synced. This is a noop if the search is deeper than the + // assigned tree. Returns false if any clauses were added or for any other + // reason we might need to re-perform propagation. + bool SyncWithLocalTrail(); + Literal DecodeDecision(ProtoLiteral literal); + std::optional EncodeDecision(Literal decision); + LiteralIndex NextDecision(); + + // Add any implications to the clause database for the current level. + // Return true if any new information was added. + bool AddImplications(absl::Span implied_literals); + + const std::vector& DecisionReason(int level); + + SatParameters* parameters_; + SharedResponseManager* shared_response_; + TimeLimit* time_limit_; + SharedTreeManager* manager_; + CpModelMapping* mapping_; + SatSolver* sat_solver_; + Trail* trail_; + IntegerTrail* integer_trail_; + IntegerEncoder* encoder_; + const ObjectiveDefinition* objective_; + ModelRandomGenerator* random_; + IntegerSearchHelper* helper_; + SearchHeuristics* heuristics_; + + ProtoTrail assigned_tree_; + int splits_wanted_ = 1; + + std::vector reason_; +}; + +} // namespace operations_research::sat + +#endif // OR_TOOLS_SAT_WORK_ASSIGNMENT_H_