re-export n_choose_k
This commit is contained in:
@@ -534,3 +534,39 @@ cc_test(
|
||||
"//ortools/base:gmock_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "n_choose_k",
|
||||
srcs = ["n_choose_k.cc"],
|
||||
hdrs = ["n_choose_k.h"],
|
||||
deps = [
|
||||
":binary_search",
|
||||
"//ortools/base:mathutil",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/numeric:int128",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/time",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "n_choose_k_test",
|
||||
srcs = ["n_choose_k_test.cc"],
|
||||
deps = [
|
||||
":n_choose_k",
|
||||
"//ortools/base:dump_vars",
|
||||
"//ortools/base:fuzztest",
|
||||
"//ortools/base:gmock_main",
|
||||
"//ortools/base:mathutil",
|
||||
"//ortools/util:flat_matrix",
|
||||
"@com_google_absl//absl/numeric:int128",
|
||||
"@com_google_absl//absl/random",
|
||||
"@com_google_absl//absl/random:distributions",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_benchmark//:benchmark",
|
||||
],
|
||||
)
|
||||
|
||||
169
ortools/algorithms/n_choose_k.cc
Normal file
169
ortools/algorithms/n_choose_k.cc
Normal file
@@ -0,0 +1,169 @@
|
||||
// Copyright 2010-2024 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/algorithms/n_choose_k.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/numeric/int128.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/time/clock.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "ortools/algorithms/binary_search.h"
|
||||
#include "ortools/base/logging.h"
|
||||
#include "ortools/base/mathutil.h"
|
||||
|
||||
namespace operations_research {
|
||||
namespace {
|
||||
// This is the actual computation. It's in O(k).
|
||||
template <typename Int>
|
||||
Int InternalChoose(Int n, Int k) {
|
||||
DCHECK_LE(k, n - k);
|
||||
DCHECK_GT(k, 0); // Having k>0 lets us start with i=2 (small optimization).
|
||||
// We compute n * (n-1) * ... * (n-k+1) / k! in the best possible order to
|
||||
// guarantee exact results, while trying to avoid overflows. It's not
|
||||
// perfect: we finish with a division by k, which means that me may overflow
|
||||
// even if the result doesn't (by a factor of up to k).
|
||||
Int result = n;
|
||||
for (Int i = 2; i <= k; ++i) {
|
||||
result *= n + 1 - i;
|
||||
result /= i; // The product of i consecutive numbers is divisible by i!.
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// This function precomputes the maximum N such that (N choose K) doesn't
|
||||
// overflow, for all K.
|
||||
// When `overflows_intermediate_computation` is true, "overflow" means
|
||||
// "some overflow happens inside InternalChoose<int64_t>()", and when it's false
|
||||
// it simply means "the result doesn't fit in an int64_t".
|
||||
// This is only used in contexts where K ≤ N-K, which implies N ≥ 2K, thus we
|
||||
// can stop when (2K Choose K) overflows, because at and beyond such K,
|
||||
// (N Choose K) will always overflow. In practice that happens for K=31 or 34
|
||||
// depending on `overflows_intermediate_computation`.
|
||||
template <class Int>
|
||||
std::vector<Int> LastNThatDoesNotOverflowForAllK(
|
||||
bool overflows_intermediate_computation) {
|
||||
absl::Time start_time = absl::Now();
|
||||
// Given the algorithm used in InternalChoose(), it's not hard to
|
||||
// find out when (N choose K) overflows an int64_t during its internal
|
||||
// computation: that's when (N choose K) > MAX_INT / k.
|
||||
|
||||
// For K ≤ 2, we hardcode the values of the maximum N. That's because
|
||||
// the binary search done below uses MathUtil::LogCombinations, which only
|
||||
// works on int32_t, and that's problematic for the max N we get for K=2.
|
||||
//
|
||||
// For K=2, we want N(N-1) ≤ 2^num_digits, or N(N-1)/2 ≤ 2^num_digits if
|
||||
// !overflows_intermediate_computation, i.e. N(N-1) ≤ 2^(num_digits+1).
|
||||
// Then, when d is even, N(N-1) ≤ 2^d ⇔ N ≤ 2^(d/2), which is simple.
|
||||
// When d is odd, it's harder: N(N-1)≈(N-0.5)² and thus we get the bound
|
||||
// N ≤ pow(2.0, d/2)+0.5.
|
||||
const int bound_digits = std::numeric_limits<Int>::digits +
|
||||
(overflows_intermediate_computation ? 0 : 1);
|
||||
std::vector<Int> result = {
|
||||
std::numeric_limits<Int>::max(), // K=0
|
||||
std::numeric_limits<Int>::max(), // K=1
|
||||
bound_digits % 2 == 0
|
||||
? Int{1} << (bound_digits / 2)
|
||||
: static_cast<Int>(
|
||||
0.5 + std::pow(2.0, 0.5 * std::numeric_limits<Int>::digits)),
|
||||
};
|
||||
// We find the last N with binary search, for all K. We stop growing K
|
||||
// when (2*K Choose K) overflows.
|
||||
for (Int k = 3;; ++k) {
|
||||
const double max_log_comb =
|
||||
overflows_intermediate_computation
|
||||
? std::numeric_limits<Int>::digits * std::log(2) - std::log(k)
|
||||
: std::numeric_limits<Int>::digits * std::log(2);
|
||||
result.push_back(BinarySearch<Int>(
|
||||
/*x_true*/ k,
|
||||
// x_false=X, X needs to be large enough so that X choose 3 overflows:
|
||||
// (X choose 3)≈(X-1)³/6, so we pick X = 2+6*2^(num_digits/3+1).
|
||||
/*x_false=*/
|
||||
(static_cast<Int>(
|
||||
2 + 6 * std::pow(2.0, std::numeric_limits<Int>::digits / 3 + 1))),
|
||||
[k, max_log_comb](Int n) {
|
||||
return MathUtil::LogCombinations(n, k) <= max_log_comb;
|
||||
}));
|
||||
if (result.back() < 2 * k) {
|
||||
result.pop_back();
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Some DCHECKs for int64_t, which should validate the general formulaes.
|
||||
if constexpr (std::numeric_limits<Int>::digits == 63) {
|
||||
DCHECK_EQ(result.size(),
|
||||
overflows_intermediate_computation
|
||||
? 31 // 60 Choose 30 < 2^63/30 but 62 Choose 31 > 2^63/31.
|
||||
: 34); // 66 Choose 33 < 2^63 but 68 Choose 34 > 2^63.
|
||||
}
|
||||
VLOG(1) << "LastNThatDoesNotOverflowForAllK(): " << absl::Now() - start_time;
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Int>
|
||||
bool NChooseKIntermediateComputationOverflowsInt(Int n, Int k) {
|
||||
DCHECK_LE(k, n - k);
|
||||
static const auto* const result =
|
||||
new std::vector<Int>(LastNThatDoesNotOverflowForAllK<Int>(
|
||||
/*overflows_intermediate_computation=*/true));
|
||||
return k < result->size() ? n > (*result)[k] : true;
|
||||
}
|
||||
|
||||
template <typename Int>
|
||||
bool NChooseKResultOverflowsInt(Int n, Int k) {
|
||||
DCHECK_LE(k, n - k);
|
||||
static const auto* const result =
|
||||
new std::vector<Int>(LastNThatDoesNotOverflowForAllK<Int>(
|
||||
/*overflows_intermediate_computation=*/false));
|
||||
return k < result->size() ? n > (*result)[k] : true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// NOTE(user): If performance ever matters, we could simply precompute and
|
||||
// store all (N choose K) that don't overflow, there aren't that many of them:
|
||||
// only a few tens of thousands, after removing simple cases like k ≤ 5.
|
||||
absl::StatusOr<int64_t> NChooseK(int64_t n, int64_t k) {
|
||||
if (n < 0) {
|
||||
return absl::InvalidArgumentError(absl::StrFormat("n is negative (%d)", n));
|
||||
}
|
||||
if (k < 0) {
|
||||
return absl::InvalidArgumentError(absl::StrFormat("k is negative (%d)", k));
|
||||
}
|
||||
if (k > n) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("k=%d is greater than n=%d", k, n));
|
||||
}
|
||||
if (k > n / 2) k = n - k;
|
||||
if (k == 0) return 1;
|
||||
if (n < std::numeric_limits<uint32_t>::max() &&
|
||||
!NChooseKIntermediateComputationOverflowsInt<uint32_t>(n, k)) {
|
||||
return static_cast<int64_t>(InternalChoose<uint32_t>(n, k));
|
||||
}
|
||||
if (!NChooseKIntermediateComputationOverflowsInt<int64_t>(n, k)) {
|
||||
return InternalChoose<uint64_t>(n, k);
|
||||
}
|
||||
if (NChooseKResultOverflowsInt<int64_t>(n, k)) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("(%d choose %d) overflows int64", n, k));
|
||||
}
|
||||
return static_cast<int64_t>(InternalChoose<absl::uint128>(n, k));
|
||||
}
|
||||
|
||||
} // namespace operations_research
|
||||
34
ortools/algorithms/n_choose_k.h
Normal file
34
ortools/algorithms/n_choose_k.h
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright 2010-2024 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef OR_TOOLS_ALGORITHMS_N_CHOOSE_K_H_
|
||||
#define OR_TOOLS_ALGORITHMS_N_CHOOSE_K_H_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
|
||||
namespace operations_research {
|
||||
// Returns the number of ways to choose k elements among n, ignoring the order,
|
||||
// i.e., the binomial coefficient (n, k).
|
||||
// This is like std::exp(MathUtil::LogCombinations(n, k)), but faster, with
|
||||
// perfect accuracy, and returning an error iff the result would overflow an
|
||||
// int64_t or if an argument is invalid (i.e., n < 0, k < 0, or k > n).
|
||||
//
|
||||
// NOTE(user): If you need a variation of this, ask the authors: it's very easy
|
||||
// to add. E.g., other int types, other behaviors (e.g., return 0 if k > n, or
|
||||
// std::numeric_limits<int64_t>::max() on overflow, etc).
|
||||
absl::StatusOr<int64_t> NChooseK(int64_t n, int64_t k);
|
||||
} // namespace operations_research
|
||||
|
||||
#endif // OR_TOOLS_ALGORITHMS_N_CHOOSE_K_H_
|
||||
323
ortools/algorithms/n_choose_k_test.cc
Normal file
323
ortools/algorithms/n_choose_k_test.cc
Normal file
@@ -0,0 +1,323 @@
|
||||
// Copyright 2010-2024 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/algorithms/n_choose_k.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/numeric/int128.h"
|
||||
#include "absl/random/distributions.h"
|
||||
#include "absl/random/random.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "benchmark/benchmark.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/base/dump_vars.h"
|
||||
//#include "ortools/base/fuzztest.h"
|
||||
#include "ortools/base/gmock.h"
|
||||
#include "ortools/base/mathutil.h"
|
||||
#include "ortools/util/flat_matrix.h"
|
||||
|
||||
namespace operations_research {
|
||||
namespace {
|
||||
//using ::fuzztest::NonNegative;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::status::IsOkAndHolds;
|
||||
using ::testing::status::StatusIs;
|
||||
|
||||
constexpr int64_t kint64max = std::numeric_limits<int64_t>::max();
|
||||
|
||||
TEST(NChooseKTest, TrivialErrorCases) {
|
||||
absl::BitGen random;
|
||||
constexpr int kNumTests = 100'000;
|
||||
for (int t = 0; t < kNumTests; ++t) {
|
||||
const int64_t x = absl::LogUniform<int64_t>(random, 0, kint64max);
|
||||
EXPECT_THAT(NChooseK(-1, x), StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("n is negative")));
|
||||
EXPECT_THAT(NChooseK(x, -1), StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("k is negative")));
|
||||
if (x != kint64max) {
|
||||
EXPECT_THAT(NChooseK(x, x + 1),
|
||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("greater than n")));
|
||||
}
|
||||
ASSERT_FALSE(HasFailure()) << DUMP_VARS(t, x);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NChooseKTest, Symmetry) {
|
||||
absl::BitGen random;
|
||||
constexpr int kNumTests = 1'000'000;
|
||||
for (int t = 0; t < kNumTests; ++t) {
|
||||
const int64_t n = absl::LogUniform<int64_t>(random, 0, kint64max);
|
||||
const int64_t k = absl::LogUniform<int64_t>(random, 0, n);
|
||||
const absl::StatusOr<int64_t> result1 = NChooseK(n, k);
|
||||
const absl::StatusOr<int64_t> result2 = NChooseK(n, n - k);
|
||||
if (result1.ok()) {
|
||||
ASSERT_THAT(result2, IsOkAndHolds(result1.value())) << DUMP_VARS(t, n, k);
|
||||
} else {
|
||||
ASSERT_EQ(result2.status().code(), result1.status().code())
|
||||
<< DUMP_VARS(t, n, k, result1, result2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NChooseKTest, Invariant) {
|
||||
absl::BitGen random;
|
||||
constexpr int kNumTests = 1'000'000;
|
||||
int num_tested_invariants = 0;
|
||||
for (int t = 0; t < kNumTests; ++t) {
|
||||
const int64_t n = absl::LogUniform<int64_t>(random, 2, 100);
|
||||
const int64_t k = absl::LogUniform<int64_t>(random, 1, n - 1);
|
||||
const absl::StatusOr<int64_t> n_k = NChooseK(n, k);
|
||||
const absl::StatusOr<int64_t> nm1_k = NChooseK(n - 1, k);
|
||||
const absl::StatusOr<int64_t> nm1_km1 = NChooseK(n - 1, k - 1);
|
||||
if (n_k.ok()) {
|
||||
++num_tested_invariants;
|
||||
ASSERT_OK(nm1_k);
|
||||
ASSERT_OK(nm1_km1);
|
||||
ASSERT_EQ(n_k.value(), nm1_k.value() + nm1_km1.value())
|
||||
<< DUMP_VARS(t, n, k, n_k, nm1_k, nm1_km1);
|
||||
}
|
||||
}
|
||||
EXPECT_GE(num_tested_invariants, kNumTests / 10);
|
||||
}
|
||||
|
||||
TEST(NChooseKTest, ComparisonAgainstClosedFormsForK0) {
|
||||
for (int64_t n : {int64_t{0}, int64_t{1}, kint64max}) {
|
||||
EXPECT_THAT(NChooseK(n, 0), IsOkAndHolds(1)) << n;
|
||||
}
|
||||
absl::BitGen random;
|
||||
constexpr int kNumTests = 1'000'000;
|
||||
for (int t = 0; t < kNumTests; ++t) {
|
||||
const int64_t n = absl::LogUniform<int64_t>(random, 0, kint64max);
|
||||
ASSERT_THAT(NChooseK(n, 0), IsOkAndHolds(1)) << DUMP_VARS(n, t);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NChooseKTest, ComparisonAgainstClosedFormsForK1) {
|
||||
for (int64_t n : {int64_t{1}, kint64max}) {
|
||||
EXPECT_THAT(NChooseK(n, 1), IsOkAndHolds(n));
|
||||
}
|
||||
absl::BitGen random;
|
||||
constexpr int kNumTests = 1'000'000;
|
||||
for (int t = 0; t < kNumTests; ++t) {
|
||||
const int64_t n = absl::LogUniform<int64_t>(random, 1, kint64max);
|
||||
ASSERT_THAT(NChooseK(n, 1), IsOkAndHolds(n)) << DUMP_VARS(t);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NChooseKTest, ComparisonAgainstClosedFormsForK2) {
|
||||
// 2^32 Choose 2 = 2^32 × (2^32-1) / 2 = 2^63 - 2^31 < kint64max,
|
||||
// but (2^32+1) Choose 2 = 2^63 + 2^31 overflows.
|
||||
constexpr int64_t max_n = int64_t{1} << 32;
|
||||
for (int64_t n : {int64_t{2}, max_n}) {
|
||||
const int64_t n_choose_2 =
|
||||
static_cast<int64_t>(absl::uint128(n) * (n - 1) / 2);
|
||||
EXPECT_THAT(NChooseK(n, 2), IsOkAndHolds(n_choose_2)) << DUMP_VARS(n);
|
||||
}
|
||||
EXPECT_THAT(NChooseK(max_n + 1, 2),
|
||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("overflows int64")));
|
||||
|
||||
absl::BitGen random;
|
||||
constexpr int kNumTests = 100'000;
|
||||
// Random valid results.
|
||||
for (int t = 0; t < kNumTests; ++t) {
|
||||
const int64_t n = absl::LogUniform<int64_t>(random, 2, max_n);
|
||||
const int64_t n_choose_2 =
|
||||
static_cast<int64_t>(absl::uint128(n) * (n - 1) / 2);
|
||||
ASSERT_THAT(NChooseK(n, 2), IsOkAndHolds(n_choose_2)) << DUMP_VARS(t, n);
|
||||
}
|
||||
// Random overflows.
|
||||
for (int t = 0; t < kNumTests; ++t) {
|
||||
const int64_t n = absl::LogUniform<int64_t>(random, max_n + 1, kint64max);
|
||||
ASSERT_THAT(NChooseK(n, 2), StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("overflows int64")))
|
||||
<< DUMP_VARS(t, n);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NChooseKTest, ComparisonAgainstClosedFormsForK3) {
|
||||
// This is 1 + ∛6×2^21. Checked manually on Google's scientific calculator.
|
||||
const int64_t max_n =
|
||||
static_cast<int64_t>(1 + std::pow(6, 1.0 / 3) * std::pow(2, 21));
|
||||
for (int64_t n : {int64_t{3}, max_n}) {
|
||||
const int64_t n_choose_3 =
|
||||
static_cast<int64_t>(absl::uint128(n) * (n - 1) * (n - 2) / 6);
|
||||
EXPECT_THAT(NChooseK(n, 3), IsOkAndHolds(n_choose_3)) << DUMP_VARS(n);
|
||||
}
|
||||
EXPECT_THAT(NChooseK(max_n + 1, 3),
|
||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("overflows int64")));
|
||||
|
||||
absl::BitGen random;
|
||||
constexpr int kNumTests = 100'000;
|
||||
// Random valid results.
|
||||
for (int t = 0; t < kNumTests; ++t) {
|
||||
const int64_t n = absl::LogUniform<int64_t>(random, 3, max_n);
|
||||
const int64_t n_choose_3 =
|
||||
static_cast<int64_t>(absl::uint128(n) * (n - 1) * (n - 2) / 6);
|
||||
ASSERT_THAT(NChooseK(n, 3), IsOkAndHolds(n_choose_3)) << DUMP_VARS(t, n);
|
||||
}
|
||||
// Random overflows.
|
||||
for (int t = 0; t < kNumTests; ++t) {
|
||||
const int64_t n = absl::LogUniform<int64_t>(random, max_n + 1, kint64max);
|
||||
ASSERT_THAT(NChooseK(n, 3), StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("overflows int64")))
|
||||
<< DUMP_VARS(t, n);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NChooseKTest, ComparisonAgainstClosedFormsForK4) {
|
||||
// This is 1.5 + ∜24 × 2^(63/4).
|
||||
// Checked manually on Google's scientific calculator.
|
||||
const int64_t max_n =
|
||||
static_cast<int64_t>(1.5 + std::pow(24, 1.0 / 4) * std::pow(2, 63.0 / 4));
|
||||
for (int64_t n : {int64_t{4}, max_n}) {
|
||||
const int64_t n_choose_4 = static_cast<int64_t>(absl::uint128(n) * (n - 1) *
|
||||
(n - 2) * (n - 3) / 24);
|
||||
EXPECT_THAT(NChooseK(n, 4), IsOkAndHolds(n_choose_4)) << DUMP_VARS(n);
|
||||
}
|
||||
EXPECT_THAT(NChooseK(max_n + 1, 4),
|
||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("overflows int64")));
|
||||
|
||||
absl::BitGen random;
|
||||
constexpr int kNumTests = 100'000;
|
||||
// Random valid results.
|
||||
for (int t = 0; t < kNumTests; ++t) {
|
||||
const int64_t n = absl::LogUniform<int64_t>(random, 4, max_n);
|
||||
const int64_t n_choose_4 = static_cast<int64_t>(absl::uint128(n) * (n - 1) *
|
||||
(n - 2) * (n - 3) / 24);
|
||||
ASSERT_THAT(NChooseK(n, 4), IsOkAndHolds(n_choose_4)) << DUMP_VARS(t, n);
|
||||
}
|
||||
// Random overflows.
|
||||
for (int t = 0; t < kNumTests; ++t) {
|
||||
const int64_t n = absl::LogUniform<int64_t>(random, max_n + 1, kint64max);
|
||||
ASSERT_THAT(NChooseK(n, 4), StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("overflows int64")))
|
||||
<< DUMP_VARS(t, n);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(NChooseKTest, ComparisonAgainstPascalTriangleForK5OrAbove) {
|
||||
// Fill the Pascal triangle. Use -1 for int64_t overflows. We go up to n =
|
||||
// 17000 because (17000 Choose 5) ≈ 1.2e19 which overflows an int64_t.
|
||||
constexpr int max_n = 17000;
|
||||
FlatMatrix<int64_t> triangle(max_n + 1, max_n + 1);
|
||||
for (int n = 0; n <= max_n; ++n) {
|
||||
triangle[n][0] = 1;
|
||||
triangle[n][n] = 1;
|
||||
for (int i = 1; i < n; ++i) {
|
||||
const int64_t a = triangle[n - 1][i - 1];
|
||||
const int64_t b = triangle[n - 1][i];
|
||||
if (a < 0 || b < 0 || absl::int128(a) + b > kint64max) {
|
||||
triangle[n][i] = -1;
|
||||
} else {
|
||||
triangle[n][i] = a + b;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Checking all 17000²/2 slots would be too expensive, so we check each
|
||||
// "column" downwards until the first 10 overflows, and stop.
|
||||
for (int k = 5; k < max_n; ++k) {
|
||||
int num_overflows = 0;
|
||||
for (int n = k + 5; n < max_n; ++n) {
|
||||
if (num_overflows > 0) EXPECT_EQ(triangle[n][k], -1);
|
||||
if (triangle[n][k] < 0) {
|
||||
++num_overflows;
|
||||
EXPECT_THAT(NChooseK(n, k), StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("overflows int64")));
|
||||
if (num_overflows > 10) break;
|
||||
} else {
|
||||
EXPECT_THAT(NChooseK(n, k), IsOkAndHolds(triangle[n][k]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MatchesLogCombinations(int n, int k) {
|
||||
if (n < k) {
|
||||
std::swap(k, n);
|
||||
}
|
||||
const auto exact = NChooseK(n, k);
|
||||
const double log_approx = MathUtil::LogCombinations(n, k);
|
||||
if (exact.ok()) {
|
||||
// We accepted to compute the exact value, make sure that it matches the
|
||||
// approximation.
|
||||
ASSERT_NEAR(log(exact.value()), log_approx, 0.0001);
|
||||
} else {
|
||||
// We declined to compute the exact value, make sure that we had a good
|
||||
// reason to, i.e. that the result did indeed overflow.
|
||||
ASSERT_THAT(exact, StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("overflows int64")));
|
||||
const double approx = exp(log_approx);
|
||||
ASSERT_GE(std::nextafter(approx, std::numeric_limits<double>::infinity()),
|
||||
std::numeric_limits<int64_t>::max())
|
||||
<< "we declined to compute the exact value of NChooseK(" << n << ", "
|
||||
<< k << "), but the log value is " << log_approx
|
||||
<< " (value: " << approx << "), which fits in int64_t";
|
||||
}
|
||||
}
|
||||
/*
|
||||
FUZZ_TEST(NChooseKTest, MatchesLogCombinations)
|
||||
// Ideally we'd test with `uint64_t`, but `LogCombinations` only accepts
|
||||
// `int`.
|
||||
.WithDomains(NonNegative<int>(), NonNegative<int>());
|
||||
*/
|
||||
template <int kMaxN, auto algo>
|
||||
void BM_NChooseK(benchmark::State& state) {
|
||||
static constexpr int kNumInputs = 1000;
|
||||
// Use deterministic random numbers to avoid noise.
|
||||
std::mt19937 gen(42);
|
||||
std::uniform_int_distribution<int64_t> random(0, kMaxN);
|
||||
std::vector<std::pair<int64_t, int64_t>> inputs;
|
||||
inputs.reserve(kNumInputs);
|
||||
for (int i = 0; i < kNumInputs; ++i) {
|
||||
int64_t n = random(gen);
|
||||
int64_t k = random(gen);
|
||||
if (n < k) {
|
||||
std::swap(n, k);
|
||||
}
|
||||
inputs.emplace_back(n, k);
|
||||
}
|
||||
// Force the one-time, costly static initializations of NChooseK() to happen
|
||||
// before the benchmark starts.
|
||||
auto result = NChooseK(62, 31);
|
||||
benchmark::DoNotOptimize(result);
|
||||
|
||||
// Start the benchmark.
|
||||
for (auto _ : state) {
|
||||
for (const auto [n, k] : inputs) {
|
||||
auto result = algo(n, k);
|
||||
benchmark::DoNotOptimize(result);
|
||||
}
|
||||
}
|
||||
state.SetItemsProcessed(state.iterations() * kNumInputs);
|
||||
}
|
||||
BENCHMARK(BM_NChooseK<30, operations_research::NChooseK>); // int32_t domain.
|
||||
BENCHMARK(
|
||||
BM_NChooseK<60, operations_research::NChooseK>); // int{32,64} domain.
|
||||
BENCHMARK(
|
||||
BM_NChooseK<100, operations_research::NChooseK>); // int{32,64,128} domain.
|
||||
BENCHMARK(
|
||||
BM_NChooseK<100, MathUtil::LogCombinations>); // int{32,64,128} domain.
|
||||
|
||||
} // namespace
|
||||
} // namespace operations_research
|
||||
Reference in New Issue
Block a user