From cafbcb17a30d7733ca0767bfff7d664f915780c3 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Tue, 5 Sep 2023 17:09:29 +0200 Subject: [PATCH] polish/fix/cleanup includes --- ortools/algorithms/binary_search.h | 154 ++++++++++++++--------- ortools/algorithms/binary_search_test.cc | 42 ++++++- ortools/base/threadpool.cc | 3 +- ortools/base/threadpool.h | 4 +- ortools/lp_data/mps_reader.cc | 14 +-- ortools/lp_data/mps_reader.h | 13 +- 6 files changed, 151 insertions(+), 79 deletions(-) diff --git a/ortools/algorithms/binary_search.h b/ortools/algorithms/binary_search.h index 150617405f..079d90f1e9 100644 --- a/ortools/algorithms/binary_search.h +++ b/ortools/algorithms/binary_search.h @@ -16,8 +16,12 @@ #include #include +#include #include +#include +#include "absl/functional/function_ref.h" +#include "absl/log/check.h" #include "absl/numeric/int128.h" #include "ortools/base/dump_vars.h" #include "ortools/base/logging.h" @@ -99,7 +103,7 @@ Point BinarySearchMidpoint(Point x, Point y); // - We technically do not need the points to be sorted and can use // linear-time median computation to speed this up. // -// TODO(user): replace std::function by absl::AnyInvocable here and in +// TODO(user): replace std::function by absl::FunctionRef here and in // BinarySearch(). template std::pair ConvexMinimum(absl::Span sorted_points, @@ -115,6 +119,15 @@ std::pair ConvexMinimum(bool is_to_the_right, absl::Span sorted_points, std::function f); +// Searches in the range [begin, end), where Point supports basic arithmetic. +template +std::pair RangeConvexMinimum(Point begin, Point end, + absl::FunctionRef f); +template +std::pair RangeConvexMinimum(std::pair current_min, + Point begin, Point end, + absl::FunctionRef f); + // _____________________________________________________________________________ // Implementation. @@ -222,38 +235,89 @@ Point BinarySearch(Point x_true, Point x_false, std::function f) { } template -std::pair ConvexMinimum(absl::Span sorted_points, - std::function f) { - DCHECK(!sorted_points.empty()); - if (sorted_points.size() == 1) { - return {sorted_points[0], f(sorted_points[0])}; +std::pair RangeConvexMinimum(Point begin, Point end, + absl::FunctionRef f) { + DCHECK_LT(begin, end); + const Value size = end - begin; + if (size == 1) { + return {begin, f(begin)}; } // Starts by splitting interval in two with two queries and getting some info. // Note the current min will be outside the interval. - bool is_to_the_right; std::pair current_min; { - DCHECK_GE(sorted_points.size(), 2); - const int i = sorted_points.size() / 2; - const Value v = f(sorted_points[i]); - const int before_i = i - 1; - const Value before_v = f(sorted_points[before_i]); - if (before_v == v) return {sorted_points[before_i], before_v}; + DCHECK_GE(size, 2); + const Point mid = begin + (end - begin) / 2; + DCHECK_GT(mid, begin); + const Value v = f(mid); + const Point before_mid = mid - 1; + const Value before_v = f(before_mid); + if (before_v == v) return {before_mid, before_v}; if (before_v < v) { - // Note that we exclude before_i from the span. - current_min = {sorted_points[before_i], before_v}; - is_to_the_right = true; - sorted_points = sorted_points.subspan(0, std::max(0, before_i)); + // Note that we exclude before_mid from the range. + current_min = {before_mid, before_v}; + end = before_mid; } else { - is_to_the_right = false; - current_min = {sorted_points[i], v}; - sorted_points = sorted_points.subspan(i + 1); + current_min = {mid, v}; + begin = mid + 1; } } - if (sorted_points.empty()) return current_min; - return ConvexMinimum(is_to_the_right, current_min, - sorted_points, std::move(f)); + if (begin >= end) return current_min; + return RangeConvexMinimum(current_min, begin, end, f); +} + +template +std::pair RangeConvexMinimum(std::pair current_min, + Point begin, Point end, + absl::FunctionRef f) { + DCHECK_LT(begin, end); + while ((end - begin) > 1) { + DCHECK(current_min.first < begin || current_min.first >= end); + bool current_is_after_end = current_min.first >= end; + const Point mid = begin + (end - begin) / 2; + const Value v = f(mid); + if (v >= current_min.second) { + // If the midpoint is no better than our current minimum, then the + // global min must lie between our midpoint and our current min. + if (current_is_after_end) { + begin = mid + 1; + } else { + end = mid; + } + } else { + // v < current_min.second, we cannot decide, so we use a second value + // close to v like in the initial step. + DCHECK_GT(mid, begin); + const Point before_mid = mid - 1; + const Value before_v = f(before_mid); + if (before_v == v) return {before_mid, before_v}; + if (before_v < v) { + current_min = {before_mid, before_v}; + current_is_after_end = true; + end = before_mid; + } else { + current_is_after_end = false; + current_min = {mid, v}; + begin = mid + 1; + } + } + } + + if (end - begin == 1) { + const Value v = f(begin); + if (v <= current_min.second) return {begin, v}; + } + return current_min; +} + +template +std::pair ConvexMinimum(absl::Span sorted_points, + std::function f) { + auto index_f = [&](int index) -> Value { return f(sorted_points[index]); }; + const auto& [index, v] = + RangeConvexMinimum(0, sorted_points.size(), index_f); + return {sorted_points[index], v}; } template @@ -261,44 +325,14 @@ std::pair ConvexMinimum(bool is_to_the_right, std::pair current_min, absl::Span sorted_points, std::function f) { - DCHECK(!sorted_points.empty()); - while (sorted_points.size() > 1) { - const int i = sorted_points.size() / 2; - const Value v = f(sorted_points[i]); - if (v >= current_min.second) { - // If the midpoint is no better than our current minimum, then the - // global min must lie between our midpoint and our current min. - if (is_to_the_right) { - sorted_points = sorted_points.subspan(i + 1); - } else { - sorted_points = sorted_points.subspan(0, i); - } - } else { - // v < current_min.second, we cannot decide, so we use a second value - // close to v like in the initial step. - DCHECK_GT(i, 0); - const int before_i = i - 1; - const Value before_v = f(sorted_points[before_i]); - if (before_v == v) return {sorted_points[before_i], before_v}; - if (before_v < v) { - current_min = {sorted_points[before_i], before_v}; - is_to_the_right = true; - sorted_points = sorted_points.subspan(0, std::max(0, before_i)); - } else { - is_to_the_right = false; - current_min = {sorted_points[i], v}; - sorted_points = sorted_points.subspan(i + 1); - } - } - } - - if (!sorted_points.empty()) { - const Value v = f(sorted_points[0]); - if (v <= current_min.second) return {sorted_points[0], v}; - } - return current_min; + auto index_f = [&](int index) -> Value { return f(sorted_points[index]); }; + std::pair index_current_min = std::make_pair( + is_to_the_right ? sorted_points.size() : -1, current_min.second); + const auto& [index, v] = RangeConvexMinimum( + index_current_min, 0, sorted_points.size(), index_f); + if (index == index_current_min.first) return current_min; + return {sorted_points[index], v}; } - } // namespace operations_research #endif // OR_TOOLS_ALGORITHMS_BINARY_SEARCH_H_ diff --git a/ortools/algorithms/binary_search_test.cc b/ortools/algorithms/binary_search_test.cc index ff194c8e1a..d3cb24040a 100644 --- a/ortools/algorithms/binary_search_test.cc +++ b/ortools/algorithms/binary_search_test.cc @@ -328,9 +328,13 @@ TEST(ConvexMinimumTest, ExhaustiveTest) { }); total_num_queries += num_queries; max_num_queries = std::max(max_num_queries, num_queries); - ASSERT_EQ(value, 0); - ASSERT_GE(point, b1); - ASSERT_LE(point, b2); + EXPECT_EQ(value, 0); + EXPECT_GE(point, b1); + EXPECT_LE(point, b2); + // Fail after one example. + ASSERT_TRUE(value == 0 && b1 <= point && point <= b2) + << "queries: " << num_queries << " opt range: [" << b1 << ", " << b2 + << "]"; } } @@ -378,4 +382,36 @@ TEST(ConvexMinimumTest, TwoQueriesIfSizeTwoReversed) { EXPECT_EQ(num_queries, 2); } +TEST(RangeConvexMinimumTest, HugeRangeTest) { + int total_num_queries = 0; + int max_num_queries = 0; + for (int b1 = -100; b1 < 100; ++b1) { + for (int b2 = b1; b2 < b1 + 100; ++b2) { + int num_queries = 0; + const auto [point, value] = RangeConvexMinimum( + std::numeric_limits::min() / 2, + std::numeric_limits::max() / 2, [&](int64_t v) -> double { + ++num_queries; + if (v < b1) { + return b1 - v; + } else if (v > b2) { + return v - b2; + } + return 0; + }); + total_num_queries += num_queries; + max_num_queries = std::max(max_num_queries, num_queries); + EXPECT_EQ(value, 0); + EXPECT_GE(point, b1); + EXPECT_LE(point, b2); + // Don't continue past the first failing example to limit the number of + // errors. + ASSERT_TRUE(value == 0 && b1 <= point && point <= b2) + << "queries: " << num_queries << " opt range: [" << b1 << ", " << b2 + << "]"; + } + } + // 80 is the worst case we would expect from ternary search: 2*log_3(2^63). + EXPECT_LE(max_num_queries, 80); +} } // namespace operations_research diff --git a/ortools/base/threadpool.cc b/ortools/base/threadpool.cc index e82c9c569b..efa9d978b2 100644 --- a/ortools/base/threadpool.cc +++ b/ortools/base/threadpool.cc @@ -14,6 +14,7 @@ #include "ortools/base/threadpool.h" #include "absl/log/check.h" +#include "absl/strings/string_view.h" namespace operations_research { void RunWorker(void* data) { @@ -25,7 +26,7 @@ void RunWorker(void* data) { } } -ThreadPool::ThreadPool(const std::string& prefix, int num_workers) +ThreadPool::ThreadPool(absl::string_view prefix, int num_workers) : num_workers_(num_workers) {} ThreadPool::~ThreadPool() { diff --git a/ortools/base/threadpool.h b/ortools/base/threadpool.h index 7595afed6b..3a89e55740 100644 --- a/ortools/base/threadpool.h +++ b/ortools/base/threadpool.h @@ -22,10 +22,12 @@ #include // NOLINT #include +#include "absl/strings/string_view.h" + namespace operations_research { class ThreadPool { public: - ThreadPool(const std::string& prefix, int num_threads); + ThreadPool(absl::string_view prefix, int num_threads); ~ThreadPool(); void StartWorkers(); diff --git a/ortools/lp_data/mps_reader.cc b/ortools/lp_data/mps_reader.cc index 65b42ccdc2..3e0eba8d2d 100644 --- a/ortools/lp_data/mps_reader.cc +++ b/ortools/lp_data/mps_reader.cc @@ -321,7 +321,7 @@ MPSReaderFormat TemplateFormat(MPSReader::Form form) { } // namespace // Parses instance from a file. -absl::Status MPSReader::ParseFile(const std::string& file_name, +absl::Status MPSReader::ParseFile(absl::string_view file_name, LinearProgram* data, Form form) { DataWrapper data_wrapper(data); return MPSReaderTemplate>() @@ -329,7 +329,7 @@ absl::Status MPSReader::ParseFile(const std::string& file_name, .status(); } -absl::Status MPSReader::ParseFile(const std::string& file_name, +absl::Status MPSReader::ParseFile(absl::string_view file_name, MPModelProto* data, Form form) { DataWrapper data_wrapper(data); return MPSReaderTemplate>() @@ -339,7 +339,7 @@ absl::Status MPSReader::ParseFile(const std::string& file_name, // Loads instance from string. Useful with MapReduce. Automatically detects // the file's format (free or fixed). -absl::Status MPSReader::ParseProblemFromString(const std::string& source, +absl::Status MPSReader::ParseProblemFromString(absl::string_view source, LinearProgram* data, MPSReader::Form form) { DataWrapper data_wrapper(data); @@ -348,7 +348,7 @@ absl::Status MPSReader::ParseProblemFromString(const std::string& source, .status(); } -absl::Status MPSReader::ParseProblemFromString(const std::string& source, +absl::Status MPSReader::ParseProblemFromString(absl::string_view source, MPModelProto* data, MPSReader::Form form) { DataWrapper data_wrapper(data); @@ -357,8 +357,7 @@ absl::Status MPSReader::ParseProblemFromString(const std::string& source, .status(); } -absl::StatusOr MpsDataToMPModelProto( - const std::string& mps_data) { +absl::StatusOr MpsDataToMPModelProto(absl::string_view mps_data) { MPModelProto model; DataWrapper data_wrapper(&model); RETURN_IF_ERROR( @@ -368,8 +367,7 @@ absl::StatusOr MpsDataToMPModelProto( return model; } -absl::StatusOr MpsFileToMPModelProto( - const std::string& mps_file) { +absl::StatusOr MpsFileToMPModelProto(absl::string_view mps_file) { MPModelProto model; DataWrapper data_wrapper(&model); RETURN_IF_ERROR( diff --git a/ortools/lp_data/mps_reader.h b/ortools/lp_data/mps_reader.h index 82fcb740ba..df0dbdde53 100644 --- a/ortools/lp_data/mps_reader.h +++ b/ortools/lp_data/mps_reader.h @@ -29,6 +29,7 @@ #include "absl/base/attributes.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "ortools/linear_solver/linear_solver.pb.h" #include "ortools/lp_data/lp_data.h" @@ -36,10 +37,10 @@ namespace operations_research { namespace glop { // Parses an MPS model from a string. -absl::StatusOr MpsDataToMPModelProto(const std::string& mps_data); +absl::StatusOr MpsDataToMPModelProto(absl::string_view mps_data); // Parses an MPS model from a file. -absl::StatusOr MpsFileToMPModelProto(const std::string& mps_file); +absl::StatusOr MpsFileToMPModelProto(absl::string_view mps_file); // Implementation class. Please use the 2 functions above. // @@ -54,17 +55,17 @@ class ABSL_DEPRECATED("Use the direct methods instead") MPSReader { enum Form { AUTO_DETECT, FREE, FIXED }; // Parses instance from a file. - absl::Status ParseFile(const std::string& file_name, LinearProgram* data, + absl::Status ParseFile(absl::string_view file_name, LinearProgram* data, Form form = AUTO_DETECT); - absl::Status ParseFile(const std::string& file_name, MPModelProto* data, + absl::Status ParseFile(absl::string_view file_name, MPModelProto* data, Form form = AUTO_DETECT); // Loads instance from string. Useful with MapReduce. Automatically detects // the file's format (free or fixed). - absl::Status ParseProblemFromString(const std::string& source, + absl::Status ParseProblemFromString(absl::string_view source, LinearProgram* data, MPSReader::Form form = AUTO_DETECT); - absl::Status ParseProblemFromString(const std::string& source, + absl::Status ParseProblemFromString(absl::string_view source, MPModelProto* data, MPSReader::Form form = AUTO_DETECT); };