Update thread_pool code (#4890)
This commit is contained in:
committed by
Corentin Le Molgat
parent
6969f23df4
commit
d8d50bae68
@@ -535,8 +535,12 @@ cc_library(
|
||||
srcs = ["threadpool.cc"],
|
||||
hdrs = ["threadpool.h"],
|
||||
deps = [
|
||||
"@abseil-cpp//absl/algorithm:container",
|
||||
"@abseil-cpp//absl/base:core_headers",
|
||||
"@abseil-cpp//absl/base:nullability",
|
||||
"@abseil-cpp//absl/functional:any_invocable",
|
||||
"@abseil-cpp//absl/log:check",
|
||||
"@abseil-cpp//absl/strings",
|
||||
"@abseil-cpp//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -13,84 +13,120 @@
|
||||
|
||||
#include "ortools/base/threadpool.h"
|
||||
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/base/nullability.h"
|
||||
#include "absl/base/optimization.h"
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/functional/any_invocable.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
|
||||
namespace operations_research {
|
||||
void RunWorker(void* data) {
|
||||
ThreadPool* const thread_pool = reinterpret_cast<ThreadPool*>(data);
|
||||
std::function<void()> work = thread_pool->GetNextTask();
|
||||
while (work != nullptr) {
|
||||
work();
|
||||
work = thread_pool->GetNextTask();
|
||||
}
|
||||
|
||||
// It is a common error to call ThreadPool(workitems.size()), which
|
||||
// crashes when workitems is empty. Prevent those crashes by creating at
|
||||
// least one thread.
|
||||
ThreadPool::ThreadPool(int num_threads)
|
||||
: max_threads_(num_threads == 0 ? 1 : num_threads) {
|
||||
CHECK_GT(max_threads_, 0u);
|
||||
// Spawn a single thread to handle work by default.
|
||||
absl::MutexLock lock(mutex_);
|
||||
SpawnThread();
|
||||
}
|
||||
|
||||
ThreadPool::ThreadPool(int num_threads) : num_workers_(num_threads) {}
|
||||
|
||||
ThreadPool::ThreadPool(absl::string_view /*prefix*/, int num_threads)
|
||||
: num_workers_(num_threads) {}
|
||||
ThreadPool::ThreadPool(absl::string_view prefix, int num_threads)
|
||||
: ThreadPool(num_threads) {}
|
||||
|
||||
ThreadPool::~ThreadPool() {
|
||||
if (started_) {
|
||||
std::unique_lock<std::mutex> mutex_lock(mutex_);
|
||||
waiting_to_finish_ = true;
|
||||
mutex_lock.unlock();
|
||||
condition_.notify_all();
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
all_workers_[i].join();
|
||||
// Make threads finish up by setting stopping_. Ensure all threads waiting see
|
||||
// this change by signalling their condvar.
|
||||
{
|
||||
absl::MutexLock l(mutex_);
|
||||
stopping_ = true;
|
||||
for (Waiter* absl_nonnull waiter : waiters_) {
|
||||
waiter->cv.Signal();
|
||||
}
|
||||
// Wait until the queue is empty. This implies no new threads will be
|
||||
// spawned, and all existing threads are exiting.
|
||||
auto queue_empty = [this]() ABSL_SHARED_LOCKS_REQUIRED(mutex_) {
|
||||
return queue_.empty();
|
||||
};
|
||||
mutex_.Await(absl::Condition(&queue_empty));
|
||||
}
|
||||
// Join and delete all threads. Because the queue is empty, we know no new
|
||||
// threads will be added to threads_.
|
||||
for (auto& worker : threads_) {
|
||||
worker.join();
|
||||
}
|
||||
}
|
||||
|
||||
void ThreadPool::SetQueueCapacity(int capacity) {
|
||||
CHECK_GT(capacity, num_workers_);
|
||||
CHECK(!started_);
|
||||
queue_capacity_ = capacity;
|
||||
void ThreadPool::SpawnThread() {
|
||||
CHECK_LE(threads_.size(), max_threads_);
|
||||
threads_.emplace_back([this] { RunWorker(); });
|
||||
}
|
||||
|
||||
void ThreadPool::StartWorkers() {
|
||||
started_ = true;
|
||||
for (int i = 0; i < num_workers_; ++i) {
|
||||
all_workers_.push_back(std::thread(&RunWorker, this));
|
||||
void ThreadPool::RunWorker() {
|
||||
{
|
||||
absl::MutexLock lock(mutex_);
|
||||
++running_threads_;
|
||||
}
|
||||
}
|
||||
|
||||
std::function<void()> ThreadPool::GetNextTask() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
for (;;) {
|
||||
if (!tasks_.empty()) {
|
||||
std::function<void()> task = tasks_.front();
|
||||
tasks_.pop_front();
|
||||
if (tasks_.size() < queue_capacity_ && waiting_for_capacity_) {
|
||||
waiting_for_capacity_ = false;
|
||||
capacity_condition_.notify_all();
|
||||
}
|
||||
return task;
|
||||
}
|
||||
if (waiting_to_finish_) {
|
||||
return nullptr;
|
||||
} else {
|
||||
condition_.wait(lock);
|
||||
while (true) {
|
||||
std::optional<absl::AnyInvocable<void() &&>> item = DequeueWork();
|
||||
if (!item.has_value()) { // Requesting to stop the worker thread.
|
||||
break;
|
||||
}
|
||||
DCHECK(item);
|
||||
std::move (*item)();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void ThreadPool::Schedule(std::function<void()> closure) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
while (tasks_.size() >= queue_capacity_) {
|
||||
waiting_for_capacity_ = true;
|
||||
capacity_condition_.wait(lock);
|
||||
void ThreadPool::SignalWaiter() {
|
||||
DCHECK(!queue_.empty());
|
||||
if (waiters_.empty()) {
|
||||
// If there are no waiters, try spawning a new thread to pick up work.
|
||||
if (running_threads_ == threads_.size() && threads_.size() < max_threads_) {
|
||||
SpawnThread();
|
||||
}
|
||||
} else {
|
||||
// If there are waiters we wake the last inserted waiter. Note that we can
|
||||
// signal this waiter multiple times. This is not only ok but it is crucial
|
||||
// to reduce spurious wakeups.
|
||||
waiters_.back()->cv.Signal();
|
||||
}
|
||||
tasks_.push_back(closure);
|
||||
if (started_) {
|
||||
lock.unlock();
|
||||
condition_.notify_all();
|
||||
}
|
||||
|
||||
std::optional<absl::AnyInvocable<void() &&>> ThreadPool::DequeueWork() {
|
||||
// Wait for queue to be not-empty
|
||||
absl::MutexLock m(mutex_);
|
||||
while (queue_.empty() && !stopping_) {
|
||||
Waiter self;
|
||||
waiters_.push_back(&self);
|
||||
self.cv.Wait(&mutex_);
|
||||
waiters_.erase(absl::c_find(waiters_, &self));
|
||||
}
|
||||
if (queue_.empty()) {
|
||||
DCHECK(stopping_);
|
||||
return std::nullopt;
|
||||
}
|
||||
absl::AnyInvocable<void() &&> result = std::move(queue_.front());
|
||||
queue_.pop_front();
|
||||
if (!queue_.empty()) {
|
||||
SignalWaiter();
|
||||
}
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
void ThreadPool::Schedule(absl::AnyInvocable<void() &&> callback) {
|
||||
// Wait for queue to be not-full
|
||||
absl::MutexLock m(mutex_);
|
||||
DCHECK(!stopping_) << "Callback added after destructor started";
|
||||
if (ABSL_PREDICT_FALSE(stopping_)) return;
|
||||
queue_.push_back(std::move(callback));
|
||||
SignalWaiter();
|
||||
}
|
||||
|
||||
} // namespace operations_research
|
||||
|
||||
@@ -14,39 +14,62 @@
|
||||
#ifndef OR_TOOLS_BASE_THREADPOOL_H_
|
||||
#define OR_TOOLS_BASE_THREADPOOL_H_
|
||||
|
||||
#include <condition_variable> // NOLINT
|
||||
#include <functional>
|
||||
#include <list>
|
||||
#include <mutex> // NOLINT
|
||||
#include <string>
|
||||
#include <cstddef>
|
||||
#include <deque>
|
||||
#include <optional>
|
||||
#include <thread> // NOLINT
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/nullability.h"
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/functional/any_invocable.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
|
||||
namespace operations_research {
|
||||
|
||||
class ThreadPool {
|
||||
public:
|
||||
explicit ThreadPool(int num_threads);
|
||||
ThreadPool(absl::string_view prefix, int num_threads);
|
||||
~ThreadPool();
|
||||
|
||||
void StartWorkers();
|
||||
void Schedule(std::function<void()> closure);
|
||||
std::function<void()> GetNextTask();
|
||||
void SetQueueCapacity(int capacity);
|
||||
void Schedule(absl::AnyInvocable<void() &&> callback);
|
||||
|
||||
private:
|
||||
const int num_workers_;
|
||||
std::list<std::function<void()>> tasks_;
|
||||
std::mutex mutex_;
|
||||
std::condition_variable condition_;
|
||||
std::condition_variable capacity_condition_;
|
||||
bool waiting_to_finish_ = false;
|
||||
bool waiting_for_capacity_ = false;
|
||||
bool started_ = false;
|
||||
int queue_capacity_ = 2e9;
|
||||
std::vector<std::thread> all_workers_;
|
||||
// Waiter for a single thread.
|
||||
struct Waiter {
|
||||
absl::CondVar cv; // signalled when there is work to do
|
||||
};
|
||||
|
||||
// Spawn a single new worker thread.
|
||||
//
|
||||
// REQUIRES: threads_.size() < max_threads_
|
||||
void SpawnThread() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
|
||||
|
||||
void RunWorker();
|
||||
|
||||
// Removes the oldest element from the queue and returns it. Causes the
|
||||
// current thread to wait for producers if the queue is empty. Returns
|
||||
// an empty `std::optional` if the thread pool is shutting down.
|
||||
std::optional<absl::AnyInvocable<void() &&>> DequeueWork()
|
||||
ABSL_LOCKS_EXCLUDED(mutex_);
|
||||
|
||||
// Signals a waiter if there is one, or spawns a thread to try to add a new
|
||||
// waiter.
|
||||
//
|
||||
// REQUIRES: !queue_.empty()
|
||||
void SignalWaiter() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
|
||||
|
||||
mutable absl::Mutex mutex_;
|
||||
absl::CondVar wait_nonfull_ ABSL_GUARDED_BY(mutex_);
|
||||
std::vector<Waiter* absl_nonnull> waiters_ ABSL_GUARDED_BY(mutex_);
|
||||
const size_t max_threads_;
|
||||
std::deque<absl::AnyInvocable<void() &&>> queue_;
|
||||
bool stopping_ ABSL_GUARDED_BY(mutex_) = false;
|
||||
size_t running_threads_ ABSL_GUARDED_BY(mutex_) = 0;
|
||||
std::vector<std::thread> threads_ ABSL_GUARDED_BY(mutex_);
|
||||
};
|
||||
|
||||
} // namespace operations_research
|
||||
#endif // OR_TOOLS_BASE_THREADPOOL_H_
|
||||
|
||||
@@ -206,7 +206,6 @@ BidirectionalDijkstra<GraphType, DistanceType>::BidirectionalDijkstra(
|
||||
distances_[dir].assign(num_nodes, infinity());
|
||||
parent_arc_[dir].assign(num_nodes, -1);
|
||||
}
|
||||
search_threads_.StartWorkers();
|
||||
}
|
||||
|
||||
template <typename GraphType, typename DistanceType>
|
||||
|
||||
@@ -557,7 +557,6 @@ GraphPathWithLength<GraphType> ConstrainedShortestPathsOnDagWrapper<
|
||||
|
||||
{
|
||||
ThreadPool search_threads(2);
|
||||
search_threads.StartWorkers();
|
||||
for (const Direction dir : {FORWARD, BACKWARD}) {
|
||||
search_threads.Schedule([this, dir, &sub_arc_lengths]() {
|
||||
RunHalfConstrainedShortestPathOnDag(
|
||||
|
||||
@@ -761,7 +761,6 @@ void ComputeManyToManyShortestPathsWithMultipleThreads(
|
||||
graph.num_nodes());
|
||||
{
|
||||
std::unique_ptr<ThreadPool> pool(new ThreadPool(num_threads));
|
||||
pool->StartWorkers();
|
||||
for (int i = 0; i < unique_sources.size(); ++i) {
|
||||
pool->Schedule(absl::bind_front(
|
||||
&internal::ComputeOneToManyOnGraph<GraphType>, &graph, &arc_lengths,
|
||||
|
||||
@@ -1155,7 +1155,6 @@ void MPSolver::SolveLazyMutableRequest(LazyMutableCopy<MPModelRequest> request,
|
||||
// the user. They shouldn't matter for polling, but for solving we might
|
||||
// e.g. use a larger stack.
|
||||
ThreadPool thread_pool(/*num_threads=*/1);
|
||||
thread_pool.StartWorkers();
|
||||
thread_pool.Schedule(polling_func);
|
||||
|
||||
// Make sure the interruption notification didn't arrive while waiting to
|
||||
|
||||
@@ -51,9 +51,7 @@ class GoogleThreadPoolScheduler : public Scheduler {
|
||||
public:
|
||||
GoogleThreadPoolScheduler(int num_threads)
|
||||
: num_threads_(num_threads),
|
||||
threadpool_(std::make_unique<ThreadPool>("pdlp", num_threads)) {
|
||||
threadpool_->StartWorkers();
|
||||
}
|
||||
threadpool_(std::make_unique<ThreadPool>("pdlp", num_threads)) {}
|
||||
int num_threads() const override { return num_threads_; };
|
||||
std::string info_string() const override { return "google_threadpool"; };
|
||||
|
||||
@@ -79,7 +77,7 @@ class EigenThreadPoolScheduler : public Scheduler {
|
||||
public:
|
||||
EigenThreadPoolScheduler(int num_threads)
|
||||
: num_threads_(num_threads),
|
||||
eigen_threadpool_(std::make_unique<Eigen::ThreadPool>(num_threads)) {}
|
||||
g3_threadpool_(std::make_unique<Eigen::ThreadPool>(num_threads)) {}
|
||||
int num_threads() const override { return num_threads_; };
|
||||
std::string info_string() const override { return "eigen_threadpool"; };
|
||||
|
||||
@@ -87,7 +85,7 @@ class EigenThreadPoolScheduler : public Scheduler {
|
||||
absl::AnyInvocable<void(int)> do_func) override {
|
||||
Eigen::Barrier eigen_barrier(end - start);
|
||||
for (int i = start; i < end; ++i) {
|
||||
eigen_threadpool_->Schedule([&, i]() {
|
||||
g3_threadpool_->Schedule([&, i]() {
|
||||
do_func(i);
|
||||
eigen_barrier.Notify();
|
||||
});
|
||||
@@ -97,7 +95,7 @@ class EigenThreadPoolScheduler : public Scheduler {
|
||||
|
||||
private:
|
||||
const int num_threads_;
|
||||
std::unique_ptr<Eigen::ThreadPool> eigen_threadpool_ = nullptr;
|
||||
std::unique_ptr<Eigen::ThreadPool> g3_threadpool_ = nullptr;
|
||||
};
|
||||
|
||||
// Makes a scheduler of a given type.
|
||||
|
||||
@@ -142,7 +142,6 @@ void DeterministicLoop(std::vector<std::unique_ptr<SubSolver>>& subsolvers,
|
||||
std::vector<double> timing;
|
||||
to_run.reserve(batch_size);
|
||||
ThreadPool pool(num_threads);
|
||||
pool.StartWorkers();
|
||||
for (int batch_index = 0;; ++batch_index) {
|
||||
VLOG(2) << "Starting deterministic batch of size " << batch_size;
|
||||
SynchronizeAll(subsolvers);
|
||||
@@ -214,7 +213,6 @@ void NonDeterministicLoop(std::vector<std::unique_ptr<SubSolver>>& subsolvers,
|
||||
};
|
||||
|
||||
ThreadPool pool(num_threads);
|
||||
pool.StartWorkers();
|
||||
|
||||
// The lambda below are using little space, but there is no reason
|
||||
// to create millions of them, so we use the blocking nature of
|
||||
|
||||
Reference in New Issue
Block a user