From 17498776bfd98ff5821a633f05b99f9f79bce486 Mon Sep 17 00:00:00 2001 From: Corentin Le Molgat Date: Fri, 16 May 2025 14:13:06 +0200 Subject: [PATCH] algorithms: export from google3 --- ortools/algorithms/BUILD.bazel | 6 +- ortools/algorithms/radix_sort.h | 135 ++++++++++++++++++------- ortools/algorithms/radix_sort_test.cc | 138 +++++++++++++++++++------- ortools/base/dump_vars.h | 5 +- 4 files changed, 208 insertions(+), 76 deletions(-) diff --git a/ortools/algorithms/BUILD.bazel b/ortools/algorithms/BUILD.bazel index 1d391a77d2..2d520cb7da 100644 --- a/ortools/algorithms/BUILD.bazel +++ b/ortools/algorithms/BUILD.bazel @@ -97,6 +97,7 @@ cc_library( deps = [ "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/base", + "@abseil-cpp//absl/base:log_severity", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/numeric:bits", @@ -118,14 +119,13 @@ cc_test( "//ortools/base:dump_vars", "//ortools/base:gmock_main", "//ortools/base:mathutil", - "//ortools/base:timer", "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/log", - "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/numeric:bits", + "@abseil-cpp//absl/numeric:int128", "@abseil-cpp//absl/random", "@abseil-cpp//absl/random:bit_gen_ref", "@abseil-cpp//absl/random:distributions", - "@abseil-cpp//absl/time", "@abseil-cpp//absl/types:span", "@com_google_benchmark//:benchmark", ], diff --git a/ortools/algorithms/radix_sort.h b/ortools/algorithms/radix_sort.h index 7ed8764b06..419bd930ef 100644 --- a/ortools/algorithms/radix_sort.h +++ b/ortools/algorithms/radix_sort.h @@ -30,9 +30,6 @@ // But the worst-case performance of RadixSort() is much faster than the // worst-case performance of std::sort(). // To be sure, you should benchmark your use case. -// -// TODO: it could be even faster than that when the values are in [0..N) for a -// known value N that's significantly lower than the max integer value. #include #include @@ -45,8 +42,10 @@ #include "absl/algorithm/container.h" #include "absl/base/casts.h" +#include "absl/base/log_severity.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/numeric/bits.h" #include "absl/types/span.h" namespace operations_research { @@ -54,14 +53,24 @@ namespace operations_research { // Sorts an array of int, double, or other numeric types. Up to ~10x faster than // std::sort() when size ≥ 8k: go/radix-sort-bench. See file-level comment. template -void RadixSort(absl::Span values); +void RadixSort( + absl::Span values, + // ADVANCED USAGE: if you're sorting nonnegative integers, and suspect that + // their values use less bits than their full bit width, you may improve + // performance by setting `num_bits` to a lower value, for example + // NumBitsForZeroTo(max_value). It might even be faster to scan the values + // once just to do that, e.g., RadixSort(values, + // NumBitsForZeroTo(*absl::c_max_element(values))); + int num_bits = sizeof(T) * 8); + +template +int NumBitsForZeroTo(T max_value); // ADVANCED USAGE: For power users who know which radix_width or num_passes // they need, possibly differing from the canonical values used by RadixSort(). template void RadixSortTpl(absl::Span values); -// TODO(user): Support arbitrary types with an int() or other numerical getter. // TODO(user): Support the user providing already-allocated memory buffers // for the radix counts and/or for the temporary vector copy. @@ -240,49 +249,101 @@ void RadixSortTpl(absl::Span values) { } } -// TODO(user): Expose an API that takes the "max value" as argument, for -// users who want to take advantage of that knowledge to reduce the number of -// passes. template -void RadixSort(absl::Span values) { - switch (sizeof(T)) { - case 1: - if (values.size() < 300) { - absl::c_sort(values); +int NumBitsForZeroTo(T max_value) { + if constexpr (!std::is_integral_v) { + return sizeof(T) * 8; + } else { + using U = std::make_unsigned_t; + DCHECK_GE(max_value, 0); + return std::numeric_limits::digits - absl::countl_zero(max_value); + } +} + +#ifdef NDEBUG +const bool DEBUG_MODE = false; +#else +const bool DEBUG_MODE = true; +#endif + +template +void RadixSort(absl::Span values, int num_bits) { + // Debug-check that num_bits is valid w.r.t. the values given. + if constexpr (DEBUG_MODE) { + if constexpr (!std::is_integral_v) { + DCHECK_EQ(num_bits, sizeof(T) * 8); + } else if (!values.empty()) { + auto minmax_it = absl::c_minmax_element(values); + const T min_val = *minmax_it.first; + const T max_val = *minmax_it.second; + if (num_bits == 0) { + DCHECK_EQ(max_val, 0); } else { - RadixSortTpl(values); + using U = std::make_unsigned_t; + // We only shift by num_bits - 1, to avoid to potentially shift by the + // entire bit width, which would be undefined behavior. + DCHECK_LE(static_cast(max_val) >> (num_bits - 1), 1); + DCHECK_LE(static_cast(min_val) >> (num_bits - 1), 1); } - return; - case 2: - if (values.size() < 300) { - absl::c_sort(values); + } + } + + // This shortcut here is important to have early, guarded by as few "if" + // branches as possible, for the use case where the array is very small. + // For larger arrays below, the overhead of a few "if" is negligible. + if (values.size() < 300) { + absl::c_sort(values); + return; + } + + // TODO(user): More complex decision tree, based on benchmarks. This one + // is already nice, but some cases can surely be optimized. + if (num_bits <= 16) { + if (num_bits <= 8) { + RadixSortTpl(values); + } else { + RadixSortTpl(values); + } + } else if (num_bits <= 32) { // num_bits ∈ [17..32] + if (values.size() < 1000) { + if (num_bits <= 24) { + RadixSortTpl(values); } else { - RadixSortTpl(values); - } - return; - case 4: - if (values.size() < 300) { - absl::c_sort(values); - } else if (values.size() < 1000) { RadixSortTpl(values); - } else if (values.size() < 2'500'000) { - RadixSortTpl(values); - } else { - RadixSortTpl(values); } - return; - case 8: - if (values.size() < 5000) { - absl::c_sort(values); - } else if (values.size() < 1'500'000) { + } else if (values.size() < 2'500'000) { + if (num_bits <= 22) { + RadixSortTpl(values); + } else { + RadixSortTpl(values); + } + } else { + RadixSortTpl(values); + } + } else if (num_bits <= 64) { // num_bits ∈ [33..64] + if (values.size() < 5000) { + absl::c_sort(values); + } else if (values.size() < 1'500'000) { + if (num_bits <= 33) { + RadixSortTpl(values); + } else if (num_bits <= 44) { + RadixSortTpl(values); + } else if (num_bits <= 55) { + RadixSortTpl(values); + } else { RadixSortTpl(values); + } + } else { + if (num_bits <= 48) { + RadixSortTpl(values); } else { RadixSortTpl(values); } - return; + } + } else { + LOG(DFATAL) << "RadixSort() called with unsupported value type"; + absl::c_sort(values); } - LOG(DFATAL) << "RadixSort() called with unsupported value type"; - absl::c_sort(values); } } // namespace operations_research diff --git a/ortools/algorithms/radix_sort_test.cc b/ortools/algorithms/radix_sort_test.cc index 88957d3ed0..269b76c459 100644 --- a/ortools/algorithms/radix_sort_test.cc +++ b/ortools/algorithms/radix_sort_test.cc @@ -13,7 +13,6 @@ #include "ortools/algorithms/radix_sort.h" -#include #include #include #include @@ -25,6 +24,8 @@ #include "absl/algorithm/container.h" #include "absl/log/log.h" +#include "absl/numeric/bits.h" +#include "absl/numeric/int128.h" #include "absl/random/bit_gen_ref.h" #include "absl/random/distributions.h" #include "absl/random/random.h" @@ -41,6 +42,28 @@ namespace { using ::testing::ElementsAre; using ::testing::IsEmpty; +template +class NumBitsForZeroToTest : public ::testing::Test {}; + +TYPED_TEST_SUITE_P(NumBitsForZeroToTest); + +TYPED_TEST_P(NumBitsForZeroToTest, CorrectnessStressTest) { + absl::BitGen rng; + constexpr int kNumTests = 1'000'000; + for (int test = 0; test < kNumTests; ++test) { + const TypeParam max_val = absl::LogUniform( + rng, 0, std::numeric_limits::max()); + const int num_bits = NumBitsForZeroTo(max_val); + EXPECT_LE(absl::int128{max_val}, absl::int128{1} << num_bits); + } +} + +REGISTER_TYPED_TEST_SUITE_P(NumBitsForZeroToTest, CorrectnessStressTest); +using IntTypes = ::testing::Types; + +INSTANTIATE_TYPED_TEST_SUITE_P(My, NumBitsForZeroToTest, IntTypes); + // If T is a floating-point type, ignores min_val / max_val. template std::vector RandomValues(absl::BitGenRef rng, size_t size, @@ -103,6 +126,9 @@ TYPED_TEST_P(RadixSortTest, RandomizedCorrectnessTestAgainstStdSortSmallSizes) { // Will we use the standard RadixSort() or the RadixSortTpl<>() variant? const bool use_main_radix_sort = absl::Bernoulli(rng, 0.5); + const bool use_num_bits = std::is_integral_v && + use_main_radix_sort && !allow_negative && + absl::Bernoulli(rng, 0.5); // We potentially test the "power usage" of calling RadixSortTpl<> with // radix_width * num_passes < num_bits(TypeParam), when the actual values @@ -128,7 +154,12 @@ TYPED_TEST_P(RadixSortTest, RandomizedCorrectnessTestAgainstStdSortSmallSizes) { int radix_width = -1; int num_passes = -1; if (use_main_radix_sort) { - RadixSort(absl::MakeSpan(sorted_values)); + if (use_num_bits) { + RadixSort(absl::MakeSpan(sorted_values), + NumBitsForZeroTo(max_abs_val.value())); + } else { + RadixSort(absl::MakeSpan(sorted_values)); + } } else { // Draw random (radix_width, num_passes) pairs until we get a valid one. constexpr int kMaxNumPasses = 8; @@ -147,8 +178,8 @@ TYPED_TEST_P(RadixSortTest, RandomizedCorrectnessTestAgainstStdSortSmallSizes) { absl::c_sort(expected_values); ASSERT_TRUE(sorted_values == expected_values) << DUMP_VARS(test, use_main_radix_sort, radix_width, num_passes, size, - allow_negative, val_bits, max_abs_val, unsorted_values, - sorted_values, expected_values); + allow_negative, use_num_bits, val_bits, max_abs_val, + unsorted_values, sorted_values, expected_values); } } @@ -205,10 +236,20 @@ TYPED_TEST_P(RadixSortTest, RandomizedCorrectnessTestAgainstStdSortLargeSizes) { std::vector values = RandomValues(rng, size, allow_negative, /*max_abs_val=*/{}); const bool use_main_radix_sort = absl::Bernoulli(rng, 0.5); + const bool use_num_bits = std::is_integral_v && + use_main_radix_sort && !allow_negative && + absl::Bernoulli(rng, 0.5); + int radix_width = -1; int num_passes = -1; if (use_main_radix_sort) { - RadixSort(absl::MakeSpan(values)); + if (use_num_bits) { + RadixSort( + absl::MakeSpan(values), + NumBitsForZeroTo(size == 0 ? 1 : *absl::c_max_element(values))); + } else { + RadixSort(absl::MakeSpan(values)); + } } else { radix_width = RandomRadixWidth(rng); num_passes = @@ -218,7 +259,7 @@ TYPED_TEST_P(RadixSortTest, RandomizedCorrectnessTestAgainstStdSortLargeSizes) { // Contrary to the 'small' stress test, we don't log the data upon failure. ASSERT_TRUE(absl::c_is_sorted(values)) << DUMP_VARS(test, use_main_radix_sort, radix_width, num_passes, size, - allow_negative); + allow_negative, use_num_bits); } } @@ -237,13 +278,16 @@ template std::vector SortedValues(size_t size) { const T offset = std::is_signed_v ? -static_cast(size) / 2 : T{0}; std::vector values(size); - for (size_t i = 0; i < size; ++i) values[i] = i = offset; + for (size_t i = 0; i < size; ++i) values[i] = i + offset; return values; } enum Algo { kStdSort, - kRadixSort, + kRadixSortTpl, + kRadixSortKnownMax, + kRadixSortComputeMax, + kRadixSortWorst, }; enum InputOrder { @@ -280,9 +324,22 @@ void BM_Sort(benchmark::State& state) { to_sort = values; if constexpr (algo == kStdSort) { absl::c_sort(to_sort); - } else { + } else if constexpr (algo == kRadixSortTpl) { absl::Span span{to_sort.data(), to_sort.size()}; RadixSortTpl(span); + } else if constexpr (algo == kRadixSortKnownMax) { + absl::Span span = absl::MakeSpan(to_sort); + RadixSort(span, NumBitsForZeroTo( + max_abs_val.value_or(std::numeric_limits::max()))); + } else if constexpr (algo == kRadixSortComputeMax) { + absl::Span span{to_sort.data(), to_sort.size()}; + RadixSort(span, NumBitsForZeroTo( + size == 0 ? 1 : *absl::c_max_element(to_sort))); + } else if constexpr (algo == kRadixSortWorst) { + absl::Span span{to_sort.data(), to_sort.size()}; + RadixSort(span); + } else { + LOG(DFATAL) << "Unsupported algo: " << algo; } benchmark::DoNotOptimize(to_sort); } @@ -317,114 +374,127 @@ BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(1, 128 << 10); -BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(16, 2048); -BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(256, 32 << 20); -BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(128 << 10, 32 << 20); -BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(16, 2048); -BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(256, 32 << 20); -BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(128 << 10, 32 << 20); -BENCHMARK(BM_Sort) +BENCHMARK(BM_Sort) ->RangeMultiplier(2) - ->Range(16, 2048); -BENCHMARK(BM_Sort) + ->Range(128 << 10, 32 << 20); +BENCHMARK(BM_Sort) ->RangeMultiplier(2) - ->Range(256, 32 << 20); -BENCHMARK(BM_SortRange(128 << 10, 32 << 20); +BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(128 << 10, 32 << 20); -BENCHMARK(BM_Sort) + ->RangeMultiplier(2) + ->Range(16, 2048); +BENCHMARK(BM_Sort) + ->RangeMultiplier(2) + ->Range(256, 32 << 20); +BENCHMARK(BM_Sort) + ->RangeMultiplier(2) + ->Range(128 << 10, 32 << 20); + +BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(2048, 8 << 20) ->Arg(32 << 20) ->Arg(128 << 20); -BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(2048, 8 << 20) ->Arg(32 << 20) ->Arg(128 << 20); -BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(128 << 10, 8 << 20) ->Arg(32 << 20) ->Arg(128 << 20); -BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(128 << 10, 8 << 20) ->Arg(32 << 20) ->Arg(128 << 20); -BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(2048, 8 << 20) ->Arg(32 << 20) ->Arg(128 << 20); -BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(2048, 8 << 20) ->Arg(32 << 20) ->Arg(128 << 20); -BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(128 << 10, 8 << 20) ->Arg(32 << 20) ->Arg(128 << 20); -BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(128 << 10, 8 << 20) ->Arg(32 << 20) ->Arg(128 << 20); -BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(2048, 8 << 20) ->Arg(32 << 20) ->Arg(128 << 20); -BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(2048, 8 << 20) ->Arg(32 << 20) ->Arg(128 << 20); -BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(128 << 10, 8 << 20) ->Arg(32 << 20) ->Arg(128 << 20); -BENCHMARK(BM_Sort) ->RangeMultiplier(2) ->Range(128 << 10, 8 << 20) diff --git a/ortools/base/dump_vars.h b/ortools/base/dump_vars.h index 449cb84ac5..8413948cd3 100644 --- a/ortools/base/dump_vars.h +++ b/ortools/base/dump_vars.h @@ -62,6 +62,7 @@ #define DUMP_FOR_EACH_N9(F, a, ...) F(a) DUMP_FOR_EACH_N8(F, __VA_ARGS__) #define DUMP_FOR_EACH_N10(F, a, ...) F(a) DUMP_FOR_EACH_N9(F, __VA_ARGS__) #define DUMP_FOR_EACH_N11(F, a, ...) F(a) DUMP_FOR_EACH_N10(F, __VA_ARGS__) +#define DUMP_FOR_EACH_N12(F, a, ...) F(a) DUMP_FOR_EACH_N11(F, __VA_ARGS__) #define DUMP_CONCATENATE(x, y) x##y #define DUMP_FOR_EACH_(N, F, ...) \ @@ -69,8 +70,8 @@ #define DUMP_NARG(...) DUMP_NARG_(__VA_OPT__(__VA_ARGS__, ) DUMP_RSEQ_N()) #define DUMP_NARG_(...) DUMP_ARG_N(__VA_ARGS__) -#define DUMP_ARG_N(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, N, ...) N -#define DUMP_RSEQ_N() 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 +#define DUMP_ARG_N(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, N, ...) N +#define DUMP_RSEQ_N() 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 #define DUMP_FOR_EACH(F, ...) \ DUMP_FOR_EACH_(DUMP_NARG(__VA_ARGS__), F __VA_OPT__(, __VA_ARGS__))