diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index a7b4a1b70b..9891972bc3 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -1509,7 +1509,8 @@ void PostsolveResponse(const std::string& debug_info, } SharedResponseManager local_response_manager( - /*log_updates=*/false, &mapping_proto, wall_timer); + /*log_updates=*/false, /*enumerate_all_solutions=*/false, + /*solution_limit=*/-1, &mapping_proto, wall_timer); LoadCpModel(mapping_proto, &local_response_manager, &postsolve_model); SolveLoadedCpModel(mapping_proto, &local_response_manager, &postsolve_model); const CpSolverResponse postsolve_response = @@ -1967,7 +1968,8 @@ class LnsSolver : public SubSolver { // parameters that work bests (core, linearization_level, etc...) or // maybe we can just randomize them like for the base solution used. SharedResponseManager local_response_manager( - /*log_updates=*/false, &neighborhood.cp_model, shared_->wall_timer); + /*log_updates=*/false, /*enumerate_all_solutions=*/false, + /*solution_limit=*/-1, &neighborhood.cp_model, shared_->wall_timer); LoadCpModel(neighborhood.cp_model, &local_response_manager, &local_model); QuickSolveWithHint(neighborhood.cp_model, &local_response_manager, &local_model); @@ -2366,8 +2368,10 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { }; } - SharedResponseManager shared_response_manager(log_search, &new_cp_model_proto, - &wall_timer); + SharedResponseManager shared_response_manager( + log_search, params.enumerate_all_solutions(), + params.stop_after_first_solution() ? 1 : -1, &new_cp_model_proto, + &wall_timer); const auto& observers = model->GetOrCreate()->observers; if (!observers.empty()) { shared_response_manager.AddSolutionCallback( diff --git a/ortools/sat/synchronization.cc b/ortools/sat/synchronization.cc index 6fda72205b..66dc86b7c6 100644 --- a/ortools/sat/synchronization.cc +++ b/ortools/sat/synchronization.cc @@ -81,9 +81,13 @@ void SharedSolutionRepository::Synchronize() { // TODO(user): Experiments and play with the num_solutions_to_keep parameter. SharedResponseManager::SharedResponseManager(bool log_updates, + bool enumerate_all_solutions, + int solution_limit, const CpModelProto* proto, const WallTimer* wall_timer) : log_updates_(log_updates), + enumerate_all_solutions_(enumerate_all_solutions), + solution_limit_(solution_limit), model_proto_(*proto), wall_timer_(*wall_timer), solutions_(/*num_solutions_to_keep=*/10) {} @@ -252,6 +256,8 @@ void SharedResponseManager::NewSolution(const CpSolverResponse& response, absl::MutexLock mutex_lock(&mutex_); CHECK_NE(best_response_.status(), CpSolverStatus::INFEASIBLE); + if (solution_limit_ > 0 && num_solutions_ >= solution_limit_) return; + if (model_proto_.has_objective()) { const int64 objective_value = ComputeInnerObjective(model_proto_.objective(), response); @@ -355,10 +361,16 @@ void SharedResponseManager::SetStatsFromModelInternal(Model* model) { bool SharedResponseManager::ProblemIsSolved() const { absl::MutexLock mutex_lock(&mutex_); + if (solution_limit_ > 0 && num_solutions_ >= solution_limit_) { + return true; + } + // TODO(user): Currently this work because we do not allow enumerate all // solution in multithread. if (!model_proto_.has_objective() && - best_response_.status() == CpSolverStatus::FEASIBLE) { + ((best_response_.status() == CpSolverStatus::FEASIBLE && + !enumerate_all_solutions_) || + best_response_.status() == CpSolverStatus::OPTIMAL)) { return true; } diff --git a/ortools/sat/synchronization.h b/ortools/sat/synchronization.h index 3e56e09c68..07cdb5909c 100644 --- a/ortools/sat/synchronization.h +++ b/ortools/sat/synchronization.h @@ -154,7 +154,8 @@ class SharedResponseManager { public: // If log_updates is true, then all updates to the global "state" will be // logged. This class is responsible for our solver log progress. - SharedResponseManager(bool log_updates_, const CpModelProto* proto, + SharedResponseManager(bool log_updates, bool enumerate_all_solutions, + int solution_limit, const CpModelProto* proto, const WallTimer* wall_timer); // Returns the current solver response. That is the best known response at the @@ -230,6 +231,8 @@ class SharedResponseManager { void SetStatsFromModelInternal(Model* model) EXCLUSIVE_LOCKS_REQUIRED(mutex_); const bool log_updates_; + const bool enumerate_all_solutions_; + const int solution_limit_; const CpModelProto& model_proto_; const WallTimer& wall_timer_;