graph: export from google3
This commit is contained in:
@@ -766,6 +766,8 @@ cc_test(
|
||||
":io",
|
||||
"//ortools/base:dump_vars",
|
||||
"//ortools/base:gmock_main",
|
||||
"//ortools/base:intops",
|
||||
"//ortools/base:strong_vector",
|
||||
"//ortools/util:flat_matrix",
|
||||
"@abseil-cpp//absl/algorithm:container",
|
||||
"@abseil-cpp//absl/log:check",
|
||||
|
||||
@@ -81,9 +81,6 @@ PathWithLength ConstrainedShortestPathsOnDag(
|
||||
// Advanced API.
|
||||
// -----------------------------------------------------------------------------
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
struct GraphPathWithLength {
|
||||
double length = 0.0;
|
||||
// The returned arc indices points into the `arcs_with_length` passed to the
|
||||
@@ -97,9 +94,6 @@ struct GraphPathWithLength {
|
||||
// computations efficiently on the given DAG (on which resources do not change).
|
||||
// `GraphType` can use one of the interfaces defined in `util/graph/graph.h`.
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
class ConstrainedShortestPathsOnDagWrapper {
|
||||
public:
|
||||
using NodeIndex = typename GraphType::NodeIndex;
|
||||
@@ -285,9 +279,6 @@ std::vector<T> GetInversePermutation(const std::vector<T>& permutation);
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
ConstrainedShortestPathsOnDagWrapper<GraphType>::
|
||||
ConstrainedShortestPathsOnDagWrapper(
|
||||
const GraphType* graph, const std::vector<double>* arc_lengths,
|
||||
@@ -543,9 +534,6 @@ ConstrainedShortestPathsOnDagWrapper<GraphType>::
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
GraphPathWithLength<GraphType> ConstrainedShortestPathsOnDagWrapper<
|
||||
GraphType>::RunConstrainedShortestPathOnDag() {
|
||||
if (source_is_destination_.has_value()) {
|
||||
@@ -664,9 +652,6 @@ GraphPathWithLength<GraphType> ConstrainedShortestPathsOnDagWrapper<
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
void ConstrainedShortestPathsOnDagWrapper<GraphType>::
|
||||
RunHalfConstrainedShortestPathOnDag(
|
||||
const GraphType& reverse_graph, absl::Span<const double> arc_lengths,
|
||||
@@ -792,9 +777,6 @@ void ConstrainedShortestPathsOnDagWrapper<GraphType>::
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
typename GraphType::ArcIndex
|
||||
ConstrainedShortestPathsOnDagWrapper<GraphType>::MergeHalfRuns(
|
||||
const GraphType& graph, absl::Span<const double> arc_lengths,
|
||||
@@ -879,9 +861,6 @@ ConstrainedShortestPathsOnDagWrapper<GraphType>::MergeHalfRuns(
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
std::vector<typename GraphType::ArcIndex>
|
||||
ConstrainedShortestPathsOnDagWrapper<GraphType>::ArcPathTo(
|
||||
const int best_label_index,
|
||||
@@ -901,9 +880,6 @@ ConstrainedShortestPathsOnDagWrapper<GraphType>::ArcPathTo(
|
||||
}
|
||||
|
||||
template <typename GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
std::vector<typename GraphType::NodeIndex>
|
||||
ConstrainedShortestPathsOnDagWrapper<GraphType>::NodePathImpliedBy(
|
||||
absl::Span<const ArcIndex> arc_path, const GraphType& graph) const {
|
||||
|
||||
@@ -15,9 +15,7 @@
|
||||
#define OR_TOOLS_GRAPH_DAG_SHORTEST_PATH_H_
|
||||
|
||||
#include <cmath>
|
||||
#if __cplusplus >= 202002L
|
||||
#include <concepts>
|
||||
#endif
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
@@ -82,39 +80,17 @@ std::vector<PathWithLength> KShortestPathsOnDag(
|
||||
// -----------------------------------------------------------------------------
|
||||
// Advanced API.
|
||||
// -----------------------------------------------------------------------------
|
||||
// This concept only enforces the standard graph API needed for all algorithms
|
||||
// on DAGs. One could add the requirement of being a DAG wihtin this concept
|
||||
// (which is done before running the algorithm).
|
||||
#if __cplusplus >= 202002L
|
||||
template <class GraphType>
|
||||
concept DagGraphType = requires(GraphType graph) {
|
||||
{ typename GraphType::NodeIndex{} };
|
||||
{ typename GraphType::ArcIndex{} };
|
||||
{ graph.num_nodes() } -> std::same_as<typename GraphType::NodeIndex>;
|
||||
{ graph.num_arcs() } -> std::same_as<typename GraphType::ArcIndex>;
|
||||
{ graph.OutgoingArcs(typename GraphType::NodeIndex{}) };
|
||||
{
|
||||
graph.Tail(typename GraphType::ArcIndex{})
|
||||
} -> std::same_as<typename GraphType::NodeIndex>;
|
||||
{
|
||||
graph.Head(typename GraphType::ArcIndex{})
|
||||
} -> std::same_as<typename GraphType::NodeIndex>;
|
||||
{ graph.Build() };
|
||||
};
|
||||
#endif
|
||||
|
||||
// A wrapper that holds the memory needed to run many shortest path computations
|
||||
// efficiently on the given DAG. One call of `RunShortestPathOnDag()` has time
|
||||
// complexity O(|E| + |V|) and space complexity O(|V|).
|
||||
// `GraphType` can use any of the interfaces defined in `util/graph/graph.h`.
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
// `ArcLengthContainer` can be any container of doubles.
|
||||
template <class GraphType, typename ArcLengthContainer = std::vector<double>>
|
||||
class ShortestPathsOnDagWrapper {
|
||||
public:
|
||||
using NodeIndex = typename GraphType::NodeIndex;
|
||||
using ArcIndex = typename GraphType::ArcIndex;
|
||||
using ArcLengths = ArcLengthContainer;
|
||||
|
||||
// IMPORTANT: All arguments must outlive the class.
|
||||
//
|
||||
@@ -138,7 +114,7 @@ class ShortestPathsOnDagWrapper {
|
||||
// so will obviously invalidate the result API of the last shortest path run,
|
||||
// which could return an upper bound, junk, or crash.
|
||||
ShortestPathsOnDagWrapper(const GraphType* graph,
|
||||
const std::vector<double>* arc_lengths,
|
||||
const ArcLengths* arc_lengths,
|
||||
absl::Span<const NodeIndex> topological_order);
|
||||
|
||||
// Computes the shortest path to all reachable nodes from the given sources.
|
||||
@@ -151,7 +127,9 @@ class ShortestPathsOnDagWrapper {
|
||||
const std::vector<NodeIndex>& reached_nodes() const { return reached_nodes_; }
|
||||
|
||||
// Returns the length of the shortest path from `node`'s source to `node`.
|
||||
double LengthTo(NodeIndex node) const { return length_from_sources_[node]; }
|
||||
double LengthTo(NodeIndex node) const {
|
||||
return length_from_sources_[static_cast<size_t>(node)];
|
||||
}
|
||||
std::vector<double> LengthTo() const { return length_from_sources_; }
|
||||
|
||||
// Returns the list of all the arcs in the shortest path from `node`'s
|
||||
@@ -164,12 +142,12 @@ class ShortestPathsOnDagWrapper {
|
||||
|
||||
// Accessors to the underlying graph and arc lengths.
|
||||
const GraphType& graph() const { return *graph_; }
|
||||
const std::vector<double>& arc_lengths() const { return *arc_lengths_; }
|
||||
const ArcLengths& arc_lengths() const { return *arc_lengths_; }
|
||||
|
||||
private:
|
||||
static constexpr double kInf = std::numeric_limits<double>::infinity();
|
||||
const GraphType* const graph_;
|
||||
const std::vector<double>* const arc_lengths_;
|
||||
const ArcLengths* const arc_lengths_;
|
||||
absl::Span<const NodeIndex> const topological_order_;
|
||||
|
||||
// Data about the last call of the RunShortestPathOnDag() function.
|
||||
@@ -185,14 +163,12 @@ class ShortestPathsOnDagWrapper {
|
||||
// `GraphType` can use any of the interfaces defined in `util/graph/graph.h`.
|
||||
// IMPORTANT: Only use if `path_count > 1` (k > 1) otherwise use
|
||||
// `ShortestPathsOnDagWrapper`.
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
template <class GraphType, typename ArcLengthContainer = std::vector<double>>
|
||||
class KShortestPathsOnDagWrapper {
|
||||
public:
|
||||
using NodeIndex = typename GraphType::NodeIndex;
|
||||
using ArcIndex = typename GraphType::ArcIndex;
|
||||
using ArcLengths = ArcLengthContainer;
|
||||
|
||||
// IMPORTANT: All arguments must outlive the class.
|
||||
//
|
||||
@@ -216,7 +192,7 @@ class KShortestPathsOnDagWrapper {
|
||||
// so will obviously invalidate the result API of the last shortest path run,
|
||||
// which could return an upper bound, junk, or crash.
|
||||
KShortestPathsOnDagWrapper(const GraphType* graph,
|
||||
const std::vector<double>* arc_lengths,
|
||||
const ArcLengths* arc_lengths,
|
||||
absl::Span<const NodeIndex> topological_order,
|
||||
int path_count);
|
||||
|
||||
@@ -244,14 +220,14 @@ class KShortestPathsOnDagWrapper {
|
||||
|
||||
// Accessors to the underlying graph and arc lengths.
|
||||
const GraphType& graph() const { return *graph_; }
|
||||
const std::vector<double>& arc_lengths() const { return *arc_lengths_; }
|
||||
const ArcLengths& arc_lengths() const { return *arc_lengths_; }
|
||||
int path_count() const { return path_count_; }
|
||||
|
||||
private:
|
||||
static constexpr double kInf = std::numeric_limits<double>::infinity();
|
||||
|
||||
const GraphType* const graph_;
|
||||
const std::vector<double>* const arc_lengths_;
|
||||
const ArcLengths* const arc_lengths_;
|
||||
absl::Span<const NodeIndex> const topological_order_;
|
||||
const int path_count_;
|
||||
|
||||
@@ -269,10 +245,7 @@ class KShortestPathsOnDagWrapper {
|
||||
std::vector<NodeIndex> reached_nodes_;
|
||||
};
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
template <class GraphType, typename ArcLengths>
|
||||
absl::Status TopologicalOrderIsValid(
|
||||
const GraphType& graph,
|
||||
absl::Span<const typename GraphType::NodeIndex> topological_order);
|
||||
@@ -286,9 +259,6 @@ absl::Status TopologicalOrderIsValid(
|
||||
// (2) assign into an index rather than with push_back
|
||||
// (3) return by absl::Span (or return a copy) with known size.
|
||||
template <typename GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
std::vector<typename GraphType::NodeIndex> NodePathImpliedBy(
|
||||
absl::Span<const typename GraphType::ArcIndex> arc_path,
|
||||
const GraphType& graph) {
|
||||
@@ -303,47 +273,47 @@ std::vector<typename GraphType::NodeIndex> NodePathImpliedBy(
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
void CheckNodeIsValid(typename GraphType::NodeIndex node,
|
||||
const GraphType& graph) {
|
||||
CHECK_GE(node, 0) << "Node must be nonnegative. Input value: " << node;
|
||||
CHECK_GE(node, typename GraphType::NodeIndex(0))
|
||||
<< "Node must be nonnegative. Input value: " << node;
|
||||
CHECK_LT(node, graph.num_nodes())
|
||||
<< "Node must be a valid node. Input value: " << node
|
||||
<< ". Number of nodes in the input graph: " << graph.num_nodes();
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
absl::Status TopologicalOrderIsValid(
|
||||
const GraphType& graph,
|
||||
absl::Span<const typename GraphType::NodeIndex> topological_order) {
|
||||
using NodeIndex = typename GraphType::NodeIndex;
|
||||
using ArcIndex = typename GraphType::ArcIndex;
|
||||
const NodeIndex num_nodes = graph.num_nodes();
|
||||
if (topological_order.size() != num_nodes) {
|
||||
if (topological_order.size() != static_cast<size_t>(num_nodes)) {
|
||||
return absl::InvalidArgumentError(absl::StrFormat(
|
||||
"topological_order.size() = %i, != graph.num_nodes() = %i",
|
||||
"topological_order.size() = %i, != graph.num_nodes() = %v",
|
||||
topological_order.size(), num_nodes));
|
||||
}
|
||||
std::vector<NodeIndex> inverse_topology(num_nodes, -1);
|
||||
for (NodeIndex node = 0; node < topological_order.size(); ++node) {
|
||||
if (inverse_topology[topological_order[node]] >= 0) {
|
||||
std::vector<NodeIndex> inverse_topology(static_cast<size_t>(num_nodes),
|
||||
GraphType::kNilNode);
|
||||
for (NodeIndex node(0); node < num_nodes; ++node) {
|
||||
if (inverse_topology[static_cast<size_t>(
|
||||
topological_order[static_cast<size_t>(node)])] !=
|
||||
GraphType::kNilNode) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("node % i appears twice in topological order",
|
||||
topological_order[node]));
|
||||
absl::StrFormat("node %v appears twice in topological order",
|
||||
topological_order[static_cast<size_t>(node)]));
|
||||
}
|
||||
inverse_topology[topological_order[node]] = node;
|
||||
inverse_topology[static_cast<size_t>(
|
||||
topological_order[static_cast<size_t>(node)])] = node;
|
||||
}
|
||||
for (NodeIndex tail = 0; tail < num_nodes; ++tail) {
|
||||
for (NodeIndex tail(0); tail < num_nodes; ++tail) {
|
||||
for (const ArcIndex arc : graph.OutgoingArcs(tail)) {
|
||||
const NodeIndex head = graph.Head(arc);
|
||||
if (inverse_topology[tail] >= inverse_topology[head]) {
|
||||
if (inverse_topology[static_cast<size_t>(tail)] >=
|
||||
inverse_topology[static_cast<size_t>(head)]) {
|
||||
return absl::InvalidArgumentError(absl::StrFormat(
|
||||
"arc (%i, %i) is inconsistent with topological order", tail, head));
|
||||
"arc (%v, %v) is inconsistent with topological order", tail, head));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -353,21 +323,20 @@ absl::Status TopologicalOrderIsValid(
|
||||
// -----------------------------------------------------------------------------
|
||||
// ShortestPathsOnDagWrapper implementation.
|
||||
// -----------------------------------------------------------------------------
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
ShortestPathsOnDagWrapper<GraphType>::ShortestPathsOnDagWrapper(
|
||||
const GraphType* graph, const std::vector<double>* arc_lengths,
|
||||
template <class GraphType, typename ArcLengths>
|
||||
ShortestPathsOnDagWrapper<GraphType, ArcLengths>::ShortestPathsOnDagWrapper(
|
||||
const GraphType* graph, const ArcLengths* arc_lengths,
|
||||
absl::Span<const NodeIndex> topological_order)
|
||||
: graph_(graph),
|
||||
arc_lengths_(arc_lengths),
|
||||
topological_order_(topological_order) {
|
||||
const size_t num_nodes = static_cast<size_t>(graph_->num_nodes());
|
||||
CHECK(graph_ != nullptr);
|
||||
CHECK(arc_lengths_ != nullptr);
|
||||
CHECK_GT(graph_->num_nodes(), 0) << "The graph is empty: it has no nodes";
|
||||
CHECK_GT(num_nodes, 0) << "The graph is empty: it has no nodes";
|
||||
#ifndef NDEBUG
|
||||
CHECK_EQ(arc_lengths_->size(), graph_->num_arcs());
|
||||
CHECK_EQ(typename GraphType::ArcIndex(arc_lengths_->size()),
|
||||
graph_->num_arcs());
|
||||
for (const double arc_length : *arc_lengths_) {
|
||||
CHECK(arc_length != -kInf && !std::isnan(arc_length))
|
||||
<< absl::StrFormat("length cannot be -inf nor NaN");
|
||||
@@ -378,16 +347,13 @@ ShortestPathsOnDagWrapper<GraphType>::ShortestPathsOnDagWrapper(
|
||||
|
||||
// Memory allocation is done here and only once in order to avoid reallocation
|
||||
// at each call of `RunShortestPathOnDag()` for better performance.
|
||||
length_from_sources_.resize(graph_->num_nodes(), kInf);
|
||||
incoming_shortest_path_arc_.resize(graph_->num_nodes(), -1);
|
||||
reached_nodes_.reserve(graph_->num_nodes());
|
||||
length_from_sources_.resize(num_nodes, kInf);
|
||||
incoming_shortest_path_arc_.resize(num_nodes, GraphType::kNilArc);
|
||||
reached_nodes_.reserve(num_nodes);
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
void ShortestPathsOnDagWrapper<GraphType>::RunShortestPathOnDag(
|
||||
template <class GraphType, typename ArcLengths>
|
||||
void ShortestPathsOnDagWrapper<GraphType, ArcLengths>::RunShortestPathOnDag(
|
||||
absl::Span<const NodeIndex> sources) {
|
||||
// Caching the vector addresses allow to not fetch it on each access.
|
||||
const absl::Span<double> length_from_sources =
|
||||
@@ -398,7 +364,7 @@ void ShortestPathsOnDagWrapper<GraphType>::RunShortestPathOnDag(
|
||||
// performance, so it only makes sense for nodes that are reachable from at
|
||||
// least one source, the other ones will contain junk.
|
||||
for (const NodeIndex node : reached_nodes_) {
|
||||
length_from_sources[node] = kInf;
|
||||
length_from_sources[static_cast<size_t>(node)] = kInf;
|
||||
}
|
||||
DCHECK(std::all_of(length_from_sources.begin(), length_from_sources.end(),
|
||||
[](double l) { return l == kInf; }));
|
||||
@@ -406,11 +372,12 @@ void ShortestPathsOnDagWrapper<GraphType>::RunShortestPathOnDag(
|
||||
|
||||
for (const NodeIndex source : sources) {
|
||||
CheckNodeIsValid(source, *graph_);
|
||||
length_from_sources[source] = 0.0;
|
||||
length_from_sources[static_cast<size_t>(source)] = 0.0;
|
||||
}
|
||||
|
||||
for (const NodeIndex tail : topological_order_) {
|
||||
const double length_to_tail = length_from_sources[tail];
|
||||
const double length_to_tail =
|
||||
length_from_sources[static_cast<size_t>(tail)];
|
||||
// Stop exploring a node as soon as its length to all sources is +inf.
|
||||
if (length_to_tail == kInf) {
|
||||
continue;
|
||||
@@ -418,37 +385,35 @@ void ShortestPathsOnDagWrapper<GraphType>::RunShortestPathOnDag(
|
||||
reached_nodes_.push_back(tail);
|
||||
for (const ArcIndex arc : graph_->OutgoingArcs(tail)) {
|
||||
const NodeIndex head = graph_->Head(arc);
|
||||
DCHECK(arc_lengths[arc] != -kInf);
|
||||
const double length_to_head = arc_lengths[arc] + length_to_tail;
|
||||
if (length_to_head < length_from_sources[head]) {
|
||||
length_from_sources[head] = length_to_head;
|
||||
incoming_shortest_path_arc_[head] = arc;
|
||||
DCHECK(arc_lengths[static_cast<size_t>(arc)] != -kInf);
|
||||
const double length_to_head =
|
||||
arc_lengths[static_cast<size_t>(arc)] + length_to_tail;
|
||||
if (length_to_head < length_from_sources[static_cast<size_t>(head)]) {
|
||||
length_from_sources[static_cast<size_t>(head)] = length_to_head;
|
||||
incoming_shortest_path_arc_[static_cast<size_t>(head)] = arc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
bool ShortestPathsOnDagWrapper<GraphType>::IsReachable(NodeIndex node) const {
|
||||
template <class GraphType, typename ArcLengths>
|
||||
bool ShortestPathsOnDagWrapper<GraphType, ArcLengths>::IsReachable(
|
||||
NodeIndex node) const {
|
||||
CheckNodeIsValid(node, *graph_);
|
||||
return length_from_sources_[node] < kInf;
|
||||
return length_from_sources_[static_cast<size_t>(node)] < kInf;
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
template <class GraphType, typename ArcLengths>
|
||||
std::vector<typename GraphType::ArcIndex>
|
||||
ShortestPathsOnDagWrapper<GraphType>::ArcPathTo(NodeIndex node) const {
|
||||
ShortestPathsOnDagWrapper<GraphType, ArcLengths>::ArcPathTo(
|
||||
NodeIndex node) const {
|
||||
CHECK(IsReachable(node));
|
||||
std::vector<ArcIndex> arc_path;
|
||||
NodeIndex current_node = node;
|
||||
for (int i = 0; i < graph_->num_nodes(); ++i) {
|
||||
ArcIndex current_arc = incoming_shortest_path_arc_[current_node];
|
||||
if (current_arc == -1) {
|
||||
for (NodeIndex i(0); i < graph_->num_nodes(); ++i) {
|
||||
ArcIndex current_arc =
|
||||
incoming_shortest_path_arc_[static_cast<size_t>(current_node)];
|
||||
if (current_arc == GraphType::kNilArc) {
|
||||
break;
|
||||
}
|
||||
arc_path.push_back(current_arc);
|
||||
@@ -458,12 +423,10 @@ ShortestPathsOnDagWrapper<GraphType>::ArcPathTo(NodeIndex node) const {
|
||||
return arc_path;
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
template <class GraphType, typename ArcLengths>
|
||||
std::vector<typename GraphType::NodeIndex>
|
||||
ShortestPathsOnDagWrapper<GraphType>::NodePathTo(NodeIndex node) const {
|
||||
ShortestPathsOnDagWrapper<GraphType, ArcLengths>::NodePathTo(
|
||||
NodeIndex node) const {
|
||||
const std::vector<typename GraphType::ArcIndex> arc_path = ArcPathTo(node);
|
||||
if (arc_path.empty()) {
|
||||
return {node};
|
||||
@@ -474,12 +437,9 @@ ShortestPathsOnDagWrapper<GraphType>::NodePathTo(NodeIndex node) const {
|
||||
// -----------------------------------------------------------------------------
|
||||
// KShortestPathsOnDagWrapper implementation.
|
||||
// -----------------------------------------------------------------------------
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
|
||||
const GraphType* graph, const std::vector<double>* arc_lengths,
|
||||
template <class GraphType, typename ArcLengths>
|
||||
KShortestPathsOnDagWrapper<GraphType, ArcLengths>::KShortestPathsOnDagWrapper(
|
||||
const GraphType* graph, const ArcLengths* arc_lengths,
|
||||
absl::Span<const NodeIndex> topological_order, const int path_count)
|
||||
: graph_(graph),
|
||||
arc_lengths_(arc_lengths),
|
||||
@@ -487,10 +447,12 @@ KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
|
||||
path_count_(path_count) {
|
||||
CHECK(graph_ != nullptr);
|
||||
CHECK(arc_lengths_ != nullptr);
|
||||
CHECK_GT(graph_->num_nodes(), 0) << "The graph is empty: it has no nodes";
|
||||
const size_t num_nodes = static_cast<size_t>(graph_->num_nodes());
|
||||
CHECK_GT(num_nodes, 0) << "The graph is empty: it has no nodes";
|
||||
CHECK_GT(path_count_, 0) << "path_count must be greater than 0";
|
||||
#ifndef NDEBUG
|
||||
CHECK_EQ(arc_lengths_->size(), graph_->num_arcs());
|
||||
CHECK_EQ(typename GraphType::ArcIndex(arc_lengths_->size()),
|
||||
graph_->num_arcs());
|
||||
for (const double arc_length : *arc_lengths_) {
|
||||
CHECK(arc_length != -kInf && !std::isnan(arc_length))
|
||||
<< absl::StrFormat("length cannot be -inf nor NaN");
|
||||
@@ -501,9 +463,9 @@ KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
|
||||
|
||||
// TODO(b/332475713): Optimize if reverse graph is already provided in
|
||||
// `GraphType`.
|
||||
const int num_arcs = graph_->num_arcs();
|
||||
const ArcIndex num_arcs = graph_->num_arcs();
|
||||
reverse_graph_ = GraphType(graph_->num_nodes(), num_arcs);
|
||||
for (ArcIndex arc_index = 0; arc_index < num_arcs; ++arc_index) {
|
||||
for (ArcIndex arc_index(0); arc_index < num_arcs; ++arc_index) {
|
||||
reverse_graph_.AddArc(graph->Head(arc_index), graph->Tail(arc_index));
|
||||
}
|
||||
std::vector<ArcIndex> permutation;
|
||||
@@ -511,7 +473,7 @@ KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
|
||||
arc_indices_.resize(permutation.size());
|
||||
if (!permutation.empty()) {
|
||||
for (int i = 0; i < permutation.size(); ++i) {
|
||||
arc_indices_[permutation[i]] = i;
|
||||
arc_indices_[static_cast<size_t>(permutation[i])] = ArcIndex(i);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -521,19 +483,16 @@ KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
|
||||
incoming_shortest_paths_arc_.resize(path_count_);
|
||||
incoming_shortest_paths_index_.resize(path_count_);
|
||||
for (int k = 0; k < path_count_; ++k) {
|
||||
lengths_from_sources_[k].resize(graph_->num_nodes(), kInf);
|
||||
incoming_shortest_paths_arc_[k].resize(graph_->num_nodes(), -1);
|
||||
incoming_shortest_paths_index_[k].resize(graph_->num_nodes(), -1);
|
||||
lengths_from_sources_[k].resize(num_nodes, kInf);
|
||||
incoming_shortest_paths_arc_[k].resize(num_nodes, GraphType::kNilArc);
|
||||
incoming_shortest_paths_index_[k].resize(num_nodes, -1);
|
||||
}
|
||||
is_source_.resize(graph_->num_nodes(), false);
|
||||
reached_nodes_.reserve(graph_->num_nodes());
|
||||
is_source_.resize(num_nodes, false);
|
||||
reached_nodes_.reserve(num_nodes);
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
|
||||
template <class GraphType, typename ArcLengths>
|
||||
void KShortestPathsOnDagWrapper<GraphType, ArcLengths>::RunKShortestPathOnDag(
|
||||
absl::Span<const NodeIndex> sources) {
|
||||
// Caching the vector addresses allow to not fetch it on each access.
|
||||
const absl::Span<const double> arc_lengths = *arc_lengths_;
|
||||
@@ -544,9 +503,9 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
|
||||
// least one source, the other ones will contain junk.
|
||||
|
||||
for (const NodeIndex node : reached_nodes_) {
|
||||
is_source_[node] = false;
|
||||
is_source_[static_cast<size_t>(node)] = false;
|
||||
for (int k = 0; k < path_count_; ++k) {
|
||||
lengths_from_sources_[k][node] = kInf;
|
||||
lengths_from_sources_[k][static_cast<size_t>(node)] = kInf;
|
||||
}
|
||||
}
|
||||
reached_nodes_.clear();
|
||||
@@ -560,14 +519,14 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
|
||||
|
||||
for (const NodeIndex source : sources) {
|
||||
CheckNodeIsValid(source, *graph_);
|
||||
is_source_[source] = true;
|
||||
is_source_[static_cast<size_t>(source)] = true;
|
||||
}
|
||||
|
||||
struct IncomingArcPath {
|
||||
double path_length = 0.0;
|
||||
ArcIndex arc_index = 0;
|
||||
ArcIndex arc_index = ArcIndex(0);
|
||||
double arc_length = 0.0;
|
||||
NodeIndex from = 0;
|
||||
NodeIndex from = NodeIndex(0);
|
||||
int path_index = 0;
|
||||
|
||||
bool operator<(const IncomingArcPath& other) const {
|
||||
@@ -580,18 +539,19 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
|
||||
auto comp = std::greater<IncomingArcPath>();
|
||||
for (const NodeIndex to : topological_order_) {
|
||||
min_heap.clear();
|
||||
if (is_source_[to]) {
|
||||
min_heap.push_back({.arc_index = -1});
|
||||
if (is_source_[static_cast<size_t>(to)]) {
|
||||
min_heap.push_back({.arc_index = GraphType::kNilArc});
|
||||
}
|
||||
for (const ArcIndex reverse_arc_index : reverse_graph_.OutgoingArcs(to)) {
|
||||
const ArcIndex arc_index = arc_indices.empty()
|
||||
? reverse_arc_index
|
||||
: arc_indices[reverse_arc_index];
|
||||
const ArcIndex arc_index =
|
||||
arc_indices.empty()
|
||||
? reverse_arc_index
|
||||
: arc_indices[static_cast<size_t>(reverse_arc_index)];
|
||||
const NodeIndex from = graph_->Tail(arc_index);
|
||||
const double arc_length = arc_lengths[arc_index];
|
||||
const double arc_length = arc_lengths[static_cast<size_t>(arc_index)];
|
||||
DCHECK(arc_length != -kInf);
|
||||
const double path_length =
|
||||
lengths_from_sources_.front()[from] + arc_length;
|
||||
lengths_from_sources_.front()[static_cast<size_t>(from)] + arc_length;
|
||||
if (path_length == kInf) {
|
||||
continue;
|
||||
}
|
||||
@@ -608,17 +568,21 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
|
||||
for (int k = 0; k < path_count_; ++k) {
|
||||
std::pop_heap(min_heap.begin(), min_heap.end(), comp);
|
||||
IncomingArcPath& incoming_arc_path = min_heap.back();
|
||||
lengths_from_sources_[k][to] = incoming_arc_path.path_length;
|
||||
incoming_shortest_paths_arc_[k][to] = incoming_arc_path.arc_index;
|
||||
incoming_shortest_paths_index_[k][to] = incoming_arc_path.path_index;
|
||||
if (incoming_arc_path.arc_index != -1 &&
|
||||
lengths_from_sources_[k][static_cast<size_t>(to)] =
|
||||
incoming_arc_path.path_length;
|
||||
incoming_shortest_paths_arc_[k][static_cast<size_t>(to)] =
|
||||
incoming_arc_path.arc_index;
|
||||
incoming_shortest_paths_index_[k][static_cast<size_t>(to)] =
|
||||
incoming_arc_path.path_index;
|
||||
if (incoming_arc_path.arc_index != GraphType::kNilArc &&
|
||||
incoming_arc_path.path_index < path_count_ - 1 &&
|
||||
lengths_from_sources_[incoming_arc_path.path_index + 1]
|
||||
[incoming_arc_path.from] < kInf) {
|
||||
[static_cast<size_t>(incoming_arc_path.from)] <
|
||||
kInf) {
|
||||
++incoming_arc_path.path_index;
|
||||
incoming_arc_path.path_length =
|
||||
lengths_from_sources_[incoming_arc_path.path_index]
|
||||
[incoming_arc_path.from] +
|
||||
[static_cast<size_t>(incoming_arc_path.from)] +
|
||||
incoming_arc_path.arc_length;
|
||||
std::push_heap(min_heap.begin(), min_heap.end(), comp);
|
||||
} else {
|
||||
@@ -631,25 +595,22 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
|
||||
}
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
bool KShortestPathsOnDagWrapper<GraphType>::IsReachable(NodeIndex node) const {
|
||||
template <class GraphType, typename ArcLengths>
|
||||
bool KShortestPathsOnDagWrapper<GraphType, ArcLengths>::IsReachable(
|
||||
NodeIndex node) const {
|
||||
CheckNodeIsValid(node, *graph_);
|
||||
return lengths_from_sources_.front()[node] < kInf;
|
||||
return lengths_from_sources_.front()[static_cast<size_t>(node)] < kInf;
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
std::vector<double> KShortestPathsOnDagWrapper<GraphType>::LengthsTo(
|
||||
template <class GraphType, typename ArcLengths>
|
||||
std::vector<double>
|
||||
KShortestPathsOnDagWrapper<GraphType, ArcLengths>::LengthsTo(
|
||||
NodeIndex node) const {
|
||||
std::vector<double> lengths_to;
|
||||
lengths_to.reserve(path_count_);
|
||||
for (int k = 0; k < path_count_; ++k) {
|
||||
const double length_to = lengths_from_sources_[k][node];
|
||||
const double length_to =
|
||||
lengths_from_sources_[k][static_cast<size_t>(node)];
|
||||
if (length_to == kInf) {
|
||||
break;
|
||||
}
|
||||
@@ -658,30 +619,30 @@ std::vector<double> KShortestPathsOnDagWrapper<GraphType>::LengthsTo(
|
||||
return lengths_to;
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
template <class GraphType, typename ArcLengths>
|
||||
std::vector<std::vector<typename GraphType::ArcIndex>>
|
||||
KShortestPathsOnDagWrapper<GraphType>::ArcPathsTo(NodeIndex node) const {
|
||||
KShortestPathsOnDagWrapper<GraphType, ArcLengths>::ArcPathsTo(
|
||||
NodeIndex node) const {
|
||||
std::vector<std::vector<ArcIndex>> arc_paths;
|
||||
arc_paths.reserve(path_count_);
|
||||
for (int k = 0; k < path_count_; ++k) {
|
||||
if (lengths_from_sources_[k][node] == kInf) {
|
||||
if (lengths_from_sources_[k][static_cast<size_t>(node)] == kInf) {
|
||||
break;
|
||||
}
|
||||
std::vector<ArcIndex> arc_path;
|
||||
int current_path_index = k;
|
||||
NodeIndex current_node = node;
|
||||
for (int i = 0; i < graph_->num_nodes(); ++i) {
|
||||
for (NodeIndex i(0); i < graph_->num_nodes(); ++i) {
|
||||
ArcIndex current_arc =
|
||||
incoming_shortest_paths_arc_[current_path_index][current_node];
|
||||
if (current_arc == -1) {
|
||||
incoming_shortest_paths_arc_[current_path_index]
|
||||
[static_cast<size_t>(current_node)];
|
||||
if (current_arc == GraphType::kNilArc) {
|
||||
break;
|
||||
}
|
||||
arc_path.push_back(current_arc);
|
||||
current_path_index =
|
||||
incoming_shortest_paths_index_[current_path_index][current_node];
|
||||
incoming_shortest_paths_index_[current_path_index]
|
||||
[static_cast<size_t>(current_node)];
|
||||
current_node = graph_->Tail(current_arc);
|
||||
}
|
||||
absl::c_reverse(arc_path);
|
||||
@@ -690,12 +651,10 @@ KShortestPathsOnDagWrapper<GraphType>::ArcPathsTo(NodeIndex node) const {
|
||||
return arc_paths;
|
||||
}
|
||||
|
||||
template <class GraphType>
|
||||
#if __cplusplus >= 202002L
|
||||
requires DagGraphType<GraphType>
|
||||
#endif
|
||||
template <class GraphType, typename ArcLengths>
|
||||
std::vector<std::vector<typename GraphType::NodeIndex>>
|
||||
KShortestPathsOnDagWrapper<GraphType>::NodePathsTo(NodeIndex node) const {
|
||||
KShortestPathsOnDagWrapper<GraphType, ArcLengths>::NodePathsTo(
|
||||
NodeIndex node) const {
|
||||
const std::vector<std::vector<ArcIndex>> arc_paths = ArcPathsTo(node);
|
||||
std::vector<std::vector<NodeIndex>> node_paths(arc_paths.size());
|
||||
for (int k = 0; k < arc_paths.size(); ++k) {
|
||||
|
||||
@@ -29,6 +29,8 @@
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/base/dump_vars.h"
|
||||
#include "ortools/base/gmock.h"
|
||||
#include "ortools/base/strong_int.h"
|
||||
#include "ortools/base/strong_vector.h"
|
||||
#include "ortools/graph/graph.h"
|
||||
#include "ortools/graph/graph_io.h"
|
||||
#include "ortools/util/flat_matrix.h"
|
||||
@@ -75,7 +77,7 @@ TEST(TopologicalOrderIsValidTest, ValidateTopologicalOrder) {
|
||||
TEST(ShortestPathOnDagTest, EmptyGraph) {
|
||||
EXPECT_DEATH(ShortestPathsOnDag(/*num_nodes=*/0, /*arcs_with_length=*/{},
|
||||
/*source=*/0, /*destination=*/0),
|
||||
"num_nodes\\(\\) > 0");
|
||||
"num_nodes > 0");
|
||||
}
|
||||
|
||||
TEST(ShortestPathOnDagTest, NonExistingSourceBecauseNegative) {
|
||||
@@ -89,7 +91,7 @@ TEST(ShortestPathOnDagTest, NonExistingSourceBecauseTooLarge) {
|
||||
EXPECT_DEATH(
|
||||
ShortestPathsOnDag(/*num_nodes=*/2, /*arcs_with_length=*/{{0, 1, 0.0}},
|
||||
/*source=*/3, /*destination=*/1),
|
||||
"num_nodes\\(\\)");
|
||||
"num_nodes");
|
||||
}
|
||||
|
||||
TEST(ShortestPathOnDagTest, NonExistingDestinationBecauseNegative) {
|
||||
@@ -103,7 +105,7 @@ TEST(ShortestPathOnDagTest, NonExistingDestinationBecauseTooLarge) {
|
||||
EXPECT_DEATH(
|
||||
ShortestPathsOnDag(/*num_nodes=*/2, /*arcs_with_length=*/{{0, 1, 0.0}},
|
||||
/*source=*/0, /*destination=*/3),
|
||||
"num_nodes\\(\\)");
|
||||
"num_nodes");
|
||||
}
|
||||
|
||||
TEST(ShortestPathOnDagTest, Cycle) {
|
||||
@@ -287,6 +289,37 @@ TEST(ShortestPathOnDagTest, UpdateCost) {
|
||||
/*node_path=*/ElementsAre(source, b, destination)));
|
||||
}
|
||||
|
||||
DEFINE_STRONG_INT_TYPE(NodeIndex, int32_t);
|
||||
DEFINE_STRONG_INT_TYPE(ArcIndex, int32_t);
|
||||
|
||||
TEST(ShortestPathsOnDagWrapperTest, StrongIndices) {
|
||||
const NodeIndex source_1(0);
|
||||
const NodeIndex source_2(1);
|
||||
const NodeIndex destination(2);
|
||||
const NodeIndex num_nodes(3);
|
||||
util::ListGraph<NodeIndex, ArcIndex> graph(num_nodes,
|
||||
/*arc_capacity=*/ArcIndex(2));
|
||||
|
||||
using ArcLengths = util_intops::StrongVector<ArcIndex, double>;
|
||||
ArcLengths arc_lengths;
|
||||
graph.AddArc(source_1, destination);
|
||||
arc_lengths.push_back(-6.0);
|
||||
graph.AddArc(source_2, destination);
|
||||
arc_lengths.push_back(3.0);
|
||||
const std::vector<NodeIndex> topological_order = {source_2, source_1,
|
||||
destination};
|
||||
ShortestPathsOnDagWrapper<util::ListGraph<NodeIndex, ArcIndex>, ArcLengths>
|
||||
shortest_path_on_dag(&graph, &arc_lengths, topological_order);
|
||||
shortest_path_on_dag.RunShortestPathOnDag({source_1, source_2});
|
||||
|
||||
EXPECT_TRUE(shortest_path_on_dag.IsReachable(destination));
|
||||
EXPECT_THAT(shortest_path_on_dag.LengthTo(destination), -6.0);
|
||||
EXPECT_THAT(shortest_path_on_dag.ArcPathTo(destination),
|
||||
ElementsAre(ArcIndex(0)));
|
||||
EXPECT_THAT(shortest_path_on_dag.NodePathTo(destination),
|
||||
ElementsAre(source_1, destination));
|
||||
}
|
||||
|
||||
TEST(ShortestPathsOnDagWrapperTest, MultipleSources) {
|
||||
const int source_1 = 0;
|
||||
const int source_2 = 1;
|
||||
@@ -634,7 +667,7 @@ TEST(KShortestPathOnDagTest, EmptyGraph) {
|
||||
EXPECT_DEATH(
|
||||
KShortestPathsOnDag(/*num_nodes=*/0, /*arcs_with_length=*/{},
|
||||
/*source=*/0, /*destination=*/0, /*path_count=*/2),
|
||||
"num_nodes\\(\\) > 0");
|
||||
"num_nodes > 0");
|
||||
}
|
||||
|
||||
TEST(KShortestPathOnDagTest, NonExistingSourceBecauseNegative) {
|
||||
@@ -648,7 +681,7 @@ TEST(KShortestPathOnDagTest, NonExistingSourceBecauseTooLarge) {
|
||||
EXPECT_DEATH(
|
||||
KShortestPathsOnDag(/*num_nodes=*/2, /*arcs_with_length=*/{{0, 1, 0.0}},
|
||||
/*source=*/3, /*destination=*/1, /*path_count=*/2),
|
||||
"num_nodes\\(\\)");
|
||||
"num_nodes");
|
||||
}
|
||||
|
||||
TEST(KShortestPathOnDagTest, NonExistingDestinationBecauseNegative) {
|
||||
@@ -936,6 +969,38 @@ TEST(KShortestPathOnDagTest, UpdateCost) {
|
||||
/*node_path=*/ElementsAre(source, a, destination))));
|
||||
}
|
||||
|
||||
TEST(KShortestPathsOnDagWrapperTest, StrongIndices) {
|
||||
const NodeIndex source_1(0);
|
||||
const NodeIndex source_2(1);
|
||||
const NodeIndex destination(2);
|
||||
const NodeIndex num_nodes(3);
|
||||
util::ListGraph<NodeIndex, ArcIndex> graph(num_nodes,
|
||||
/*arc_capacity=*/ArcIndex(2));
|
||||
|
||||
using ArcLengths = util_intops::StrongVector<ArcIndex, double>;
|
||||
ArcLengths arc_lengths;
|
||||
graph.AddArc(source_1, destination);
|
||||
arc_lengths.push_back(-6.0);
|
||||
graph.AddArc(source_2, destination);
|
||||
arc_lengths.push_back(3.0);
|
||||
const std::vector<NodeIndex> topological_order = {source_2, source_1,
|
||||
destination};
|
||||
const int path_count = 2;
|
||||
KShortestPathsOnDagWrapper<util::ListGraph<NodeIndex, ArcIndex>, ArcLengths>
|
||||
shortest_paths_on_dag(&graph, &arc_lengths, topological_order,
|
||||
path_count);
|
||||
shortest_paths_on_dag.RunKShortestPathOnDag({source_1, source_2});
|
||||
|
||||
EXPECT_TRUE(shortest_paths_on_dag.IsReachable(destination));
|
||||
EXPECT_THAT(shortest_paths_on_dag.LengthsTo(destination),
|
||||
ElementsAre(-6.0, 3.0));
|
||||
EXPECT_THAT(shortest_paths_on_dag.ArcPathsTo(destination),
|
||||
ElementsAre(ElementsAre(ArcIndex(0)), ElementsAre(ArcIndex(1))));
|
||||
EXPECT_THAT(shortest_paths_on_dag.NodePathsTo(destination),
|
||||
ElementsAre(ElementsAre(source_1, destination),
|
||||
ElementsAre(source_2, destination)));
|
||||
}
|
||||
|
||||
TEST(KShortestPathsOnDagWrapperTest, MultipleSources) {
|
||||
const int source_1 = 0;
|
||||
const int source_2 = 1;
|
||||
|
||||
@@ -286,8 +286,12 @@ class BaseGraph {
|
||||
|
||||
// Constants that will never be a valid node or arc.
|
||||
// They are the maximum possible node and arc capacity.
|
||||
static const NodeIndexType kNilNode;
|
||||
static const ArcIndexType kNilArc;
|
||||
static_assert(std::numeric_limits<NodeIndexType>::is_specialized);
|
||||
static constexpr NodeIndexType kNilNode =
|
||||
std::numeric_limits<NodeIndexType>::max();
|
||||
static_assert(std::numeric_limits<ArcIndexType>::is_specialized);
|
||||
static constexpr ArcIndexType kNilArc =
|
||||
std::numeric_limits<ArcIndexType>::max();
|
||||
|
||||
protected:
|
||||
// Functions commented when defined because they are implementation details.
|
||||
@@ -590,6 +594,26 @@ class SVector {
|
||||
|
||||
} // namespace internal
|
||||
|
||||
// Graph traits, to allow algorithms to manipulate graphs as adjacency lists.
|
||||
// This works with any graph type, and any object that has:
|
||||
// - a size() method returning the number of nodes.
|
||||
// - an operator[] method taking a node index and returning a range of neighbour
|
||||
// node indices.
|
||||
// One common example is using `std::vector<std::vector<int>>` to represent
|
||||
// adjacency lists.
|
||||
template <typename Graph>
|
||||
struct GraphTraits {
|
||||
private:
|
||||
// The type of the range returned by `operator[]`.
|
||||
using NeighborRangeType = std::decay_t<
|
||||
decltype(std::declval<Graph>()[std::declval<Graph>().size()])>;
|
||||
|
||||
public:
|
||||
// The index type for nodes of the graph.
|
||||
using NodeIndex =
|
||||
std::decay_t<decltype(*(std::declval<NeighborRangeType>().begin()))>;
|
||||
};
|
||||
|
||||
// Basic graph implementation without reverse arc. This class also serves as a
|
||||
// documentation for the generic graph interface (minus the part related to
|
||||
// reverse arcs).
|
||||
@@ -1121,18 +1145,6 @@ BaseGraph<NodeIndexType, ArcIndexType, HasNegativeReverseArcs>::AllForwardArcs()
|
||||
return IntegerRange<ArcIndexType>(ArcIndexType(0), num_arcs_);
|
||||
}
|
||||
|
||||
template <typename NodeIndexType, typename ArcIndexType,
|
||||
bool HasNegativeReverseArcs>
|
||||
const NodeIndexType
|
||||
BaseGraph<NodeIndexType, ArcIndexType, HasNegativeReverseArcs>::kNilNode =
|
||||
std::numeric_limits<NodeIndexType>::max();
|
||||
|
||||
template <typename NodeIndexType, typename ArcIndexType,
|
||||
bool HasNegativeReverseArcs>
|
||||
const ArcIndexType
|
||||
BaseGraph<NodeIndexType, ArcIndexType, HasNegativeReverseArcs>::kNilArc =
|
||||
std::numeric_limits<ArcIndexType>::max();
|
||||
|
||||
template <typename NodeIndexType, typename ArcIndexType,
|
||||
bool HasNegativeReverseArcs>
|
||||
NodeIndexType BaseGraph<NodeIndexType, ArcIndexType,
|
||||
|
||||
@@ -57,6 +57,8 @@ class BeginEndWrapper {
|
||||
|
||||
Iterator begin() const { return begin_; }
|
||||
Iterator end() const { return end_; }
|
||||
|
||||
// Available only if `Iterator` is a random access iterator.
|
||||
size_t size() const { return end_ - begin_; }
|
||||
|
||||
bool empty() const { return begin() == end(); }
|
||||
@@ -127,8 +129,6 @@ class IntegerRangeIterator
|
||||
#endif
|
||||
{
|
||||
public:
|
||||
// TODO(b/385094969): This should be `IntegerType` for integers,
|
||||
// `IntegerType:value_type` for strong signed integer types.
|
||||
using difference_type = ptrdiff_t;
|
||||
using value_type = IntegerType;
|
||||
|
||||
@@ -210,7 +210,7 @@ class IntegerRangeIterator
|
||||
|
||||
friend difference_type operator-(const IntegerRangeIterator l,
|
||||
const IntegerRangeIterator r) {
|
||||
return l.index_ - r.index_;
|
||||
return static_cast<difference_type>(l.index_ - r.index_);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -248,9 +248,7 @@ class ChasingIterator
|
||||
#endif
|
||||
{
|
||||
public:
|
||||
// TODO(b/385094969): This should be `IntegerType` for integers,
|
||||
// `IntegerType:value_type` for strong signed integer types.
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using difference_type = ptrdiff_t;
|
||||
using value_type = IndexT;
|
||||
|
||||
ChasingIterator() : index_(sentinel), next_(nullptr) {}
|
||||
|
||||
@@ -30,6 +30,7 @@
|
||||
#ifndef UTIL_GRAPH_TOPOLOGICALSORTER_H__
|
||||
#define UTIL_GRAPH_TOPOLOGICALSORTER_H__
|
||||
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <queue>
|
||||
@@ -84,7 +85,9 @@ namespace graph {
|
||||
// FastTopologicalSort(util::StaticGraph<>::FromArcs(num_nodes, arcs)));
|
||||
//
|
||||
template <class AdjacencyLists> // vector<vector<int>>, util::StaticGraph<>, ..
|
||||
absl::StatusOr<std::vector<int>> FastTopologicalSort(const AdjacencyLists& adj);
|
||||
absl::StatusOr<
|
||||
std::vector<typename util::GraphTraits<AdjacencyLists>::NodeIndex>>
|
||||
FastTopologicalSort(const AdjacencyLists& adj);
|
||||
|
||||
// Finds a cycle in the directed graph given as argument: nodes are dense
|
||||
// integers in 0..num_nodes-1, and (directed) arcs are pairs of nodes
|
||||
@@ -93,7 +96,9 @@ absl::StatusOr<std::vector<int>> FastTopologicalSort(const AdjacencyLists& adj);
|
||||
// if the cycle 1->4->3->1 exists.
|
||||
// If the graph is acyclic, returns an empty vector.
|
||||
template <class AdjacencyLists> // vector<vector<int>>, util::StaticGraph<>, ..
|
||||
absl::StatusOr<std::vector<int>> FindCycleInGraph(const AdjacencyLists& adj);
|
||||
absl::StatusOr<
|
||||
std::vector<typename util::GraphTraits<AdjacencyLists>::NodeIndex>>
|
||||
FindCycleInGraph(const AdjacencyLists& adj);
|
||||
|
||||
} // namespace graph
|
||||
|
||||
@@ -615,38 +620,38 @@ std::vector<T> StableTopologicalSortOrDie(
|
||||
}
|
||||
|
||||
template <class AdjacencyLists>
|
||||
absl::StatusOr<std::vector<int>> FastTopologicalSort(
|
||||
const AdjacencyLists& adj) {
|
||||
const size_t num_nodes = adj.size();
|
||||
if (num_nodes > std::numeric_limits<int>::max()) {
|
||||
return absl::InvalidArgumentError("More than kint32max nodes");
|
||||
absl::StatusOr<std::vector<typename GraphTraits<AdjacencyLists>::NodeIndex>>
|
||||
FastTopologicalSort(const AdjacencyLists& adj) {
|
||||
using NodeIndex = typename GraphTraits<AdjacencyLists>::NodeIndex;
|
||||
if (adj.size() > std::numeric_limits<NodeIndex>::max()) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("Too many nodes: adj.size()=%v", adj.size()));
|
||||
}
|
||||
std::vector<int> indegree(num_nodes, 0);
|
||||
std::vector<int> topo_order;
|
||||
topo_order.reserve(num_nodes);
|
||||
for (int from = 0; from < num_nodes; ++from) {
|
||||
for (const int head : adj[from]) {
|
||||
// We cast to unsigned int to test "head < 0 || head ≥ num_nodes" with a
|
||||
// single test. Microbenchmarks showed a ~1% overall performance gain.
|
||||
if (static_cast<uint32_t>(head) >= num_nodes) {
|
||||
const NodeIndex num_nodes(adj.size());
|
||||
std::vector<NodeIndex> indegree(static_cast<size_t>(num_nodes), NodeIndex(0));
|
||||
std::vector<NodeIndex> topo_order;
|
||||
topo_order.reserve(static_cast<size_t>(num_nodes));
|
||||
for (NodeIndex from(0); from < num_nodes; ++from) {
|
||||
for (const NodeIndex head : adj[from]) {
|
||||
if (!(NodeIndex(0) <= head && head < num_nodes)) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("Invalid arc in adj[%d]: %d (num_nodes=%d)", from,
|
||||
absl::StrFormat("Invalid arc in adj[%v]: %v (num_nodes=%v)", from,
|
||||
head, num_nodes));
|
||||
}
|
||||
// NOTE(user): We could detect self-arcs here (head == from) and exit
|
||||
// early, but microbenchmarks show a 2 to 4% slow-down if we do it, so we
|
||||
// simply rely on self-arcs being detected as cycles in the topo sort.
|
||||
++indegree[head];
|
||||
++indegree[static_cast<size_t>(head)];
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < num_nodes; ++i) {
|
||||
if (!indegree[i]) topo_order.push_back(i);
|
||||
for (NodeIndex i(0); i < num_nodes; ++i) {
|
||||
if (!indegree[static_cast<size_t>(i)]) topo_order.push_back(i);
|
||||
}
|
||||
size_t num_visited = 0;
|
||||
while (num_visited < topo_order.size()) {
|
||||
const int from = topo_order[num_visited++];
|
||||
for (const int head : adj[from]) {
|
||||
if (!--indegree[head]) topo_order.push_back(head);
|
||||
const NodeIndex from = topo_order[num_visited++];
|
||||
for (const NodeIndex head : adj[from]) {
|
||||
if (!--indegree[static_cast<size_t>(head)]) topo_order.push_back(head);
|
||||
}
|
||||
}
|
||||
if (topo_order.size() < static_cast<size_t>(num_nodes)) {
|
||||
@@ -656,77 +661,99 @@ absl::StatusOr<std::vector<int>> FastTopologicalSort(
|
||||
}
|
||||
|
||||
template <class AdjacencyLists>
|
||||
absl::StatusOr<std::vector<int>> FindCycleInGraph(const AdjacencyLists& adj) {
|
||||
const size_t num_nodes = adj.size();
|
||||
if (num_nodes > std::numeric_limits<int>::max()) {
|
||||
absl::StatusOr<
|
||||
std::vector<typename util::GraphTraits<AdjacencyLists>::NodeIndex>>
|
||||
FindCycleInGraph(const AdjacencyLists& adj) {
|
||||
using NodeIndex = typename GraphTraits<AdjacencyLists>::NodeIndex;
|
||||
if (adj.size() > std::numeric_limits<NodeIndex>::max()) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("Too many nodes: adj.size()=%d", adj.size()));
|
||||
absl::StrFormat("Too many nodes: adj.size()=%v", adj.size()));
|
||||
}
|
||||
const NodeIndex num_nodes(adj.size());
|
||||
|
||||
// First pass to validate that inputs are valid.
|
||||
for (NodeIndex node(0); node < NodeIndex(node); ++node) {
|
||||
for (const NodeIndex head : adj[node]) {
|
||||
if (head >= num_nodes) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("Invalid child %v in adj[%v]", head, node));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// To find a cycle, we start a DFS from each yet-unvisited node and
|
||||
// try to find a cycle, if we don't find it then we know for sure that
|
||||
// no cycle is reachable from any of the explored nodes (so, we don't
|
||||
// explore them in later DFSs).
|
||||
std::vector<bool> no_cycle_reachable_from(num_nodes, false);
|
||||
std::vector<bool> no_cycle_reachable_from(static_cast<size_t>(num_nodes),
|
||||
false);
|
||||
// The DFS stack will contain a chain of nodes, from the root of the
|
||||
// DFS to the current leaf.
|
||||
struct DfsState {
|
||||
int node;
|
||||
NodeIndex node;
|
||||
// Points at the first child node that we did *not* yet look at.
|
||||
int adj_list_index;
|
||||
explicit DfsState(int _node) : node(_node), adj_list_index(0) {}
|
||||
decltype(adj[NodeIndex(0)].begin()) children;
|
||||
decltype(adj[NodeIndex(0)].end()) children_end;
|
||||
|
||||
explicit DfsState(NodeIndex _node,
|
||||
const decltype(adj[NodeIndex(0)])& neighbours)
|
||||
: node(_node),
|
||||
children(neighbours.begin()),
|
||||
children_end(neighbours.end()) {}
|
||||
};
|
||||
std::vector<DfsState> dfs_stack;
|
||||
std::vector<bool> in_cur_stack(num_nodes, false);
|
||||
for (int start_node = 0; start_node < static_cast<int>(num_nodes);
|
||||
std::vector<bool> visited(static_cast<size_t>(num_nodes), false);
|
||||
for (NodeIndex start_node(0); start_node < NodeIndex(num_nodes);
|
||||
++start_node) {
|
||||
if (no_cycle_reachable_from[start_node]) continue;
|
||||
if (no_cycle_reachable_from[static_cast<size_t>(start_node)]) continue;
|
||||
|
||||
// Start the DFS.
|
||||
dfs_stack.push_back(DfsState(start_node));
|
||||
in_cur_stack[start_node] = true;
|
||||
visited[static_cast<size_t>(start_node)] = true;
|
||||
dfs_stack.push_back(DfsState(start_node, adj[start_node]));
|
||||
while (!dfs_stack.empty()) {
|
||||
DfsState* cur_state = &dfs_stack.back();
|
||||
if (static_cast<size_t>(cur_state->adj_list_index) >=
|
||||
adj[cur_state->node].size()) {
|
||||
no_cycle_reachable_from[cur_state->node] = true;
|
||||
in_cur_stack[cur_state->node] = false;
|
||||
DfsState* const cur_state = &dfs_stack.back();
|
||||
while (
|
||||
cur_state->children != cur_state->children_end &&
|
||||
no_cycle_reachable_from[static_cast<size_t>(*cur_state->children)]) {
|
||||
++cur_state->children;
|
||||
}
|
||||
if (cur_state->children == cur_state->children_end) {
|
||||
no_cycle_reachable_from[static_cast<size_t>(cur_state->node)] = true;
|
||||
dfs_stack.pop_back();
|
||||
continue;
|
||||
}
|
||||
// Look at the current child, and increase the current state's
|
||||
// adj_list_index.
|
||||
// TODO(user): Caching adj[cur_state->node] in a local stack to improve
|
||||
// locality and so that the [] operator is called exactly once per node.
|
||||
const int child = adj[cur_state->node][cur_state->adj_list_index++];
|
||||
if (static_cast<size_t>(child) >= num_nodes) {
|
||||
return absl::InvalidArgumentError(absl::StrFormat(
|
||||
"Invalid child %d in adj[%d]", child, cur_state->node));
|
||||
}
|
||||
if (no_cycle_reachable_from[child]) continue;
|
||||
if (in_cur_stack[child]) {
|
||||
|
||||
const NodeIndex child = *cur_state->children;
|
||||
// At that point the child is either:
|
||||
// - visited and all finalized (all its children are visited). We know
|
||||
// that it's not part of a cycle, otherwise we'd already have
|
||||
// returned.
|
||||
// - visited and not finalized (some of its children are not visited).
|
||||
// That means that we've reached it again from a child, so we've found
|
||||
// a cycle.
|
||||
// - not visited. We push it on the stack and explore it.
|
||||
if (no_cycle_reachable_from[static_cast<size_t>(child)]) continue;
|
||||
if (visited[static_cast<size_t>(child)]) {
|
||||
// We detected a cycle! It corresponds to the tail end of dfs_stack,
|
||||
// in reverse order, until we find "child".
|
||||
int cycle_start = dfs_stack.size() - 1;
|
||||
size_t cycle_start = dfs_stack.size() - 1;
|
||||
while (dfs_stack[cycle_start].node != child) --cycle_start;
|
||||
const int cycle_size = dfs_stack.size() - cycle_start;
|
||||
std::vector<int> cycle(cycle_size);
|
||||
for (int c = 0; c < cycle_size; ++c) {
|
||||
const size_t cycle_size = dfs_stack.size() - cycle_start;
|
||||
std::vector<NodeIndex> cycle(cycle_size);
|
||||
for (size_t c = 0; c < cycle_size; ++c) {
|
||||
cycle[c] = dfs_stack[cycle_start + c].node;
|
||||
}
|
||||
return cycle;
|
||||
}
|
||||
|
||||
// Push the child onto the stack.
|
||||
dfs_stack.push_back(DfsState(child));
|
||||
in_cur_stack[child] = true;
|
||||
// Verify that its adjacency list seems valid.
|
||||
if (adj[child].size() > std::numeric_limits<int>::max()) {
|
||||
return absl::InvalidArgumentError(absl::StrFormat(
|
||||
"Invalid adj[%d].size() = %d", child, adj[child].size()));
|
||||
}
|
||||
dfs_stack.push_back(DfsState(child, adj[child]));
|
||||
visited[static_cast<size_t>(child)] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// If we're here, then all the DFS stopped, and there is no cycle.
|
||||
return std::vector<int>{};
|
||||
return std::vector<NodeIndex>{};
|
||||
}
|
||||
|
||||
} // namespace graph
|
||||
|
||||
Reference in New Issue
Block a user