missing stl

This commit is contained in:
Laurent Perron
2022-09-09 16:47:18 +02:00
parent 79aaf280bd
commit f1e85a3373
6 changed files with 275 additions and 107 deletions

View File

@@ -22,6 +22,7 @@ cc_library(
":iterators",
"//ortools/base",
"@com_google_absl//absl/debugging:leak_check",
"@com_google_absl//absl/types:span",
],
)
@@ -295,6 +296,7 @@ cc_library(
":max_flow",
"//ortools/base",
"//ortools/base:mathutil",
"//ortools/util:saturated_arithmetic",
"//ortools/util:stats",
"//ortools/util:zvector",
],
@@ -361,9 +363,12 @@ cc_library(
"//ortools/base",
"//ortools/base:container_logging",
"//ortools/base:map_util",
"//ortools/base:status_builder",
"//ortools/base:stl_util",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)

View File

@@ -168,6 +168,7 @@
#include "absl/base/port.h"
#include "absl/debugging/leak_check.h"
#include "absl/types/span.h"
#include "ortools/base/integral_types.h"
#include "ortools/base/logging.h"
#include "ortools/base/macros.h"
@@ -203,8 +204,10 @@ class BaseGraph {
const_capacities_(false) {}
virtual ~BaseGraph() {}
// Returns the number of valid nodes in the graph.
// Returns the number of valid nodes in the graph. Prefer using num_nodes():
// the size() API is here to make Graph and vector<vector<int>> more alike.
NodeIndexType num_nodes() const { return num_nodes_; }
NodeIndexType size() const { return num_nodes_; } // Prefer num_nodes().
// Returns the number of valid arcs in the graph.
ArcIndexType num_arcs() const { return num_arcs_; }
@@ -417,6 +420,11 @@ class StaticGraph : public BaseGraph<NodeIndexType, ArcIndexType, false> {
this->AddNode(num_nodes - 1);
}
// Shortcut to directly create a finalized graph, i.e. Build() is called.
template <class ArcContainer> // e.g. vector<pair<int, int>>.
static StaticGraph FromArcs(NodeIndexType num_nodes,
const ArcContainer& arcs);
// Do not use directly. See instead the arc iteration functions below.
class OutgoingArcIterator;
@@ -430,7 +438,7 @@ class StaticGraph : public BaseGraph<NodeIndexType, ArcIndexType, false> {
// This loops over the heads of the OutgoingArcs(node). It is just a more
// convenient way to achieve this. Moreover this interface is used by some
// graph algorithms.
BeginEndWrapper<NodeIndexType const*> operator[](NodeIndexType node) const;
absl::Span<const NodeIndexType> operator[](NodeIndexType node) const;
void ReserveNodes(NodeIndexType bound) override;
void ReserveArcs(ArcIndexType bound) override;
@@ -598,7 +606,7 @@ class ReverseArcStaticGraph
// This loops over the heads of the OutgoingArcs(node). It is just a more
// convenient way to achieve this. Moreover this interface is used by some
// graph algorithms.
BeginEndWrapper<NodeIndexType const*> operator[](NodeIndexType node) const;
absl::Span<const NodeIndexType> operator[](NodeIndexType node) const;
ArcIndexType OppositeArc(ArcIndexType arc) const;
// TODO(user): support Head() and Tail() before Build(), like StaticGraph<>.
@@ -685,7 +693,7 @@ class ReverseArcMixedGraph
// This loops over the heads of the OutgoingArcs(node). It is just a more
// convenient way to achieve this. Moreover this interface is used by some
// graph algorithms.
BeginEndWrapper<NodeIndexType const*> operator[](NodeIndexType node) const;
absl::Span<const NodeIndexType> operator[](NodeIndexType node) const;
ArcIndexType OppositeArc(ArcIndexType arc) const;
// TODO(user): support Head() and Tail() before Build(), like StaticGraph<>.
@@ -1261,13 +1269,24 @@ class ListGraph<NodeIndexType, ArcIndexType>::OutgoingHeadIterator {
// StaticGraph implementation --------------------------------------------------
template <typename NodeIndexType, typename ArcIndexType>
template <class ArcContainer>
StaticGraph<NodeIndexType, ArcIndexType>
StaticGraph<NodeIndexType, ArcIndexType>::FromArcs(NodeIndexType num_nodes,
const ArcContainer& arcs) {
StaticGraph g(num_nodes, arcs.size());
for (const auto& [from, to] : arcs) g.AddArc(from, to);
g.Build();
return g;
}
DEFINE_RANGE_BASED_ARC_ITERATION(StaticGraph, Outgoing, DirectArcLimit(node));
template <typename NodeIndexType, typename ArcIndexType>
BeginEndWrapper<NodeIndexType const*>
absl::Span<const NodeIndexType>
StaticGraph<NodeIndexType, ArcIndexType>::operator[](NodeIndexType node) const {
return BeginEndWrapper<NodeIndexType const*>(
head_.data() + start_[node], head_.data() + DirectArcLimit(node));
return absl::Span<const NodeIndexType>(head_.data() + start_[node],
DirectArcLimit(node) - start_[node]);
}
template <typename NodeIndexType, typename ArcIndexType>
@@ -1716,11 +1735,11 @@ ArcIndexType ReverseArcStaticGraph<NodeIndexType, ArcIndexType>::InDegree(
}
template <typename NodeIndexType, typename ArcIndexType>
BeginEndWrapper<NodeIndexType const*>
absl::Span<const NodeIndexType>
ReverseArcStaticGraph<NodeIndexType, ArcIndexType>::operator[](
NodeIndexType node) const {
return BeginEndWrapper<NodeIndexType const*>(
head_.data() + start_[node], head_.data() + DirectArcLimit(node));
return absl::Span<const NodeIndexType>(head_.data() + start_[node],
DirectArcLimit(node) - start_[node]);
}
template <typename NodeIndexType, typename ArcIndexType>
@@ -1975,11 +1994,11 @@ ArcIndexType ReverseArcMixedGraph<NodeIndexType, ArcIndexType>::InDegree(
}
template <typename NodeIndexType, typename ArcIndexType>
BeginEndWrapper<NodeIndexType const*>
absl::Span<const NodeIndexType>
ReverseArcMixedGraph<NodeIndexType, ArcIndexType>::operator[](
NodeIndexType node) const {
return BeginEndWrapper<NodeIndexType const*>(
head_.data() + start_[node], head_.data() + DirectArcLimit(node));
return absl::Span<const NodeIndexType>(head_.data() + start_[node],
DirectArcLimit(node) - start_[node]);
}
template <typename NodeIndexType, typename ArcIndexType>

View File

@@ -26,6 +26,7 @@
#include "ortools/graph/graph.h"
#include "ortools/graph/graphs.h"
#include "ortools/graph/max_flow.h"
#include "ortools/util/saturated_arithmetic.h"
// TODO(user): Remove these flags and expose the parameters in the API.
// New clients, please do not use these flags!
@@ -59,7 +60,6 @@ GenericMinCostFlow<Graph, ArcFlowType, ArcScaledCostType>::GenericMinCostFlow(
alpha_(absl::GetFlag(FLAGS_min_cost_flow_alpha)),
cost_scaling_factor_(1),
scaled_arc_unit_cost_(),
total_flow_cost_(0),
status_(NOT_SOLVED),
initial_node_excess_(),
feasible_node_excess_(),
@@ -226,27 +226,44 @@ bool GenericMinCostFlow<Graph, ArcFlowType, ArcScaledCostType>::CheckResult()
template <typename Graph, typename ArcFlowType, typename ArcScaledCostType>
bool GenericMinCostFlow<Graph, ArcFlowType, ArcScaledCostType>::CheckCostRange()
const {
CostValue min_cost_magnitude = std::numeric_limits<CostValue>::max();
CostValue max_cost_magnitude = 0;
using UnsignedCostValue = uint64_t;
static_assert(sizeof(UnsignedCostValue) >= sizeof(CostValue), "");
UnsignedCostValue max_cost_magnitude = 0;
UnsignedCostValue min_cost_magnitude =
std::numeric_limits<UnsignedCostValue>::max();
// Traverse the initial arcs of the graph:
for (ArcIndex arc = 0; arc < graph_->num_arcs(); ++arc) {
const CostValue cost_magnitude = MathUtil::Abs(scaled_arc_unit_cost_[arc]);
const UnsignedCostValue cost_magnitude =
static_cast<UnsignedCostValue>(std::abs(scaled_arc_unit_cost_[arc]));
max_cost_magnitude = std::max(max_cost_magnitude, cost_magnitude);
if (cost_magnitude != 0.0) {
if (cost_magnitude != 0) {
min_cost_magnitude = std::min(min_cost_magnitude, cost_magnitude);
}
}
VLOG(3) << "Min cost magnitude = " << min_cost_magnitude
<< ", Max cost magnitude = " << max_cost_magnitude;
#if !defined(_MSC_VER)
if (log(std::numeric_limits<CostValue>::max()) <
log(max_cost_magnitude + 1) + log(graph_->num_nodes() + 1)) {
LOG(DFATAL) << "Maximum cost magnitude " << max_cost_magnitude << " is too "
<< "high for the number of nodes. Try changing the data.";
return false;
}
#endif
return true;
constexpr UnsignedCostValue kMaxCost =
std::numeric_limits<UnsignedCostValue>::max();
const UnsignedCostValue num_nodes = graph_->num_nodes();
// The predicate we want to verify is:
// 3 * max_cost_magnitude * num_nodes ≤ kMaxCost.
// NOTE(user): The factor of 3 might be reduced to 2 or even 1 if we audited
// the potential overflow-driving code, but it's not trivial. See cl/457335394
// which changed the factor from 2 to 3 because it had detected overflows.
//
// To verify the above predicate without overflows, we use this trick:
// a×b ≤ c ⇔ (a < c/b || (a == c/b && c%b == 0)).
if (num_nodes == 0) return true;
const UnsignedCostValue quotient = kMaxCost / num_nodes;
const UnsignedCostValue remainder = kMaxCost % num_nodes;
// First, we guard against overflows when computing 3 * max_cost_magnitude.
if (max_cost_magnitude > kMaxCost / 3) return false;
if (3 * max_cost_magnitude < quotient) return true; // Common case.
if (3 * max_cost_magnitude <= quotient && remainder == 0) return true;
LOG(DFATAL) << "max(3 * abs(arc cost)) * num_nodes overflows: "
<< "max_cost_magnitude: " << max_cost_magnitude
<< "num_nodes: " << num_nodes << "kMaxCost: " << kMaxCost;
return false;
}
template <typename Graph, typename ArcFlowType, typename ArcScaledCostType>
@@ -520,19 +537,38 @@ bool GenericMinCostFlow<Graph, ArcFlowType, ArcScaledCostType>::Solve() {
UnscaleCosts();
if (status_ != OPTIMAL) {
LOG(DFATAL) << "Status != OPTIMAL";
total_flow_cost_ = 0;
return false;
}
total_flow_cost_ = 0;
for (ArcIndex arc = 0; arc < graph_->num_arcs(); ++arc) {
const FlowQuantity flow_on_arc = residual_arc_capacity_[Opposite(arc)];
total_flow_cost_ += scaled_arc_unit_cost_[arc] * flow_on_arc;
}
status_ = OPTIMAL;
IF_STATS_ENABLED(VLOG(1) << stats_.StatString());
return true;
}
template <typename Graph, typename ArcFlowType, typename ArcScaledCostType>
CostValue
GenericMinCostFlow<Graph, ArcFlowType, ArcScaledCostType>::GetOptimalCost() {
if (status_ != OPTIMAL) {
return 0;
}
// The total cost of the flow.
// We cap the result if its overflow.
CostValue total_flow_cost = 0;
const CostValue kMaxCost = std::numeric_limits<CostValue>::max();
const CostValue kMinCost = std::numeric_limits<CostValue>::min();
for (ArcIndex arc = 0; arc < graph_->num_arcs(); ++arc) {
const CostValue flow_on_arc = residual_arc_capacity_[Opposite(arc)];
const CostValue flow_cost =
CapProd(scaled_arc_unit_cost_[arc], flow_on_arc);
if (flow_cost == kMaxCost || flow_cost == kMinCost) return kMaxCost;
total_flow_cost = CapAdd(flow_cost, total_flow_cost);
if (total_flow_cost == kMaxCost || total_flow_cost == kMinCost) {
return kMaxCost;
}
}
return total_flow_cost;
}
template <typename Graph, typename ArcFlowType, typename ArcScaledCostType>
void GenericMinCostFlow<Graph, ArcFlowType,
ArcScaledCostType>::ResetFirstAdmissibleArcs() {

View File

@@ -381,8 +381,12 @@ class GenericMinCostFlow : public MinCostFlowBase {
// MakeFeasible returns false if CheckFeasibility() was not called before.
bool MakeFeasible();
// Returns the cost of the minimum-cost flow found by the algorithm.
CostValue GetOptimalCost() const { return total_flow_cost_; }
// Returns the cost of the minimum-cost flow found by the algorithm. This
// works in O(num_arcs). This will only work if the last Solve() call was
// successful and returned true, otherwise it will return 0. Note that the
// computation might overflow, in which case we will cap the cost at
// std::numeric_limits<CostValue>::max().
CostValue GetOptimalCost();
// Returns the flow on the given arc using the equations given in the
// comment on residual_arc_capacity_.
@@ -573,9 +577,6 @@ class GenericMinCostFlow : public MinCostFlowBase {
// An array representing the scaled unit cost for each arc in graph_.
ZVector<ArcScaledCostType> scaled_arc_unit_cost_;
// The total cost of the flow.
CostValue total_flow_cost_;
// The status of the problem.
Status status_;

View File

@@ -62,6 +62,35 @@ void DenseIntTopologicalSorterTpl<stable_sort>::AddNode(int node_index) {
// small adjacency lists in case there are repeated edges, I picked 16.
static const int kLazyDuplicateDetectionSizeThreshold = 16;
template <bool stable_sort>
void DenseIntTopologicalSorterTpl<stable_sort>::AddEdges(
const std::vector<std::pair<int, int>>& edges) {
CHECK(!TraversalStarted()) << "Cannot add edges after starting traversal";
// Make a first pass to detect the number of nodes.
int max_node = -1;
for (const auto& [from, to] : edges) {
if (from > max_node) max_node = from;
if (to > max_node) max_node = to;
}
if (max_node >= 0) AddNode(max_node);
// Make a second pass to reserve the adjacency list sizes.
// We use indegree_ as temporary node buffer to store the node out-degrees,
// since it isn't being used yet.
indegree_.assign(max_node + 1, 0);
for (const auto& [from, to] : edges) ++indegree_[from];
for (int node = 0; node < max_node; ++node) {
adjacency_lists_[node].reserve(indegree_[node]);
}
indegree_.clear();
// Finally, add edges to the adjacency lists in a third pass. Don't bother
// doing the duplicate detection: in the bulk API, we assume that there isn't
// much edge duplication.
for (const auto& [from, to] : edges) adjacency_lists_[from].push_back(to);
}
template <bool stable_sort>
void DenseIntTopologicalSorterTpl<stable_sort>::AddEdge(int from, int to) {
CHECK(!TraversalStarted()) << "Cannot add edges after starting traversal";

View File

@@ -11,56 +11,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// TopologicalSorter provides topologically sorted traversal of the
// nodes of a directed acyclic graph (DAG) with up to INT_MAX nodes.
// It sorts ancestor nodes before their descendants.
// This file provides topologically sorted traversal of the nodes of a directed
// acyclic graph (DAG) with up to INT_MAX nodes.
// It sorts ancestor nodes before their descendants. Multi-arcs are fine.
//
// If your graph is not a DAG and you're reading this, you are probably
// looking for ortools/graph/strongly_connected_components.h which does
// the topological decomposition of a directed graph.
//
// EXAMPLE:
//
// vector<int> result;
// vector<string> nodes = {"a", "b", "c"};
// vector<pair<string, string>> arcs = {{"a", "c"}, {"a", "b"}, {"b", "c"}};
// if (util::StableTopologicalSort(num_nodes, arcs, &result)) {
// LOG(INFO) << "The topological order is: " << gtl::LogContainer(result);
// } else {
// LOG(INFO) << "The graph is cyclic.";
// // Note: you can extract a cycle with the TopologicalSorter class, or
// // with the API defined in circularity_detector.h.
// }
// // This will be successful and result will be equal to {"a", "b", "c"}.
//
// There are 8 flavors of topological sort, from these 3 bits:
// - There are OrDie() versions that directly return the topological order, or
// crash if a cycle is detected (and LOG the cycle).
// - There are type-generic versions that can take any node type (including
// non-dense integers), but slower, or the "dense int" versions which requires
// nodes to be a dense interval [0..num_nodes-1]. Note that the type must
// be compatible with LOG << T if you're using the OrDie() version.
// - The sorting can be either stable or not. "Stable" essentially means that it
// will preserve the order of nodes, if possible. More precisely, the returned
// topological order will be the lexicographically minimal valid order, where
// "lexicographic" applies to the indices of the nodes.
//
// TopologicalSort()
// TopologicalSortOrDie()
// StableTopologicalSort()
// StableTopologicalSortOrDie()
// DenseIntTopologicalSort()
// DenseIntTopologicalSortOrDie()
// DenseIntStableTopologicalSort()
// DenseIntStableTopologicalSortOrDie()
//
// If you need more control, or a step-by-step topological sort, see the
// TopologicalSorter classes below.
// USAGE:
// - If performance matters, use FastTopologicalSort().
// - If your nodes are non-integers, or you need to break topological ties by
// node index (like "stable_sort"), use one of the DenseIntTopologicalSort()
// or TopologicalSort variants (see below).
// - If you need more control (cycle extraction?), or a step-by-step topological
// sort, see the TopologicalSorter classes below.
#ifndef UTIL_GRAPH_TOPOLOGICALSORTER_H__
#define UTIL_GRAPH_TOPOLOGICALSORTER_H__
#include <functional>
#include <limits>
#include <queue>
#include <type_traits>
#include <utility>
@@ -68,25 +39,54 @@
#include "absl/base/attributes.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "ortools/base/container_logging.h"
#include "ortools/base/logging.h"
#include "ortools/base/macros.h"
#include "ortools/base/map_util.h"
#include "ortools/base/status_builder.h"
#include "ortools/base/stl_util.h"
namespace util {
namespace graph {
// Returns true if the graph was a DAG, and outputs the topological order in
// "topological_order". Returns false if the graph is cyclic.
// Works in O(num_nodes + arcs.size()), and is pretty fast.
ABSL_MUST_USE_RESULT inline bool DenseIntTopologicalSort(
int num_nodes, const std::vector<std::pair<int, int>>& arcs,
std::vector<int>* topological_order);
// This is the recommended API when performance matters. It's also very simple.
// AdjacencyList is any type that lets you iterate over the neighbors of
// node with the [] operator, for example vector<vector<int>> or util::Graph.
//
// If you don't already have an adjacency list representation, build one using
// StaticGraph<> in ./graph.h: FastTopologicalSort() can take any such graph as
// input.
//
// ERRORS: returns InvalidArgumentError if the input is broken (negative or
// out-of-bounds integers) or if the graph is cyclic. In the latter case, the
// error message will contain "cycle". Note that if cycles may occur in your
// input, you can probably assume that your input isn't broken, and thus rely
// on failures to detect that the graph is cyclic.
//
// TIE BREAKING: the returned topological order is deterministic and fixed, and
// corresponds to iterating on nodes in a LIFO (Breadth-first) order.
//
// Benchmark: gpaste/6147236302946304, 4-10x faster than util_graph::TopoSort().
//
// EXAMPLES:
// std::vector<std::vector<int>> adj = {{..}, {..}, ..};
// ASSIGN_OR_RETURN(std::vector<int> topo_order, FastTopologicalSort(adj));
//
// or
// util::StaticGraph<> graph(/*num_nodes=*/10, /*num_edges=*/42);
// graph.AddEdge(1, 3);
// ...
// graph.Build();
// ASSIGN_OR_RETURN(std::vector<int> topo_order, FastTopologicalSort(graph));
//
template <class AdjacencyLists>
absl::StatusOr<std::vector<int>> FastTopologicalSort(const AdjacencyLists& adj);
// Like DenseIntTopologicalSort, but stable.
ABSL_MUST_USE_RESULT inline bool DenseIntStableTopologicalSort(
int num_nodes, const std::vector<std::pair<int, int>>& arcs,
std::vector<int>* topological_order);
} // namespace graph
// Finds a cycle in the directed graph given as argument: nodes are dense
// integers in 0..num_nodes-1, and (directed) arcs are pairs of nodes
@@ -94,33 +94,60 @@ ABSL_MUST_USE_RESULT inline bool DenseIntStableTopologicalSort(
// The returned cycle is a list of nodes that form a cycle, eg. {1, 4, 3}
// if the cycle 1->4->3->1 exists.
// If the graph is acyclic, returns an empty vector.
// TODO(user): Deprecate this version and promote an adjacency-list based one.
ABSL_MUST_USE_RESULT std::vector<int> FindCycleInDenseIntGraph(
int num_nodes, const std::vector<std::pair<int, int>>& arcs);
// Like the two above, but with generic node types. The nodes must be provided.
// Can be significantly slower, but still linear.
// [Stable]TopologicalSort[OrDie]:
//
// These variants are much slower than FastTopologicalSort(), but support
// non-integer (or integer, but sparse) nodes.
// Note that if performance matters, you're probably better off building your
// own mapping from node to dense index with a flat_hash_map and calling
// FastTopologicalSort().
// Returns true if the graph was a DAG, and outputs the topological order in
// "topological_order". Returns false if the graph is cyclic.
template <typename T>
ABSL_MUST_USE_RESULT bool TopologicalSort(
const std::vector<T>& nodes, const std::vector<std::pair<T, T>>& arcs,
std::vector<T>* topological_order);
// OrDie() variant of the above.
template <typename T>
std::vector<T> TopologicalSortOrDie(const std::vector<T>& nodes,
const std::vector<std::pair<T, T>>& arcs);
// The "Stable" variants are a little slower but preserve the input order of
// nodes, if possible. More precisely, the returned topological order will be
// the lexicographically minimal valid order, where "lexicographic" applies to
// the indices of the nodes.
template <typename T>
ABSL_MUST_USE_RESULT bool StableTopologicalSort(
const std::vector<T>& nodes, const std::vector<std::pair<T, T>>& arcs,
std::vector<T>* topological_order);
// "OrDie()" versions of the 4 functions above. Those directly return the
// topological order, which makes their API even simpler.
inline std::vector<int> DenseIntTopologicalSortOrDie(
int num_nodes, const std::vector<std::pair<int, int>>& arcs);
inline std::vector<int> DenseIntStableTopologicalSortOrDie(
int num_nodes, const std::vector<std::pair<int, int>>& arcs);
template <typename T>
std::vector<T> TopologicalSortOrDie(const std::vector<T>& nodes,
const std::vector<std::pair<T, T>>& arcs);
template <typename T>
std::vector<T> StableTopologicalSortOrDie(
const std::vector<T>& nodes, const std::vector<std::pair<T, T>>& arcs);
// ______________________ END OF THE RECOMMENDED API ___________________________
// DEPRECATED: DenseInt[Stable]TopologicalSort[OrDie].
// Kept here for legacy reasons, but most new users should use
// FastTopologicalSort():
// - If your input is a list of edges, build you own StaticGraph<> (see
// ./graph.h) and pass it to FastTopologicalSort().
// - If you need the "stable sort" bit, contact viger@ and/or or-core-team@
// to see if they can create FastStableTopologicalSort().
ABSL_MUST_USE_RESULT inline bool DenseIntTopologicalSort(
int num_nodes, const std::vector<std::pair<int, int>>& arcs,
std::vector<int>* topological_order);
inline std::vector<int> DenseIntStableTopologicalSortOrDie(
int num_nodes, const std::vector<std::pair<int, int>>& arcs);
ABSL_MUST_USE_RESULT inline bool DenseIntStableTopologicalSort(
int num_nodes, const std::vector<std::pair<int, int>>& arcs,
std::vector<int>* topological_order);
inline std::vector<int> DenseIntTopologicalSortOrDie(
int num_nodes, const std::vector<std::pair<int, int>>& arcs);
namespace internal {
// Internal wrapper around the *TopologicalSort classes.
template <typename T, typename Sorter>
@@ -144,7 +171,7 @@ template <bool stable_sort = false>
class DenseIntTopologicalSorterTpl {
public:
// To store the adjacency lists efficiently.
typedef std::vector<int> AdjacencyList;
typedef absl::InlinedVector<int, 4> AdjacencyList;
// For efficiency, it is best to specify how many nodes are required
// by using the next constructor.
@@ -169,8 +196,13 @@ class DenseIntTopologicalSorterTpl {
// it will be faster and use less memory.
void AddNode(int node_index);
// Performs AddEdge() in bulk. Much faster if you add *all* edges at once.
void AddEdges(const std::vector<std::pair<int, int>>& edges);
// Performs in constant amortized time. Calling this will make all
// node indices in [0, max(from, to)] be valid node indices.
// THIS IS MUCH SLOWER than calling AddEdges() if you already have all the
// edges.
void AddEdge(int from, int to);
// Performs in O(average degree) in average. If a cycle is detected
@@ -283,6 +315,11 @@ class TopologicalSorter {
// or if more than INT_MAX nodes are being added.
void AddNode(const T& node) { int_sorter_.AddNode(LookupOrInsertNode(node)); }
// Shortcut to AddEdge() in bulk. Not optimized.
void AddEdges(const std::vector<std::pair<T, T>>& edges) {
for (const auto& [from, to] : edges) AddEdge(from, to);
}
// Adds a directed edge with the given endpoints to the graph. There
// is no requirement (nor is it an error) to call AddNode() for the
// endpoints. Dies with a fatal error if called after a traversal
@@ -391,9 +428,7 @@ ABSL_MUST_USE_RESULT bool RunTopologicalSorter(
Sorter* sorter, const std::vector<std::pair<T, T>>& arcs,
std::vector<T>* topological_order, std::vector<T>* cycle) {
topological_order->clear();
for (const auto& arc : arcs) {
sorter->AddEdge(arc.first, arc.second);
}
sorter->AddEdges(arcs);
bool cyclic = false;
sorter->StartTraversal();
T next;
@@ -408,6 +443,7 @@ ABSL_MUST_USE_RESULT bool DenseIntTopologicalSortImpl(
int num_nodes, const std::vector<std::pair<int, int>>& arcs,
std::vector<int>* topological_order) {
DenseIntTopologicalSorterTpl<stable_sort> sorter(num_nodes);
topological_order->reserve(num_nodes);
return RunTopologicalSorter<int, decltype(sorter)>(
&sorter, arcs, topological_order, nullptr);
}
@@ -427,8 +463,9 @@ ABSL_MUST_USE_RESULT bool TopologicalSortImpl(
// Now, the OrDie() versions, which directly return the topological order.
template <typename T, typename Sorter>
std::vector<T> RunTopologicalSorterOrDie(
Sorter* sorter, const std::vector<std::pair<T, T>>& arcs) {
Sorter* sorter, int num_nodes, const std::vector<std::pair<T, T>>& arcs) {
std::vector<T> topo_order;
topo_order.reserve(num_nodes);
CHECK(RunTopologicalSorter(sorter, arcs, &topo_order, &topo_order))
<< "Found cycle: " << gtl::LogContainer(topo_order);
return topo_order;
@@ -438,7 +475,7 @@ template <bool stable_sort = false>
std::vector<int> DenseIntTopologicalSortOrDieImpl(
int num_nodes, const std::vector<std::pair<int, int>>& arcs) {
DenseIntTopologicalSorterTpl<stable_sort> sorter(num_nodes);
return RunTopologicalSorterOrDie(&sorter, arcs);
return RunTopologicalSorterOrDie(&sorter, num_nodes, arcs);
}
template <typename T, bool stable_sort = false>
@@ -448,7 +485,7 @@ std::vector<T> TopologicalSortOrDieImpl(
for (const T& node : nodes) {
sorter.AddNode(node);
}
return RunTopologicalSorterOrDie(&sorter, arcs);
return RunTopologicalSorterOrDie(&sorter, nodes.size(), arcs);
}
} // namespace internal
@@ -535,6 +572,47 @@ std::vector<T> StableTopologicalSortOrDie(
return ::util::StableTopologicalSortOrDie<T>(nodes, arcs);
}
template <class AdjacencyLists>
absl::StatusOr<std::vector<int>> FastTopologicalSort(
const AdjacencyLists& adj) {
const size_t num_nodes = adj.size();
if (num_nodes > std::numeric_limits<int>::max()) {
return absl::InvalidArgumentError("More than kint32max nodes");
}
std::vector<int> indegree(num_nodes, 0);
std::vector<int> topo_order;
topo_order.reserve(num_nodes);
for (int from = 0; from < num_nodes; ++from) {
for (const int head : adj[from]) {
// We cast to unsigned int to test "head < 0 || head ≥ num_nodes" with a
// single test. Microbenchmarks showed a ~1% overall performance gain.
if (static_cast<uint32_t>(head) >= num_nodes) {
return absl::InvalidArgumentError(
absl::StrFormat("Invalid arc in adj[%d]: %d (num_nodes=%d)", from,
head, num_nodes));
}
// NOTE(user): We could detect self-arcs here (head == from) and exit
// early, but microbenchmarks show a 2 to 4% slow-down if we do it, so we
// simply rely on self-arcs being detected as cycles in the topo sort.
++indegree[head];
}
}
for (int i = 0; i < num_nodes; ++i) {
if (!indegree[i]) topo_order.push_back(i);
}
size_t num_visited = 0;
while (num_visited < topo_order.size()) {
const int from = topo_order[num_visited++];
for (const int head : adj[from]) {
if (!--indegree[head]) topo_order.push_back(head);
}
}
if (topo_order.size() < static_cast<size_t>(num_nodes)) {
return absl::InvalidArgumentError("The graph has a cycle");
}
return topo_order;
}
} // namespace graph
} // namespace util