polish/reindent graph python wrappers and samples

This commit is contained in:
Laurent Perron
2022-10-25 20:30:41 +02:00
parent 5a127cbf3f
commit 2caae2985c
6 changed files with 37 additions and 24 deletions

View File

@@ -54,6 +54,12 @@
#include "ortools/base/ptr_util.h"
namespace util {
// Generic version of GetConnectedComponents() (see below) that supports other
// integer types, e.g. int64_t for huge graphs with more than 2^31 nodes.
template <class UndirectedGraph, class NodeType>
std::vector<NodeType> GetConnectedComponentsTpl(NodeType num_nodes,
const UndirectedGraph& graph);
// Finds the connected components of the graph, using BFS internally.
// Works on any *undirected* graph class whose nodes are dense integers and that
// supports the [] operator for adjacency lists: graph[x] must be an integer
@@ -71,7 +77,9 @@ namespace util {
// GetConnectedComponents(graph); // returns [0, 0, 1, 0, 1, 0].
template <class UndirectedGraph>
std::vector<int> GetConnectedComponents(int num_nodes,
const UndirectedGraph& graph);
const UndirectedGraph& graph) {
return GetConnectedComponentsTpl(num_nodes, graph);
}
} // namespace util
// NOTE(user): The rest of the functions below should also be in namespace
@@ -320,20 +328,23 @@ class ConnectedComponentsFinder {
// Implementations of the method templates
// =============================================================================
namespace util {
template <class UndirectedGraph>
std::vector<int> GetConnectedComponents(int num_nodes,
const UndirectedGraph& graph) {
std::vector<int> component_of_node(num_nodes, -1);
std::vector<int> bfs_queue;
int num_components = 0;
for (int src = 0; src < num_nodes; ++src) {
if (component_of_node[src] >= 0) continue;
template <class UndirectedGraph, typename NodeType>
std::vector<NodeType> GetConnectedComponentsTpl(NodeType num_nodes,
const UndirectedGraph& graph) {
// We use 'num_nodes' as special component id meaning 'unknown', because
// it's of the right type, and -1 is tricky to use with unsigned ints.
std::vector<NodeType> component_of_node(num_nodes, num_nodes);
std::vector<NodeType> bfs_queue;
NodeType num_components = 0;
for (NodeType src = 0; src < num_nodes; ++src) {
if (component_of_node[src] != num_nodes) continue;
bfs_queue.push_back(src);
component_of_node[src] = num_components;
for (int num_visited = 0; num_visited < bfs_queue.size(); ++num_visited) {
const int node = bfs_queue[num_visited];
for (const int neighbor : graph[node]) {
if (component_of_node[neighbor] >= 0) continue;
for (size_t num_visited = 0; num_visited < bfs_queue.size();
++num_visited) {
const NodeType node = bfs_queue[num_visited];
for (const NodeType neighbor : graph[node]) {
if (component_of_node[neighbor] != num_nodes) continue;
component_of_node[neighbor] = num_components;
bfs_queue.push_back(neighbor);
}
@@ -343,6 +354,7 @@ std::vector<int> GetConnectedComponents(int num_nodes,
}
return component_of_node;
}
} // namespace util
#endif // UTIL_GRAPH_CONNECTED_COMPONENTS_H_

View File

@@ -12,7 +12,6 @@
// limitations under the License.
#include "ortools/graph/assignment.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"

View File

@@ -47,8 +47,8 @@ PYBIND11_MODULE(min_cost_flow, m) {
smcf.def("optimal_cost", &SimpleMinCostFlow::OptimalCost);
smcf.def("maximum_flow", &SimpleMinCostFlow::MaximumFlow);
smcf.def("flow", &SimpleMinCostFlow::Flow, arg("arc"));
smcf.def("flows",
pybind11::vectorize(&SimpleMinCostFlow::Flow));
smcf.def("flows", pybind11::vectorize(&SimpleMinCostFlow::Flow));
pybind11::enum_<SimpleMinCostFlow::Status>(smcf, "Status")
.value("BAD_COST_RANGE", MinCostFlowBase::Status::BAD_COST_RANGE)
.value("BAD_RESULT", MinCostFlowBase::Status::BAD_RESULT)

View File

@@ -15,9 +15,9 @@
# [START program]
"""Solve assignment problem using linear assignment solver."""
# [START import]
from ortools.graph.python import linear_sum_assignment
import numpy as np
from ortools.graph.python import linear_sum_assignment
# [END import]
@@ -35,8 +35,10 @@ def main():
[45, 110, 95, 115],
])
# Let's transform this into 3 parallel vectors (start_nodes, end_nodes, arc_costs)
end_nodes_unraveled, start_nodes_unraveled = np.meshgrid(np.arange(costs.shape[1]),np.arange(costs.shape[0]))
# Let's transform this into 3 parallel vectors (start_nodes, end_nodes,
# arc_costs)
end_nodes_unraveled, start_nodes_unraveled = np.meshgrid(
np.arange(costs.shape[1]), np.arange(costs.shape[0]))
start_nodes = start_nodes_unraveled.ravel()
end_nodes = end_nodes_unraveled.ravel()
arc_costs = costs.ravel()

View File

@@ -15,9 +15,9 @@
# [START program]
"""From Taha 'Introduction to Operations Research', example 6.4-2."""
# [START import]
from ortools.graph.python import max_flow
import numpy as np
from ortools.graph.python import max_flow
# [END import]

View File

@@ -15,9 +15,9 @@
# [START program]
"""From Bradley, Hax and Maganti, 'Applied Mathematical Programming', figure 8.1."""
# [START import]
from ortools.graph.python import min_cost_flow
import numpy as np
from ortools.graph.python import min_cost_flow
# [END import]