diff --git a/ortools/base/BUILD.bazel b/ortools/base/BUILD.bazel index 7d1015517a..a731d99bc8 100644 --- a/ortools/base/BUILD.bazel +++ b/ortools/base/BUILD.bazel @@ -535,8 +535,12 @@ cc_library( srcs = ["threadpool.cc"], hdrs = ["threadpool.h"], deps = [ + "@abseil-cpp//absl/algorithm:container", + "@abseil-cpp//absl/base:core_headers", + "@abseil-cpp//absl/base:nullability", + "@abseil-cpp//absl/functional:any_invocable", "@abseil-cpp//absl/log:check", - "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/synchronization", ], ) diff --git a/ortools/base/threadpool.cc b/ortools/base/threadpool.cc index 5ba6cd6821..6dd04c662e 100644 --- a/ortools/base/threadpool.cc +++ b/ortools/base/threadpool.cc @@ -13,84 +13,120 @@ #include "ortools/base/threadpool.h" -#include -#include +#include +#include +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/base/thread_annotations.h" +#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" namespace operations_research { -void RunWorker(void* data) { - ThreadPool* const thread_pool = reinterpret_cast(data); - std::function work = thread_pool->GetNextTask(); - while (work != nullptr) { - work(); - work = thread_pool->GetNextTask(); - } + +// It is a common error to call ThreadPool(workitems.size()), which +// crashes when workitems is empty. Prevent those crashes by creating at +// least one thread. +ThreadPool::ThreadPool(int num_threads) + : max_threads_(num_threads == 0 ? 1 : num_threads) { + CHECK_GT(max_threads_, 0u); + // Spawn a single thread to handle work by default. + absl::MutexLock lock(mutex_); + SpawnThread(); } -ThreadPool::ThreadPool(int num_threads) : num_workers_(num_threads) {} - -ThreadPool::ThreadPool(absl::string_view /*prefix*/, int num_threads) - : num_workers_(num_threads) {} +ThreadPool::ThreadPool(absl::string_view prefix, int num_threads) + : ThreadPool(num_threads) {} ThreadPool::~ThreadPool() { - if (started_) { - std::unique_lock mutex_lock(mutex_); - waiting_to_finish_ = true; - mutex_lock.unlock(); - condition_.notify_all(); - for (int i = 0; i < num_workers_; ++i) { - all_workers_[i].join(); + // Make threads finish up by setting stopping_. Ensure all threads waiting see + // this change by signalling their condvar. + { + absl::MutexLock l(mutex_); + stopping_ = true; + for (Waiter* absl_nonnull waiter : waiters_) { + waiter->cv.Signal(); } + // Wait until the queue is empty. This implies no new threads will be + // spawned, and all existing threads are exiting. + auto queue_empty = [this]() ABSL_SHARED_LOCKS_REQUIRED(mutex_) { + return queue_.empty(); + }; + mutex_.Await(absl::Condition(&queue_empty)); + } + // Join and delete all threads. Because the queue is empty, we know no new + // threads will be added to threads_. + for (auto& worker : threads_) { + worker.join(); } } -void ThreadPool::SetQueueCapacity(int capacity) { - CHECK_GT(capacity, num_workers_); - CHECK(!started_); - queue_capacity_ = capacity; +void ThreadPool::SpawnThread() { + CHECK_LE(threads_.size(), max_threads_); + threads_.emplace_back([this] { RunWorker(); }); } -void ThreadPool::StartWorkers() { - started_ = true; - for (int i = 0; i < num_workers_; ++i) { - all_workers_.push_back(std::thread(&RunWorker, this)); +void ThreadPool::RunWorker() { + { + absl::MutexLock lock(mutex_); + ++running_threads_; } -} - -std::function ThreadPool::GetNextTask() { - std::unique_lock lock(mutex_); - for (;;) { - if (!tasks_.empty()) { - std::function task = tasks_.front(); - tasks_.pop_front(); - if (tasks_.size() < queue_capacity_ && waiting_for_capacity_) { - waiting_for_capacity_ = false; - capacity_condition_.notify_all(); - } - return task; - } - if (waiting_to_finish_) { - return nullptr; - } else { - condition_.wait(lock); + while (true) { + std::optional> item = DequeueWork(); + if (!item.has_value()) { // Requesting to stop the worker thread. + break; } + DCHECK(item); + std::move (*item)(); } - return nullptr; } -void ThreadPool::Schedule(std::function closure) { - std::unique_lock lock(mutex_); - while (tasks_.size() >= queue_capacity_) { - waiting_for_capacity_ = true; - capacity_condition_.wait(lock); +void ThreadPool::SignalWaiter() { + DCHECK(!queue_.empty()); + if (waiters_.empty()) { + // If there are no waiters, try spawning a new thread to pick up work. + if (running_threads_ == threads_.size() && threads_.size() < max_threads_) { + SpawnThread(); + } + } else { + // If there are waiters we wake the last inserted waiter. Note that we can + // signal this waiter multiple times. This is not only ok but it is crucial + // to reduce spurious wakeups. + waiters_.back()->cv.Signal(); } - tasks_.push_back(closure); - if (started_) { - lock.unlock(); - condition_.notify_all(); +} + +std::optional> ThreadPool::DequeueWork() { + // Wait for queue to be not-empty + absl::MutexLock m(mutex_); + while (queue_.empty() && !stopping_) { + Waiter self; + waiters_.push_back(&self); + self.cv.Wait(&mutex_); + waiters_.erase(absl::c_find(waiters_, &self)); } + if (queue_.empty()) { + DCHECK(stopping_); + return std::nullopt; + } + absl::AnyInvocable result = std::move(queue_.front()); + queue_.pop_front(); + if (!queue_.empty()) { + SignalWaiter(); + } + return std::move(result); +} + +void ThreadPool::Schedule(absl::AnyInvocable callback) { + // Wait for queue to be not-full + absl::MutexLock m(mutex_); + DCHECK(!stopping_) << "Callback added after destructor started"; + if (ABSL_PREDICT_FALSE(stopping_)) return; + queue_.push_back(std::move(callback)); + SignalWaiter(); } } // namespace operations_research diff --git a/ortools/base/threadpool.h b/ortools/base/threadpool.h index 4a69390822..78130046b4 100644 --- a/ortools/base/threadpool.h +++ b/ortools/base/threadpool.h @@ -14,39 +14,62 @@ #ifndef OR_TOOLS_BASE_THREADPOOL_H_ #define OR_TOOLS_BASE_THREADPOOL_H_ -#include // NOLINT -#include -#include -#include // NOLINT -#include +#include +#include +#include #include // NOLINT #include +#include "absl/base/nullability.h" +#include "absl/base/thread_annotations.h" +#include "absl/functional/any_invocable.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" namespace operations_research { + class ThreadPool { public: explicit ThreadPool(int num_threads); ThreadPool(absl::string_view prefix, int num_threads); ~ThreadPool(); - void StartWorkers(); - void Schedule(std::function closure); - std::function GetNextTask(); - void SetQueueCapacity(int capacity); + void Schedule(absl::AnyInvocable callback); private: - const int num_workers_; - std::list> tasks_; - std::mutex mutex_; - std::condition_variable condition_; - std::condition_variable capacity_condition_; - bool waiting_to_finish_ = false; - bool waiting_for_capacity_ = false; - bool started_ = false; - int queue_capacity_ = 2e9; - std::vector all_workers_; + // Waiter for a single thread. + struct Waiter { + absl::CondVar cv; // signalled when there is work to do + }; + + // Spawn a single new worker thread. + // + // REQUIRES: threads_.size() < max_threads_ + void SpawnThread() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + void RunWorker(); + + // Removes the oldest element from the queue and returns it. Causes the + // current thread to wait for producers if the queue is empty. Returns + // an empty `std::optional` if the thread pool is shutting down. + std::optional> DequeueWork() + ABSL_LOCKS_EXCLUDED(mutex_); + + // Signals a waiter if there is one, or spawns a thread to try to add a new + // waiter. + // + // REQUIRES: !queue_.empty() + void SignalWaiter() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + mutable absl::Mutex mutex_; + absl::CondVar wait_nonfull_ ABSL_GUARDED_BY(mutex_); + std::vector waiters_ ABSL_GUARDED_BY(mutex_); + const size_t max_threads_; + std::deque> queue_; + bool stopping_ ABSL_GUARDED_BY(mutex_) = false; + size_t running_threads_ ABSL_GUARDED_BY(mutex_) = 0; + std::vector threads_ ABSL_GUARDED_BY(mutex_); }; + } // namespace operations_research #endif // OR_TOOLS_BASE_THREADPOOL_H_ diff --git a/ortools/graph/bidirectional_dijkstra.h b/ortools/graph/bidirectional_dijkstra.h index fe82dc5d06..04d93a9e2c 100644 --- a/ortools/graph/bidirectional_dijkstra.h +++ b/ortools/graph/bidirectional_dijkstra.h @@ -206,7 +206,6 @@ BidirectionalDijkstra::BidirectionalDijkstra( distances_[dir].assign(num_nodes, infinity()); parent_arc_[dir].assign(num_nodes, -1); } - search_threads_.StartWorkers(); } template diff --git a/ortools/graph/dag_constrained_shortest_path.h b/ortools/graph/dag_constrained_shortest_path.h index ede783b923..7e140d2397 100644 --- a/ortools/graph/dag_constrained_shortest_path.h +++ b/ortools/graph/dag_constrained_shortest_path.h @@ -557,7 +557,6 @@ GraphPathWithLength ConstrainedShortestPathsOnDagWrapper< { ThreadPool search_threads(2); - search_threads.StartWorkers(); for (const Direction dir : {FORWARD, BACKWARD}) { search_threads.Schedule([this, dir, &sub_arc_lengths]() { RunHalfConstrainedShortestPathOnDag( diff --git a/ortools/graph/shortest_paths.h b/ortools/graph/shortest_paths.h index 68071e3a3b..585882d6dc 100644 --- a/ortools/graph/shortest_paths.h +++ b/ortools/graph/shortest_paths.h @@ -761,7 +761,6 @@ void ComputeManyToManyShortestPathsWithMultipleThreads( graph.num_nodes()); { std::unique_ptr pool(new ThreadPool(num_threads)); - pool->StartWorkers(); for (int i = 0; i < unique_sources.size(); ++i) { pool->Schedule(absl::bind_front( &internal::ComputeOneToManyOnGraph, &graph, &arc_lengths, diff --git a/ortools/linear_solver/linear_solver.cc b/ortools/linear_solver/linear_solver.cc index 8c781bceca..db5fce17f9 100644 --- a/ortools/linear_solver/linear_solver.cc +++ b/ortools/linear_solver/linear_solver.cc @@ -1155,7 +1155,6 @@ void MPSolver::SolveLazyMutableRequest(LazyMutableCopy request, // the user. They shouldn't matter for polling, but for solving we might // e.g. use a larger stack. ThreadPool thread_pool(/*num_threads=*/1); - thread_pool.StartWorkers(); thread_pool.Schedule(polling_func); // Make sure the interruption notification didn't arrive while waiting to diff --git a/ortools/pdlp/scheduler.h b/ortools/pdlp/scheduler.h index 9667c99417..59ddd5ccf5 100644 --- a/ortools/pdlp/scheduler.h +++ b/ortools/pdlp/scheduler.h @@ -51,9 +51,7 @@ class GoogleThreadPoolScheduler : public Scheduler { public: GoogleThreadPoolScheduler(int num_threads) : num_threads_(num_threads), - threadpool_(std::make_unique("pdlp", num_threads)) { - threadpool_->StartWorkers(); - } + threadpool_(std::make_unique("pdlp", num_threads)) {} int num_threads() const override { return num_threads_; }; std::string info_string() const override { return "google_threadpool"; }; @@ -79,7 +77,7 @@ class EigenThreadPoolScheduler : public Scheduler { public: EigenThreadPoolScheduler(int num_threads) : num_threads_(num_threads), - eigen_threadpool_(std::make_unique(num_threads)) {} + g3_threadpool_(std::make_unique(num_threads)) {} int num_threads() const override { return num_threads_; }; std::string info_string() const override { return "eigen_threadpool"; }; @@ -87,7 +85,7 @@ class EigenThreadPoolScheduler : public Scheduler { absl::AnyInvocable do_func) override { Eigen::Barrier eigen_barrier(end - start); for (int i = start; i < end; ++i) { - eigen_threadpool_->Schedule([&, i]() { + g3_threadpool_->Schedule([&, i]() { do_func(i); eigen_barrier.Notify(); }); @@ -97,7 +95,7 @@ class EigenThreadPoolScheduler : public Scheduler { private: const int num_threads_; - std::unique_ptr eigen_threadpool_ = nullptr; + std::unique_ptr g3_threadpool_ = nullptr; }; // Makes a scheduler of a given type. diff --git a/ortools/sat/subsolver.cc b/ortools/sat/subsolver.cc index 237515be0e..ea50718b26 100644 --- a/ortools/sat/subsolver.cc +++ b/ortools/sat/subsolver.cc @@ -142,7 +142,6 @@ void DeterministicLoop(std::vector>& subsolvers, std::vector timing; to_run.reserve(batch_size); ThreadPool pool(num_threads); - pool.StartWorkers(); for (int batch_index = 0;; ++batch_index) { VLOG(2) << "Starting deterministic batch of size " << batch_size; SynchronizeAll(subsolvers); @@ -214,7 +213,6 @@ void NonDeterministicLoop(std::vector>& subsolvers, }; ThreadPool pool(num_threads); - pool.StartWorkers(); // The lambda below are using little space, but there is no reason // to create millions of them, so we use the blocking nature of