graph: export from google3

This commit is contained in:
Corentin Le Molgat
2025-03-31 15:16:06 +02:00
parent e0a0a597a0
commit 6153511356
7 changed files with 324 additions and 285 deletions

View File

@@ -766,6 +766,8 @@ cc_test(
":io",
"//ortools/base:dump_vars",
"//ortools/base:gmock_main",
"//ortools/base:intops",
"//ortools/base:strong_vector",
"//ortools/util:flat_matrix",
"@abseil-cpp//absl/algorithm:container",
"@abseil-cpp//absl/log:check",

View File

@@ -81,9 +81,6 @@ PathWithLength ConstrainedShortestPathsOnDag(
// Advanced API.
// -----------------------------------------------------------------------------
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
struct GraphPathWithLength {
double length = 0.0;
// The returned arc indices points into the `arcs_with_length` passed to the
@@ -97,9 +94,6 @@ struct GraphPathWithLength {
// computations efficiently on the given DAG (on which resources do not change).
// `GraphType` can use one of the interfaces defined in `util/graph/graph.h`.
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
class ConstrainedShortestPathsOnDagWrapper {
public:
using NodeIndex = typename GraphType::NodeIndex;
@@ -285,9 +279,6 @@ std::vector<T> GetInversePermutation(const std::vector<T>& permutation);
// -----------------------------------------------------------------------------
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
ConstrainedShortestPathsOnDagWrapper<GraphType>::
ConstrainedShortestPathsOnDagWrapper(
const GraphType* graph, const std::vector<double>* arc_lengths,
@@ -543,9 +534,6 @@ ConstrainedShortestPathsOnDagWrapper<GraphType>::
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
GraphPathWithLength<GraphType> ConstrainedShortestPathsOnDagWrapper<
GraphType>::RunConstrainedShortestPathOnDag() {
if (source_is_destination_.has_value()) {
@@ -664,9 +652,6 @@ GraphPathWithLength<GraphType> ConstrainedShortestPathsOnDagWrapper<
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
void ConstrainedShortestPathsOnDagWrapper<GraphType>::
RunHalfConstrainedShortestPathOnDag(
const GraphType& reverse_graph, absl::Span<const double> arc_lengths,
@@ -792,9 +777,6 @@ void ConstrainedShortestPathsOnDagWrapper<GraphType>::
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
typename GraphType::ArcIndex
ConstrainedShortestPathsOnDagWrapper<GraphType>::MergeHalfRuns(
const GraphType& graph, absl::Span<const double> arc_lengths,
@@ -879,9 +861,6 @@ ConstrainedShortestPathsOnDagWrapper<GraphType>::MergeHalfRuns(
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
std::vector<typename GraphType::ArcIndex>
ConstrainedShortestPathsOnDagWrapper<GraphType>::ArcPathTo(
const int best_label_index,
@@ -901,9 +880,6 @@ ConstrainedShortestPathsOnDagWrapper<GraphType>::ArcPathTo(
}
template <typename GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
std::vector<typename GraphType::NodeIndex>
ConstrainedShortestPathsOnDagWrapper<GraphType>::NodePathImpliedBy(
absl::Span<const ArcIndex> arc_path, const GraphType& graph) const {

View File

@@ -15,9 +15,7 @@
#define OR_TOOLS_GRAPH_DAG_SHORTEST_PATH_H_
#include <cmath>
#if __cplusplus >= 202002L
#include <concepts>
#endif
#include <cstddef>
#include <functional>
#include <limits>
#include <vector>
@@ -82,39 +80,17 @@ std::vector<PathWithLength> KShortestPathsOnDag(
// -----------------------------------------------------------------------------
// Advanced API.
// -----------------------------------------------------------------------------
// This concept only enforces the standard graph API needed for all algorithms
// on DAGs. One could add the requirement of being a DAG wihtin this concept
// (which is done before running the algorithm).
#if __cplusplus >= 202002L
template <class GraphType>
concept DagGraphType = requires(GraphType graph) {
{ typename GraphType::NodeIndex{} };
{ typename GraphType::ArcIndex{} };
{ graph.num_nodes() } -> std::same_as<typename GraphType::NodeIndex>;
{ graph.num_arcs() } -> std::same_as<typename GraphType::ArcIndex>;
{ graph.OutgoingArcs(typename GraphType::NodeIndex{}) };
{
graph.Tail(typename GraphType::ArcIndex{})
} -> std::same_as<typename GraphType::NodeIndex>;
{
graph.Head(typename GraphType::ArcIndex{})
} -> std::same_as<typename GraphType::NodeIndex>;
{ graph.Build() };
};
#endif
// A wrapper that holds the memory needed to run many shortest path computations
// efficiently on the given DAG. One call of `RunShortestPathOnDag()` has time
// complexity O(|E| + |V|) and space complexity O(|V|).
// `GraphType` can use any of the interfaces defined in `util/graph/graph.h`.
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
// `ArcLengthContainer` can be any container of doubles.
template <class GraphType, typename ArcLengthContainer = std::vector<double>>
class ShortestPathsOnDagWrapper {
public:
using NodeIndex = typename GraphType::NodeIndex;
using ArcIndex = typename GraphType::ArcIndex;
using ArcLengths = ArcLengthContainer;
// IMPORTANT: All arguments must outlive the class.
//
@@ -138,7 +114,7 @@ class ShortestPathsOnDagWrapper {
// so will obviously invalidate the result API of the last shortest path run,
// which could return an upper bound, junk, or crash.
ShortestPathsOnDagWrapper(const GraphType* graph,
const std::vector<double>* arc_lengths,
const ArcLengths* arc_lengths,
absl::Span<const NodeIndex> topological_order);
// Computes the shortest path to all reachable nodes from the given sources.
@@ -151,7 +127,9 @@ class ShortestPathsOnDagWrapper {
const std::vector<NodeIndex>& reached_nodes() const { return reached_nodes_; }
// Returns the length of the shortest path from `node`'s source to `node`.
double LengthTo(NodeIndex node) const { return length_from_sources_[node]; }
double LengthTo(NodeIndex node) const {
return length_from_sources_[static_cast<size_t>(node)];
}
std::vector<double> LengthTo() const { return length_from_sources_; }
// Returns the list of all the arcs in the shortest path from `node`'s
@@ -164,12 +142,12 @@ class ShortestPathsOnDagWrapper {
// Accessors to the underlying graph and arc lengths.
const GraphType& graph() const { return *graph_; }
const std::vector<double>& arc_lengths() const { return *arc_lengths_; }
const ArcLengths& arc_lengths() const { return *arc_lengths_; }
private:
static constexpr double kInf = std::numeric_limits<double>::infinity();
const GraphType* const graph_;
const std::vector<double>* const arc_lengths_;
const ArcLengths* const arc_lengths_;
absl::Span<const NodeIndex> const topological_order_;
// Data about the last call of the RunShortestPathOnDag() function.
@@ -185,14 +163,12 @@ class ShortestPathsOnDagWrapper {
// `GraphType` can use any of the interfaces defined in `util/graph/graph.h`.
// IMPORTANT: Only use if `path_count > 1` (k > 1) otherwise use
// `ShortestPathsOnDagWrapper`.
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
template <class GraphType, typename ArcLengthContainer = std::vector<double>>
class KShortestPathsOnDagWrapper {
public:
using NodeIndex = typename GraphType::NodeIndex;
using ArcIndex = typename GraphType::ArcIndex;
using ArcLengths = ArcLengthContainer;
// IMPORTANT: All arguments must outlive the class.
//
@@ -216,7 +192,7 @@ class KShortestPathsOnDagWrapper {
// so will obviously invalidate the result API of the last shortest path run,
// which could return an upper bound, junk, or crash.
KShortestPathsOnDagWrapper(const GraphType* graph,
const std::vector<double>* arc_lengths,
const ArcLengths* arc_lengths,
absl::Span<const NodeIndex> topological_order,
int path_count);
@@ -244,14 +220,14 @@ class KShortestPathsOnDagWrapper {
// Accessors to the underlying graph and arc lengths.
const GraphType& graph() const { return *graph_; }
const std::vector<double>& arc_lengths() const { return *arc_lengths_; }
const ArcLengths& arc_lengths() const { return *arc_lengths_; }
int path_count() const { return path_count_; }
private:
static constexpr double kInf = std::numeric_limits<double>::infinity();
const GraphType* const graph_;
const std::vector<double>* const arc_lengths_;
const ArcLengths* const arc_lengths_;
absl::Span<const NodeIndex> const topological_order_;
const int path_count_;
@@ -269,10 +245,7 @@ class KShortestPathsOnDagWrapper {
std::vector<NodeIndex> reached_nodes_;
};
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
template <class GraphType, typename ArcLengths>
absl::Status TopologicalOrderIsValid(
const GraphType& graph,
absl::Span<const typename GraphType::NodeIndex> topological_order);
@@ -286,9 +259,6 @@ absl::Status TopologicalOrderIsValid(
// (2) assign into an index rather than with push_back
// (3) return by absl::Span (or return a copy) with known size.
template <typename GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
std::vector<typename GraphType::NodeIndex> NodePathImpliedBy(
absl::Span<const typename GraphType::ArcIndex> arc_path,
const GraphType& graph) {
@@ -303,47 +273,47 @@ std::vector<typename GraphType::NodeIndex> NodePathImpliedBy(
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
void CheckNodeIsValid(typename GraphType::NodeIndex node,
const GraphType& graph) {
CHECK_GE(node, 0) << "Node must be nonnegative. Input value: " << node;
CHECK_GE(node, typename GraphType::NodeIndex(0))
<< "Node must be nonnegative. Input value: " << node;
CHECK_LT(node, graph.num_nodes())
<< "Node must be a valid node. Input value: " << node
<< ". Number of nodes in the input graph: " << graph.num_nodes();
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
absl::Status TopologicalOrderIsValid(
const GraphType& graph,
absl::Span<const typename GraphType::NodeIndex> topological_order) {
using NodeIndex = typename GraphType::NodeIndex;
using ArcIndex = typename GraphType::ArcIndex;
const NodeIndex num_nodes = graph.num_nodes();
if (topological_order.size() != num_nodes) {
if (topological_order.size() != static_cast<size_t>(num_nodes)) {
return absl::InvalidArgumentError(absl::StrFormat(
"topological_order.size() = %i, != graph.num_nodes() = %i",
"topological_order.size() = %i, != graph.num_nodes() = %v",
topological_order.size(), num_nodes));
}
std::vector<NodeIndex> inverse_topology(num_nodes, -1);
for (NodeIndex node = 0; node < topological_order.size(); ++node) {
if (inverse_topology[topological_order[node]] >= 0) {
std::vector<NodeIndex> inverse_topology(static_cast<size_t>(num_nodes),
GraphType::kNilNode);
for (NodeIndex node(0); node < num_nodes; ++node) {
if (inverse_topology[static_cast<size_t>(
topological_order[static_cast<size_t>(node)])] !=
GraphType::kNilNode) {
return absl::InvalidArgumentError(
absl::StrFormat("node % i appears twice in topological order",
topological_order[node]));
absl::StrFormat("node %v appears twice in topological order",
topological_order[static_cast<size_t>(node)]));
}
inverse_topology[topological_order[node]] = node;
inverse_topology[static_cast<size_t>(
topological_order[static_cast<size_t>(node)])] = node;
}
for (NodeIndex tail = 0; tail < num_nodes; ++tail) {
for (NodeIndex tail(0); tail < num_nodes; ++tail) {
for (const ArcIndex arc : graph.OutgoingArcs(tail)) {
const NodeIndex head = graph.Head(arc);
if (inverse_topology[tail] >= inverse_topology[head]) {
if (inverse_topology[static_cast<size_t>(tail)] >=
inverse_topology[static_cast<size_t>(head)]) {
return absl::InvalidArgumentError(absl::StrFormat(
"arc (%i, %i) is inconsistent with topological order", tail, head));
"arc (%v, %v) is inconsistent with topological order", tail, head));
}
}
}
@@ -353,21 +323,20 @@ absl::Status TopologicalOrderIsValid(
// -----------------------------------------------------------------------------
// ShortestPathsOnDagWrapper implementation.
// -----------------------------------------------------------------------------
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
ShortestPathsOnDagWrapper<GraphType>::ShortestPathsOnDagWrapper(
const GraphType* graph, const std::vector<double>* arc_lengths,
template <class GraphType, typename ArcLengths>
ShortestPathsOnDagWrapper<GraphType, ArcLengths>::ShortestPathsOnDagWrapper(
const GraphType* graph, const ArcLengths* arc_lengths,
absl::Span<const NodeIndex> topological_order)
: graph_(graph),
arc_lengths_(arc_lengths),
topological_order_(topological_order) {
const size_t num_nodes = static_cast<size_t>(graph_->num_nodes());
CHECK(graph_ != nullptr);
CHECK(arc_lengths_ != nullptr);
CHECK_GT(graph_->num_nodes(), 0) << "The graph is empty: it has no nodes";
CHECK_GT(num_nodes, 0) << "The graph is empty: it has no nodes";
#ifndef NDEBUG
CHECK_EQ(arc_lengths_->size(), graph_->num_arcs());
CHECK_EQ(typename GraphType::ArcIndex(arc_lengths_->size()),
graph_->num_arcs());
for (const double arc_length : *arc_lengths_) {
CHECK(arc_length != -kInf && !std::isnan(arc_length))
<< absl::StrFormat("length cannot be -inf nor NaN");
@@ -378,16 +347,13 @@ ShortestPathsOnDagWrapper<GraphType>::ShortestPathsOnDagWrapper(
// Memory allocation is done here and only once in order to avoid reallocation
// at each call of `RunShortestPathOnDag()` for better performance.
length_from_sources_.resize(graph_->num_nodes(), kInf);
incoming_shortest_path_arc_.resize(graph_->num_nodes(), -1);
reached_nodes_.reserve(graph_->num_nodes());
length_from_sources_.resize(num_nodes, kInf);
incoming_shortest_path_arc_.resize(num_nodes, GraphType::kNilArc);
reached_nodes_.reserve(num_nodes);
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
void ShortestPathsOnDagWrapper<GraphType>::RunShortestPathOnDag(
template <class GraphType, typename ArcLengths>
void ShortestPathsOnDagWrapper<GraphType, ArcLengths>::RunShortestPathOnDag(
absl::Span<const NodeIndex> sources) {
// Caching the vector addresses allow to not fetch it on each access.
const absl::Span<double> length_from_sources =
@@ -398,7 +364,7 @@ void ShortestPathsOnDagWrapper<GraphType>::RunShortestPathOnDag(
// performance, so it only makes sense for nodes that are reachable from at
// least one source, the other ones will contain junk.
for (const NodeIndex node : reached_nodes_) {
length_from_sources[node] = kInf;
length_from_sources[static_cast<size_t>(node)] = kInf;
}
DCHECK(std::all_of(length_from_sources.begin(), length_from_sources.end(),
[](double l) { return l == kInf; }));
@@ -406,11 +372,12 @@ void ShortestPathsOnDagWrapper<GraphType>::RunShortestPathOnDag(
for (const NodeIndex source : sources) {
CheckNodeIsValid(source, *graph_);
length_from_sources[source] = 0.0;
length_from_sources[static_cast<size_t>(source)] = 0.0;
}
for (const NodeIndex tail : topological_order_) {
const double length_to_tail = length_from_sources[tail];
const double length_to_tail =
length_from_sources[static_cast<size_t>(tail)];
// Stop exploring a node as soon as its length to all sources is +inf.
if (length_to_tail == kInf) {
continue;
@@ -418,37 +385,35 @@ void ShortestPathsOnDagWrapper<GraphType>::RunShortestPathOnDag(
reached_nodes_.push_back(tail);
for (const ArcIndex arc : graph_->OutgoingArcs(tail)) {
const NodeIndex head = graph_->Head(arc);
DCHECK(arc_lengths[arc] != -kInf);
const double length_to_head = arc_lengths[arc] + length_to_tail;
if (length_to_head < length_from_sources[head]) {
length_from_sources[head] = length_to_head;
incoming_shortest_path_arc_[head] = arc;
DCHECK(arc_lengths[static_cast<size_t>(arc)] != -kInf);
const double length_to_head =
arc_lengths[static_cast<size_t>(arc)] + length_to_tail;
if (length_to_head < length_from_sources[static_cast<size_t>(head)]) {
length_from_sources[static_cast<size_t>(head)] = length_to_head;
incoming_shortest_path_arc_[static_cast<size_t>(head)] = arc;
}
}
}
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
bool ShortestPathsOnDagWrapper<GraphType>::IsReachable(NodeIndex node) const {
template <class GraphType, typename ArcLengths>
bool ShortestPathsOnDagWrapper<GraphType, ArcLengths>::IsReachable(
NodeIndex node) const {
CheckNodeIsValid(node, *graph_);
return length_from_sources_[node] < kInf;
return length_from_sources_[static_cast<size_t>(node)] < kInf;
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
template <class GraphType, typename ArcLengths>
std::vector<typename GraphType::ArcIndex>
ShortestPathsOnDagWrapper<GraphType>::ArcPathTo(NodeIndex node) const {
ShortestPathsOnDagWrapper<GraphType, ArcLengths>::ArcPathTo(
NodeIndex node) const {
CHECK(IsReachable(node));
std::vector<ArcIndex> arc_path;
NodeIndex current_node = node;
for (int i = 0; i < graph_->num_nodes(); ++i) {
ArcIndex current_arc = incoming_shortest_path_arc_[current_node];
if (current_arc == -1) {
for (NodeIndex i(0); i < graph_->num_nodes(); ++i) {
ArcIndex current_arc =
incoming_shortest_path_arc_[static_cast<size_t>(current_node)];
if (current_arc == GraphType::kNilArc) {
break;
}
arc_path.push_back(current_arc);
@@ -458,12 +423,10 @@ ShortestPathsOnDagWrapper<GraphType>::ArcPathTo(NodeIndex node) const {
return arc_path;
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
template <class GraphType, typename ArcLengths>
std::vector<typename GraphType::NodeIndex>
ShortestPathsOnDagWrapper<GraphType>::NodePathTo(NodeIndex node) const {
ShortestPathsOnDagWrapper<GraphType, ArcLengths>::NodePathTo(
NodeIndex node) const {
const std::vector<typename GraphType::ArcIndex> arc_path = ArcPathTo(node);
if (arc_path.empty()) {
return {node};
@@ -474,12 +437,9 @@ ShortestPathsOnDagWrapper<GraphType>::NodePathTo(NodeIndex node) const {
// -----------------------------------------------------------------------------
// KShortestPathsOnDagWrapper implementation.
// -----------------------------------------------------------------------------
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
const GraphType* graph, const std::vector<double>* arc_lengths,
template <class GraphType, typename ArcLengths>
KShortestPathsOnDagWrapper<GraphType, ArcLengths>::KShortestPathsOnDagWrapper(
const GraphType* graph, const ArcLengths* arc_lengths,
absl::Span<const NodeIndex> topological_order, const int path_count)
: graph_(graph),
arc_lengths_(arc_lengths),
@@ -487,10 +447,12 @@ KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
path_count_(path_count) {
CHECK(graph_ != nullptr);
CHECK(arc_lengths_ != nullptr);
CHECK_GT(graph_->num_nodes(), 0) << "The graph is empty: it has no nodes";
const size_t num_nodes = static_cast<size_t>(graph_->num_nodes());
CHECK_GT(num_nodes, 0) << "The graph is empty: it has no nodes";
CHECK_GT(path_count_, 0) << "path_count must be greater than 0";
#ifndef NDEBUG
CHECK_EQ(arc_lengths_->size(), graph_->num_arcs());
CHECK_EQ(typename GraphType::ArcIndex(arc_lengths_->size()),
graph_->num_arcs());
for (const double arc_length : *arc_lengths_) {
CHECK(arc_length != -kInf && !std::isnan(arc_length))
<< absl::StrFormat("length cannot be -inf nor NaN");
@@ -501,9 +463,9 @@ KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
// TODO(b/332475713): Optimize if reverse graph is already provided in
// `GraphType`.
const int num_arcs = graph_->num_arcs();
const ArcIndex num_arcs = graph_->num_arcs();
reverse_graph_ = GraphType(graph_->num_nodes(), num_arcs);
for (ArcIndex arc_index = 0; arc_index < num_arcs; ++arc_index) {
for (ArcIndex arc_index(0); arc_index < num_arcs; ++arc_index) {
reverse_graph_.AddArc(graph->Head(arc_index), graph->Tail(arc_index));
}
std::vector<ArcIndex> permutation;
@@ -511,7 +473,7 @@ KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
arc_indices_.resize(permutation.size());
if (!permutation.empty()) {
for (int i = 0; i < permutation.size(); ++i) {
arc_indices_[permutation[i]] = i;
arc_indices_[static_cast<size_t>(permutation[i])] = ArcIndex(i);
}
}
@@ -521,19 +483,16 @@ KShortestPathsOnDagWrapper<GraphType>::KShortestPathsOnDagWrapper(
incoming_shortest_paths_arc_.resize(path_count_);
incoming_shortest_paths_index_.resize(path_count_);
for (int k = 0; k < path_count_; ++k) {
lengths_from_sources_[k].resize(graph_->num_nodes(), kInf);
incoming_shortest_paths_arc_[k].resize(graph_->num_nodes(), -1);
incoming_shortest_paths_index_[k].resize(graph_->num_nodes(), -1);
lengths_from_sources_[k].resize(num_nodes, kInf);
incoming_shortest_paths_arc_[k].resize(num_nodes, GraphType::kNilArc);
incoming_shortest_paths_index_[k].resize(num_nodes, -1);
}
is_source_.resize(graph_->num_nodes(), false);
reached_nodes_.reserve(graph_->num_nodes());
is_source_.resize(num_nodes, false);
reached_nodes_.reserve(num_nodes);
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
template <class GraphType, typename ArcLengths>
void KShortestPathsOnDagWrapper<GraphType, ArcLengths>::RunKShortestPathOnDag(
absl::Span<const NodeIndex> sources) {
// Caching the vector addresses allow to not fetch it on each access.
const absl::Span<const double> arc_lengths = *arc_lengths_;
@@ -544,9 +503,9 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
// least one source, the other ones will contain junk.
for (const NodeIndex node : reached_nodes_) {
is_source_[node] = false;
is_source_[static_cast<size_t>(node)] = false;
for (int k = 0; k < path_count_; ++k) {
lengths_from_sources_[k][node] = kInf;
lengths_from_sources_[k][static_cast<size_t>(node)] = kInf;
}
}
reached_nodes_.clear();
@@ -560,14 +519,14 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
for (const NodeIndex source : sources) {
CheckNodeIsValid(source, *graph_);
is_source_[source] = true;
is_source_[static_cast<size_t>(source)] = true;
}
struct IncomingArcPath {
double path_length = 0.0;
ArcIndex arc_index = 0;
ArcIndex arc_index = ArcIndex(0);
double arc_length = 0.0;
NodeIndex from = 0;
NodeIndex from = NodeIndex(0);
int path_index = 0;
bool operator<(const IncomingArcPath& other) const {
@@ -580,18 +539,19 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
auto comp = std::greater<IncomingArcPath>();
for (const NodeIndex to : topological_order_) {
min_heap.clear();
if (is_source_[to]) {
min_heap.push_back({.arc_index = -1});
if (is_source_[static_cast<size_t>(to)]) {
min_heap.push_back({.arc_index = GraphType::kNilArc});
}
for (const ArcIndex reverse_arc_index : reverse_graph_.OutgoingArcs(to)) {
const ArcIndex arc_index = arc_indices.empty()
? reverse_arc_index
: arc_indices[reverse_arc_index];
const ArcIndex arc_index =
arc_indices.empty()
? reverse_arc_index
: arc_indices[static_cast<size_t>(reverse_arc_index)];
const NodeIndex from = graph_->Tail(arc_index);
const double arc_length = arc_lengths[arc_index];
const double arc_length = arc_lengths[static_cast<size_t>(arc_index)];
DCHECK(arc_length != -kInf);
const double path_length =
lengths_from_sources_.front()[from] + arc_length;
lengths_from_sources_.front()[static_cast<size_t>(from)] + arc_length;
if (path_length == kInf) {
continue;
}
@@ -608,17 +568,21 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
for (int k = 0; k < path_count_; ++k) {
std::pop_heap(min_heap.begin(), min_heap.end(), comp);
IncomingArcPath& incoming_arc_path = min_heap.back();
lengths_from_sources_[k][to] = incoming_arc_path.path_length;
incoming_shortest_paths_arc_[k][to] = incoming_arc_path.arc_index;
incoming_shortest_paths_index_[k][to] = incoming_arc_path.path_index;
if (incoming_arc_path.arc_index != -1 &&
lengths_from_sources_[k][static_cast<size_t>(to)] =
incoming_arc_path.path_length;
incoming_shortest_paths_arc_[k][static_cast<size_t>(to)] =
incoming_arc_path.arc_index;
incoming_shortest_paths_index_[k][static_cast<size_t>(to)] =
incoming_arc_path.path_index;
if (incoming_arc_path.arc_index != GraphType::kNilArc &&
incoming_arc_path.path_index < path_count_ - 1 &&
lengths_from_sources_[incoming_arc_path.path_index + 1]
[incoming_arc_path.from] < kInf) {
[static_cast<size_t>(incoming_arc_path.from)] <
kInf) {
++incoming_arc_path.path_index;
incoming_arc_path.path_length =
lengths_from_sources_[incoming_arc_path.path_index]
[incoming_arc_path.from] +
[static_cast<size_t>(incoming_arc_path.from)] +
incoming_arc_path.arc_length;
std::push_heap(min_heap.begin(), min_heap.end(), comp);
} else {
@@ -631,25 +595,22 @@ void KShortestPathsOnDagWrapper<GraphType>::RunKShortestPathOnDag(
}
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
bool KShortestPathsOnDagWrapper<GraphType>::IsReachable(NodeIndex node) const {
template <class GraphType, typename ArcLengths>
bool KShortestPathsOnDagWrapper<GraphType, ArcLengths>::IsReachable(
NodeIndex node) const {
CheckNodeIsValid(node, *graph_);
return lengths_from_sources_.front()[node] < kInf;
return lengths_from_sources_.front()[static_cast<size_t>(node)] < kInf;
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
std::vector<double> KShortestPathsOnDagWrapper<GraphType>::LengthsTo(
template <class GraphType, typename ArcLengths>
std::vector<double>
KShortestPathsOnDagWrapper<GraphType, ArcLengths>::LengthsTo(
NodeIndex node) const {
std::vector<double> lengths_to;
lengths_to.reserve(path_count_);
for (int k = 0; k < path_count_; ++k) {
const double length_to = lengths_from_sources_[k][node];
const double length_to =
lengths_from_sources_[k][static_cast<size_t>(node)];
if (length_to == kInf) {
break;
}
@@ -658,30 +619,30 @@ std::vector<double> KShortestPathsOnDagWrapper<GraphType>::LengthsTo(
return lengths_to;
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
template <class GraphType, typename ArcLengths>
std::vector<std::vector<typename GraphType::ArcIndex>>
KShortestPathsOnDagWrapper<GraphType>::ArcPathsTo(NodeIndex node) const {
KShortestPathsOnDagWrapper<GraphType, ArcLengths>::ArcPathsTo(
NodeIndex node) const {
std::vector<std::vector<ArcIndex>> arc_paths;
arc_paths.reserve(path_count_);
for (int k = 0; k < path_count_; ++k) {
if (lengths_from_sources_[k][node] == kInf) {
if (lengths_from_sources_[k][static_cast<size_t>(node)] == kInf) {
break;
}
std::vector<ArcIndex> arc_path;
int current_path_index = k;
NodeIndex current_node = node;
for (int i = 0; i < graph_->num_nodes(); ++i) {
for (NodeIndex i(0); i < graph_->num_nodes(); ++i) {
ArcIndex current_arc =
incoming_shortest_paths_arc_[current_path_index][current_node];
if (current_arc == -1) {
incoming_shortest_paths_arc_[current_path_index]
[static_cast<size_t>(current_node)];
if (current_arc == GraphType::kNilArc) {
break;
}
arc_path.push_back(current_arc);
current_path_index =
incoming_shortest_paths_index_[current_path_index][current_node];
incoming_shortest_paths_index_[current_path_index]
[static_cast<size_t>(current_node)];
current_node = graph_->Tail(current_arc);
}
absl::c_reverse(arc_path);
@@ -690,12 +651,10 @@ KShortestPathsOnDagWrapper<GraphType>::ArcPathsTo(NodeIndex node) const {
return arc_paths;
}
template <class GraphType>
#if __cplusplus >= 202002L
requires DagGraphType<GraphType>
#endif
template <class GraphType, typename ArcLengths>
std::vector<std::vector<typename GraphType::NodeIndex>>
KShortestPathsOnDagWrapper<GraphType>::NodePathsTo(NodeIndex node) const {
KShortestPathsOnDagWrapper<GraphType, ArcLengths>::NodePathsTo(
NodeIndex node) const {
const std::vector<std::vector<ArcIndex>> arc_paths = ArcPathsTo(node);
std::vector<std::vector<NodeIndex>> node_paths(arc_paths.size());
for (int k = 0; k < arc_paths.size(); ++k) {

View File

@@ -29,6 +29,8 @@
#include "gtest/gtest.h"
#include "ortools/base/dump_vars.h"
#include "ortools/base/gmock.h"
#include "ortools/base/strong_int.h"
#include "ortools/base/strong_vector.h"
#include "ortools/graph/graph.h"
#include "ortools/graph/graph_io.h"
#include "ortools/util/flat_matrix.h"
@@ -75,7 +77,7 @@ TEST(TopologicalOrderIsValidTest, ValidateTopologicalOrder) {
TEST(ShortestPathOnDagTest, EmptyGraph) {
EXPECT_DEATH(ShortestPathsOnDag(/*num_nodes=*/0, /*arcs_with_length=*/{},
/*source=*/0, /*destination=*/0),
"num_nodes\\(\\) > 0");
"num_nodes > 0");
}
TEST(ShortestPathOnDagTest, NonExistingSourceBecauseNegative) {
@@ -89,7 +91,7 @@ TEST(ShortestPathOnDagTest, NonExistingSourceBecauseTooLarge) {
EXPECT_DEATH(
ShortestPathsOnDag(/*num_nodes=*/2, /*arcs_with_length=*/{{0, 1, 0.0}},
/*source=*/3, /*destination=*/1),
"num_nodes\\(\\)");
"num_nodes");
}
TEST(ShortestPathOnDagTest, NonExistingDestinationBecauseNegative) {
@@ -103,7 +105,7 @@ TEST(ShortestPathOnDagTest, NonExistingDestinationBecauseTooLarge) {
EXPECT_DEATH(
ShortestPathsOnDag(/*num_nodes=*/2, /*arcs_with_length=*/{{0, 1, 0.0}},
/*source=*/0, /*destination=*/3),
"num_nodes\\(\\)");
"num_nodes");
}
TEST(ShortestPathOnDagTest, Cycle) {
@@ -287,6 +289,37 @@ TEST(ShortestPathOnDagTest, UpdateCost) {
/*node_path=*/ElementsAre(source, b, destination)));
}
DEFINE_STRONG_INT_TYPE(NodeIndex, int32_t);
DEFINE_STRONG_INT_TYPE(ArcIndex, int32_t);
TEST(ShortestPathsOnDagWrapperTest, StrongIndices) {
const NodeIndex source_1(0);
const NodeIndex source_2(1);
const NodeIndex destination(2);
const NodeIndex num_nodes(3);
util::ListGraph<NodeIndex, ArcIndex> graph(num_nodes,
/*arc_capacity=*/ArcIndex(2));
using ArcLengths = util_intops::StrongVector<ArcIndex, double>;
ArcLengths arc_lengths;
graph.AddArc(source_1, destination);
arc_lengths.push_back(-6.0);
graph.AddArc(source_2, destination);
arc_lengths.push_back(3.0);
const std::vector<NodeIndex> topological_order = {source_2, source_1,
destination};
ShortestPathsOnDagWrapper<util::ListGraph<NodeIndex, ArcIndex>, ArcLengths>
shortest_path_on_dag(&graph, &arc_lengths, topological_order);
shortest_path_on_dag.RunShortestPathOnDag({source_1, source_2});
EXPECT_TRUE(shortest_path_on_dag.IsReachable(destination));
EXPECT_THAT(shortest_path_on_dag.LengthTo(destination), -6.0);
EXPECT_THAT(shortest_path_on_dag.ArcPathTo(destination),
ElementsAre(ArcIndex(0)));
EXPECT_THAT(shortest_path_on_dag.NodePathTo(destination),
ElementsAre(source_1, destination));
}
TEST(ShortestPathsOnDagWrapperTest, MultipleSources) {
const int source_1 = 0;
const int source_2 = 1;
@@ -634,7 +667,7 @@ TEST(KShortestPathOnDagTest, EmptyGraph) {
EXPECT_DEATH(
KShortestPathsOnDag(/*num_nodes=*/0, /*arcs_with_length=*/{},
/*source=*/0, /*destination=*/0, /*path_count=*/2),
"num_nodes\\(\\) > 0");
"num_nodes > 0");
}
TEST(KShortestPathOnDagTest, NonExistingSourceBecauseNegative) {
@@ -648,7 +681,7 @@ TEST(KShortestPathOnDagTest, NonExistingSourceBecauseTooLarge) {
EXPECT_DEATH(
KShortestPathsOnDag(/*num_nodes=*/2, /*arcs_with_length=*/{{0, 1, 0.0}},
/*source=*/3, /*destination=*/1, /*path_count=*/2),
"num_nodes\\(\\)");
"num_nodes");
}
TEST(KShortestPathOnDagTest, NonExistingDestinationBecauseNegative) {
@@ -936,6 +969,38 @@ TEST(KShortestPathOnDagTest, UpdateCost) {
/*node_path=*/ElementsAre(source, a, destination))));
}
TEST(KShortestPathsOnDagWrapperTest, StrongIndices) {
const NodeIndex source_1(0);
const NodeIndex source_2(1);
const NodeIndex destination(2);
const NodeIndex num_nodes(3);
util::ListGraph<NodeIndex, ArcIndex> graph(num_nodes,
/*arc_capacity=*/ArcIndex(2));
using ArcLengths = util_intops::StrongVector<ArcIndex, double>;
ArcLengths arc_lengths;
graph.AddArc(source_1, destination);
arc_lengths.push_back(-6.0);
graph.AddArc(source_2, destination);
arc_lengths.push_back(3.0);
const std::vector<NodeIndex> topological_order = {source_2, source_1,
destination};
const int path_count = 2;
KShortestPathsOnDagWrapper<util::ListGraph<NodeIndex, ArcIndex>, ArcLengths>
shortest_paths_on_dag(&graph, &arc_lengths, topological_order,
path_count);
shortest_paths_on_dag.RunKShortestPathOnDag({source_1, source_2});
EXPECT_TRUE(shortest_paths_on_dag.IsReachable(destination));
EXPECT_THAT(shortest_paths_on_dag.LengthsTo(destination),
ElementsAre(-6.0, 3.0));
EXPECT_THAT(shortest_paths_on_dag.ArcPathsTo(destination),
ElementsAre(ElementsAre(ArcIndex(0)), ElementsAre(ArcIndex(1))));
EXPECT_THAT(shortest_paths_on_dag.NodePathsTo(destination),
ElementsAre(ElementsAre(source_1, destination),
ElementsAre(source_2, destination)));
}
TEST(KShortestPathsOnDagWrapperTest, MultipleSources) {
const int source_1 = 0;
const int source_2 = 1;

View File

@@ -286,8 +286,12 @@ class BaseGraph {
// Constants that will never be a valid node or arc.
// They are the maximum possible node and arc capacity.
static const NodeIndexType kNilNode;
static const ArcIndexType kNilArc;
static_assert(std::numeric_limits<NodeIndexType>::is_specialized);
static constexpr NodeIndexType kNilNode =
std::numeric_limits<NodeIndexType>::max();
static_assert(std::numeric_limits<ArcIndexType>::is_specialized);
static constexpr ArcIndexType kNilArc =
std::numeric_limits<ArcIndexType>::max();
protected:
// Functions commented when defined because they are implementation details.
@@ -590,6 +594,26 @@ class SVector {
} // namespace internal
// Graph traits, to allow algorithms to manipulate graphs as adjacency lists.
// This works with any graph type, and any object that has:
// - a size() method returning the number of nodes.
// - an operator[] method taking a node index and returning a range of neighbour
// node indices.
// One common example is using `std::vector<std::vector<int>>` to represent
// adjacency lists.
template <typename Graph>
struct GraphTraits {
private:
// The type of the range returned by `operator[]`.
using NeighborRangeType = std::decay_t<
decltype(std::declval<Graph>()[std::declval<Graph>().size()])>;
public:
// The index type for nodes of the graph.
using NodeIndex =
std::decay_t<decltype(*(std::declval<NeighborRangeType>().begin()))>;
};
// Basic graph implementation without reverse arc. This class also serves as a
// documentation for the generic graph interface (minus the part related to
// reverse arcs).
@@ -1121,18 +1145,6 @@ BaseGraph<NodeIndexType, ArcIndexType, HasNegativeReverseArcs>::AllForwardArcs()
return IntegerRange<ArcIndexType>(ArcIndexType(0), num_arcs_);
}
template <typename NodeIndexType, typename ArcIndexType,
bool HasNegativeReverseArcs>
const NodeIndexType
BaseGraph<NodeIndexType, ArcIndexType, HasNegativeReverseArcs>::kNilNode =
std::numeric_limits<NodeIndexType>::max();
template <typename NodeIndexType, typename ArcIndexType,
bool HasNegativeReverseArcs>
const ArcIndexType
BaseGraph<NodeIndexType, ArcIndexType, HasNegativeReverseArcs>::kNilArc =
std::numeric_limits<ArcIndexType>::max();
template <typename NodeIndexType, typename ArcIndexType,
bool HasNegativeReverseArcs>
NodeIndexType BaseGraph<NodeIndexType, ArcIndexType,

View File

@@ -57,6 +57,8 @@ class BeginEndWrapper {
Iterator begin() const { return begin_; }
Iterator end() const { return end_; }
// Available only if `Iterator` is a random access iterator.
size_t size() const { return end_ - begin_; }
bool empty() const { return begin() == end(); }
@@ -127,8 +129,6 @@ class IntegerRangeIterator
#endif
{
public:
// TODO(b/385094969): This should be `IntegerType` for integers,
// `IntegerType:value_type` for strong signed integer types.
using difference_type = ptrdiff_t;
using value_type = IntegerType;
@@ -210,7 +210,7 @@ class IntegerRangeIterator
friend difference_type operator-(const IntegerRangeIterator l,
const IntegerRangeIterator r) {
return l.index_ - r.index_;
return static_cast<difference_type>(l.index_ - r.index_);
}
private:
@@ -248,9 +248,7 @@ class ChasingIterator
#endif
{
public:
// TODO(b/385094969): This should be `IntegerType` for integers,
// `IntegerType:value_type` for strong signed integer types.
using difference_type = std::ptrdiff_t;
using difference_type = ptrdiff_t;
using value_type = IndexT;
ChasingIterator() : index_(sentinel), next_(nullptr) {}

View File

@@ -30,6 +30,7 @@
#ifndef UTIL_GRAPH_TOPOLOGICALSORTER_H__
#define UTIL_GRAPH_TOPOLOGICALSORTER_H__
#include <cstddef>
#include <functional>
#include <limits>
#include <queue>
@@ -84,7 +85,9 @@ namespace graph {
// FastTopologicalSort(util::StaticGraph<>::FromArcs(num_nodes, arcs)));
//
template <class AdjacencyLists> // vector<vector<int>>, util::StaticGraph<>, ..
absl::StatusOr<std::vector<int>> FastTopologicalSort(const AdjacencyLists& adj);
absl::StatusOr<
std::vector<typename util::GraphTraits<AdjacencyLists>::NodeIndex>>
FastTopologicalSort(const AdjacencyLists& adj);
// Finds a cycle in the directed graph given as argument: nodes are dense
// integers in 0..num_nodes-1, and (directed) arcs are pairs of nodes
@@ -93,7 +96,9 @@ absl::StatusOr<std::vector<int>> FastTopologicalSort(const AdjacencyLists& adj);
// if the cycle 1->4->3->1 exists.
// If the graph is acyclic, returns an empty vector.
template <class AdjacencyLists> // vector<vector<int>>, util::StaticGraph<>, ..
absl::StatusOr<std::vector<int>> FindCycleInGraph(const AdjacencyLists& adj);
absl::StatusOr<
std::vector<typename util::GraphTraits<AdjacencyLists>::NodeIndex>>
FindCycleInGraph(const AdjacencyLists& adj);
} // namespace graph
@@ -615,38 +620,38 @@ std::vector<T> StableTopologicalSortOrDie(
}
template <class AdjacencyLists>
absl::StatusOr<std::vector<int>> FastTopologicalSort(
const AdjacencyLists& adj) {
const size_t num_nodes = adj.size();
if (num_nodes > std::numeric_limits<int>::max()) {
return absl::InvalidArgumentError("More than kint32max nodes");
absl::StatusOr<std::vector<typename GraphTraits<AdjacencyLists>::NodeIndex>>
FastTopologicalSort(const AdjacencyLists& adj) {
using NodeIndex = typename GraphTraits<AdjacencyLists>::NodeIndex;
if (adj.size() > std::numeric_limits<NodeIndex>::max()) {
return absl::InvalidArgumentError(
absl::StrFormat("Too many nodes: adj.size()=%v", adj.size()));
}
std::vector<int> indegree(num_nodes, 0);
std::vector<int> topo_order;
topo_order.reserve(num_nodes);
for (int from = 0; from < num_nodes; ++from) {
for (const int head : adj[from]) {
// We cast to unsigned int to test "head < 0 || head ≥ num_nodes" with a
// single test. Microbenchmarks showed a ~1% overall performance gain.
if (static_cast<uint32_t>(head) >= num_nodes) {
const NodeIndex num_nodes(adj.size());
std::vector<NodeIndex> indegree(static_cast<size_t>(num_nodes), NodeIndex(0));
std::vector<NodeIndex> topo_order;
topo_order.reserve(static_cast<size_t>(num_nodes));
for (NodeIndex from(0); from < num_nodes; ++from) {
for (const NodeIndex head : adj[from]) {
if (!(NodeIndex(0) <= head && head < num_nodes)) {
return absl::InvalidArgumentError(
absl::StrFormat("Invalid arc in adj[%d]: %d (num_nodes=%d)", from,
absl::StrFormat("Invalid arc in adj[%v]: %v (num_nodes=%v)", from,
head, num_nodes));
}
// NOTE(user): We could detect self-arcs here (head == from) and exit
// early, but microbenchmarks show a 2 to 4% slow-down if we do it, so we
// simply rely on self-arcs being detected as cycles in the topo sort.
++indegree[head];
++indegree[static_cast<size_t>(head)];
}
}
for (int i = 0; i < num_nodes; ++i) {
if (!indegree[i]) topo_order.push_back(i);
for (NodeIndex i(0); i < num_nodes; ++i) {
if (!indegree[static_cast<size_t>(i)]) topo_order.push_back(i);
}
size_t num_visited = 0;
while (num_visited < topo_order.size()) {
const int from = topo_order[num_visited++];
for (const int head : adj[from]) {
if (!--indegree[head]) topo_order.push_back(head);
const NodeIndex from = topo_order[num_visited++];
for (const NodeIndex head : adj[from]) {
if (!--indegree[static_cast<size_t>(head)]) topo_order.push_back(head);
}
}
if (topo_order.size() < static_cast<size_t>(num_nodes)) {
@@ -656,77 +661,99 @@ absl::StatusOr<std::vector<int>> FastTopologicalSort(
}
template <class AdjacencyLists>
absl::StatusOr<std::vector<int>> FindCycleInGraph(const AdjacencyLists& adj) {
const size_t num_nodes = adj.size();
if (num_nodes > std::numeric_limits<int>::max()) {
absl::StatusOr<
std::vector<typename util::GraphTraits<AdjacencyLists>::NodeIndex>>
FindCycleInGraph(const AdjacencyLists& adj) {
using NodeIndex = typename GraphTraits<AdjacencyLists>::NodeIndex;
if (adj.size() > std::numeric_limits<NodeIndex>::max()) {
return absl::InvalidArgumentError(
absl::StrFormat("Too many nodes: adj.size()=%d", adj.size()));
absl::StrFormat("Too many nodes: adj.size()=%v", adj.size()));
}
const NodeIndex num_nodes(adj.size());
// First pass to validate that inputs are valid.
for (NodeIndex node(0); node < NodeIndex(node); ++node) {
for (const NodeIndex head : adj[node]) {
if (head >= num_nodes) {
return absl::InvalidArgumentError(
absl::StrFormat("Invalid child %v in adj[%v]", head, node));
}
}
}
// To find a cycle, we start a DFS from each yet-unvisited node and
// try to find a cycle, if we don't find it then we know for sure that
// no cycle is reachable from any of the explored nodes (so, we don't
// explore them in later DFSs).
std::vector<bool> no_cycle_reachable_from(num_nodes, false);
std::vector<bool> no_cycle_reachable_from(static_cast<size_t>(num_nodes),
false);
// The DFS stack will contain a chain of nodes, from the root of the
// DFS to the current leaf.
struct DfsState {
int node;
NodeIndex node;
// Points at the first child node that we did *not* yet look at.
int adj_list_index;
explicit DfsState(int _node) : node(_node), adj_list_index(0) {}
decltype(adj[NodeIndex(0)].begin()) children;
decltype(adj[NodeIndex(0)].end()) children_end;
explicit DfsState(NodeIndex _node,
const decltype(adj[NodeIndex(0)])& neighbours)
: node(_node),
children(neighbours.begin()),
children_end(neighbours.end()) {}
};
std::vector<DfsState> dfs_stack;
std::vector<bool> in_cur_stack(num_nodes, false);
for (int start_node = 0; start_node < static_cast<int>(num_nodes);
std::vector<bool> visited(static_cast<size_t>(num_nodes), false);
for (NodeIndex start_node(0); start_node < NodeIndex(num_nodes);
++start_node) {
if (no_cycle_reachable_from[start_node]) continue;
if (no_cycle_reachable_from[static_cast<size_t>(start_node)]) continue;
// Start the DFS.
dfs_stack.push_back(DfsState(start_node));
in_cur_stack[start_node] = true;
visited[static_cast<size_t>(start_node)] = true;
dfs_stack.push_back(DfsState(start_node, adj[start_node]));
while (!dfs_stack.empty()) {
DfsState* cur_state = &dfs_stack.back();
if (static_cast<size_t>(cur_state->adj_list_index) >=
adj[cur_state->node].size()) {
no_cycle_reachable_from[cur_state->node] = true;
in_cur_stack[cur_state->node] = false;
DfsState* const cur_state = &dfs_stack.back();
while (
cur_state->children != cur_state->children_end &&
no_cycle_reachable_from[static_cast<size_t>(*cur_state->children)]) {
++cur_state->children;
}
if (cur_state->children == cur_state->children_end) {
no_cycle_reachable_from[static_cast<size_t>(cur_state->node)] = true;
dfs_stack.pop_back();
continue;
}
// Look at the current child, and increase the current state's
// adj_list_index.
// TODO(user): Caching adj[cur_state->node] in a local stack to improve
// locality and so that the [] operator is called exactly once per node.
const int child = adj[cur_state->node][cur_state->adj_list_index++];
if (static_cast<size_t>(child) >= num_nodes) {
return absl::InvalidArgumentError(absl::StrFormat(
"Invalid child %d in adj[%d]", child, cur_state->node));
}
if (no_cycle_reachable_from[child]) continue;
if (in_cur_stack[child]) {
const NodeIndex child = *cur_state->children;
// At that point the child is either:
// - visited and all finalized (all its children are visited). We know
// that it's not part of a cycle, otherwise we'd already have
// returned.
// - visited and not finalized (some of its children are not visited).
// That means that we've reached it again from a child, so we've found
// a cycle.
// - not visited. We push it on the stack and explore it.
if (no_cycle_reachable_from[static_cast<size_t>(child)]) continue;
if (visited[static_cast<size_t>(child)]) {
// We detected a cycle! It corresponds to the tail end of dfs_stack,
// in reverse order, until we find "child".
int cycle_start = dfs_stack.size() - 1;
size_t cycle_start = dfs_stack.size() - 1;
while (dfs_stack[cycle_start].node != child) --cycle_start;
const int cycle_size = dfs_stack.size() - cycle_start;
std::vector<int> cycle(cycle_size);
for (int c = 0; c < cycle_size; ++c) {
const size_t cycle_size = dfs_stack.size() - cycle_start;
std::vector<NodeIndex> cycle(cycle_size);
for (size_t c = 0; c < cycle_size; ++c) {
cycle[c] = dfs_stack[cycle_start + c].node;
}
return cycle;
}
// Push the child onto the stack.
dfs_stack.push_back(DfsState(child));
in_cur_stack[child] = true;
// Verify that its adjacency list seems valid.
if (adj[child].size() > std::numeric_limits<int>::max()) {
return absl::InvalidArgumentError(absl::StrFormat(
"Invalid adj[%d].size() = %d", child, adj[child].size()));
}
dfs_stack.push_back(DfsState(child, adj[child]));
visited[static_cast<size_t>(child)] = true;
}
}
// If we're here, then all the DFS stopped, and there is no cycle.
return std::vector<int>{};
return std::vector<NodeIndex>{};
}
} // namespace graph