sync: cp-sat bug fixes, stringview, fix strong int missing stl code, update graph
This commit is contained in:
@@ -455,6 +455,58 @@ namespace std {
|
||||
template <typename StrongIntName, typename ValueType>
|
||||
struct hash<util_intops::StrongInt<StrongIntName, ValueType>>
|
||||
: util_intops::StrongInt<StrongIntName, ValueType>::Hasher {};
|
||||
|
||||
template <typename TagType, typename NativeType>
|
||||
struct numeric_limits<util_intops::StrongInt<TagType, NativeType>> {
|
||||
private:
|
||||
using StrongIntT = util_intops::StrongInt<TagType, NativeType>;
|
||||
|
||||
public:
|
||||
// NOLINTBEGIN(google3-readability-class-member-naming)
|
||||
static constexpr bool is_specialized = true;
|
||||
static constexpr bool is_signed = numeric_limits<NativeType>::is_signed;
|
||||
static constexpr bool is_integer = numeric_limits<NativeType>::is_integer;
|
||||
static constexpr bool is_exact = numeric_limits<NativeType>::is_exact;
|
||||
static constexpr bool has_infinity = numeric_limits<NativeType>::has_infinity;
|
||||
static constexpr bool has_quiet_NaN =
|
||||
numeric_limits<NativeType>::has_quiet_NaN;
|
||||
static constexpr bool has_signaling_NaN =
|
||||
numeric_limits<NativeType>::has_signaling_NaN;
|
||||
static constexpr float_denorm_style has_denorm =
|
||||
numeric_limits<NativeType>::has_denorm;
|
||||
static constexpr bool has_denorm_loss =
|
||||
numeric_limits<NativeType>::has_denorm_loss;
|
||||
static constexpr float_round_style round_style =
|
||||
numeric_limits<NativeType>::round_style;
|
||||
static constexpr bool is_iec559 = numeric_limits<NativeType>::is_iec559;
|
||||
static constexpr bool is_bounded = numeric_limits<NativeType>::is_bounded;
|
||||
static constexpr bool is_modulo = numeric_limits<NativeType>::is_modulo;
|
||||
static constexpr int digits = numeric_limits<NativeType>::digits;
|
||||
static constexpr int digits10 = numeric_limits<NativeType>::digits10;
|
||||
static constexpr int max_digits10 = numeric_limits<NativeType>::max_digits10;
|
||||
static constexpr int radix = numeric_limits<NativeType>::radix;
|
||||
static constexpr int min_exponent = numeric_limits<NativeType>::min_exponent;
|
||||
static constexpr int min_exponent10 =
|
||||
numeric_limits<NativeType>::min_exponent10;
|
||||
static constexpr int max_exponent = numeric_limits<NativeType>::max_exponent;
|
||||
static constexpr int max_exponent10 =
|
||||
numeric_limits<NativeType>::max_exponent10;
|
||||
static constexpr bool traps = numeric_limits<NativeType>::traps;
|
||||
static constexpr bool tinyness_before =
|
||||
numeric_limits<NativeType>::tinyness_before;
|
||||
// NOLINTEND(google3-readability-class-member-naming)
|
||||
|
||||
static constexpr StrongIntT(min)() { return StrongIntT(numeric_limits<NativeType>::min()); }
|
||||
static constexpr StrongIntT lowest() { return StrongIntT(numeric_limits<NativeType>::min()); }
|
||||
static constexpr StrongIntT(max)() { return StrongIntT(numeric_limits<NativeType>::max()); }
|
||||
static constexpr StrongIntT epsilon() { return StrongIntT(numeric_limits<NativeType>::epsilon()); }
|
||||
static constexpr StrongIntT round_error() { return StrongIntT(numeric_limits<NativeType>::round_error()); }
|
||||
static constexpr StrongIntT infinity() { return StrongIntT(numeric_limits<NativeType>::infinity()); }
|
||||
static constexpr StrongIntT quiet_NaN() { return StrongIntT(numeric_limits<NativeType>::quiet_NaN()); }
|
||||
static constexpr StrongIntT signaling_NaN() { return StrongIntT(numeric_limits<NativeType>::signaling_NaN()); }
|
||||
static constexpr StrongIntT denorm_min() { return StrongIntT(numeric_limits<NativeType>::denorm_min()); }
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
#endif // OR_TOOLS_BASE_STRONG_INT_H_
|
||||
|
||||
@@ -766,6 +766,8 @@ cc_test(
|
||||
":io",
|
||||
"//ortools/base:dump_vars",
|
||||
"//ortools/base:gmock_main",
|
||||
"//ortools/base:intops",
|
||||
"//ortools/base:strong_vector",
|
||||
"//ortools/util:flat_matrix",
|
||||
"@abseil-cpp//absl/algorithm:container",
|
||||
"@abseil-cpp//absl/log:check",
|
||||
|
||||
@@ -81,9 +81,6 @@ PathWithLength ConstrainedShortestPathsOnDag(
|
||||
// Advanced API.
|
||||
// -----------------------------------------------------------------------------
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
struct GraphPathWithLength {
|
||||
double length = 0.0;
|
||||
// The returned arc indices points into the `arcs_with_length` passed to the
|
||||
@@ -97,9 +94,6 @@ struct GraphPathWithLength {
|
||||
// computations efficiently on the given DAG (on which resources do not change).
|
||||
// `GraphType` can use one of the interfaces defined in `util/graph/graph.h`.
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
class ConstrainedShortestPathsOnDagWrapper {
|
||||
public:
|
||||
using NodeIndex = typename GraphType::NodeIndex;
|
||||
@@ -285,9 +279,6 @@ std::vector<T> GetInversePermutation(const std::vector<T>& permutation);
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
ConstrainedShortestPathsOnDagWrapper<GraphType>::
|
||||
ConstrainedShortestPathsOnDagWrapper(
|
||||
const GraphType* graph, const std::vector<double>* arc_lengths,
|
||||
@@ -543,9 +534,6 @@ ConstrainedShortestPathsOnDagWrapper<GraphType>::
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
GraphPathWithLength<GraphType> ConstrainedShortestPathsOnDagWrapper<
|
||||
GraphType>::RunConstrainedShortestPathOnDag() {
|
||||
if (source_is_destination_.has_value()) {
|
||||
@@ -664,9 +652,6 @@ GraphPathWithLength<GraphType> ConstrainedShortestPathsOnDagWrapper<
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
void ConstrainedShortestPathsOnDagWrapper<GraphType>::
|
||||
RunHalfConstrainedShortestPathOnDag(
|
||||
const GraphType& reverse_graph, absl::Span<const double> arc_lengths,
|
||||
@@ -792,9 +777,6 @@ void ConstrainedShortestPathsOnDagWrapper<GraphType>::
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
typename GraphType::ArcIndex
|
||||
ConstrainedShortestPathsOnDagWrapper<GraphType>::MergeHalfRuns(
|
||||
const GraphType& graph, absl::Span<const double> arc_lengths,
|
||||
@@ -879,9 +861,6 @@ ConstrainedShortestPathsOnDagWrapper<GraphType>::MergeHalfRuns(
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
std::vector<typename GraphType::ArcIndex>
|
||||
ConstrainedShortestPathsOnDagWrapper<GraphType>::ArcPathTo(
|
||||
const int best_label_index,
|
||||
@@ -901,9 +880,6 @@ ConstrainedShortestPathsOnDagWrapper<GraphType>::ArcPathTo(
|
||||
}
|
||||
|
||||
template <typename GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
std::vector<typename GraphType::NodeIndex>
|
||||
ConstrainedShortestPathsOnDagWrapper<GraphType>::NodePathImpliedBy(
|
||||
absl::Span<const ArcIndex> arc_path, const GraphType& graph) const {
|
||||
|
||||
@@ -15,9 +15,7 @@
|
||||
#define OR_TOOLS_GRAPH_DAG_SHORTEST_PATH_H_
|
||||
|
||||
#include <cmath>
|
||||
#if __cplusplus >= 202002L
|
||||
#include <concepts>
|
||||
#endif
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
@@ -82,39 +80,17 @@ std::vector<PathWithLength> KShortestPathsOnDag(
|
||||
// -----------------------------------------------------------------------------
|
||||
// Advanced API.
|
||||
// -----------------------------------------------------------------------------
|
||||
// This concept only enforces the standard graph API needed for all algorithms
|
||||
// on DAGs. One could add the requirement of being a DAG wihtin this concept
|
||||
// (which is done before running the algorithm).
|
||||
#if __cplusplus >= 202002L
|
||||
template <class GraphType>
|
||||
concept DagGraphType = requires(GraphType graph) {
|
||||
{ typename GraphType::NodeIndex{} };
|
||||
{ typename GraphType::ArcIndex{} };
|
||||
{ graph.num_nodes() } -> std::same_as<typename GraphType::NodeIndex>;
|
||||
{ graph.num_arcs() } -> std::same_as<typename GraphType::ArcIndex>;
|
||||
{ graph.OutgoingArcs(typename GraphType::NodeIndex{}) };
|
||||
{
|
||||
graph.Tail(typename GraphType::ArcIndex{})
|
||||
} -> std::same_as<typename GraphType::NodeIndex>;
|
||||
{
|
||||
graph.Head(typename GraphType::ArcIndex{})
|
||||
} -> std::same_as<typename GraphType::NodeIndex>;
|
||||
{ graph.Build() };
|
||||
};
|
||||
#endif
|
||||
|
||||
// A wrapper that holds the memory needed to run many shortest path computations
|
||||
// efficiently on the given DAG. One call of `RunShortestPathOnDag()` has time
|
||||
// complexity O(|E| + |V|) and space complexity O(|V|).
|
||||
// `GraphType` can use any of the interfaces defined in `util/graph/graph.h`.
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
// `ArcLengthContainer` can be any container of doubles.
|
||||
template <class GraphType, typename ArcLengthContainer = std::vector<double>>
|
||||
class ShortestPathsOnDagWrapper {
|
||||
public:
|
||||
using NodeIndex = typename GraphType::NodeIndex;
|
||||
using ArcIndex = typename GraphType::ArcIndex;
|
||||
using ArcLengths = ArcLengthContainer;
|
||||
|
||||
// IMPORTANT: All arguments must outlive the class.
|
||||
//
|
||||
@@ -138,7 +114,7 @@ class ShortestPathsOnDagWrapper {
|
||||
// so will obviously invalidate the result API of the last shortest path run,
|
||||
// which could return an upper bound, junk, or crash.
|
||||
ShortestPathsOnDagWrapper(const GraphType* graph,
|
||||
const std::vector<double>* arc_lengths,
|
||||
const ArcLengths* arc_lengths,
|
||||
absl::Span<const NodeIndex> topological_order);
|
||||
|
||||
// Computes the shortest path to all reachable nodes from the given sources.
|
||||
@@ -151,7 +127,9 @@ class ShortestPathsOnDagWrapper {
|
||||
const std::vector<NodeIndex>& reached_nodes() const { return reached_nodes_; }
|
||||
|
||||
// Returns the length of the shortest path from `node`'s source to `node`.
|
||||
double LengthTo(NodeIndex node) const { return length_from_sources_[node]; }
|
||||
double LengthTo(NodeIndex node) const {
|
||||
return length_from_sources_[static_cast<size_t>(node)];
|
||||
}
|
||||
std::vector<double> LengthTo() const { return length_from_sources_; }
|
||||
|
||||
// Returns the list of all the arcs in the shortest path from `node`'s
|
||||
@@ -164,12 +142,12 @@ class ShortestPathsOnDagWrapper {
|
||||
|
||||
// Accessors to the underlying graph and arc lengths.
|
||||
const GraphType& graph() const { return *graph_; }
|
||||
const std::vector<double>& arc_lengths() const { return *arc_lengths_; }
|
||||
const ArcLengths& arc_lengths() const { return *arc_lengths_; }
|
||||
|
||||
private:
|
||||
static constexpr double kInf = std::numeric_limits<double>::infinity();
|
||||
const GraphType* const graph_;
|
||||
const std::vector<double>* const arc_lengths_;
|
||||
const ArcLengths* const arc_lengths_;
|
||||
absl::Span<const NodeIndex> const topological_order_;
|
||||
|
||||
// Data about the last call of the RunShortestPathOnDag() function.
|
||||
@@ -185,14 +163,12 @@ class ShortestPathsOnDagWrapper {
|
||||
// `GraphType` can use any of the interfaces defined in `util/graph/graph.h`.
|
||||
// IMPORTANT: Only use if `path_count > 1` (k > 1) otherwise use
|
||||
// `ShortestPathsOnDagWrapper`.
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
template <class GraphType, typename ArcLengthContainer = std::vector<double>>
|
||||
class KShortestPathsOnDagWrapper {
|
||||
public:
|
||||
using NodeIndex = typename GraphType::NodeIndex;
|
||||
using ArcIndex = typename GraphType::ArcIndex;
|
||||
using ArcLengths = ArcLengthContainer;
|
||||
|
||||
// IMPORTANT: All arguments must outlive the class.
|
||||
//
|
||||
@@ -216,7 +192,7 @@ class KShortestPathsOnDagWrapper {
|
||||
// so will obviously invalidate the result API of the last shortest path run,
|
||||
// which could return an upper bound, junk, or crash.
|
||||
KShortestPathsOnDagWrapper(const GraphType* graph,
|
||||
const std::vector<double>* arc_lengths,
|
||||
const ArcLengths* arc_lengths,
|
||||
absl::Span<const NodeIndex> topological_order,
|
||||
int path_count);
|
||||
|
||||
@@ -244,14 +220,14 @@ class KShortestPathsOnDagWrapper {
|
||||
|
||||
// Accessors to the underlying graph and arc lengths.
|
||||
const GraphType& graph() const { return *graph_; }
|
||||
const std::vector<double>& arc_lengths() const { return *arc_lengths_; }
|
||||
const ArcLengths& arc_lengths() const { return *arc_lengths_; }
|
||||
int path_count() const { return path_count_; }
|
||||
|
||||
private:
|
||||
static constexpr double kInf = std::numeric_limits<double>::infinity();
|
||||
|
||||
const GraphType* const graph_;
|
||||
const std::vector<double>* const arc_lengths_;
|
||||
const ArcLengths* const arc_lengths_;
|
||||
absl::Span<const NodeIndex> const topological_order_;
|
||||
const int path_count_;
|
||||
|
||||
@@ -269,10 +245,7 @@ class KShortestPathsOnDagWrapper {
|
||||
std::vector<NodeIndex> reached_nodes_;
|
||||
};
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
template <class GraphType, typename ArcLengths>
|
||||
absl::Status TopologicalOrderIsValid(
|
||||
const GraphType& graph,
|
||||
absl::Span<const typename GraphType::NodeIndex> topological_order);
|
||||
@@ -286,9 +259,6 @@ absl::Status TopologicalOrderIsValid(
|
||||
// (2) assign into an index rather than with push_back
|
||||
// (3) return by absl::Span (or return a copy) with known size.
|
||||
template <typename GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
std::vector<typename GraphType::NodeIndex> NodePathImpliedBy(
|
||||
absl::Span<const typename GraphType::ArcIndex> arc_path,
|
||||
const GraphType& graph) {
|
||||
@@ -303,47 +273,47 @@ std::vector<typename GraphType::NodeIndex> NodePathImpliedBy(
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
void CheckNodeIsValid(typename GraphType::NodeIndex node,
|
||||
const GraphType& graph) {
|
||||
CHECK_GE(node, 0) << "Node must be nonnegative. Input value: " << node;
|
||||
CHECK_GE(node, typename GraphType::NodeIndex(0))
|
||||
<< "Node must be nonnegative. Input value: " << node;
|
||||
CHECK_LT(node, graph.num_nodes())
|
||||
<< "Node must be a valid node. Input value: " << node
|
||||
<< ". Number of nodes in the input graph: " << graph.num_nodes();
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
absl::Status TopologicalOrderIsValid(
|
||||
const GraphType& graph,
|
||||
absl::Span<const typename GraphType::NodeIndex> topological_order) {
|
||||
using NodeIndex = typename GraphType::NodeIndex;
|
||||
using ArcIndex = typename GraphType::ArcIndex;
|
||||
const NodeIndex num_nodes = graph.num_nodes();
|
||||
if (topological_order.size() != num_nodes) {
|
||||
if (topological_order.size() != static_cast<size_t>(num_nodes)) {
|
||||
return absl::InvalidArgumentError(absl::StrFormat(
|
||||
"topological_order.size() = %i, != graph.num_nodes() = %i",
|
||||
"topological_order.size() = %i, != graph.num_nodes() = %v",
|
||||
topological_order.size(), num_nodes));
|
||||
}
|
||||
std::vector<NodeIndex> inverse_topology(num_nodes, -1);
|
||||
for (NodeIndex node = 0; node < topological_order.size(); ++node) {
|
||||
if (inverse_topology[topological_order[node]] >= 0) {
|
||||
std::vector<NodeIndex> inverse_topology(static_cast<size_t>(num_nodes),
|
||||
GraphType::kNilNode);
|
||||
for (NodeIndex node(0); node < num_nodes; ++node) {
|
||||
if (inverse_topology[static_cast<size_t>(
|
||||
topological_order[static_cast<size_t>(node)])] !=
|
||||
GraphType::kNilNode) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("node % i appears twice in topological order",
|
||||
topological_order[node]));
|
||||
absl::StrFormat("node %v appears twice in topological order",
|
||||
topological_order[static_cast<size_t>(node)]));
|
||||
}
|
||||
inverse_topology[topological_order[node]] = node;
|
||||
inverse_topology[static_cast<size_t>(
|
||||
topological_order[static_cast<size_t>(node)])] = node;
|
||||
}
|
||||
for (NodeIndex tail = 0; tail < num_nodes; ++tail) {
|
||||
for (NodeIndex tail(0); tail < num_nodes; ++tail) {
|
||||
for (const ArcIndex arc : graph.OutgoingArcs(tail)) {
|
||||
const NodeIndex head = graph.Head(arc);
|
||||
if (inverse_topology[tail] >= inverse_topology[head]) {
|
||||
if (inverse_topology[static_cast<size_t>(tail)] >=
|
||||
inverse_topology[static_cast<size_t>(head)]) {
|
||||
return absl::InvalidArgumentError(absl::StrFormat(
|
||||
"arc (%i, %i) is inconsistent with topological order", tail, head));
|
||||
"arc (%v, %v) is inconsistent with topological order", tail, head));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -353,21 +323,20 @@ absl::Status TopologicalOrderIsValid(
|
||||
// -----------------------------------------------------------------------------
|
||||
// ShortestPathsOnDagWrapper implementation.
|
||||
// -----------------------------------------------------------------------------
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
ShortestPathsOnDagWrapper<GraphType>::ShortestPathsOnDagWrapper(
|
||||
const GraphType* graph, const std::vector<double>* arc_lengths,
|
||||
template <class GraphType, typename ArcLengths>
|
||||
ShortestPathsOnDagWrapper<GraphType, ArcLengths>::ShortestPathsOnDagWrapper(
|
||||
const GraphType* graph, const ArcLengths* arc_lengths,
|
||||
absl::Span<const NodeIndex> topological_order)
|
||||
: graph_(graph),
|
||||
arc_lengths_(arc_lengths),
|
||||
topological_order_(topological_order) {
|
||||
const size_t num_nodes = static_cast<size_t>(graph_->num_nodes());
|
||||
CHECK(graph_ != nullptr);
|
||||
CHECK(arc_lengths_ != nullptr);
|
||||
CHECK_GT(graph_->num_nodes(), 0) << "The graph is empty: it has no nodes";
|
||||
CHECK_GT(num_nodes, 0) << "The graph is empty: it has no nodes";
|
||||
#ifndef NDEBUG
|
||||
CHECK_EQ(arc_lengths_->size(), graph_->num_arcs());
|
||||
CHECK_EQ(typename GraphType::ArcIndex(arc_lengths_->size()),
|
||||
graph_->num_arcs());
|
||||
for (const double arc_length : *arc_lengths_) {
|
||||
CHECK(arc_length != -kInf && !std::isnan(arc_length))
|
||||
<< absl::StrFormat("length cannot be -inf nor NaN");
|
||||
@@ -378,16 +347,13 @@ ShortestPathsOnDagWrapper<GraphType>::ShortestPathsOnDagWrapper(
|
||||
|
||||
// Memory allocation is done here and only once in order to avoid reallocation
|
||||
// at each call of `RunShortestPathOnDag()` for better performance.
|
||||
length_from_sources_.resize(graph_->num_nodes(), kInf);
|
||||
incoming_shortest_path_arc_.resize(graph_->num_nodes(), -1);
|
||||
reached_nodes_.reserve(graph_->num_nodes());
|
||||
length_from_sources_.resize(num_nodes, kInf);
|
||||
incoming_shortest_path_arc_.resize(num_nodes, GraphType::kNilArc);
|
||||
reached_nodes_.reserve(num_nodes);
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
void ShortestPathsOnDagWrapper<GraphType>::RunShortestPathOnDag(
|
||||
template <class GraphType, typename ArcLengths>
|
||||
void ShortestPathsOnDagWrapper<GraphType, ArcLengths>::RunShortestPathOnDag(
|
||||
absl::Span<const NodeIndex> sources) {
|
||||
// Caching the vector addresses allow to not fetch it on each access.
|
||||
const absl::Span<double> length_from_sources =
|
||||
@@ -398,7 +364,7 @@ void ShortestPathsOnDagWrapper<GraphType>::RunShortestPathOnDag(
|
||||
// performance, so it only makes sense for nodes that are reachable from at
|
||||
// least one source, the other ones will contain junk.
|
||||
for (const NodeIndex node : reached_nodes_) {
|
||||
length_from_sources[node] = kInf;
|
||||
length_from_sources[static_cast<size_t>(node)] = kInf;
|
||||
}
|
||||
DCHECK(std::all_of(length_from_sources.begin(), length_from_sources.end(),
|
||||
[](double l) { return l == kInf; }));
|
||||
@@ -406,11 +372,12 @@ void ShortestPathsOnDagWrapper<GraphType>::RunShortestPathOnDag(
|
||||
|
||||
for (const NodeIndex source : sources) {
|
||||
CheckNodeIsValid(source, *graph_);
|
||||
length_from_sources[source] = 0.0;
|
||||
length_from_sources[static_cast<size_t>(source)] = 0.0;
|
||||
}
|
||||
|
||||
for (const NodeIndex tail : topological_order_) {
|
||||
const double length_to_tail = length_from_sources[tail];
|
||||
const double length_to_tail =
|
||||
length_from_sources[static_cast<size_t>(tail)];
|
||||
// Stop exploring a node as soon as its length to all sources is +inf.
|
||||
if (length_to_tail == kInf) {
|
||||
continue;
|
||||
@@ -418,37 +385,35 @@ void ShortestPathsOnDagWrapper<GraphType>::RunShortestPathOnDag(
|
||||
reached_nodes_.push_back(tail);
|
||||
for (const ArcIndex arc : graph_->OutgoingArcs(tail)) {
|
||||
const NodeIndex head = graph_->Head(arc);
|
||||
DCHECK(arc_lengths[arc] != -kInf);
|
||||
const double length_to_head = arc_lengths[arc] + length_to_tail;
|
||||
if (length_to_head < length_from_sources[head]) {
|
||||
length_from_sources[head] = length_to_head;
|
||||
incoming_shortest_path_arc_[head] = arc;
|
||||
DCHECK(arc_lengths[static_cast<size_t>(arc)] != -kInf);
|
||||
const double length_to_head =
|
||||
arc_lengths[static_cast<size_t>(arc)] + length_to_tail;
|
||||
if (length_to_head < length_from_sources[static_cast<size_t>(head)]) {
|
||||
length_from_sources[static_cast<size_t>(head)] = length_to_head;
|
||||
incoming_shortest_path_arc_[static_cast<size_t>(head)] = arc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
bool ShortestPathsOnDagWrapper<GraphType>::IsReachable(NodeIndex node) const {
|
||||
template <class GraphType, typename ArcLengths>
|
||||
bool ShortestPathsOnDagWrapper<GraphType, ArcLengths>::IsReachable(
|
||||
NodeIndex node) const {
|
||||
CheckNodeIsValid(node, *graph_);
|
||||
return length_from_sources_[node] < kInf;
|
||||
return length_from_sources_[static_cast<size_t>(node)] < kInf;
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
template <class GraphType, typename ArcLengths>
|
||||
std::vector<typename GraphType::ArcIndex>
|
||||
ShortestPathsOnDagWrapper<GraphType>::ArcPathTo(NodeIndex node) const {
|
||||
ShortestPathsOnDagWrapper<GraphType, ArcLengths>::ArcPathTo(
|
||||
NodeIndex node) const {
|
||||
CHECK(IsReachable(node));
|
||||
std::vector<ArcIndex> arc_path;
|
||||
NodeIndex current_node = node;
|
||||
for (int i = 0; i < graph_->num_nodes(); ++i) {
|
||||
ArcIndex current_arc = incoming_shortest_path_arc_[current_node];
|
||||
if (current_arc == -1) {
|
||||
for (NodeIndex i(0); i < graph_->num_nodes(); ++i) {
|
||||
ArcIndex current_arc =
|
||||
incoming_shortest_path_arc_[static_cast<size_t>(current_node)];
|
||||
if (current_arc == GraphType::kNilArc) {
|
||||
break;
|
||||
}
|
||||
arc_path.push_back(current_arc);
|
||||
@@ -458,12 +423,10 @@ ShortestPathsOnDagWrapper<GraphType>::ArcPathTo(NodeIndex node) const {
|
||||
return arc_path;
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
template <class GraphType, typename ArcLengths>
|
||||
std::vector<typename GraphType::NodeIndex>
|
||||
ShortestPathsOnDagWrapper<GraphType>::NodePathTo(NodeIndex node) const {
|
||||
ShortestPathsOnDagWrapper<GraphType, ArcLengths>::NodePathTo(
|
||||
NodeIndex node) const {
|
||||
const std::vector<typename GraphType::ArcIndex> arc_path = ArcPathTo(node);
|
||||
if (arc_path.empty()) {
|
||||
return {node};
|
||||
@@ -474,12 +437,9 @@ ShortestPathsOnDagWrapper<GraphType>::NodePathTo(NodeIndex node) const {
|
||||
// -----------------------------------------------------------------------------
|
||||
// KShortestPathsOnDagWrapper implementation.
|
||||
// -----------------------------------------------------------------------------
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
|
||||
const GraphType* graph, const std::vector<double>* arc_lengths,
|
||||
template <class GraphType, typename ArcLengths>
|
||||
KShortestPathsOnDagWrapper<GraphType, ArcLengths>::KShortestPathsOnDagWrapper(
|
||||
const GraphType* graph, const ArcLengths* arc_lengths,
|
||||
absl::Span<const NodeIndex> topological_order, const int path_count)
|
||||
: graph_(graph),
|
||||
arc_lengths_(arc_lengths),
|
||||
@@ -487,10 +447,12 @@ KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
|
||||
path_count_(path_count) {
|
||||
CHECK(graph_ != nullptr);
|
||||
CHECK(arc_lengths_ != nullptr);
|
||||
CHECK_GT(graph_->num_nodes(), 0) << "The graph is empty: it has no nodes";
|
||||
const size_t num_nodes = static_cast<size_t>(graph_->num_nodes());
|
||||
CHECK_GT(num_nodes, 0) << "The graph is empty: it has no nodes";
|
||||
CHECK_GT(path_count_, 0) << "path_count must be greater than 0";
|
||||
#ifndef NDEBUG
|
||||
CHECK_EQ(arc_lengths_->size(), graph_->num_arcs());
|
||||
CHECK_EQ(typename GraphType::ArcIndex(arc_lengths_->size()),
|
||||
graph_->num_arcs());
|
||||
for (const double arc_length : *arc_lengths_) {
|
||||
CHECK(arc_length != -kInf && !std::isnan(arc_length))
|
||||
<< absl::StrFormat("length cannot be -inf nor NaN");
|
||||
@@ -501,9 +463,9 @@ KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
|
||||
|
||||
// TODO(b/332475713): Optimize if reverse graph is already provided in
|
||||
// `GraphType`.
|
||||
const int num_arcs = graph_->num_arcs();
|
||||
const ArcIndex num_arcs = graph_->num_arcs();
|
||||
reverse_graph_ = GraphType(graph_->num_nodes(), num_arcs);
|
||||
for (ArcIndex arc_index = 0; arc_index < num_arcs; ++arc_index) {
|
||||
for (ArcIndex arc_index(0); arc_index < num_arcs; ++arc_index) {
|
||||
reverse_graph_.AddArc(graph->Head(arc_index), graph->Tail(arc_index));
|
||||
}
|
||||
std::vector<ArcIndex> permutation;
|
||||
@@ -511,7 +473,7 @@ KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
|
||||
arc_indices_.resize(permutation.size());
|
||||
if (!permutation.empty()) {
|
||||
for (int i = 0; i < permutation.size(); ++i) {
|
||||
arc_indices_[permutation[i]] = i;
|
||||
arc_indices_[static_cast<size_t>(permutation[i])] = ArcIndex(i);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -521,19 +483,16 @@ KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
|
||||
incoming_shortest_paths_arc_.resize(path_count_);
|
||||
incoming_shortest_paths_index_.resize(path_count_);
|
||||
for (int k = 0; k < path_count_; ++k) {
|
||||
lengths_from_sources_[k].resize(graph_->num_nodes(), kInf);
|
||||
incoming_shortest_paths_arc_[k].resize(graph_->num_nodes(), -1);
|
||||
incoming_shortest_paths_index_[k].resize(graph_->num_nodes(), -1);
|
||||
lengths_from_sources_[k].resize(num_nodes, kInf);
|
||||
incoming_shortest_paths_arc_[k].resize(num_nodes, GraphType::kNilArc);
|
||||
incoming_shortest_paths_index_[k].resize(num_nodes, -1);
|
||||
}
|
||||
is_source_.resize(graph_->num_nodes(), false);
|
||||
reached_nodes_.reserve(graph_->num_nodes());
|
||||
is_source_.resize(num_nodes, false);
|
||||
reached_nodes_.reserve(num_nodes);
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
|
||||
template <class GraphType, typename ArcLengths>
|
||||
void KShortestPathsOnDagWrapper<GraphType, ArcLengths>::RunKShortestPathOnDag(
|
||||
absl::Span<const NodeIndex> sources) {
|
||||
// Caching the vector addresses allow to not fetch it on each access.
|
||||
const absl::Span<const double> arc_lengths = *arc_lengths_;
|
||||
@@ -544,9 +503,9 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
|
||||
// least one source, the other ones will contain junk.
|
||||
|
||||
for (const NodeIndex node : reached_nodes_) {
|
||||
is_source_[node] = false;
|
||||
is_source_[static_cast<size_t>(node)] = false;
|
||||
for (int k = 0; k < path_count_; ++k) {
|
||||
lengths_from_sources_[k][node] = kInf;
|
||||
lengths_from_sources_[k][static_cast<size_t>(node)] = kInf;
|
||||
}
|
||||
}
|
||||
reached_nodes_.clear();
|
||||
@@ -560,14 +519,14 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
|
||||
|
||||
for (const NodeIndex source : sources) {
|
||||
CheckNodeIsValid(source, *graph_);
|
||||
is_source_[source] = true;
|
||||
is_source_[static_cast<size_t>(source)] = true;
|
||||
}
|
||||
|
||||
struct IncomingArcPath {
|
||||
double path_length = 0.0;
|
||||
ArcIndex arc_index = 0;
|
||||
ArcIndex arc_index = ArcIndex(0);
|
||||
double arc_length = 0.0;
|
||||
NodeIndex from = 0;
|
||||
NodeIndex from = NodeIndex(0);
|
||||
int path_index = 0;
|
||||
|
||||
bool operator<(const IncomingArcPath& other) const {
|
||||
@@ -580,18 +539,19 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
|
||||
auto comp = std::greater<IncomingArcPath>();
|
||||
for (const NodeIndex to : topological_order_) {
|
||||
min_heap.clear();
|
||||
if (is_source_[to]) {
|
||||
min_heap.push_back({.arc_index = -1});
|
||||
if (is_source_[static_cast<size_t>(to)]) {
|
||||
min_heap.push_back({.arc_index = GraphType::kNilArc});
|
||||
}
|
||||
for (const ArcIndex reverse_arc_index : reverse_graph_.OutgoingArcs(to)) {
|
||||
const ArcIndex arc_index = arc_indices.empty()
|
||||
? reverse_arc_index
|
||||
: arc_indices[reverse_arc_index];
|
||||
const ArcIndex arc_index =
|
||||
arc_indices.empty()
|
||||
? reverse_arc_index
|
||||
: arc_indices[static_cast<size_t>(reverse_arc_index)];
|
||||
const NodeIndex from = graph_->Tail(arc_index);
|
||||
const double arc_length = arc_lengths[arc_index];
|
||||
const double arc_length = arc_lengths[static_cast<size_t>(arc_index)];
|
||||
DCHECK(arc_length != -kInf);
|
||||
const double path_length =
|
||||
lengths_from_sources_.front()[from] + arc_length;
|
||||
lengths_from_sources_.front()[static_cast<size_t>(from)] + arc_length;
|
||||
if (path_length == kInf) {
|
||||
continue;
|
||||
}
|
||||
@@ -608,17 +568,21 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
|
||||
for (int k = 0; k < path_count_; ++k) {
|
||||
std::pop_heap(min_heap.begin(), min_heap.end(), comp);
|
||||
IncomingArcPath& incoming_arc_path = min_heap.back();
|
||||
lengths_from_sources_[k][to] = incoming_arc_path.path_length;
|
||||
incoming_shortest_paths_arc_[k][to] = incoming_arc_path.arc_index;
|
||||
incoming_shortest_paths_index_[k][to] = incoming_arc_path.path_index;
|
||||
if (incoming_arc_path.arc_index != -1 &&
|
||||
lengths_from_sources_[k][static_cast<size_t>(to)] =
|
||||
incoming_arc_path.path_length;
|
||||
incoming_shortest_paths_arc_[k][static_cast<size_t>(to)] =
|
||||
incoming_arc_path.arc_index;
|
||||
incoming_shortest_paths_index_[k][static_cast<size_t>(to)] =
|
||||
incoming_arc_path.path_index;
|
||||
if (incoming_arc_path.arc_index != GraphType::kNilArc &&
|
||||
incoming_arc_path.path_index < path_count_ - 1 &&
|
||||
lengths_from_sources_[incoming_arc_path.path_index + 1]
|
||||
[incoming_arc_path.from] < kInf) {
|
||||
[static_cast<size_t>(incoming_arc_path.from)] <
|
||||
kInf) {
|
||||
++incoming_arc_path.path_index;
|
||||
incoming_arc_path.path_length =
|
||||
lengths_from_sources_[incoming_arc_path.path_index]
|
||||
[incoming_arc_path.from] +
|
||||
[static_cast<size_t>(incoming_arc_path.from)] +
|
||||
incoming_arc_path.arc_length;
|
||||
std::push_heap(min_heap.begin(), min_heap.end(), comp);
|
||||
} else {
|
||||
@@ -631,25 +595,22 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
|
||||
}
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
bool KShortestPathsOnDagWrapper<GraphType>::IsReachable(NodeIndex node) const {
|
||||
template <class GraphType, typename ArcLengths>
|
||||
bool KShortestPathsOnDagWrapper<GraphType, ArcLengths>::IsReachable(
|
||||
NodeIndex node) const {
|
||||
CheckNodeIsValid(node, *graph_);
|
||||
return lengths_from_sources_.front()[node] < kInf;
|
||||
return lengths_from_sources_.front()[static_cast<size_t>(node)] < kInf;
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
std::vector<double> KShortestPathsOnDagWrapper<GraphType>::LengthsTo(
|
||||
template <class GraphType, typename ArcLengths>
|
||||
std::vector<double>
|
||||
KShortestPathsOnDagWrapper<GraphType, ArcLengths>::LengthsTo(
|
||||
NodeIndex node) const {
|
||||
std::vector<double> lengths_to;
|
||||
lengths_to.reserve(path_count_);
|
||||
for (int k = 0; k < path_count_; ++k) {
|
||||
const double length_to = lengths_from_sources_[k][node];
|
||||
const double length_to =
|
||||
lengths_from_sources_[k][static_cast<size_t>(node)];
|
||||
if (length_to == kInf) {
|
||||
break;
|
||||
}
|
||||
@@ -658,30 +619,30 @@ std::vector<double> KShortestPathsOnDagWrapper<GraphType>::LengthsTo(
|
||||
return lengths_to;
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
template <class GraphType, typename ArcLengths>
|
||||
std::vector<std::vector<typename GraphType::ArcIndex>>
|
||||
KShortestPathsOnDagWrapper<GraphType>::ArcPathsTo(NodeIndex node) const {
|
||||
KShortestPathsOnDagWrapper<GraphType, ArcLengths>::ArcPathsTo(
|
||||
NodeIndex node) const {
|
||||
std::vector<std::vector<ArcIndex>> arc_paths;
|
||||
arc_paths.reserve(path_count_);
|
||||
for (int k = 0; k < path_count_; ++k) {
|
||||
if (lengths_from_sources_[k][node] == kInf) {
|
||||
if (lengths_from_sources_[k][static_cast<size_t>(node)] == kInf) {
|
||||
break;
|
||||
}
|
||||
std::vector<ArcIndex> arc_path;
|
||||
int current_path_index = k;
|
||||
NodeIndex current_node = node;
|
||||
for (int i = 0; i < graph_->num_nodes(); ++i) {
|
||||
for (NodeIndex i(0); i < graph_->num_nodes(); ++i) {
|
||||
ArcIndex current_arc =
|
||||
incoming_shortest_paths_arc_[current_path_index][current_node];
|
||||
if (current_arc == -1) {
|
||||
incoming_shortest_paths_arc_[current_path_index]
|
||||
[static_cast<size_t>(current_node)];
|
||||
if (current_arc == GraphType::kNilArc) {
|
||||
break;
|
||||
}
|
||||
arc_path.push_back(current_arc);
|
||||
current_path_index =
|
||||
incoming_shortest_paths_index_[current_path_index][current_node];
|
||||
incoming_shortest_paths_index_[current_path_index]
|
||||
[static_cast<size_t>(current_node)];
|
||||
current_node = graph_->Tail(current_arc);
|
||||
}
|
||||
absl::c_reverse(arc_path);
|
||||
@@ -690,12 +651,10 @@ KShortestPathsOnDagWrapper<GraphType>::ArcPathsTo(NodeIndex node) const {
|
||||
return arc_paths;
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
template <class GraphType, typename ArcLengths>
|
||||
std::vector<std::vector<typename GraphType::NodeIndex>>
|
||||
KShortestPathsOnDagWrapper<GraphType>::NodePathsTo(NodeIndex node) const {
|
||||
KShortestPathsOnDagWrapper<GraphType, ArcLengths>::NodePathsTo(
|
||||
NodeIndex node) const {
|
||||
const std::vector<std::vector<ArcIndex>> arc_paths = ArcPathsTo(node);
|
||||
std::vector<std::vector<NodeIndex>> node_paths(arc_paths.size());
|
||||
for (int k = 0; k < arc_paths.size(); ++k) {
|
||||
|
||||
@@ -29,6 +29,8 @@
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/base/dump_vars.h"
|
||||
#include "ortools/base/gmock.h"
|
||||
#include "ortools/base/strong_int.h"
|
||||
#include "ortools/base/strong_vector.h"
|
||||
#include "ortools/graph/graph.h"
|
||||
#include "ortools/graph/graph_io.h"
|
||||
#include "ortools/util/flat_matrix.h"
|
||||
@@ -75,7 +77,7 @@ TEST(TopologicalOrderIsValidTest, ValidateTopologicalOrder) {
|
||||
TEST(ShortestPathOnDagTest, EmptyGraph) {
|
||||
EXPECT_DEATH(ShortestPathsOnDag(/*num_nodes=*/0, /*arcs_with_length=*/{},
|
||||
/*source=*/0, /*destination=*/0),
|
||||
"num_nodes\\(\\) > 0");
|
||||
"num_nodes > 0");
|
||||
}
|
||||
|
||||
TEST(ShortestPathOnDagTest, NonExistingSourceBecauseNegative) {
|
||||
@@ -89,7 +91,7 @@ TEST(ShortestPathOnDagTest, NonExistingSourceBecauseTooLarge) {
|
||||
EXPECT_DEATH(
|
||||
ShortestPathsOnDag(/*num_nodes=*/2, /*arcs_with_length=*/{{0, 1, 0.0}},
|
||||
/*source=*/3, /*destination=*/1),
|
||||
"num_nodes\\(\\)");
|
||||
"num_nodes");
|
||||
}
|
||||
|
||||
TEST(ShortestPathOnDagTest, NonExistingDestinationBecauseNegative) {
|
||||
@@ -103,7 +105,7 @@ TEST(ShortestPathOnDagTest, NonExistingDestinationBecauseTooLarge) {
|
||||
EXPECT_DEATH(
|
||||
ShortestPathsOnDag(/*num_nodes=*/2, /*arcs_with_length=*/{{0, 1, 0.0}},
|
||||
/*source=*/0, /*destination=*/3),
|
||||
"num_nodes\\(\\)");
|
||||
"num_nodes");
|
||||
}
|
||||
|
||||
TEST(ShortestPathOnDagTest, Cycle) {
|
||||
@@ -287,6 +289,37 @@ TEST(ShortestPathOnDagTest, UpdateCost) {
|
||||
/*node_path=*/ElementsAre(source, b, destination)));
|
||||
}
|
||||
|
||||
DEFINE_STRONG_INT_TYPE(NodeIndex, int32_t);
|
||||
DEFINE_STRONG_INT_TYPE(ArcIndex, int32_t);
|
||||
|
||||
TEST(ShortestPathsOnDagWrapperTest, StrongIndices) {
|
||||
const NodeIndex source_1(0);
|
||||
const NodeIndex source_2(1);
|
||||
const NodeIndex destination(2);
|
||||
const NodeIndex num_nodes(3);
|
||||
util::ListGraph<NodeIndex, ArcIndex> graph(num_nodes,
|
||||
/*arc_capacity=*/ArcIndex(2));
|
||||
|
||||
using ArcLengths = util_intops::StrongVector<ArcIndex, double>;
|
||||
ArcLengths arc_lengths;
|
||||
graph.AddArc(source_1, destination);
|
||||
arc_lengths.push_back(-6.0);
|
||||
graph.AddArc(source_2, destination);
|
||||
arc_lengths.push_back(3.0);
|
||||
const std::vector<NodeIndex> topological_order = {source_2, source_1,
|
||||
destination};
|
||||
ShortestPathsOnDagWrapper<util::ListGraph<NodeIndex, ArcIndex>, ArcLengths>
|
||||
shortest_path_on_dag(&graph, &arc_lengths, topological_order);
|
||||
shortest_path_on_dag.RunShortestPathOnDag({source_1, source_2});
|
||||
|
||||
EXPECT_TRUE(shortest_path_on_dag.IsReachable(destination));
|
||||
EXPECT_THAT(shortest_path_on_dag.LengthTo(destination), -6.0);
|
||||
EXPECT_THAT(shortest_path_on_dag.ArcPathTo(destination),
|
||||
ElementsAre(ArcIndex(0)));
|
||||
EXPECT_THAT(shortest_path_on_dag.NodePathTo(destination),
|
||||
ElementsAre(source_1, destination));
|
||||
}
|
||||
|
||||
TEST(ShortestPathsOnDagWrapperTest, MultipleSources) {
|
||||
const int source_1 = 0;
|
||||
const int source_2 = 1;
|
||||
@@ -634,7 +667,7 @@ TEST(KShortestPathOnDagTest, EmptyGraph) {
|
||||
EXPECT_DEATH(
|
||||
KShortestPathsOnDag(/*num_nodes=*/0, /*arcs_with_length=*/{},
|
||||
/*source=*/0, /*destination=*/0, /*path_count=*/2),
|
||||
"num_nodes\\(\\) > 0");
|
||||
"num_nodes > 0");
|
||||
}
|
||||
|
||||
TEST(KShortestPathOnDagTest, NonExistingSourceBecauseNegative) {
|
||||
@@ -648,7 +681,7 @@ TEST(KShortestPathOnDagTest, NonExistingSourceBecauseTooLarge) {
|
||||
EXPECT_DEATH(
|
||||
KShortestPathsOnDag(/*num_nodes=*/2, /*arcs_with_length=*/{{0, 1, 0.0}},
|
||||
/*source=*/3, /*destination=*/1, /*path_count=*/2),
|
||||
"num_nodes\\(\\)");
|
||||
"num_nodes");
|
||||
}
|
||||
|
||||
TEST(KShortestPathOnDagTest, NonExistingDestinationBecauseNegative) {
|
||||
@@ -936,6 +969,38 @@ TEST(KShortestPathOnDagTest, UpdateCost) {
|
||||
/*node_path=*/ElementsAre(source, a, destination))));
|
||||
}
|
||||
|
||||
TEST(KShortestPathsOnDagWrapperTest, StrongIndices) {
|
||||
const NodeIndex source_1(0);
|
||||
const NodeIndex source_2(1);
|
||||
const NodeIndex destination(2);
|
||||
const NodeIndex num_nodes(3);
|
||||
util::ListGraph<NodeIndex, ArcIndex> graph(num_nodes,
|
||||
/*arc_capacity=*/ArcIndex(2));
|
||||
|
||||
using ArcLengths = util_intops::StrongVector<ArcIndex, double>;
|
||||
ArcLengths arc_lengths;
|
||||
graph.AddArc(source_1, destination);
|
||||
arc_lengths.push_back(-6.0);
|
||||
graph.AddArc(source_2, destination);
|
||||
arc_lengths.push_back(3.0);
|
||||
const std::vector<NodeIndex> topological_order = {source_2, source_1,
|
||||
destination};
|
||||
const int path_count = 2;
|
||||
KShortestPathsOnDagWrapper<util::ListGraph<NodeIndex, ArcIndex>, ArcLengths>
|
||||
shortest_paths_on_dag(&graph, &arc_lengths, topological_order,
|
||||
path_count);
|
||||
shortest_paths_on_dag.RunKShortestPathOnDag({source_1, source_2});
|
||||
|
||||
EXPECT_TRUE(shortest_paths_on_dag.IsReachable(destination));
|
||||
EXPECT_THAT(shortest_paths_on_dag.LengthsTo(destination),
|
||||
ElementsAre(-6.0, 3.0));
|
||||
EXPECT_THAT(shortest_paths_on_dag.ArcPathsTo(destination),
|
||||
ElementsAre(ElementsAre(ArcIndex(0)), ElementsAre(ArcIndex(1))));
|
||||
EXPECT_THAT(shortest_paths_on_dag.NodePathsTo(destination),
|
||||
ElementsAre(ElementsAre(source_1, destination),
|
||||
ElementsAre(source_2, destination)));
|
||||
}
|
||||
|
||||
TEST(KShortestPathsOnDagWrapperTest, MultipleSources) {
|
||||
const int source_1 = 0;
|
||||
const int source_2 = 1;
|
||||
|
||||
@@ -286,8 +286,12 @@ class BaseGraph {
|
||||
|
||||
// Constants that will never be a valid node or arc.
|
||||
// They are the maximum possible node and arc capacity.
|
||||
static const NodeIndexType kNilNode;
|
||||
static const ArcIndexType kNilArc;
|
||||
static_assert(std::numeric_limits<NodeIndexType>::is_specialized);
|
||||
static constexpr NodeIndexType kNilNode =
|
||||
std::numeric_limits<NodeIndexType>::max();
|
||||
static_assert(std::numeric_limits<ArcIndexType>::is_specialized);
|
||||
static constexpr ArcIndexType kNilArc =
|
||||
std::numeric_limits<ArcIndexType>::max();
|
||||
|
||||
protected:
|
||||
// Functions commented when defined because they are implementation details.
|
||||
@@ -590,6 +594,26 @@ class SVector {
|
||||
|
||||
} // namespace internal
|
||||
|
||||
// Graph traits, to allow algorithms to manipulate graphs as adjacency lists.
|
||||
// This works with any graph type, and any object that has:
|
||||
// - a size() method returning the number of nodes.
|
||||
// - an operator[] method taking a node index and returning a range of neighbour
|
||||
// node indices.
|
||||
// One common example is using `std::vector<std::vector<int>>` to represent
|
||||
// adjacency lists.
|
||||
template <typename Graph>
|
||||
struct GraphTraits {
|
||||
private:
|
||||
// The type of the range returned by `operator[]`.
|
||||
using NeighborRangeType = std::decay_t<
|
||||
decltype(std::declval<Graph>()[std::declval<Graph>().size()])>;
|
||||
|
||||
public:
|
||||
// The index type for nodes of the graph.
|
||||
using NodeIndex =
|
||||
std::decay_t<decltype(*(std::declval<NeighborRangeType>().begin()))>;
|
||||
};
|
||||
|
||||
// Basic graph implementation without reverse arc. This class also serves as a
|
||||
// documentation for the generic graph interface (minus the part related to
|
||||
// reverse arcs).
|
||||
@@ -1121,18 +1145,6 @@ BaseGraph<NodeIndexType, ArcIndexType, HasNegativeReverseArcs>::AllForwardArcs()
|
||||
return IntegerRange<ArcIndexType>(ArcIndexType(0), num_arcs_);
|
||||
}
|
||||
|
||||
template <typename NodeIndexType, typename ArcIndexType,
|
||||
bool HasNegativeReverseArcs>
|
||||
const NodeIndexType
|
||||
BaseGraph<NodeIndexType, ArcIndexType, HasNegativeReverseArcs>::kNilNode =
|
||||
std::numeric_limits<NodeIndexType>::max();
|
||||
|
||||
template <typename NodeIndexType, typename ArcIndexType,
|
||||
bool HasNegativeReverseArcs>
|
||||
const ArcIndexType
|
||||
BaseGraph<NodeIndexType, ArcIndexType, HasNegativeReverseArcs>::kNilArc =
|
||||
std::numeric_limits<ArcIndexType>::max();
|
||||
|
||||
template <typename NodeIndexType, typename ArcIndexType,
|
||||
bool HasNegativeReverseArcs>
|
||||
NodeIndexType BaseGraph<NodeIndexType, ArcIndexType,
|
||||
|
||||
@@ -57,6 +57,8 @@ class BeginEndWrapper {
|
||||
|
||||
Iterator begin() const { return begin_; }
|
||||
Iterator end() const { return end_; }
|
||||
|
||||
// Available only if `Iterator` is a random access iterator.
|
||||
size_t size() const { return end_ - begin_; }
|
||||
|
||||
bool empty() const { return begin() == end(); }
|
||||
@@ -127,8 +129,6 @@ class IntegerRangeIterator
|
||||
#endif
|
||||
{
|
||||
public:
|
||||
// TODO(b/385094969): This should be `IntegerType` for integers,
|
||||
// `IntegerType:value_type` for strong signed integer types.
|
||||
using difference_type = ptrdiff_t;
|
||||
using value_type = IntegerType;
|
||||
|
||||
@@ -210,7 +210,7 @@ class IntegerRangeIterator
|
||||
|
||||
friend difference_type operator-(const IntegerRangeIterator l,
|
||||
const IntegerRangeIterator r) {
|
||||
return l.index_ - r.index_;
|
||||
return static_cast<difference_type>(l.index_ - r.index_);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -248,9 +248,7 @@ class ChasingIterator
|
||||
#endif
|
||||
{
|
||||
public:
|
||||
// TODO(b/385094969): This should be `IntegerType` for integers,
|
||||
// `IntegerType:value_type` for strong signed integer types.
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using difference_type = ptrdiff_t;
|
||||
using value_type = IndexT;
|
||||
|
||||
ChasingIterator() : index_(sentinel), next_(nullptr) {}
|
||||
|
||||
@@ -30,6 +30,7 @@
|
||||
#ifndef UTIL_GRAPH_TOPOLOGICALSORTER_H__
|
||||
#define UTIL_GRAPH_TOPOLOGICALSORTER_H__
|
||||
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <queue>
|
||||
@@ -84,7 +85,9 @@ namespace graph {
|
||||
// FastTopologicalSort(util::StaticGraph<>::FromArcs(num_nodes, arcs)));
|
||||
//
|
||||
template <class AdjacencyLists> // vector<vector<int>>, util::StaticGraph<>, ..
|
||||
absl::StatusOr<std::vector<int>> FastTopologicalSort(const AdjacencyLists& adj);
|
||||
absl::StatusOr<
|
||||
std::vector<typename util::GraphTraits<AdjacencyLists>::NodeIndex>>
|
||||
FastTopologicalSort(const AdjacencyLists& adj);
|
||||
|
||||
// Finds a cycle in the directed graph given as argument: nodes are dense
|
||||
// integers in 0..num_nodes-1, and (directed) arcs are pairs of nodes
|
||||
@@ -93,7 +96,9 @@ absl::StatusOr<std::vector<int>> FastTopologicalSort(const AdjacencyLists& adj);
|
||||
// if the cycle 1->4->3->1 exists.
|
||||
// If the graph is acyclic, returns an empty vector.
|
||||
template <class AdjacencyLists> // vector<vector<int>>, util::StaticGraph<>, ..
|
||||
absl::StatusOr<std::vector<int>> FindCycleInGraph(const AdjacencyLists& adj);
|
||||
absl::StatusOr<
|
||||
std::vector<typename util::GraphTraits<AdjacencyLists>::NodeIndex>>
|
||||
FindCycleInGraph(const AdjacencyLists& adj);
|
||||
|
||||
} // namespace graph
|
||||
|
||||
@@ -615,38 +620,38 @@ std::vector<T> StableTopologicalSortOrDie(
|
||||
}
|
||||
|
||||
template <class AdjacencyLists>
|
||||
absl::StatusOr<std::vector<int>> FastTopologicalSort(
|
||||
const AdjacencyLists& adj) {
|
||||
const size_t num_nodes = adj.size();
|
||||
if (num_nodes > std::numeric_limits<int>::max()) {
|
||||
return absl::InvalidArgumentError("More than kint32max nodes");
|
||||
absl::StatusOr<std::vector<typename GraphTraits<AdjacencyLists>::NodeIndex>>
|
||||
FastTopologicalSort(const AdjacencyLists& adj) {
|
||||
using NodeIndex = typename GraphTraits<AdjacencyLists>::NodeIndex;
|
||||
if (adj.size() > std::numeric_limits<NodeIndex>::max()) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("Too many nodes: adj.size()=%v", adj.size()));
|
||||
}
|
||||
std::vector<int> indegree(num_nodes, 0);
|
||||
std::vector<int> topo_order;
|
||||
topo_order.reserve(num_nodes);
|
||||
for (int from = 0; from < num_nodes; ++from) {
|
||||
for (const int head : adj[from]) {
|
||||
// We cast to unsigned int to test "head < 0 || head ≥ num_nodes" with a
|
||||
// single test. Microbenchmarks showed a ~1% overall performance gain.
|
||||
if (static_cast<uint32_t>(head) >= num_nodes) {
|
||||
const NodeIndex num_nodes(adj.size());
|
||||
std::vector<NodeIndex> indegree(static_cast<size_t>(num_nodes), NodeIndex(0));
|
||||
std::vector<NodeIndex> topo_order;
|
||||
topo_order.reserve(static_cast<size_t>(num_nodes));
|
||||
for (NodeIndex from(0); from < num_nodes; ++from) {
|
||||
for (const NodeIndex head : adj[from]) {
|
||||
if (!(NodeIndex(0) <= head && head < num_nodes)) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("Invalid arc in adj[%d]: %d (num_nodes=%d)", from,
|
||||
absl::StrFormat("Invalid arc in adj[%v]: %v (num_nodes=%v)", from,
|
||||
head, num_nodes));
|
||||
}
|
||||
// NOTE(user): We could detect self-arcs here (head == from) and exit
|
||||
// early, but microbenchmarks show a 2 to 4% slow-down if we do it, so we
|
||||
// simply rely on self-arcs being detected as cycles in the topo sort.
|
||||
++indegree[head];
|
||||
++indegree[static_cast<size_t>(head)];
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < num_nodes; ++i) {
|
||||
if (!indegree[i]) topo_order.push_back(i);
|
||||
for (NodeIndex i(0); i < num_nodes; ++i) {
|
||||
if (!indegree[static_cast<size_t>(i)]) topo_order.push_back(i);
|
||||
}
|
||||
size_t num_visited = 0;
|
||||
while (num_visited < topo_order.size()) {
|
||||
const int from = topo_order[num_visited++];
|
||||
for (const int head : adj[from]) {
|
||||
if (!--indegree[head]) topo_order.push_back(head);
|
||||
const NodeIndex from = topo_order[num_visited++];
|
||||
for (const NodeIndex head : adj[from]) {
|
||||
if (!--indegree[static_cast<size_t>(head)]) topo_order.push_back(head);
|
||||
}
|
||||
}
|
||||
if (topo_order.size() < static_cast<size_t>(num_nodes)) {
|
||||
@@ -656,77 +661,99 @@ absl::StatusOr<std::vector<int>> FastTopologicalSort(
|
||||
}
|
||||
|
||||
template <class AdjacencyLists>
|
||||
absl::StatusOr<std::vector<int>> FindCycleInGraph(const AdjacencyLists& adj) {
|
||||
const size_t num_nodes = adj.size();
|
||||
if (num_nodes > std::numeric_limits<int>::max()) {
|
||||
absl::StatusOr<
|
||||
std::vector<typename util::GraphTraits<AdjacencyLists>::NodeIndex>>
|
||||
FindCycleInGraph(const AdjacencyLists& adj) {
|
||||
using NodeIndex = typename GraphTraits<AdjacencyLists>::NodeIndex;
|
||||
if (adj.size() > std::numeric_limits<NodeIndex>::max()) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("Too many nodes: adj.size()=%d", adj.size()));
|
||||
absl::StrFormat("Too many nodes: adj.size()=%v", adj.size()));
|
||||
}
|
||||
const NodeIndex num_nodes(adj.size());
|
||||
|
||||
// First pass to validate that inputs are valid.
|
||||
for (NodeIndex node(0); node < NodeIndex(node); ++node) {
|
||||
for (const NodeIndex head : adj[node]) {
|
||||
if (head >= num_nodes) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("Invalid child %v in adj[%v]", head, node));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// To find a cycle, we start a DFS from each yet-unvisited node and
|
||||
// try to find a cycle, if we don't find it then we know for sure that
|
||||
// no cycle is reachable from any of the explored nodes (so, we don't
|
||||
// explore them in later DFSs).
|
||||
std::vector<bool> no_cycle_reachable_from(num_nodes, false);
|
||||
std::vector<bool> no_cycle_reachable_from(static_cast<size_t>(num_nodes),
|
||||
false);
|
||||
// The DFS stack will contain a chain of nodes, from the root of the
|
||||
// DFS to the current leaf.
|
||||
struct DfsState {
|
||||
int node;
|
||||
NodeIndex node;
|
||||
// Points at the first child node that we did *not* yet look at.
|
||||
int adj_list_index;
|
||||
explicit DfsState(int _node) : node(_node), adj_list_index(0) {}
|
||||
decltype(adj[NodeIndex(0)].begin()) children;
|
||||
decltype(adj[NodeIndex(0)].end()) children_end;
|
||||
|
||||
explicit DfsState(NodeIndex _node,
|
||||
const decltype(adj[NodeIndex(0)])& neighbours)
|
||||
: node(_node),
|
||||
children(neighbours.begin()),
|
||||
children_end(neighbours.end()) {}
|
||||
};
|
||||
std::vector<DfsState> dfs_stack;
|
||||
std::vector<bool> in_cur_stack(num_nodes, false);
|
||||
for (int start_node = 0; start_node < static_cast<int>(num_nodes);
|
||||
std::vector<bool> visited(static_cast<size_t>(num_nodes), false);
|
||||
for (NodeIndex start_node(0); start_node < NodeIndex(num_nodes);
|
||||
++start_node) {
|
||||
if (no_cycle_reachable_from[start_node]) continue;
|
||||
if (no_cycle_reachable_from[static_cast<size_t>(start_node)]) continue;
|
||||
|
||||
// Start the DFS.
|
||||
dfs_stack.push_back(DfsState(start_node));
|
||||
in_cur_stack[start_node] = true;
|
||||
visited[static_cast<size_t>(start_node)] = true;
|
||||
dfs_stack.push_back(DfsState(start_node, adj[start_node]));
|
||||
while (!dfs_stack.empty()) {
|
||||
DfsState* cur_state = &dfs_stack.back();
|
||||
if (static_cast<size_t>(cur_state->adj_list_index) >=
|
||||
adj[cur_state->node].size()) {
|
||||
no_cycle_reachable_from[cur_state->node] = true;
|
||||
in_cur_stack[cur_state->node] = false;
|
||||
DfsState* const cur_state = &dfs_stack.back();
|
||||
while (
|
||||
cur_state->children != cur_state->children_end &&
|
||||
no_cycle_reachable_from[static_cast<size_t>(*cur_state->children)]) {
|
||||
++cur_state->children;
|
||||
}
|
||||
if (cur_state->children == cur_state->children_end) {
|
||||
no_cycle_reachable_from[static_cast<size_t>(cur_state->node)] = true;
|
||||
dfs_stack.pop_back();
|
||||
continue;
|
||||
}
|
||||
// Look at the current child, and increase the current state's
|
||||
// adj_list_index.
|
||||
// TODO(user): Caching adj[cur_state->node] in a local stack to improve
|
||||
// locality and so that the [] operator is called exactly once per node.
|
||||
const int child = adj[cur_state->node][cur_state->adj_list_index++];
|
||||
if (static_cast<size_t>(child) >= num_nodes) {
|
||||
return absl::InvalidArgumentError(absl::StrFormat(
|
||||
"Invalid child %d in adj[%d]", child, cur_state->node));
|
||||
}
|
||||
if (no_cycle_reachable_from[child]) continue;
|
||||
if (in_cur_stack[child]) {
|
||||
|
||||
const NodeIndex child = *cur_state->children;
|
||||
// At that point the child is either:
|
||||
// - visited and all finalized (all its children are visited). We know
|
||||
// that it's not part of a cycle, otherwise we'd already have
|
||||
// returned.
|
||||
// - visited and not finalized (some of its children are not visited).
|
||||
// That means that we've reached it again from a child, so we've found
|
||||
// a cycle.
|
||||
// - not visited. We push it on the stack and explore it.
|
||||
if (no_cycle_reachable_from[static_cast<size_t>(child)]) continue;
|
||||
if (visited[static_cast<size_t>(child)]) {
|
||||
// We detected a cycle! It corresponds to the tail end of dfs_stack,
|
||||
// in reverse order, until we find "child".
|
||||
int cycle_start = dfs_stack.size() - 1;
|
||||
size_t cycle_start = dfs_stack.size() - 1;
|
||||
while (dfs_stack[cycle_start].node != child) --cycle_start;
|
||||
const int cycle_size = dfs_stack.size() - cycle_start;
|
||||
std::vector<int> cycle(cycle_size);
|
||||
for (int c = 0; c < cycle_size; ++c) {
|
||||
const size_t cycle_size = dfs_stack.size() - cycle_start;
|
||||
std::vector<NodeIndex> cycle(cycle_size);
|
||||
for (size_t c = 0; c < cycle_size; ++c) {
|
||||
cycle[c] = dfs_stack[cycle_start + c].node;
|
||||
}
|
||||
return cycle;
|
||||
}
|
||||
|
||||
// Push the child onto the stack.
|
||||
dfs_stack.push_back(DfsState(child));
|
||||
in_cur_stack[child] = true;
|
||||
// Verify that its adjacency list seems valid.
|
||||
if (adj[child].size() > std::numeric_limits<int>::max()) {
|
||||
return absl::InvalidArgumentError(absl::StrFormat(
|
||||
"Invalid adj[%d].size() = %d", child, adj[child].size()));
|
||||
}
|
||||
dfs_stack.push_back(DfsState(child, adj[child]));
|
||||
visited[static_cast<size_t>(child)] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// If we're here, then all the DFS stopped, and there is no cycle.
|
||||
return std::vector<int>{};
|
||||
return std::vector<NodeIndex>{};
|
||||
}
|
||||
|
||||
} // namespace graph
|
||||
|
||||
@@ -346,7 +346,7 @@ void LoadGurobiFunctions(DynamicLibrary* gurobi_dynamic_library) {
|
||||
|
||||
std::vector<std::string> GurobiDynamicLibraryPotentialPaths() {
|
||||
std::vector<std::string> potential_paths;
|
||||
const std::vector<std::string> kGurobiVersions = {
|
||||
const std::vector<absl::string_view> kGurobiVersions = {
|
||||
"1201", "1200", "1103", "1102", "1101", "1100", "1003",
|
||||
"1002", "1001", "1000", "952", "951", "950", "911",
|
||||
"910", "903", "902", "811", "801", "752"};
|
||||
@@ -355,8 +355,8 @@ std::vector<std::string> GurobiDynamicLibraryPotentialPaths() {
|
||||
// Look for libraries pointed by GUROBI_HOME first.
|
||||
const char* gurobi_home_from_env = getenv("GUROBI_HOME");
|
||||
if (gurobi_home_from_env != nullptr) {
|
||||
for (const std::string& version : kGurobiVersions) {
|
||||
const std::string lib = version.substr(0, version.size() - 1);
|
||||
for (const absl::string_view version : kGurobiVersions) {
|
||||
const absl::string_view lib = version.substr(0, version.size() - 1);
|
||||
#if defined(_MSC_VER) // Windows
|
||||
potential_paths.push_back(
|
||||
absl::StrCat(gurobi_home_from_env, "\\bin\\gurobi", lib, ".dll"));
|
||||
@@ -376,8 +376,8 @@ std::vector<std::string> GurobiDynamicLibraryPotentialPaths() {
|
||||
}
|
||||
|
||||
// Search for canonical places.
|
||||
for (const std::string& version : kGurobiVersions) {
|
||||
const std::string lib = version.substr(0, version.size() - 1);
|
||||
for (const absl::string_view version : kGurobiVersions) {
|
||||
const absl::string_view lib = version.substr(0, version.size() - 1);
|
||||
#if defined(_MSC_VER) // Windows
|
||||
potential_paths.push_back(absl::StrCat("C:\\Program Files\\gurobi", version,
|
||||
"\\win64\\bin\\gurobi", lib,
|
||||
@@ -407,7 +407,7 @@ std::vector<std::string> GurobiDynamicLibraryPotentialPaths() {
|
||||
}
|
||||
|
||||
#if defined(__GNUC__) // path in linux64 gurobi/optimizer docker image.
|
||||
for (const std::string& version :
|
||||
for (const absl::string_view version :
|
||||
{"12.0.1", "12.0.0", "11.0.3", "11.0.2", "11.0.1", "11.0.0", "10.0.3",
|
||||
"10.0.2", "10.0.1", "10.0.0", "9.5.2", "9.5.1", "9.5.0"}) {
|
||||
potential_paths.push_back(
|
||||
@@ -418,7 +418,7 @@ std::vector<std::string> GurobiDynamicLibraryPotentialPaths() {
|
||||
}
|
||||
|
||||
absl::Status LoadGurobiDynamicLibrary(
|
||||
std::vector<std::string> potential_paths) {
|
||||
std::vector<absl::string_view> potential_paths) {
|
||||
static std::once_flag gurobi_loading_done;
|
||||
static absl::Status gurobi_load_status;
|
||||
static DynamicLibrary gurobi_library;
|
||||
@@ -431,7 +431,7 @@ absl::Status LoadGurobiDynamicLibrary(
|
||||
GurobiDynamicLibraryPotentialPaths();
|
||||
potential_paths.insert(potential_paths.end(), canonical_paths.begin(),
|
||||
canonical_paths.end());
|
||||
for (const std::string& path : potential_paths) {
|
||||
for (const absl::string_view path : potential_paths) {
|
||||
if (gurobi_library.TryToLoad(path)) {
|
||||
LOG(INFO) << "Found the Gurobi library in '" << path << ".";
|
||||
break;
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "ortools/base/dynamic_library.h"
|
||||
#include "ortools/base/logging.h"
|
||||
|
||||
@@ -52,7 +53,7 @@ bool GurobiIsCorrectlyInstalled();
|
||||
// Successive calls are no-op.
|
||||
//
|
||||
// Note that it does not check if a token license can be grabbed.
|
||||
absl::Status LoadGurobiDynamicLibrary(std::vector<std::string> potential_paths);
|
||||
absl::Status LoadGurobiDynamicLibrary(std::vector<absl::string_view> potential_paths);
|
||||
|
||||
// The list of #define and extern std::function<> below is generated directly
|
||||
// from gurobi_c.h via parse_header.py
|
||||
|
||||
@@ -20,12 +20,13 @@
|
||||
#include "absl/flags/usage.h"
|
||||
#include "absl/log/globals.h"
|
||||
#include "absl/log/initialize.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "ortools/gurobi/environment.h"
|
||||
#include "ortools/sat/cp_model_solver.h"
|
||||
#include "ortools/sat/cp_model_solver_helpers.h"
|
||||
|
||||
namespace operations_research {
|
||||
void CppBridge::InitLogging(const std::string& usage) {
|
||||
void CppBridge::InitLogging(absl::string_view usage) {
|
||||
absl::SetProgramUsageMessage(usage);
|
||||
absl::InitializeLog();
|
||||
}
|
||||
@@ -41,7 +42,7 @@ void CppBridge::SetFlags(const CppFlags& flags) {
|
||||
absl::SetFlag(&FLAGS_cp_model_dump_response, flags.cp_model_dump_response);
|
||||
}
|
||||
|
||||
bool CppBridge::LoadGurobiSharedLibrary(const std::string& full_library_path) {
|
||||
bool CppBridge::LoadGurobiSharedLibrary(absl::string_view full_library_path) {
|
||||
return LoadGurobiDynamicLibrary({full_library_path}).ok();
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "ortools/base/logging.h"
|
||||
#include "ortools/base/version.h"
|
||||
#include "ortools/sat/cp_model_solver_helpers.h"
|
||||
@@ -86,7 +87,7 @@ class CppBridge {
|
||||
*
|
||||
* This must be called once before any other library from OR-Tools are used.
|
||||
*/
|
||||
static void InitLogging(const std::string& usage);
|
||||
static void InitLogging(absl::string_view usage);
|
||||
|
||||
/**
|
||||
* Shutdown the C++ logging layer.
|
||||
@@ -111,7 +112,7 @@ class CppBridge {
|
||||
* You need to pass the full path, including the shared library file.
|
||||
* It returns true if the library was found and correctly loaded.
|
||||
*/
|
||||
static bool LoadGurobiSharedLibrary(const std::string& full_library_path);
|
||||
static bool LoadGurobiSharedLibrary(absl::string_view full_library_path);
|
||||
|
||||
/**
|
||||
* Delete a temporary C++ byte array.
|
||||
|
||||
@@ -26,11 +26,11 @@ import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.SimpleFileVisitor;
|
||||
import java.nio.file.attribute.BasicFileAttributes;
|
||||
import java.util.AbstractMap;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.AbstractMap;
|
||||
import java.util.Objects;
|
||||
|
||||
/** Load native libraries needed for using ortools-java. */
|
||||
@@ -144,28 +144,28 @@ public class Loader {
|
||||
URI resourceURI = getNativeResourceURI();
|
||||
Path tempPath = unpackNativeResources(resourceURI);
|
||||
// libraries order does matter <LibraryName, isMandatory> !
|
||||
List<Map.Entry<String,Boolean>> dlls = Arrays.asList(
|
||||
(new AbstractMap.SimpleEntry("zlib1", true)),
|
||||
(new AbstractMap.SimpleEntry("abseil_dll", true)),
|
||||
(new AbstractMap.SimpleEntry("re2", true)),
|
||||
(new AbstractMap.SimpleEntry("libutf8_validity", true)),
|
||||
(new AbstractMap.SimpleEntry("libprotobuf", true)),
|
||||
(new AbstractMap.SimpleEntry("highs", false)),
|
||||
(new AbstractMap.SimpleEntry("libscip", false)),
|
||||
(new AbstractMap.SimpleEntry("ortools", true)),
|
||||
(new AbstractMap.SimpleEntry("jniortools", true)));
|
||||
List<Map.Entry<String, Boolean>> dlls =
|
||||
Arrays.asList((new AbstractMap.SimpleEntry("zlib1", true)),
|
||||
(new AbstractMap.SimpleEntry("abseil_dll", true)),
|
||||
(new AbstractMap.SimpleEntry("re2", true)),
|
||||
(new AbstractMap.SimpleEntry("libutf8_validity", true)),
|
||||
(new AbstractMap.SimpleEntry("libprotobuf", true)),
|
||||
(new AbstractMap.SimpleEntry("highs", false)),
|
||||
(new AbstractMap.SimpleEntry("libscip", false)),
|
||||
(new AbstractMap.SimpleEntry("ortools", true)),
|
||||
(new AbstractMap.SimpleEntry("jniortools", true)));
|
||||
|
||||
for (Map.Entry<String,Boolean> dll : dlls) {
|
||||
for (Map.Entry<String, Boolean> dll : dlls) {
|
||||
try {
|
||||
//System.out.println("System.load(" + dll.getKey() + ")");
|
||||
// System.out.println("System.load(" + dll.getKey() + ")");
|
||||
System.load(tempPath.resolve(RESOURCE_PATH)
|
||||
.resolve(System.mapLibraryName(dll.getKey()))
|
||||
.toAbsolutePath()
|
||||
.toString());
|
||||
} catch (UnsatisfiedLinkError e) {
|
||||
System.out.println("System.load(" + dll.getKey() + ") failed!");
|
||||
if(dll.getValue()) {
|
||||
throw new RuntimeException(e);
|
||||
if (dll.getValue()) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,8 +126,6 @@ CP solver built on top of the SAT solver:
|
||||
Propagation algorithms for the cumulative scheduling constraint.
|
||||
* [cumulative_energy.h](../sat/cumulative_energy.h):
|
||||
Propagation algorithms for a more general cumulative constraint.
|
||||
* [theta_tree.h](../sat/theta_tree.h):
|
||||
Data structure used in the cumulative/disjunctive propagation algorithm.
|
||||
|
||||
### Packing constraints
|
||||
|
||||
|
||||
@@ -1876,46 +1876,18 @@ TEST(LinMaxExpansionTest, GoldenTest) {
|
||||
variables { domain: 0 domain: 1 }
|
||||
constraints {}
|
||||
constraints {
|
||||
linear {
|
||||
vars: 0
|
||||
vars: 1
|
||||
coeffs: 1
|
||||
coeffs: -2
|
||||
domain: -1
|
||||
domain: 9223372036854775806
|
||||
}
|
||||
linear { vars: 0 vars: 1 coeffs: 1 coeffs: -2 domain: -1 domain: 5 }
|
||||
}
|
||||
constraints {
|
||||
linear {
|
||||
vars: 0
|
||||
vars: 2
|
||||
coeffs: 1
|
||||
coeffs: -1
|
||||
domain: -4
|
||||
domain: 9223372036854775803
|
||||
}
|
||||
linear { vars: 0 vars: 2 coeffs: 1 coeffs: -1 domain: -4 domain: 5 }
|
||||
}
|
||||
constraints {
|
||||
enforcement_literal: 3
|
||||
linear {
|
||||
vars: 0
|
||||
vars: 1
|
||||
coeffs: 1
|
||||
coeffs: -2
|
||||
domain: -9223372036854775808
|
||||
domain: -1
|
||||
}
|
||||
linear { vars: 0 vars: 1 coeffs: 1 coeffs: -2 domain: -10 domain: -1 }
|
||||
}
|
||||
constraints {
|
||||
enforcement_literal: -4
|
||||
linear {
|
||||
vars: 0
|
||||
vars: 2
|
||||
coeffs: 1
|
||||
coeffs: -1
|
||||
domain: -9223372036854775808
|
||||
domain: -4
|
||||
}
|
||||
linear { vars: 0 vars: 2 coeffs: 1 coeffs: -1 domain: -6 domain: -4 }
|
||||
}
|
||||
)pb");
|
||||
EXPECT_THAT(initial_model, testing::EqualsProto(expected_model));
|
||||
|
||||
@@ -2657,10 +2657,18 @@ bool PresolveContext::CanonicalizeLinearConstraint(ConstraintProto* ct) {
|
||||
const bool result = CanonicalizeLinearExpressionInternal(
|
||||
ct->enforcement_literal(), ct->mutable_linear(), &offset, &tmp_terms_,
|
||||
this);
|
||||
if (offset != 0) {
|
||||
FillDomainInProto(
|
||||
ReadDomainFromProto(ct->linear()).AdditionWith(Domain(-offset)),
|
||||
ct->mutable_linear());
|
||||
const auto [min_activity, max_activity] = ComputeMinMaxActivity(ct->linear());
|
||||
const Domain implied = Domain(min_activity, max_activity);
|
||||
const Domain original_domain =
|
||||
ReadDomainFromProto(ct->linear()).AdditionWith(Domain(-offset));
|
||||
const Domain tight_domain = implied.IntersectionWith(original_domain);
|
||||
if (tight_domain.IsEmpty()) {
|
||||
// Canonicalization is not the right place to handle unsat constraints.
|
||||
// Let's just replace the domain by one that is overflow-safe.
|
||||
const Domain bad_domain = Domain(implied.Max() + 1, implied.Max() + 2);
|
||||
FillDomainInProto(bad_domain, ct->mutable_linear());
|
||||
} else {
|
||||
FillDomainInProto(tight_domain, ct->mutable_linear());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -1053,7 +1053,7 @@ TEST(PresolveContextTest, CanonicalizeLinearConstraint) {
|
||||
linear {
|
||||
vars: [ 0, 1, 2 ]
|
||||
coeffs: [ -2, 2, -2 ]
|
||||
domain: [ 0, 1000 ]
|
||||
domain: [ 0, 16 ]
|
||||
}
|
||||
)pb");
|
||||
EXPECT_THAT(working_model.constraints(0), testing::EqualsProto(expected));
|
||||
|
||||
@@ -575,7 +575,7 @@ PYBIND11_MODULE(set_cover, m) {
|
||||
return {cleared.begin(), cleared.end()};
|
||||
});
|
||||
m.def("clear_random_subsets",
|
||||
[](const std::vector<BaseInt>& focus, BaseInt num_subsets,
|
||||
[](absl::Span<const BaseInt> focus, BaseInt num_subsets,
|
||||
SetCoverInvariant* inv) -> std::vector<BaseInt> {
|
||||
const std::vector<SubsetIndex> cleared = ClearRandomSubsets(
|
||||
VectorIntToVectorSubsetIndex(focus), num_subsets, inv);
|
||||
|
||||
@@ -101,62 +101,66 @@ class ThetaLambdaTree {
|
||||
// allows to keep the same memory for each call.
|
||||
void Reset(int num_events);
|
||||
|
||||
// Recomputes the values of internal nodes of the tree from the values in the
|
||||
// leaves. We enable batching modifications to the tree by providing
|
||||
// DelayedXXX() methods that run in O(1), but those methods do not
|
||||
// update internal nodes. This breaks tree invariants, so that GetXXX()
|
||||
// methods will not reflect modifications made to events.
|
||||
// RecomputeTreeForDelayedOperations() restores those invariants in O(n).
|
||||
// Thus, batching operations can be done by first doing calls to DelayedXXX()
|
||||
// methods, then calling RecomputeTreeForDelayedOperations() once.
|
||||
void RecomputeTreeForDelayedOperations();
|
||||
|
||||
// Makes event present and updates its initial envelope and min/max energies.
|
||||
// The initial_envelope must be >= ThetaLambdaTreeNegativeInfinity().
|
||||
// This updates the tree in O(log n).
|
||||
void AddOrUpdateEvent(int event, IntegerType initial_envelope,
|
||||
IntegerType energy_min, IntegerType energy_max);
|
||||
|
||||
// Delayed version of AddOrUpdateEvent(),
|
||||
// see RecomputeTreeForDelayedOperations().
|
||||
void DelayedAddOrUpdateEvent(int event, IntegerType initial_envelope,
|
||||
IntegerType energy_min, IntegerType energy_max);
|
||||
IntegerType energy_min, IntegerType energy_max) {
|
||||
DCHECK_LE(0, energy_min);
|
||||
DCHECK_LE(energy_min, energy_max);
|
||||
const int node = GetLeafFromEvent(event);
|
||||
tree_[node] = {.envelope = initial_envelope + energy_min,
|
||||
.envelope_opt = initial_envelope + energy_max,
|
||||
.sum_of_energy_min = energy_min,
|
||||
.max_of_energy_delta = energy_max - energy_min};
|
||||
RefreshNode(node);
|
||||
}
|
||||
|
||||
// Adds event to the lambda part of the tree only.
|
||||
// This will leave GetEnvelope() unchanged, only GetOptionalEnvelope() can
|
||||
// be affected. This is done by setting envelope to IntegerTypeMinimumValue(),
|
||||
// be affected, by setting envelope to std::numeric_limits<>::min(),
|
||||
// energy_min to 0, and initial_envelope_opt and energy_max to the parameters.
|
||||
// This updates the tree in O(log n).
|
||||
void AddOrUpdateOptionalEvent(int event, IntegerType initial_envelope_opt,
|
||||
IntegerType energy_max);
|
||||
|
||||
// Delayed version of AddOrUpdateOptionalEvent(),
|
||||
// see RecomputeTreeForDelayedOperations().
|
||||
void DelayedAddOrUpdateOptionalEvent(int event,
|
||||
IntegerType initial_envelope_opt,
|
||||
IntegerType energy_max);
|
||||
IntegerType energy_max) {
|
||||
DCHECK_LE(0, energy_max);
|
||||
const int node = GetLeafFromEvent(event);
|
||||
tree_[node] = {.envelope = std::numeric_limits<IntegerType>::min(),
|
||||
.envelope_opt = initial_envelope_opt + energy_max,
|
||||
.sum_of_energy_min = IntegerType{0},
|
||||
.max_of_energy_delta = energy_max};
|
||||
RefreshNode(node);
|
||||
}
|
||||
|
||||
// Makes event absent, compute the new envelope in O(log n).
|
||||
void RemoveEvent(int event);
|
||||
|
||||
// Delayed version of RemoveEvent(), see RecomputeTreeForDelayedOperations().
|
||||
void DelayedRemoveEvent(int event);
|
||||
void RemoveEvent(int event) {
|
||||
const int node = GetLeafFromEvent(event);
|
||||
tree_[node] = {.envelope = std::numeric_limits<IntegerType>::min(),
|
||||
.envelope_opt = std::numeric_limits<IntegerType>::min(),
|
||||
.sum_of_energy_min = IntegerType{0},
|
||||
.max_of_energy_delta = IntegerType{0}};
|
||||
RefreshNode(node);
|
||||
}
|
||||
|
||||
// Returns the maximum envelope using all the energy_min in O(1).
|
||||
// If theta is empty, returns ThetaLambdaTreeNegativeInfinity().
|
||||
IntegerType GetEnvelope() const;
|
||||
// If theta is empty, returns std::numeric_limits<>::min().
|
||||
IntegerType GetEnvelope() const { return tree_[1].envelope; }
|
||||
|
||||
// Returns the maximum envelope using the energy min of all task but
|
||||
// one and the energy max of the last one in O(1).
|
||||
// If theta and lambda are empty, returns ThetaLambdaTreeNegativeInfinity().
|
||||
IntegerType GetOptionalEnvelope() const;
|
||||
// If theta and lambda are empty, returns std::numeric_limits<>::min().
|
||||
IntegerType GetOptionalEnvelope() const { return tree_[1].envelope_opt; }
|
||||
|
||||
// Computes the maximum event s.t. GetEnvelopeOf(event) > envelope_max.
|
||||
// There must be such an event, i.e. GetEnvelope() > envelope_max.
|
||||
// This finds the maximum event e such that
|
||||
// initial_envelope(e) + sum_{e' >= e} energy_min(e') > target_envelope.
|
||||
// This operation is O(log n).
|
||||
int GetMaxEventWithEnvelopeGreaterThan(IntegerType target_envelope) const;
|
||||
int GetMaxEventWithEnvelopeGreaterThan(IntegerType target_envelope) const {
|
||||
DCHECK_LT(target_envelope, tree_[1].envelope);
|
||||
IntegerType unused;
|
||||
return GetEventFromLeaf(
|
||||
GetMaxLeafWithEnvelopeGreaterThan(1, target_envelope, &unused));
|
||||
}
|
||||
|
||||
// Returns initial_envelope(event) + sum_{event' >= event} energy_min(event'),
|
||||
// in time O(log n).
|
||||
@@ -181,7 +185,14 @@ class ThetaLambdaTree {
|
||||
// This operation is O(log n).
|
||||
void GetEventsWithOptionalEnvelopeGreaterThan(
|
||||
IntegerType target_envelope, int* critical_event, int* optional_event,
|
||||
IntegerType* available_energy) const;
|
||||
IntegerType* available_energy) const {
|
||||
int critical_leaf;
|
||||
int optional_leaf;
|
||||
GetLeavesWithOptionalEnvelopeGreaterThan(target_envelope, &critical_leaf,
|
||||
&optional_leaf, available_energy);
|
||||
*critical_event = GetEventFromLeaf(critical_leaf);
|
||||
*optional_event = GetEventFromLeaf(optional_leaf);
|
||||
}
|
||||
|
||||
// Getters.
|
||||
IntegerType EnergyMin(int event) const {
|
||||
@@ -196,10 +207,36 @@ class ThetaLambdaTree {
|
||||
IntegerType max_of_energy_delta;
|
||||
};
|
||||
|
||||
TreeNode ComposeTreeNodes(const TreeNode& left, const TreeNode& right);
|
||||
TreeNode ComposeTreeNodes(const TreeNode& left, const TreeNode& right) {
|
||||
return {
|
||||
.envelope =
|
||||
std::max(right.envelope, left.envelope + right.sum_of_energy_min),
|
||||
.envelope_opt =
|
||||
std::max(right.envelope_opt,
|
||||
right.sum_of_energy_min +
|
||||
std::max(left.envelope_opt,
|
||||
left.envelope + right.max_of_energy_delta)),
|
||||
.sum_of_energy_min = left.sum_of_energy_min + right.sum_of_energy_min,
|
||||
.max_of_energy_delta =
|
||||
std::max(right.max_of_energy_delta, left.max_of_energy_delta)};
|
||||
}
|
||||
|
||||
int GetLeafFromEvent(int event) const;
|
||||
int GetEventFromLeaf(int leaf) const;
|
||||
int GetLeafFromEvent(int event) const {
|
||||
DCHECK_LE(0, event);
|
||||
DCHECK_LT(event, num_events_);
|
||||
// Keeping the ordering of events is important, so the first set of events
|
||||
// must be mapped to the set of leaves at depth d, and the second set of
|
||||
// events must be mapped to the set of leaves at depth d-1.
|
||||
const int r = power_of_two_ + event;
|
||||
return r < 2 * num_leaves_ ? r : r - num_leaves_;
|
||||
}
|
||||
|
||||
int GetEventFromLeaf(int leaf) const {
|
||||
DCHECK_GE(leaf, num_leaves_);
|
||||
DCHECK_LT(leaf, 2 * num_leaves_);
|
||||
const int r = leaf - power_of_two_;
|
||||
return r >= 0 ? r : r + num_leaves_;
|
||||
}
|
||||
|
||||
// Propagates the change of leaf energies and envelopes towards the root.
|
||||
void RefreshNode(int node);
|
||||
@@ -225,32 +262,12 @@ class ThetaLambdaTree {
|
||||
int num_leaves_;
|
||||
int power_of_two_;
|
||||
|
||||
// A bool used in debug mode, to check that sequences of delayed operations
|
||||
// are ended by Reset() or RecomputeTreeForDelayedOperations().
|
||||
bool leaf_nodes_have_delayed_operations_ = false;
|
||||
|
||||
// Envelopes and energies of nodes.
|
||||
std::vector<TreeNode> tree_;
|
||||
};
|
||||
|
||||
template <typename IntegerType>
|
||||
typename ThetaLambdaTree<IntegerType>::TreeNode
|
||||
ThetaLambdaTree<IntegerType>::ComposeTreeNodes(const TreeNode& left,
|
||||
const TreeNode& right) {
|
||||
return {std::max(right.envelope, left.envelope + right.sum_of_energy_min),
|
||||
std::max(right.envelope_opt,
|
||||
right.sum_of_energy_min +
|
||||
std::max(left.envelope_opt,
|
||||
left.envelope + right.max_of_energy_delta)),
|
||||
left.sum_of_energy_min + right.sum_of_energy_min,
|
||||
std::max(right.max_of_energy_delta, left.max_of_energy_delta)};
|
||||
}
|
||||
|
||||
template <typename IntegerType>
|
||||
void ThetaLambdaTree<IntegerType>::Reset(int num_events) {
|
||||
#ifndef NDEBUG
|
||||
leaf_nodes_have_delayed_operations_ = false;
|
||||
#endif
|
||||
// Because the algorithm needs to access a node sibling (i.e. node_index ^ 1),
|
||||
// our tree will always have an even number of leaves, just large enough to
|
||||
// fit our number of events. And at least 2 for the empty tree case.
|
||||
@@ -258,9 +275,11 @@ void ThetaLambdaTree<IntegerType>::Reset(int num_events) {
|
||||
num_leaves_ = std::max(2, num_events + (num_events & 1));
|
||||
|
||||
const int num_nodes = 2 * num_leaves_;
|
||||
tree_.assign(num_nodes, TreeNode{std::numeric_limits<IntegerType>::min(),
|
||||
std::numeric_limits<IntegerType>::min(),
|
||||
IntegerType{0}, IntegerType{0}});
|
||||
tree_.assign(num_nodes,
|
||||
TreeNode{.envelope = std::numeric_limits<IntegerType>::min(),
|
||||
.envelope_opt = std::numeric_limits<IntegerType>::min(),
|
||||
.sum_of_energy_min = IntegerType{0},
|
||||
.max_of_energy_delta = IntegerType{0}});
|
||||
|
||||
// If num_leaves is not a power or two, the last depth of the tree will not be
|
||||
// full, and the array will look like:
|
||||
@@ -270,147 +289,8 @@ void ThetaLambdaTree<IntegerType>::Reset(int num_events) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename IntegerType>
|
||||
int ThetaLambdaTree<IntegerType>::GetLeafFromEvent(int event) const {
|
||||
DCHECK_LE(0, event);
|
||||
DCHECK_LT(event, num_events_);
|
||||
// Keeping the ordering of events is important, so the first set of events
|
||||
// must be mapped to the set of leaves at depth d, and the second set of
|
||||
// events must be mapped to the set of leaves at depth d-1.
|
||||
const int r = power_of_two_ + event;
|
||||
return r < 2 * num_leaves_ ? r : r - num_leaves_;
|
||||
}
|
||||
|
||||
template <typename IntegerType>
|
||||
int ThetaLambdaTree<IntegerType>::GetEventFromLeaf(int leaf) const {
|
||||
DCHECK_GE(leaf, num_leaves_);
|
||||
DCHECK_LT(leaf, 2 * num_leaves_);
|
||||
const int r = leaf - power_of_two_;
|
||||
return r >= 0 ? r : r + num_leaves_;
|
||||
}
|
||||
|
||||
template <typename IntegerType>
|
||||
void ThetaLambdaTree<IntegerType>::RecomputeTreeForDelayedOperations() {
|
||||
#ifndef NDEBUG
|
||||
leaf_nodes_have_delayed_operations_ = false;
|
||||
#endif
|
||||
// Only recompute internal nodes.
|
||||
const int last_internal_node = tree_.size() / 2 - 1;
|
||||
for (int node = last_internal_node; node >= 1; --node) {
|
||||
const int right = 2 * node + 1;
|
||||
const int left = 2 * node;
|
||||
tree_[node] = ComposeTreeNodes(tree_[left], tree_[right]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename IntegerType>
|
||||
void ThetaLambdaTree<IntegerType>::DelayedAddOrUpdateEvent(
|
||||
int event, IntegerType initial_envelope, IntegerType energy_min,
|
||||
IntegerType energy_max) {
|
||||
#ifndef NDEBUG
|
||||
leaf_nodes_have_delayed_operations_ = true;
|
||||
#endif
|
||||
DCHECK_LE(0, energy_min);
|
||||
DCHECK_LE(energy_min, energy_max);
|
||||
const int node = GetLeafFromEvent(event);
|
||||
tree_[node] = {initial_envelope + energy_min, initial_envelope + energy_max,
|
||||
energy_min, energy_max - energy_min};
|
||||
}
|
||||
|
||||
template <typename IntegerType>
|
||||
void ThetaLambdaTree<IntegerType>::AddOrUpdateEvent(
|
||||
int event, IntegerType initial_envelope, IntegerType energy_min,
|
||||
IntegerType energy_max) {
|
||||
DCHECK(!leaf_nodes_have_delayed_operations_);
|
||||
DCHECK_LE(0, energy_min);
|
||||
DCHECK_LE(energy_min, energy_max);
|
||||
const int node = GetLeafFromEvent(event);
|
||||
tree_[node] = {initial_envelope + energy_min, initial_envelope + energy_max,
|
||||
energy_min, energy_max - energy_min};
|
||||
RefreshNode(node);
|
||||
}
|
||||
|
||||
template <typename IntegerType>
|
||||
void ThetaLambdaTree<IntegerType>::AddOrUpdateOptionalEvent(
|
||||
int event, IntegerType initial_envelope_opt, IntegerType energy_max) {
|
||||
DCHECK(!leaf_nodes_have_delayed_operations_);
|
||||
DCHECK_LE(0, energy_max);
|
||||
const int node = GetLeafFromEvent(event);
|
||||
tree_[node] = {std::numeric_limits<IntegerType>::min(),
|
||||
initial_envelope_opt + energy_max, IntegerType{0}, energy_max};
|
||||
RefreshNode(node);
|
||||
}
|
||||
|
||||
template <typename IntegerType>
|
||||
void ThetaLambdaTree<IntegerType>::DelayedAddOrUpdateOptionalEvent(
|
||||
int event, IntegerType initial_envelope_opt, IntegerType energy_max) {
|
||||
#ifndef NDEBUG
|
||||
leaf_nodes_have_delayed_operations_ = true;
|
||||
#endif
|
||||
DCHECK_LE(0, energy_max);
|
||||
const int node = GetLeafFromEvent(event);
|
||||
tree_[node] = {std::numeric_limits<IntegerType>::min(),
|
||||
initial_envelope_opt + energy_max, IntegerType{0}, energy_max};
|
||||
}
|
||||
|
||||
template <typename IntegerType>
|
||||
void ThetaLambdaTree<IntegerType>::RemoveEvent(int event) {
|
||||
DCHECK(!leaf_nodes_have_delayed_operations_);
|
||||
const int node = GetLeafFromEvent(event);
|
||||
tree_[node] = {std::numeric_limits<IntegerType>::min(),
|
||||
std::numeric_limits<IntegerType>::min(), IntegerType{0},
|
||||
IntegerType{0}};
|
||||
RefreshNode(node);
|
||||
}
|
||||
|
||||
template <typename IntegerType>
|
||||
void ThetaLambdaTree<IntegerType>::DelayedRemoveEvent(int event) {
|
||||
#ifndef NDEBUG
|
||||
leaf_nodes_have_delayed_operations_ = true;
|
||||
#endif
|
||||
const int node = GetLeafFromEvent(event);
|
||||
tree_[node] = {std::numeric_limits<IntegerType>::min(),
|
||||
std::numeric_limits<IntegerType>::min(), IntegerType{0},
|
||||
IntegerType{0}};
|
||||
}
|
||||
|
||||
template <typename IntegerType>
|
||||
IntegerType ThetaLambdaTree<IntegerType>::GetEnvelope() const {
|
||||
DCHECK(!leaf_nodes_have_delayed_operations_);
|
||||
return tree_[1].envelope;
|
||||
}
|
||||
template <typename IntegerType>
|
||||
IntegerType ThetaLambdaTree<IntegerType>::GetOptionalEnvelope() const {
|
||||
DCHECK(!leaf_nodes_have_delayed_operations_);
|
||||
return tree_[1].envelope_opt;
|
||||
}
|
||||
|
||||
template <typename IntegerType>
|
||||
int ThetaLambdaTree<IntegerType>::GetMaxEventWithEnvelopeGreaterThan(
|
||||
IntegerType target_envelope) const {
|
||||
DCHECK(!leaf_nodes_have_delayed_operations_);
|
||||
DCHECK_LT(target_envelope, tree_[1].envelope);
|
||||
IntegerType unused;
|
||||
return GetEventFromLeaf(
|
||||
GetMaxLeafWithEnvelopeGreaterThan(1, target_envelope, &unused));
|
||||
}
|
||||
|
||||
template <typename IntegerType>
|
||||
void ThetaLambdaTree<IntegerType>::GetEventsWithOptionalEnvelopeGreaterThan(
|
||||
IntegerType target_envelope, int* critical_event, int* optional_event,
|
||||
IntegerType* available_energy) const {
|
||||
DCHECK(!leaf_nodes_have_delayed_operations_);
|
||||
int critical_leaf;
|
||||
int optional_leaf;
|
||||
GetLeavesWithOptionalEnvelopeGreaterThan(target_envelope, &critical_leaf,
|
||||
&optional_leaf, available_energy);
|
||||
*critical_event = GetEventFromLeaf(critical_leaf);
|
||||
*optional_event = GetEventFromLeaf(optional_leaf);
|
||||
}
|
||||
|
||||
template <typename IntegerType>
|
||||
IntegerType ThetaLambdaTree<IntegerType>::GetEnvelopeOf(int event) const {
|
||||
DCHECK(!leaf_nodes_have_delayed_operations_);
|
||||
const int leaf = GetLeafFromEvent(event);
|
||||
IntegerType envelope = tree_[leaf].envelope;
|
||||
for (int node = leaf; node > 1; node >>= 1) {
|
||||
@@ -434,7 +314,6 @@ void ThetaLambdaTree<IntegerType>::RefreshNode(int node) {
|
||||
template <typename IntegerType>
|
||||
int ThetaLambdaTree<IntegerType>::GetMaxLeafWithEnvelopeGreaterThan(
|
||||
int node, IntegerType target_envelope, IntegerType* extra) const {
|
||||
DCHECK(!leaf_nodes_have_delayed_operations_);
|
||||
DCHECK_LT(target_envelope, tree_[node].envelope);
|
||||
while (node < num_leaves_) {
|
||||
const int left = node << 1;
|
||||
@@ -454,7 +333,6 @@ int ThetaLambdaTree<IntegerType>::GetMaxLeafWithEnvelopeGreaterThan(
|
||||
|
||||
template <typename IntegerType>
|
||||
int ThetaLambdaTree<IntegerType>::GetLeafWithMaxEnergyDelta(int node) const {
|
||||
DCHECK(!leaf_nodes_have_delayed_operations_);
|
||||
const IntegerType delta_node = tree_[node].max_of_energy_delta;
|
||||
while (node < num_leaves_) {
|
||||
const int left = node << 1;
|
||||
@@ -474,7 +352,6 @@ template <typename IntegerType>
|
||||
void ThetaLambdaTree<IntegerType>::GetLeavesWithOptionalEnvelopeGreaterThan(
|
||||
IntegerType target_envelope, int* critical_leaf, int* optional_leaf,
|
||||
IntegerType* available_energy) const {
|
||||
DCHECK(!leaf_nodes_have_delayed_operations_);
|
||||
DCHECK_LT(target_envelope, tree_[1].envelope_opt);
|
||||
int node = 1;
|
||||
while (node < num_leaves_) {
|
||||
|
||||
Reference in New Issue
Block a user