195#ifndef OR_TOOLS_GRAPH_LINEAR_ASSIGNMENT_H_
196#define OR_TOOLS_GRAPH_LINEAR_ASSIGNMENT_H_
208#include "absl/strings/str_format.h"
226template <
typename GraphType>
249 DCHECK(graph_ ==
nullptr);
282 inline const GraphType&
Graph()
const {
return *graph_; }
293 DCHECK_EQ(0, scaled_arc_cost_[
arc] % cost_scaling_factor_);
294 return scaled_arc_cost_[
arc] / cost_scaling_factor_;
325 if (graph_ ==
nullptr) {
330 return graph_->num_nodes();
341 return matched_arc_[left_node];
354 DCHECK_NE(GraphType::kNilArc, matching_arc);
355 return Head(matching_arc);
358 std::string
StatsString()
const {
return total_stats_.StatsString(); }
363 : num_left_nodes_(num_left_nodes), node_iterator_(0) {}
366 : num_left_nodes_(assignment.
NumLeftNodes()), node_iterator_(0) {}
370 bool Ok()
const {
return node_iterator_ < num_left_nodes_; }
372 void Next() { ++node_iterator_; }
381 Stats() : pushes_(0), double_pushes_(0), relabelings_(0), refinements_(0) {}
388 void Add(
const Stats& that) {
389 pushes_ += that.pushes_;
390 double_pushes_ += that.double_pushes_;
391 relabelings_ += that.relabelings_;
392 refinements_ += that.refinements_;
395 return absl::StrFormat(
396 "%d refinements; %d relabelings; "
397 "%d double pushes; %d pushes",
398 refinements_, relabelings_, double_pushes_, pushes_);
401 int64_t double_pushes_;
402 int64_t relabelings_;
403 int64_t refinements_;
407 class ActiveNodeContainerInterface {
409 virtual ~ActiveNodeContainerInterface() {}
410 virtual bool Empty()
const = 0;
415 class ActiveNodeStack :
public ActiveNodeContainerInterface {
417 ~ActiveNodeStack()
override {}
419 bool Empty()
const override {
return v_.empty(); }
421 void Add(
NodeIndex node)
override { v_.push_back(node); }
431 std::vector<NodeIndex> v_;
434 class ActiveNodeQueue :
public ActiveNodeContainerInterface {
436 ~ActiveNodeQueue()
override {}
438 bool Empty()
const override {
return q_.empty(); }
440 void Add(
NodeIndex node)
override { q_.push_front(node); }
450 std::deque<NodeIndex> q_;
461 typedef std::pair<ArcIndex, CostValue> ImplicitPriceSummary;
465 bool EpsilonOptimal()
const;
469 bool AllMatched()
const;
476 inline ImplicitPriceSummary BestArcAndGap(
NodeIndex left_node)
const;
480 void ReportAndAccumulateStats();
491 bool UpdateEpsilon();
495 inline bool IsActive(
NodeIndex left_node)
const;
502 inline bool IsActiveForDebugging(
NodeIndex node)
const;
509 void InitializeActiveNodeContainer();
517 void SaturateNegativeArcs();
526 return scaled_arc_cost_[
arc] - price_[
Head(
arc)];
531 const GraphType* graph_;
540 bool incidence_precondition_satisfied_;
867 bool* in_range)
const {
880 const double result =
881 static_cast<double>(std::max<CostValue>(1, n / 2 - 1)) *
882 (
static_cast<double>(old_epsilon) +
static_cast<double>(new_epsilon));
885 if (result > limit) {
887 if (in_range !=
nullptr) *in_range =
false;
908 CostValue largest_scaled_cost_magnitude_;
921 ZVector<CostValue> price_;
926 std::vector<ArcIndex> matched_arc_;
934 ZVector<NodeIndex> matched_node_;
939 std::vector<CostValue> scaled_arc_cost_;
944 std::unique_ptr<ActiveNodeContainerInterface> active_nodes_;
952 Stats iteration_stats_;
960template <
typename GraphType>
961const CostValue LinearSumAssignment<GraphType>::kMinEpsilon = 1;
963template <
typename GraphType>
965 const GraphType& graph,
const NodeIndex num_left_nodes)
967 num_left_nodes_(num_left_nodes),
969 cost_scaling_factor_(1 + num_left_nodes),
970 alpha_(
absl::GetFlag(FLAGS_assignment_alpha)),
972 price_lower_bound_(0),
973 slack_relabeling_price_(0),
974 largest_scaled_cost_magnitude_(0),
976 price_(num_left_nodes, 2 * num_left_nodes - 1),
977 matched_arc_(num_left_nodes, 0),
978 matched_node_(num_left_nodes, 2 * num_left_nodes - 1),
979 scaled_arc_cost_(graph.max_end_arc_index(), 0),
980 active_nodes_(
absl::GetFlag(FLAGS_assignment_stack_order)
981 ? static_cast<ActiveNodeContainerInterface*>(
982 new ActiveNodeStack())
983 : static_cast<ActiveNodeContainerInterface*>(
984 new ActiveNodeQueue())) {}
986template <
typename GraphType>
990 num_left_nodes_(num_left_nodes),
992 cost_scaling_factor_(1 + num_left_nodes),
993 alpha_(
absl::GetFlag(FLAGS_assignment_alpha)),
995 price_lower_bound_(0),
996 slack_relabeling_price_(0),
997 largest_scaled_cost_magnitude_(0),
999 price_(num_left_nodes, 2 * num_left_nodes - 1),
1000 matched_arc_(num_left_nodes, 0),
1001 matched_node_(num_left_nodes, 2 * num_left_nodes - 1),
1002 scaled_arc_cost_(num_arcs, 0),
1003 active_nodes_(
absl::GetFlag(FLAGS_assignment_stack_order)
1004 ? static_cast<ActiveNodeContainerInterface*>(
1005 new ActiveNodeStack())
1006 : static_cast<ActiveNodeContainerInterface*>(
1007 new ActiveNodeQueue())) {}
1009template <
typename GraphType>
1011 if (graph_ !=
nullptr) {
1017 cost *= cost_scaling_factor_;
1019 largest_scaled_cost_magnitude_ =
1020 std::max(largest_scaled_cost_magnitude_, cost_magnitude);
1024template <
typename ArcIndexType>
1028 : temp_(0), cost_(
cost) {}
1031 temp_ = (*cost_)[source];
1035 ArcIndexType destination)
const override {
1036 (*cost_)[destination] = (*cost_)[source];
1040 (*cost_)[destination] = temp_;
1047 std::vector<CostValue>*
const cost_;
1057template <
typename GraphType>
1067 return ((graph_.Tail(
a) < graph_.Tail(
b)) ||
1068 ((graph_.Tail(
a) == graph_.Tail(
b)) &&
1069 (graph_.Head(
a) < graph_.Head(
b))));
1073 const GraphType& graph_;
1081template <
typename GraphType>
1082PermutationCycleHandler<typename GraphType::ArcIndex>*
1088template <
typename GraphType>
1099 graph->GroupForwardArcsByFunctor(compare, &cycle_handler);
1103template <
typename GraphType>
1105 const CostValue current_epsilon)
const {
1106 return std::max(current_epsilon / alpha_, kMinEpsilon);
1109template <
typename GraphType>
1110bool LinearSumAssignment<GraphType>::UpdateEpsilon() {
1111 CostValue new_epsilon = NewEpsilon(epsilon_);
1112 slack_relabeling_price_ = PriceChangeBound(epsilon_, new_epsilon,
nullptr);
1113 epsilon_ = new_epsilon;
1114 VLOG(3) <<
"Updated: epsilon_ == " << epsilon_;
1115 VLOG(4) <<
"slack_relabeling_price_ == " << slack_relabeling_price_;
1124template <
typename GraphType>
1125inline bool LinearSumAssignment<GraphType>::IsActive(
1128 return matched_arc_[left_node] == GraphType::kNilArc;
1134template <
typename GraphType>
1135inline bool LinearSumAssignment<GraphType>::IsActiveForDebugging(
1137 if (node < num_left_nodes_) {
1138 return IsActive(node);
1140 return matched_node_[node] == GraphType::kNilNode;
1144template <
typename GraphType>
1145void LinearSumAssignment<GraphType>::InitializeActiveNodeContainer() {
1146 DCHECK(active_nodes_->Empty());
1147 for (BipartiteLeftNodeIterator node_it(*graph_, num_left_nodes_);
1148 node_it.Ok(); node_it.Next()) {
1150 if (IsActive(node)) {
1151 active_nodes_->Add(node);
1166template <
typename GraphType>
1167void LinearSumAssignment<GraphType>::SaturateNegativeArcs() {
1169 for (BipartiteLeftNodeIterator node_it(*graph_, num_left_nodes_);
1170 node_it.Ok(); node_it.Next()) {
1172 if (IsActive(node)) {
1180 matched_arc_[node] = GraphType::kNilArc;
1181 matched_node_[mate] = GraphType::kNilNode;
1187template <
typename GraphType>
1188bool LinearSumAssignment<GraphType>::DoublePush(
NodeIndex source) {
1190 DCHECK(IsActive(source)) <<
"Node " << source
1191 <<
"must be active (unmatched)!";
1192 ImplicitPriceSummary summary = BestArcAndGap(source);
1193 const ArcIndex best_arc = summary.first;
1198 if (best_arc == GraphType::kNilArc) {
1201 const NodeIndex new_mate = Head(best_arc);
1202 const NodeIndex to_unmatch = matched_node_[new_mate];
1203 if (to_unmatch != GraphType::kNilNode) {
1206 matched_arc_[to_unmatch] = GraphType::kNilArc;
1207 active_nodes_->Add(to_unmatch);
1209 iteration_stats_.double_pushes_ += 1;
1214 iteration_stats_.pushes_ += 1;
1216 matched_arc_[source] = best_arc;
1217 matched_node_[new_mate] = source;
1219 iteration_stats_.relabelings_ += 1;
1220 const CostValue new_price = price_[new_mate] - gap - epsilon_;
1221 price_[new_mate] = new_price;
1222 return new_price >= price_lower_bound_;
1225template <
typename GraphType>
1226bool LinearSumAssignment<GraphType>::Refine() {
1227 SaturateNegativeArcs();
1228 InitializeActiveNodeContainer();
1229 while (total_excess_ > 0) {
1232 const NodeIndex node = active_nodes_->Get();
1233 if (!DoublePush(node)) {
1241 LOG_IF(DFATAL, total_stats_.refinements_ > 0)
1242 <<
"Infeasibility detection triggered after first iteration found "
1243 <<
"a feasible assignment!";
1247 DCHECK(active_nodes_->Empty());
1248 iteration_stats_.refinements_ += 1;
1266template <
typename GraphType>
1267inline typename LinearSumAssignment<GraphType>::ImplicitPriceSummary
1268LinearSumAssignment<GraphType>::BestArcAndGap(
NodeIndex left_node)
const {
1269 DCHECK(IsActive(left_node))
1270 <<
"Node " << left_node <<
" must be active (unmatched)!";
1272 typename GraphType::OutgoingArcIterator arc_it(*graph_, left_node);
1273 ArcIndex best_arc = arc_it.Index();
1274 CostValue min_partial_reduced_cost = PartialReducedCost(best_arc);
1280 const CostValue max_gap = slack_relabeling_price_ - epsilon_;
1281 CostValue second_min_partial_reduced_cost =
1282 min_partial_reduced_cost + max_gap;
1283 for (arc_it.Next(); arc_it.Ok(); arc_it.Next()) {
1285 const CostValue partial_reduced_cost = PartialReducedCost(
arc);
1286 if (partial_reduced_cost < second_min_partial_reduced_cost) {
1287 if (partial_reduced_cost < min_partial_reduced_cost) {
1289 second_min_partial_reduced_cost = min_partial_reduced_cost;
1290 min_partial_reduced_cost = partial_reduced_cost;
1292 second_min_partial_reduced_cost = partial_reduced_cost;
1296 const CostValue gap = std::min<CostValue>(
1297 second_min_partial_reduced_cost - min_partial_reduced_cost, max_gap);
1299 return std::make_pair(best_arc, gap);
1306template <
typename GraphType>
1307inline CostValue LinearSumAssignment<GraphType>::ImplicitPrice(
1311 typename GraphType::OutgoingArcIterator arc_it(*graph_, left_node);
1314 ArcIndex best_arc = arc_it.Index();
1315 if (best_arc == matched_arc_[left_node]) {
1318 best_arc = arc_it.Index();
1321 CostValue min_partial_reduced_cost = PartialReducedCost(best_arc);
1327 return -(min_partial_reduced_cost + slack_relabeling_price_);
1329 for (arc_it.Next(); arc_it.Ok(); arc_it.Next()) {
1331 if (
arc != matched_arc_[left_node]) {
1332 const CostValue partial_reduced_cost = PartialReducedCost(
arc);
1333 if (partial_reduced_cost < min_partial_reduced_cost) {
1334 min_partial_reduced_cost = partial_reduced_cost;
1338 return -min_partial_reduced_cost;
1342template <
typename GraphType>
1343bool LinearSumAssignment<GraphType>::AllMatched()
const {
1344 for (
NodeIndex node = 0; node < graph_->num_nodes(); ++node) {
1345 if (IsActiveForDebugging(node)) {
1353template <
typename GraphType>
1354bool LinearSumAssignment<GraphType>::EpsilonOptimal()
const {
1355 for (BipartiteLeftNodeIterator node_it(*graph_, num_left_nodes_);
1356 node_it.Ok(); node_it.Next()) {
1357 const NodeIndex left_node = node_it.Index();
1360 CostValue left_node_price = ImplicitPrice(left_node);
1361 for (
typename GraphType::OutgoingArcIterator arc_it(*graph_, left_node);
1362 arc_it.Ok(); arc_it.Next()) {
1364 const CostValue reduced_cost = left_node_price + PartialReducedCost(
arc);
1369 if (matched_arc_[left_node] ==
arc) {
1373 if (reduced_cost > epsilon_) {
1379 if (reduced_cost < 0) {
1388template <
typename GraphType>
1390 incidence_precondition_satisfied_ =
true;
1394 epsilon_ =
std::max(largest_scaled_cost_magnitude_, kMinEpsilon + 1);
1395 VLOG(2) <<
"Largest given cost magnitude: "
1396 << largest_scaled_cost_magnitude_ / cost_scaling_factor_;
1399 for (
NodeIndex node = 0; node < num_left_nodes_; ++node) {
1400 matched_arc_[node] = GraphType::kNilArc;
1401 typename GraphType::OutgoingArcIterator arc_it(*graph_, node);
1403 incidence_precondition_satisfied_ =
false;
1408 for (
NodeIndex node = num_left_nodes_; node < graph_->num_nodes(); ++node) {
1410 matched_node_[node] = GraphType::kNilNode;
1412 bool in_range =
true;
1413 double double_price_lower_bound = 0.0;
1415 CostValue old_error_parameter = epsilon_;
1417 new_error_parameter = NewEpsilon(old_error_parameter);
1418 double_price_lower_bound -=
1419 2.0 *
static_cast<double>(PriceChangeBound(
1420 old_error_parameter, new_error_parameter, &in_range));
1421 old_error_parameter = new_error_parameter;
1422 }
while (new_error_parameter != kMinEpsilon);
1423 const double limit =
1425 if (double_price_lower_bound < limit) {
1429 price_lower_bound_ =
static_cast<CostValue>(double_price_lower_bound);
1431 VLOG(4) <<
"price_lower_bound_ == " << price_lower_bound_;
1434 LOG(
WARNING) <<
"Price change bound exceeds range of representable "
1435 <<
"costs; arithmetic overflow is not ruled out and "
1436 <<
"infeasibility might go undetected.";
1441template <
typename GraphType>
1443 total_stats_.Add(iteration_stats_);
1444 VLOG(3) <<
"Iteration stats: " << iteration_stats_.StatsString();
1445 iteration_stats_.Clear();
1448template <
typename GraphType>
1450 CHECK(graph_ !=
nullptr);
1451 bool ok = graph_->num_nodes() == 2 * num_left_nodes_;
1452 if (!ok)
return false;
1459 ok = ok && incidence_precondition_satisfied_;
1460 DCHECK(!ok || EpsilonOptimal());
1461 while (ok && epsilon_ > kMinEpsilon) {
1462 ok = ok && UpdateEpsilon();
1463 ok = ok && Refine();
1464 ReportAndAccumulateStats();
1465 DCHECK(!ok || EpsilonOptimal());
1466 DCHECK(!ok || AllMatched());
1469 VLOG(1) <<
"Overall stats: " << total_stats_.StatsString();
1473template <
typename GraphType>
1480 cost += GetAssignmentCost(node_it.Index());
#define LOG_IF(severity, condition)
#define DCHECK_LE(val1, val2)
#define DCHECK_NE(val1, val2)
#define DCHECK_GE(val1, val2)
#define DCHECK_GT(val1, val2)
#define DCHECK_LT(val1, val2)
#define DCHECK(condition)
#define DCHECK_EQ(val1, val2)
#define VLOG(verboselevel)
ArcIndexOrderingByTailNode(const GraphType &graph)
bool operator()(typename GraphType::ArcIndex a, typename GraphType::ArcIndex b) const
~CostValueCycleHandler() override
void SetIndexFromIndex(ArcIndexType source, ArcIndexType destination) const override
CostValueCycleHandler(std::vector< CostValue > *cost)
void SetTempFromIndex(ArcIndexType source) override
void SetIndexFromTemp(ArcIndexType destination) const override
BipartiteLeftNodeIterator(const GraphType &graph, NodeIndex num_left_nodes)
BipartiteLeftNodeIterator(const LinearSumAssignment &assignment)
NodeIndex NumLeftNodes() const
std::string StatsString() const
ArcIndex GetAssignmentArc(NodeIndex left_node) const
GraphType::ArcIndex ArcIndex
NodeIndex NumNodes() const
CostValue GetCost() const
const GraphType & Graph() const
void SetArcCost(ArcIndex arc, CostValue cost)
void OptimizeGraphLayout(GraphType *graph)
NodeIndex Head(ArcIndex arc) const
CostValue ArcCost(ArcIndex arc) const
CostValue GetAssignmentCost(NodeIndex node) const
operations_research::PermutationCycleHandler< typename GraphType::ArcIndex > * ArcAnnotationCycleHandler()
GraphType::NodeIndex NodeIndex
void SetCostScalingDivisor(CostValue factor)
NodeIndex GetMate(NodeIndex left_node) const
void SetGraph(const GraphType *graph)
LinearSumAssignment(const GraphType &graph, NodeIndex num_left_nodes)
bool BuildTailArrayFromAdjacencyListsIfForwardGraph() const
void ReleaseTailArrayIfForwardGraph() const
ABSL_DECLARE_FLAG(int64_t, assignment_alpha)
Collection of objects used to extend the Constraint Solver library.