diff --git a/ortools/sat/cp_model_lns.cc b/ortools/sat/cp_model_lns.cc index 53bb6da5dd..5281a662a4 100644 --- a/ortools/sat/cp_model_lns.cc +++ b/ortools/sat/cp_model_lns.cc @@ -1247,7 +1247,7 @@ double NeighborhoodGenerator::GetUCBScore(int64_t total_num_calls) const { return current_average_ + sqrt((2 * log(total_num_calls)) / num_calls_); } -double NeighborhoodGenerator::Synchronize() { +absl::Span NeighborhoodGenerator::Synchronize() { absl::MutexLock mutex_lock(&generator_mutex_); // To make the whole update process deterministic, we currently sort the @@ -1258,7 +1258,7 @@ double NeighborhoodGenerator::Synchronize() { int num_fully_solved_in_batch = 0; int num_not_fully_solved_in_batch = 0; - double total_dtime = 0.0; + tmp_dtimes_.clear(); for (const SolveData& data : solve_data_) { ++num_calls_; @@ -1304,7 +1304,7 @@ double NeighborhoodGenerator::Synchronize() { current_average_ = 0.9 * current_average_ + 0.1 * gain_per_time_unit; } - total_dtime += data.deterministic_time; + tmp_dtimes_.push_back(data.deterministic_time); } // Update the difficulty. @@ -1327,7 +1327,7 @@ double NeighborhoodGenerator::Synchronize() { } solve_data_.clear(); - return total_dtime; + return tmp_dtimes_; } std::vector diff --git a/ortools/sat/cp_model_lns.h b/ortools/sat/cp_model_lns.h index bc151fb217..91b823b4e8 100644 --- a/ortools/sat/cp_model_lns.h +++ b/ortools/sat/cp_model_lns.h @@ -475,9 +475,9 @@ class NeighborhoodGenerator { } // Process all the recently added solve data and update this generator - // score and difficulty. This returns the sum of the deterministic time of + // score and difficulty. This returns list of the deterministic time of // each SolveData. - double Synchronize(); + absl::Span Synchronize(); // Returns a short description of the generator. std::string name() const { return name_; } @@ -528,6 +528,7 @@ class NeighborhoodGenerator { private: std::vector solve_data_; + std::vector tmp_dtimes_; // Current parameters to be used when generating/solving a neighborhood with // this generator. Only updated on Synchronize(). diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index 680555185a..dcf164ef21 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -1721,9 +1721,13 @@ class LnsSolver : public SubSolver { } void Synchronize() override { - const double dtime = generator_->Synchronize(); - AddTaskDeterministicDuration(dtime); - shared_->time_limit->AdvanceDeterministicTime(dtime); + double sum = 0.0; + const absl::Span dtimes = generator_->Synchronize(); + for (const double dtime : dtimes) { + sum += dtime; + AddTaskDeterministicDuration(dtime); + } + shared_->time_limit->AdvanceDeterministicTime(sum); } private: