polish/fix/cleanup includes
This commit is contained in:
@@ -16,8 +16,12 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/functional/function_ref.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/numeric/int128.h"
|
||||
#include "ortools/base/dump_vars.h"
|
||||
#include "ortools/base/logging.h"
|
||||
@@ -99,7 +103,7 @@ Point BinarySearchMidpoint(Point x, Point y);
|
||||
// - We technically do not need the points to be sorted and can use
|
||||
// linear-time median computation to speed this up.
|
||||
//
|
||||
// TODO(user): replace std::function by absl::AnyInvocable here and in
|
||||
// TODO(user): replace std::function by absl::FunctionRef here and in
|
||||
// BinarySearch().
|
||||
template <class Point, class Value>
|
||||
std::pair<Point, Value> ConvexMinimum(absl::Span<const Point> sorted_points,
|
||||
@@ -115,6 +119,15 @@ std::pair<Point, Value> ConvexMinimum(bool is_to_the_right,
|
||||
absl::Span<const Point> sorted_points,
|
||||
std::function<Value(Point)> f);
|
||||
|
||||
// Searches in the range [begin, end), where Point supports basic arithmetic.
|
||||
template <class Point, class Value>
|
||||
std::pair<Point, Value> RangeConvexMinimum(Point begin, Point end,
|
||||
absl::FunctionRef<Value(Point)> f);
|
||||
template <class Point, class Value>
|
||||
std::pair<Point, Value> RangeConvexMinimum(std::pair<Point, Value> current_min,
|
||||
Point begin, Point end,
|
||||
absl::FunctionRef<Value(Point)> f);
|
||||
|
||||
// _____________________________________________________________________________
|
||||
// Implementation.
|
||||
|
||||
@@ -222,38 +235,89 @@ Point BinarySearch(Point x_true, Point x_false, std::function<bool(Point)> f) {
|
||||
}
|
||||
|
||||
template <class Point, class Value>
|
||||
std::pair<Point, Value> ConvexMinimum(absl::Span<const Point> sorted_points,
|
||||
std::function<Value(Point)> f) {
|
||||
DCHECK(!sorted_points.empty());
|
||||
if (sorted_points.size() == 1) {
|
||||
return {sorted_points[0], f(sorted_points[0])};
|
||||
std::pair<Point, Value> RangeConvexMinimum(Point begin, Point end,
|
||||
absl::FunctionRef<Value(Point)> f) {
|
||||
DCHECK_LT(begin, end);
|
||||
const Value size = end - begin;
|
||||
if (size == 1) {
|
||||
return {begin, f(begin)};
|
||||
}
|
||||
|
||||
// Starts by splitting interval in two with two queries and getting some info.
|
||||
// Note the current min will be outside the interval.
|
||||
bool is_to_the_right;
|
||||
std::pair<Point, Value> current_min;
|
||||
{
|
||||
DCHECK_GE(sorted_points.size(), 2);
|
||||
const int i = sorted_points.size() / 2;
|
||||
const Value v = f(sorted_points[i]);
|
||||
const int before_i = i - 1;
|
||||
const Value before_v = f(sorted_points[before_i]);
|
||||
if (before_v == v) return {sorted_points[before_i], before_v};
|
||||
DCHECK_GE(size, 2);
|
||||
const Point mid = begin + (end - begin) / 2;
|
||||
DCHECK_GT(mid, begin);
|
||||
const Value v = f(mid);
|
||||
const Point before_mid = mid - 1;
|
||||
const Value before_v = f(before_mid);
|
||||
if (before_v == v) return {before_mid, before_v};
|
||||
if (before_v < v) {
|
||||
// Note that we exclude before_i from the span.
|
||||
current_min = {sorted_points[before_i], before_v};
|
||||
is_to_the_right = true;
|
||||
sorted_points = sorted_points.subspan(0, std::max(0, before_i));
|
||||
// Note that we exclude before_mid from the range.
|
||||
current_min = {before_mid, before_v};
|
||||
end = before_mid;
|
||||
} else {
|
||||
is_to_the_right = false;
|
||||
current_min = {sorted_points[i], v};
|
||||
sorted_points = sorted_points.subspan(i + 1);
|
||||
current_min = {mid, v};
|
||||
begin = mid + 1;
|
||||
}
|
||||
}
|
||||
if (sorted_points.empty()) return current_min;
|
||||
return ConvexMinimum<Point, Value>(is_to_the_right, current_min,
|
||||
sorted_points, std::move(f));
|
||||
if (begin >= end) return current_min;
|
||||
return RangeConvexMinimum<Point, Value>(current_min, begin, end, f);
|
||||
}
|
||||
|
||||
template <class Point, class Value>
|
||||
std::pair<Point, Value> RangeConvexMinimum(std::pair<Point, Value> current_min,
|
||||
Point begin, Point end,
|
||||
absl::FunctionRef<Value(Point)> f) {
|
||||
DCHECK_LT(begin, end);
|
||||
while ((end - begin) > 1) {
|
||||
DCHECK(current_min.first < begin || current_min.first >= end);
|
||||
bool current_is_after_end = current_min.first >= end;
|
||||
const Point mid = begin + (end - begin) / 2;
|
||||
const Value v = f(mid);
|
||||
if (v >= current_min.second) {
|
||||
// If the midpoint is no better than our current minimum, then the
|
||||
// global min must lie between our midpoint and our current min.
|
||||
if (current_is_after_end) {
|
||||
begin = mid + 1;
|
||||
} else {
|
||||
end = mid;
|
||||
}
|
||||
} else {
|
||||
// v < current_min.second, we cannot decide, so we use a second value
|
||||
// close to v like in the initial step.
|
||||
DCHECK_GT(mid, begin);
|
||||
const Point before_mid = mid - 1;
|
||||
const Value before_v = f(before_mid);
|
||||
if (before_v == v) return {before_mid, before_v};
|
||||
if (before_v < v) {
|
||||
current_min = {before_mid, before_v};
|
||||
current_is_after_end = true;
|
||||
end = before_mid;
|
||||
} else {
|
||||
current_is_after_end = false;
|
||||
current_min = {mid, v};
|
||||
begin = mid + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (end - begin == 1) {
|
||||
const Value v = f(begin);
|
||||
if (v <= current_min.second) return {begin, v};
|
||||
}
|
||||
return current_min;
|
||||
}
|
||||
|
||||
template <class Point, class Value>
|
||||
std::pair<Point, Value> ConvexMinimum(absl::Span<const Point> sorted_points,
|
||||
std::function<Value(Point)> f) {
|
||||
auto index_f = [&](int index) -> Value { return f(sorted_points[index]); };
|
||||
const auto& [index, v] =
|
||||
RangeConvexMinimum<int64_t, Value>(0, sorted_points.size(), index_f);
|
||||
return {sorted_points[index], v};
|
||||
}
|
||||
|
||||
template <class Point, class Value>
|
||||
@@ -261,44 +325,14 @@ std::pair<Point, Value> ConvexMinimum(bool is_to_the_right,
|
||||
std::pair<Point, Value> current_min,
|
||||
absl::Span<const Point> sorted_points,
|
||||
std::function<Value(Point)> f) {
|
||||
DCHECK(!sorted_points.empty());
|
||||
while (sorted_points.size() > 1) {
|
||||
const int i = sorted_points.size() / 2;
|
||||
const Value v = f(sorted_points[i]);
|
||||
if (v >= current_min.second) {
|
||||
// If the midpoint is no better than our current minimum, then the
|
||||
// global min must lie between our midpoint and our current min.
|
||||
if (is_to_the_right) {
|
||||
sorted_points = sorted_points.subspan(i + 1);
|
||||
} else {
|
||||
sorted_points = sorted_points.subspan(0, i);
|
||||
}
|
||||
} else {
|
||||
// v < current_min.second, we cannot decide, so we use a second value
|
||||
// close to v like in the initial step.
|
||||
DCHECK_GT(i, 0);
|
||||
const int before_i = i - 1;
|
||||
const Value before_v = f(sorted_points[before_i]);
|
||||
if (before_v == v) return {sorted_points[before_i], before_v};
|
||||
if (before_v < v) {
|
||||
current_min = {sorted_points[before_i], before_v};
|
||||
is_to_the_right = true;
|
||||
sorted_points = sorted_points.subspan(0, std::max(0, before_i));
|
||||
} else {
|
||||
is_to_the_right = false;
|
||||
current_min = {sorted_points[i], v};
|
||||
sorted_points = sorted_points.subspan(i + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!sorted_points.empty()) {
|
||||
const Value v = f(sorted_points[0]);
|
||||
if (v <= current_min.second) return {sorted_points[0], v};
|
||||
}
|
||||
return current_min;
|
||||
auto index_f = [&](int index) -> Value { return f(sorted_points[index]); };
|
||||
std::pair<int, Value> index_current_min = std::make_pair(
|
||||
is_to_the_right ? sorted_points.size() : -1, current_min.second);
|
||||
const auto& [index, v] = RangeConvexMinimum<int64_t, Value>(
|
||||
index_current_min, 0, sorted_points.size(), index_f);
|
||||
if (index == index_current_min.first) return current_min;
|
||||
return {sorted_points[index], v};
|
||||
}
|
||||
|
||||
} // namespace operations_research
|
||||
|
||||
#endif // OR_TOOLS_ALGORITHMS_BINARY_SEARCH_H_
|
||||
|
||||
@@ -328,9 +328,13 @@ TEST(ConvexMinimumTest, ExhaustiveTest) {
|
||||
});
|
||||
total_num_queries += num_queries;
|
||||
max_num_queries = std::max(max_num_queries, num_queries);
|
||||
ASSERT_EQ(value, 0);
|
||||
ASSERT_GE(point, b1);
|
||||
ASSERT_LE(point, b2);
|
||||
EXPECT_EQ(value, 0);
|
||||
EXPECT_GE(point, b1);
|
||||
EXPECT_LE(point, b2);
|
||||
// Fail after one example.
|
||||
ASSERT_TRUE(value == 0 && b1 <= point && point <= b2)
|
||||
<< "queries: " << num_queries << " opt range: [" << b1 << ", " << b2
|
||||
<< "]";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -378,4 +382,36 @@ TEST(ConvexMinimumTest, TwoQueriesIfSizeTwoReversed) {
|
||||
EXPECT_EQ(num_queries, 2);
|
||||
}
|
||||
|
||||
TEST(RangeConvexMinimumTest, HugeRangeTest) {
|
||||
int total_num_queries = 0;
|
||||
int max_num_queries = 0;
|
||||
for (int b1 = -100; b1 < 100; ++b1) {
|
||||
for (int b2 = b1; b2 < b1 + 100; ++b2) {
|
||||
int num_queries = 0;
|
||||
const auto [point, value] = RangeConvexMinimum<int64_t, double>(
|
||||
std::numeric_limits<int64_t>::min() / 2,
|
||||
std::numeric_limits<int64_t>::max() / 2, [&](int64_t v) -> double {
|
||||
++num_queries;
|
||||
if (v < b1) {
|
||||
return b1 - v;
|
||||
} else if (v > b2) {
|
||||
return v - b2;
|
||||
}
|
||||
return 0;
|
||||
});
|
||||
total_num_queries += num_queries;
|
||||
max_num_queries = std::max(max_num_queries, num_queries);
|
||||
EXPECT_EQ(value, 0);
|
||||
EXPECT_GE(point, b1);
|
||||
EXPECT_LE(point, b2);
|
||||
// Don't continue past the first failing example to limit the number of
|
||||
// errors.
|
||||
ASSERT_TRUE(value == 0 && b1 <= point && point <= b2)
|
||||
<< "queries: " << num_queries << " opt range: [" << b1 << ", " << b2
|
||||
<< "]";
|
||||
}
|
||||
}
|
||||
// 80 is the worst case we would expect from ternary search: 2*log_3(2^63).
|
||||
EXPECT_LE(max_num_queries, 80);
|
||||
}
|
||||
} // namespace operations_research
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "ortools/base/threadpool.h"
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
|
||||
namespace operations_research {
|
||||
void RunWorker(void* data) {
|
||||
@@ -25,7 +26,7 @@ void RunWorker(void* data) {
|
||||
}
|
||||
}
|
||||
|
||||
ThreadPool::ThreadPool(const std::string& prefix, int num_workers)
|
||||
ThreadPool::ThreadPool(absl::string_view prefix, int num_workers)
|
||||
: num_workers_(num_workers) {}
|
||||
|
||||
ThreadPool::~ThreadPool() {
|
||||
|
||||
@@ -22,10 +22,12 @@
|
||||
#include <thread> // NOLINT
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
|
||||
namespace operations_research {
|
||||
class ThreadPool {
|
||||
public:
|
||||
ThreadPool(const std::string& prefix, int num_threads);
|
||||
ThreadPool(absl::string_view prefix, int num_threads);
|
||||
~ThreadPool();
|
||||
|
||||
void StartWorkers();
|
||||
|
||||
@@ -321,7 +321,7 @@ MPSReaderFormat TemplateFormat(MPSReader::Form form) {
|
||||
} // namespace
|
||||
|
||||
// Parses instance from a file.
|
||||
absl::Status MPSReader::ParseFile(const std::string& file_name,
|
||||
absl::Status MPSReader::ParseFile(absl::string_view file_name,
|
||||
LinearProgram* data, Form form) {
|
||||
DataWrapper<LinearProgram> data_wrapper(data);
|
||||
return MPSReaderTemplate<DataWrapper<LinearProgram>>()
|
||||
@@ -329,7 +329,7 @@ absl::Status MPSReader::ParseFile(const std::string& file_name,
|
||||
.status();
|
||||
}
|
||||
|
||||
absl::Status MPSReader::ParseFile(const std::string& file_name,
|
||||
absl::Status MPSReader::ParseFile(absl::string_view file_name,
|
||||
MPModelProto* data, Form form) {
|
||||
DataWrapper<MPModelProto> data_wrapper(data);
|
||||
return MPSReaderTemplate<DataWrapper<MPModelProto>>()
|
||||
@@ -339,7 +339,7 @@ absl::Status MPSReader::ParseFile(const std::string& file_name,
|
||||
|
||||
// Loads instance from string. Useful with MapReduce. Automatically detects
|
||||
// the file's format (free or fixed).
|
||||
absl::Status MPSReader::ParseProblemFromString(const std::string& source,
|
||||
absl::Status MPSReader::ParseProblemFromString(absl::string_view source,
|
||||
LinearProgram* data,
|
||||
MPSReader::Form form) {
|
||||
DataWrapper<LinearProgram> data_wrapper(data);
|
||||
@@ -348,7 +348,7 @@ absl::Status MPSReader::ParseProblemFromString(const std::string& source,
|
||||
.status();
|
||||
}
|
||||
|
||||
absl::Status MPSReader::ParseProblemFromString(const std::string& source,
|
||||
absl::Status MPSReader::ParseProblemFromString(absl::string_view source,
|
||||
MPModelProto* data,
|
||||
MPSReader::Form form) {
|
||||
DataWrapper<MPModelProto> data_wrapper(data);
|
||||
@@ -357,8 +357,7 @@ absl::Status MPSReader::ParseProblemFromString(const std::string& source,
|
||||
.status();
|
||||
}
|
||||
|
||||
absl::StatusOr<MPModelProto> MpsDataToMPModelProto(
|
||||
const std::string& mps_data) {
|
||||
absl::StatusOr<MPModelProto> MpsDataToMPModelProto(absl::string_view mps_data) {
|
||||
MPModelProto model;
|
||||
DataWrapper<MPModelProto> data_wrapper(&model);
|
||||
RETURN_IF_ERROR(
|
||||
@@ -368,8 +367,7 @@ absl::StatusOr<MPModelProto> MpsDataToMPModelProto(
|
||||
return model;
|
||||
}
|
||||
|
||||
absl::StatusOr<MPModelProto> MpsFileToMPModelProto(
|
||||
const std::string& mps_file) {
|
||||
absl::StatusOr<MPModelProto> MpsFileToMPModelProto(absl::string_view mps_file) {
|
||||
MPModelProto model;
|
||||
DataWrapper<MPModelProto> data_wrapper(&model);
|
||||
RETURN_IF_ERROR(
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
#include "absl/base/attributes.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "ortools/linear_solver/linear_solver.pb.h"
|
||||
#include "ortools/lp_data/lp_data.h"
|
||||
|
||||
@@ -36,10 +37,10 @@ namespace operations_research {
|
||||
namespace glop {
|
||||
|
||||
// Parses an MPS model from a string.
|
||||
absl::StatusOr<MPModelProto> MpsDataToMPModelProto(const std::string& mps_data);
|
||||
absl::StatusOr<MPModelProto> MpsDataToMPModelProto(absl::string_view mps_data);
|
||||
|
||||
// Parses an MPS model from a file.
|
||||
absl::StatusOr<MPModelProto> MpsFileToMPModelProto(const std::string& mps_file);
|
||||
absl::StatusOr<MPModelProto> MpsFileToMPModelProto(absl::string_view mps_file);
|
||||
|
||||
// Implementation class. Please use the 2 functions above.
|
||||
//
|
||||
@@ -54,17 +55,17 @@ class ABSL_DEPRECATED("Use the direct methods instead") MPSReader {
|
||||
enum Form { AUTO_DETECT, FREE, FIXED };
|
||||
|
||||
// Parses instance from a file.
|
||||
absl::Status ParseFile(const std::string& file_name, LinearProgram* data,
|
||||
absl::Status ParseFile(absl::string_view file_name, LinearProgram* data,
|
||||
Form form = AUTO_DETECT);
|
||||
|
||||
absl::Status ParseFile(const std::string& file_name, MPModelProto* data,
|
||||
absl::Status ParseFile(absl::string_view file_name, MPModelProto* data,
|
||||
Form form = AUTO_DETECT);
|
||||
// Loads instance from string. Useful with MapReduce. Automatically detects
|
||||
// the file's format (free or fixed).
|
||||
absl::Status ParseProblemFromString(const std::string& source,
|
||||
absl::Status ParseProblemFromString(absl::string_view source,
|
||||
LinearProgram* data,
|
||||
MPSReader::Form form = AUTO_DETECT);
|
||||
absl::Status ParseProblemFromString(const std::string& source,
|
||||
absl::Status ParseProblemFromString(absl::string_view source,
|
||||
MPModelProto* data,
|
||||
MPSReader::Form form = AUTO_DETECT);
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user