[CP-SAT] fix #3706, add experimental work-stealing workers, fix minor bugs; improve subsolver multi-thread code
This commit is contained in:
@@ -22,6 +22,7 @@
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/synchronization/blocking_counter.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/time/clock.h"
|
||||
#include "absl/time/time.h"
|
||||
@@ -104,6 +105,8 @@ void DeterministicLoop(
|
||||
std::vector<int64_t> num_generated_tasks(subsolvers.size(), 0);
|
||||
std::vector<std::function<void()>> to_run;
|
||||
to_run.reserve(batch_size);
|
||||
ThreadPool pool("DeterministicLoop", num_threads);
|
||||
pool.StartWorkers();
|
||||
while (true) {
|
||||
SynchronizeAll(subsolvers);
|
||||
|
||||
@@ -118,31 +121,39 @@ void DeterministicLoop(
|
||||
}
|
||||
if (to_run.empty()) break;
|
||||
|
||||
// TODO(user): We could reuse the same ThreadPool as long as we wait for all
|
||||
// the task in a batch to finish before scheduling new ones. Not sure how
|
||||
// to easily do that, so for now we just recreate the pool for each to_run.
|
||||
ThreadPool pool("DeterministicLoop", num_threads);
|
||||
pool.StartWorkers();
|
||||
// Schedule each task.
|
||||
absl::BlockingCounter blocking_counter(static_cast<int>(to_run.size()));
|
||||
for (auto& f : to_run) {
|
||||
pool.Schedule(std::move(f));
|
||||
pool.Schedule([f = std::move(f), &blocking_counter]() {
|
||||
f();
|
||||
blocking_counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
to_run.clear();
|
||||
|
||||
// Wait for all tasks of this batch to be done before scheduling another
|
||||
// batch.
|
||||
blocking_counter.Wait();
|
||||
}
|
||||
}
|
||||
|
||||
void NonDeterministicLoop(
|
||||
const std::vector<std::unique_ptr<SubSolver>>& subsolvers,
|
||||
int num_threads) {
|
||||
const int num_threads) {
|
||||
CHECK_GT(num_threads, 0);
|
||||
if (num_threads == 1) {
|
||||
return SequentialLoop(subsolvers);
|
||||
}
|
||||
|
||||
// The mutex will protect these two fields. This is used to only keep
|
||||
// num_threads task in-flight and detect when the search is done.
|
||||
// The mutex guards num_in_flight. This is used to detect when the search is
|
||||
// done.
|
||||
absl::Mutex mutex;
|
||||
absl::CondVar thread_available_condition;
|
||||
int num_scheduled_and_not_done = 0;
|
||||
int num_in_flight = 0; // Guarded by `mutex`.
|
||||
// Predicate to be used with absl::Condition to detect that num_in_flight <
|
||||
// num_threads. Must only be called while locking `mutex`.
|
||||
const auto num_in_flight_lt_num_threads = [&num_in_flight, num_threads]() {
|
||||
return num_in_flight < num_threads;
|
||||
};
|
||||
|
||||
ThreadPool pool("NonDeterministicLoop", num_threads);
|
||||
pool.StartWorkers();
|
||||
@@ -153,18 +164,16 @@ void NonDeterministicLoop(
|
||||
int64_t task_id = 0;
|
||||
std::vector<int64_t> num_generated_tasks(subsolvers.size(), 0);
|
||||
while (true) {
|
||||
// Set to true if no task is pending right now.
|
||||
bool all_done = false;
|
||||
{
|
||||
absl::MutexLock mutex_lock(&mutex);
|
||||
// Wait if num_in_flight == num_threads.
|
||||
const absl::MutexLock mutex_lock(
|
||||
&mutex, absl::Condition(&num_in_flight_lt_num_threads));
|
||||
|
||||
// The stopping condition is that we do not have anything else to generate
|
||||
// once all the task are done and synchronized.
|
||||
if (num_scheduled_and_not_done == 0) all_done = true;
|
||||
|
||||
// Wait if num_scheduled_and_not_done == num_threads.
|
||||
if (num_scheduled_and_not_done == num_threads) {
|
||||
thread_available_condition.Wait(&mutex);
|
||||
}
|
||||
if (num_in_flight == 0) all_done = true;
|
||||
}
|
||||
|
||||
SynchronizeAll(subsolvers);
|
||||
@@ -184,20 +193,16 @@ void NonDeterministicLoop(
|
||||
num_generated_tasks[best]++;
|
||||
{
|
||||
absl::MutexLock mutex_lock(&mutex);
|
||||
num_scheduled_and_not_done++;
|
||||
num_in_flight++;
|
||||
}
|
||||
std::function<void()> task = subsolvers[best]->GenerateTask(task_id++);
|
||||
const std::string name = subsolvers[best]->name();
|
||||
pool.Schedule([task, num_threads, name, &mutex, &num_scheduled_and_not_done,
|
||||
&thread_available_condition]() {
|
||||
pool.Schedule([task = std::move(task), name, &mutex, &num_in_flight]() {
|
||||
task();
|
||||
|
||||
absl::MutexLock mutex_lock(&mutex);
|
||||
const absl::MutexLock mutex_lock(&mutex);
|
||||
VLOG(1) << name << " done.";
|
||||
num_scheduled_and_not_done--;
|
||||
if (num_scheduled_and_not_done == num_threads - 1) {
|
||||
thread_available_condition.SignalAll();
|
||||
}
|
||||
num_in_flight--;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user