From 34b26eb5b09d5b4f34db02e2dca56009c90918f8 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Fri, 29 Oct 2021 14:02:25 +0200 Subject: [PATCH] [CP-SAT] reorganize linear code; tweak lb_tree_search code --- ortools/sat/cp_model_loader.cc | 22 +-- ortools/sat/integer_expr.cc | 65 +++++++-- ortools/sat/integer_expr.h | 6 + ortools/sat/lb_tree_search.cc | 140 ++++++++++++++++--- ortools/sat/lb_tree_search.h | 34 +++-- ortools/sat/linear_programming_constraint.cc | 2 +- ortools/sat/linear_programming_constraint.h | 7 + ortools/sat/linear_relaxation.cc | 110 ++++++--------- ortools/sat/linear_relaxation.h | 66 ++++----- 9 files changed, 285 insertions(+), 167 deletions(-) diff --git a/ortools/sat/cp_model_loader.cc b/ortools/sat/cp_model_loader.cc index 24b6bfc805..199236218b 100644 --- a/ortools/sat/cp_model_loader.cc +++ b/ortools/sat/cp_model_loader.cc @@ -1099,27 +1099,7 @@ void LoadAllDiffConstraint(const ConstraintProto& ct, Model* m) { auto* mapping = m->GetOrCreate(); const std::vector vars = mapping->Integers(ct.all_diff().vars()); - // If all variables are fully encoded and domains are not too large, use - // arc-consistent reasoning. Otherwise, use bounds-consistent reasoning. - IntegerTrail* integer_trail = m->GetOrCreate(); - IntegerEncoder* encoder = m->GetOrCreate(); - int num_fully_encoded = 0; - int64_t max_domain_size = 0; - for (const IntegerVariable variable : vars) { - if (encoder->VariableIsFullyEncoded(variable)) num_fully_encoded++; - - IntegerValue lb = integer_trail->LowerBound(variable); - IntegerValue ub = integer_trail->UpperBound(variable); - const int64_t domain_size = ub.value() - lb.value() + 1; - max_domain_size = std::max(max_domain_size, domain_size); - } - - if (num_fully_encoded == vars.size() && max_domain_size < 1024) { - m->Add(AllDifferentBinary(vars)); - m->Add(AllDifferentAC(vars)); - } else { - m->Add(AllDifferentOnBounds(vars)); - } + m->Add(AllDifferentOnBounds(vars)); } void LoadIntProdConstraint(const ConstraintProto& ct, Model* m) { diff --git a/ortools/sat/integer_expr.cc b/ortools/sat/integer_expr.cc index 27550abd39..8942f8e835 100644 --- a/ortools/sat/integer_expr.cc +++ b/ortools/sat/integer_expr.cc @@ -76,6 +76,53 @@ void IntegerSumLE::FillIntegerReason() { } } +std::pair IntegerSumLE::ConditionalLb( + IntegerVariable bool_view, IntegerVariable target_var) const { + if (integer_trail_->LowerBound(bool_view) != 0 && + integer_trail_->UpperBound(bool_view) != 1) { + return {kMinIntegerValue, kMinIntegerValue}; + } + + // Recall that all our coefficient are positive. + bool bool_view_present = false; + bool bool_view_present_positively = false; + IntegerValue view_coeff; + + bool target_var_present_negatively = false; + IntegerValue target_coeff; + + // Compute the implied_lb excluding "- target_coeff * target". + IntegerValue implied_lb(-upper_bound_); + for (int i = 0; i < vars_.size(); ++i) { + const IntegerVariable var = vars_[i]; + const IntegerValue coeff = coeffs_[i]; + if (var == NegationOf(target_var)) { + target_coeff = coeff; + target_var_present_negatively = true; + continue; + } + + const IntegerValue lb = integer_trail_->LowerBound(var); + implied_lb += coeff * lb; + if (PositiveVariable(var) == PositiveVariable(bool_view)) { + view_coeff = coeff; + bool_view_present = true; + bool_view_present_positively = (var == bool_view); + } + } + if (!bool_view_present || !target_var_present_negatively) { + return {kMinIntegerValue, kMinIntegerValue}; + } + + if (bool_view_present_positively) { + return {CeilRatio(implied_lb, target_coeff), + CeilRatio(implied_lb + view_coeff, target_coeff)}; + } else { + return {CeilRatio(implied_lb + view_coeff, target_coeff), + CeilRatio(implied_lb, target_coeff)}; + } +} + bool IntegerSumLE::Propagate() { // Reified case: If any of the enforcement_literals are false, we ignore the // constraint. @@ -719,7 +766,7 @@ bool ProductPropagator::PropagateMaxOnPositiveProduct(AffineExpression a, IntegerValue min_p, IntegerValue max_p) { const IntegerValue max_a = integer_trail_->UpperBound(a); - DCHECK_GT(max_a, 0); + if (max_a <= 0) return true; DCHECK_GT(min_p, 0); if (max_a >= min_p) { @@ -844,8 +891,6 @@ bool ProductPropagator::Propagate() { const AffineExpression b = i == 0 ? b_ : a_; const IntegerValue max_b = integer_trail_->UpperBound(b); const IntegerValue min_b = integer_trail_->LowerBound(b); - const IntegerValue max_a = integer_trail_->UpperBound(a); - const IntegerValue min_a = integer_trail_->LowerBound(a); // If the domain of b contain zero, we can't propagate anything on a. // Because of CanonicalizeCases(), we just deal with min_b > 0 here. @@ -855,11 +900,15 @@ bool ProductPropagator::Propagate() { if (min_b < 0 && max_b > 0) { CHECK_GT(min_p, 0); // Because zero is not possible. - // This should be done on the next Propagate() call. - if (min_a >= 0 || max_a <= 0) continue; - - PropagateMaxOnPositiveProduct(a, b, min_p, max_p); - PropagateMaxOnPositiveProduct(a.Negated(), b.Negated(), min_p, max_p); + // If a is not across zero, we will deal with this on the next + // Propagate() call. + if (!PropagateMaxOnPositiveProduct(a, b, min_p, max_p)) { + return false; + } + if (!PropagateMaxOnPositiveProduct(a.Negated(), b.Negated(), min_p, + max_p)) { + return false; + } continue; } diff --git a/ortools/sat/integer_expr.h b/ortools/sat/integer_expr.h index d6ed699baa..602c73f3e0 100644 --- a/ortools/sat/integer_expr.h +++ b/ortools/sat/integer_expr.h @@ -71,6 +71,12 @@ class IntegerSumLE : public PropagatorInterface { // really late in the search tree. bool PropagateAtLevelZero(); + // This is a pretty usage specific function. Returns the implied lower bound + // on var if the bool_view take the value 0 or 1. If the variables do not + // appear both in the linear inequality, this returns two kMinIntegerValue. + std::pair ConditionalLb( + IntegerVariable bool_view, IntegerVariable target_var) const; + private: // Fills integer_reason_ with all the current lower_bounds. The real // explanation may require removing one of them, but as an optimization, we diff --git a/ortools/sat/lb_tree_search.cc b/ortools/sat/lb_tree_search.cc index 36b8c9b235..51f90ab8a0 100644 --- a/ortools/sat/lb_tree_search.cc +++ b/ortools/sat/lb_tree_search.cc @@ -24,6 +24,7 @@ LbTreeSearch::LbTreeSearch(Model* model) : time_limit_(model->GetOrCreate()), random_(model->GetOrCreate()), sat_solver_(model->GetOrCreate()), + integer_encoder_(model->GetOrCreate()), integer_trail_(model->GetOrCreate()), shared_response_(model->GetOrCreate()), sat_decision_(model->GetOrCreate()), @@ -37,6 +38,16 @@ LbTreeSearch::LbTreeSearch(Model* model) CHECK(objective != nullptr); objective_var_ = objective->objective_var; + // Identify an LP with the same objective variable. + // + // TODO(user): if we have many independent LP, this will find nothing. + for (LinearProgrammingConstraint* lp : + *model->GetOrCreate()) { + if (lp->ObjectiveVariable() == objective_var_) { + lp_constraint_ = lp; + } + } + // We use the normal SAT search but we will bump the variable activity // slightly differently. In addition to the conflicts, we also bump it each // time the objective lower bound increase in a sub-node. @@ -54,10 +65,10 @@ void LbTreeSearch::UpdateParentObjective(int level) { const NodeIndex child_index = current_branch_[level]; const Node& child = nodes_[child_index]; if (parent.true_child == child_index) { - parent.UpdateTrueObjective(child.objective_lb); + parent.UpdateTrueObjective(child.MinObjective()); } else { CHECK_EQ(parent.false_child, child_index); - parent.UpdateFalseObjective(child.objective_lb); + parent.UpdateFalseObjective(child.MinObjective()); } } @@ -67,6 +78,7 @@ void LbTreeSearch::UpdateObjectiveFromParent(int level) { if (level == 0) return; const NodeIndex parent_index = current_branch_[level - 1]; const Node& parent = nodes_[parent_index]; + CHECK_GE(parent.MinObjective(), current_objective_lb_); const NodeIndex child_index = current_branch_[level]; Node& child = nodes_[child_index]; if (parent.true_child == child_index) { @@ -77,6 +89,46 @@ void LbTreeSearch::UpdateObjectiveFromParent(int level) { } } +void LbTreeSearch::DebugDisplayTree(NodeIndex root) const { + int num_nodes = 0; + const IntegerValue root_lb = nodes_[root].MinObjective(); + const auto shifted_lb = [root_lb](IntegerValue lb) { + return std::max(0, (lb - root_lb).value()); + }; + + absl::StrongVector level(nodes_.size(), 0); + std::vector to_explore = {root}; + while (!to_explore.empty()) { + NodeIndex n = to_explore.back(); + to_explore.pop_back(); + + ++num_nodes; + const Node& node = nodes_[n]; + + std::string s(level[n], ' '); + absl::StrAppend(&s, "#", n.value()); + + if (node.true_child < nodes_.size()) { + absl::StrAppend(&s, " [t:#", node.true_child.value(), " ", + shifted_lb(node.true_objective), "]"); + to_explore.push_back(node.true_child); + level[node.true_child] = level[n] + 1; + } else { + absl::StrAppend(&s, " [t:## ", shifted_lb(node.true_objective), "]"); + } + if (node.false_child < nodes_.size()) { + absl::StrAppend(&s, " [f:#", node.false_child.value(), " ", + shifted_lb(node.false_objective), "]"); + to_explore.push_back(node.false_child); + level[node.false_child] = level[n] + 1; + } else { + absl::StrAppend(&s, " [f:## ", shifted_lb(node.false_objective), "]"); + } + LOG(INFO) << s; + } + LOG(INFO) << "num_nodes: " << num_nodes; +} + SatSolver::Status LbTreeSearch::Search( const std::function& feasible_solution_observer) { if (!sat_solver_->RestoreSolverToAssumptionLevel()) { @@ -148,17 +200,20 @@ SatSolver::Status LbTreeSearch::Search( for (int level = current_branch_.size(); --level > 0;) { UpdateParentObjective(level); } + nodes_[current_branch_[0]].UpdateObjective(current_objective_lb_); for (int level = 1; level < current_branch_.size(); ++level) { UpdateObjectiveFromParent(level); } // If the root lb increased, update global shared objective lb. - if (nodes_[current_branch_[0]].objective_lb > current_objective_lb_) { + const IntegerValue bound = nodes_[current_branch_[0]].MinObjective(); + if (bound > current_objective_lb_) { shared_response_->UpdateInnerObjectiveBounds( - absl::StrCat("lb_tree_search #nodes:", nodes_.size()), - nodes_[current_branch_[0]].objective_lb, - integer_trail_->LevelZeroUpperBound(objective_var_)); - current_objective_lb_ = nodes_[current_branch_[0]].objective_lb; + absl::StrCat("lb_tree_search #nodes:", nodes_.size(), + " #rc:", num_rc_detected_), + bound, integer_trail_->LevelZeroUpperBound(objective_var_)); + current_objective_lb_ = bound; + if (VLOG_IS_ON(2)) DebugDisplayTree(current_branch_[0]); } } @@ -199,10 +254,10 @@ SatSolver::Status LbTreeSearch::Search( // // TODO(user): If we remember how far we can backjump for both true/false // branch, we could be more efficient. - while ( - current_branch_.size() > sat_solver_->CurrentDecisionLevel() + 1 || - (current_branch_.size() > 1 && - nodes_[current_branch_.back()].objective_lb > current_objective_lb_)) { + while (current_branch_.size() > sat_solver_->CurrentDecisionLevel() + 1 || + (current_branch_.size() > 1 && + nodes_[current_branch_.back()].MinObjective() > + current_objective_lb_)) { current_branch_.pop_back(); } @@ -221,16 +276,15 @@ SatSolver::Status LbTreeSearch::Search( // Dive: Follow the branch with lowest objective. // Note that we do not creates new nodes here. while (current_branch_.size() == sat_solver_->CurrentDecisionLevel() + 1) { - // Note that node.objective_lb could be worse than the current best - // bound. const int level = current_branch_.size() - 1; CHECK_EQ(level, sat_solver_->CurrentDecisionLevel()); Node& node = nodes_[current_branch_[level]]; - node.UpdateObjective(integer_trail_->LowerBound(objective_var_)); - UpdateObjectiveFromParent(level); - if (node.objective_lb > current_objective_lb_) { + node.UpdateObjective(std::max( + current_objective_lb_, integer_trail_->LowerBound(objective_var_))); + if (node.MinObjective() > current_objective_lb_) { break; } + CHECK_EQ(node.MinObjective(), current_objective_lb_) << level; // This will be set to the next node index. NodeIndex n; @@ -261,7 +315,7 @@ SatSolver::Status LbTreeSearch::Search( nodes_[parent].false_child = n; nodes_[parent].UpdateFalseObjective(new_lb); } - if (nodes_[parent].objective_lb > current_objective_lb_) break; + if (nodes_[parent].MinObjective() > current_objective_lb_) break; } } else { // If both lower bound are the same, we pick a random sub-branch. @@ -342,19 +396,65 @@ SatSolver::Status LbTreeSearch::Search( // Note that the decision will be pushed to the solver on the next loop. const NodeIndex n(nodes_.size()); nodes_.emplace_back(Literal(decision), - integer_trail_->LowerBound(objective_var_)); + std::max(current_objective_lb_, + integer_trail_->LowerBound(objective_var_))); if (!current_branch_.empty()) { const NodeIndex parent = current_branch_.back(); if (sat_solver_->Assignment().LiteralIsTrue(nodes_[parent].literal)) { nodes_[parent].true_child = n; - nodes_[parent].UpdateTrueObjective(nodes_.back().objective_lb); + nodes_[parent].UpdateTrueObjective(nodes_.back().MinObjective()); } else { CHECK(sat_solver_->Assignment().LiteralIsFalse(nodes_[parent].literal)); nodes_[parent].false_child = n; - nodes_[parent].UpdateFalseObjective(nodes_.back().objective_lb); + nodes_[parent].UpdateFalseObjective(nodes_.back().MinObjective()); } } current_branch_.push_back(n); + + // Looking at the reduced costs, we can already have a bound for one of the + // branch. Increasing the corresponding objective can save some branches, + // and also allow for a more incremental LP solving since we do less back + // and forth. + // + // TODO(user): The code to recover that is a bit convoluted. + // TODO(user): Incorporate this in the heuristic so we choose more Boolean + // inside these LP explanations? + if (lp_constraint_ != nullptr) { + IntegerSumLE* last_rc = lp_constraint_->LatestOptimalConstraintOrNull(); + if (last_rc != nullptr) { + const IntegerVariable pos_view = + integer_encoder_->GetLiteralView(Literal(decision)); + if (pos_view != kNoIntegerVariable) { + const std::pair bounds = + last_rc->ConditionalLb(pos_view, objective_var_); + Node& node = nodes_[n]; + if (bounds.first > node.false_objective) { + ++num_rc_detected_; + node.UpdateFalseObjective(bounds.first); + } + if (bounds.second > node.true_objective) { + ++num_rc_detected_; + node.UpdateTrueObjective(bounds.second); + } + } + + const IntegerVariable neg_view = + integer_encoder_->GetLiteralView(Literal(decision).Negated()); + if (neg_view != kNoIntegerVariable) { + const std::pair bounds = + last_rc->ConditionalLb(neg_view, objective_var_); + Node& node = nodes_[n]; + if (bounds.first > node.true_objective) { + ++num_rc_detected_; + node.UpdateTrueObjective(bounds.second); + } + if (bounds.second > node.false_objective) { + ++num_rc_detected_; + node.UpdateFalseObjective(bounds.second); + } + } + } + } } return SatSolver::LIMIT_REACHED; diff --git a/ortools/sat/lb_tree_search.h b/ortools/sat/lb_tree_search.h index f26b7aa260..7610e2a296 100644 --- a/ortools/sat/lb_tree_search.h +++ b/ortools/sat/lb_tree_search.h @@ -19,6 +19,7 @@ #include "ortools/sat/integer.h" #include "ortools/sat/integer_search.h" +#include "ortools/sat/linear_programming_constraint.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_solver.h" #include "ortools/sat/synchronization.h" @@ -52,37 +53,29 @@ class LbTreeSearch { DEFINE_INT_TYPE(NodeIndex, int); struct Node { Node(Literal l, IntegerValue lb) - : literal(l), - objective_lb(lb), - true_objective(lb), - false_objective(lb) {} + : literal(l), true_objective(lb), false_objective(lb) {} + + // The objective lower bound at this node. + IntegerValue MinObjective() const { + return std::min(true_objective, false_objective); + } // Invariant: the objective bounds only increase. void UpdateObjective(IntegerValue v) { - objective_lb = std::max(objective_lb, v); true_objective = std::max(true_objective, v); false_objective = std::max(false_objective, v); } void UpdateTrueObjective(IntegerValue v) { true_objective = std::max(true_objective, v); - objective_lb = - std::max(objective_lb, std::min(true_objective, false_objective)); } void UpdateFalseObjective(IntegerValue v) { false_objective = std::max(false_objective, v); - objective_lb = - std::max(objective_lb, std::min(true_objective, false_objective)); } // The decision for the true and false branch under this node. /*const*/ Literal literal; - // The objective lower bound at this node. - IntegerValue objective_lb; - - // The objective lower bound in both branches. This should be the same - // (after a sync) as the objective_lb of the corresponding child node when - // that node is instantiated. + // The objective lower bound in both branches. IntegerValue true_objective; IntegerValue false_objective; @@ -91,6 +84,10 @@ class LbTreeSearch { NodeIndex false_child = NodeIndex(std::numeric_limits::max()); }; + // Display the current tree, this is mainly here to investigate ideas to + // improve the code. + void DebugDisplayTree(NodeIndex root) const; + // Updates the objective of the node in the current branch at level n from // the one at level n - 1. void UpdateObjectiveFromParent(int level); @@ -103,12 +100,17 @@ class LbTreeSearch { TimeLimit* time_limit_; ModelRandomGenerator* random_; SatSolver* sat_solver_; + IntegerEncoder* integer_encoder_; IntegerTrail* integer_trail_; SharedResponseManager* shared_response_; SatDecisionPolicy* sat_decision_; IntegerSearchHelper* search_helper_; IntegerVariable objective_var_; + // This can stay null. Otherwise it will be the lp constraint with + // objective_var_ as objective. + LinearProgrammingConstraint* lp_constraint_ = nullptr; + // We temporarily cache the shared_response_ objective lb here. IntegerValue current_objective_lb_; @@ -120,6 +122,8 @@ class LbTreeSearch { // Our heuristic used to explore the tree. See code for detail. std::function search_heuristic_; + + int64_t num_rc_detected_ = 0; }; } // namespace sat diff --git a/ortools/sat/linear_programming_constraint.cc b/ortools/sat/linear_programming_constraint.cc index 9f3046ed96..5cd59a3203 100644 --- a/ortools/sat/linear_programming_constraint.cc +++ b/ortools/sat/linear_programming_constraint.cc @@ -806,7 +806,7 @@ bool LinearProgrammingConstraint::AddCutFromConstraints( bool at_least_one_added = false; - // Try cover appraoch to find cut. + // Try cover approach to find cut. { if (cover_cut_helper_.TrySimpleKnapsack(cut_, tmp_lp_values_, tmp_var_lbs_, tmp_var_ubs_)) { diff --git a/ortools/sat/linear_programming_constraint.h b/ortools/sat/linear_programming_constraint.h index e8efc7d5c7..3ba2a345e2 100644 --- a/ortools/sat/linear_programming_constraint.h +++ b/ortools/sat/linear_programming_constraint.h @@ -145,6 +145,7 @@ class LinearProgrammingConstraint : public PropagatorInterface, // The main objective variable should be equal to the linear sum of // the arguments passed to SetObjectiveCoefficient(). void SetMainObjectiveVariable(IntegerVariable ivar) { objective_cp_ = ivar; } + IntegerVariable ObjectiveVariable() const { return objective_cp_; } // Register a new cut generator with this constraint. void AddCutGenerator(CutGenerator generator); @@ -224,6 +225,12 @@ class LinearProgrammingConstraint : public PropagatorInterface, // Returns some statistics about this LP. std::string Statistics() const; + // Important: this is only temporarily valid. + IntegerSumLE* LatestOptimalConstraintOrNull() const { + if (optimal_constraints_.empty()) return nullptr; + return optimal_constraints_.back().get(); + } + private: // Helper methods for branching. Returns true if branching on the given // variable helps with more propagation or finds a conflict. diff --git a/ortools/sat/linear_relaxation.cc b/ortools/sat/linear_relaxation.cc index 5026b88567..aba476358f 100644 --- a/ortools/sat/linear_relaxation.cc +++ b/ortools/sat/linear_relaxation.cc @@ -953,14 +953,17 @@ void AppendLinearConstraintRelaxation(const ConstraintProto& ct, rhs_domain_max, *model, relaxation); } -// Add a linear relaxation of the CP constraint to the set of linear -// constraints. The highest linearization_level is, the more types of constraint -// we encode. This method should be called only for linearization_level > 0. -// -// Note: IntProd is linearized dynamically using the cut generators. +// Add a static and a dynamic linear relaxation of the CP constraint to the set +// of linear constraints. The highest linearization_level is, the more types of +// constraint we encode. This method should be called only for +// linearization_level > 0. The static part is just called a relaxation and is +// called at the root node of the search. The dynamic part is implemented +// through a set of linear cut generators that will be called throughout the +// search. // // TODO(user): In full generality, we could encode all the constraint as an LP. // TODO(user): Add unit tests for this method. +// TODO(user): Remove and merge with model loading. void TryToLinearizeConstraint(const CpModelProto& model_proto, const ConstraintProto& ct, int linearization_level, Model* model, @@ -989,11 +992,32 @@ void TryToLinearizeConstraint(const CpModelProto& model_proto, AppendExactlyOneRelaxation(ct, model, relaxation); break; } + case ConstraintProto::ConstraintCase::kIntProd: { + // No relaxation, just a cut generator . + AddIntProdCutGenerator(ct, linearization_level, model, relaxation); + break; + } case ConstraintProto::ConstraintCase::kLinMax: { AppendLinMaxRelaxationPart1(ct, model, relaxation); - if (LinMaxContainsOnlyOneVarInExpressions(ct)) { + const bool is_affine_max = LinMaxContainsOnlyOneVarInExpressions(ct); + if (is_affine_max) { AppendMaxAffineRelaxation(ct, model, relaxation); } + + // Add cut generators. + if (linearization_level > 1) { + if (is_affine_max) { + AddMaxAffineCutGenerator(ct, model, relaxation); + } else { + AddLinMaxCutGenerator(ct, model, relaxation); + } + } + break; + } + case ConstraintProto::ConstraintCase::kAllDiff: { + if (linearization_level > 1) { + AddAllDiffCutGenerator(ct, model, relaxation); + } break; } case ConstraintProto::ConstraintCase::kLinear: { @@ -1004,21 +1028,35 @@ void TryToLinearizeConstraint(const CpModelProto& model_proto, } case ConstraintProto::ConstraintCase::kCircuit: { AppendCircuitRelaxation(ct, model, relaxation); + if (linearization_level > 1) { + AddCircuitCutGenerator(ct, model, relaxation); + } break; } case ConstraintProto::ConstraintCase::kRoutes: { AppendRoutesRelaxation(ct, model, relaxation); + if (linearization_level > 1) { + AddRoutesCutGenerator(ct, model, relaxation); + } break; } case ConstraintProto::ConstraintCase::kNoOverlap: { if (linearization_level > 1) { AppendNoOverlapRelaxation(model_proto, ct, model, relaxation); + AddNoOverlapCutGenerator(ct, model, relaxation); } break; } case ConstraintProto::ConstraintCase::kCumulative: { if (linearization_level > 1) { AppendCumulativeRelaxation(model_proto, ct, model, relaxation); + AddCumulativeCutGenerator(ct, model, relaxation); + } + break; + } + case ConstraintProto::ConstraintCase::kNoOverlap2D: { + if (linearization_level > 1) { + AddNoOverlap2dCutGenerator(ct, model, relaxation); } break; } @@ -1237,65 +1275,6 @@ void AddLinMaxCutGenerator(const ConstraintProto& ct, Model* m, CreateLinMaxCutGenerator(target, exprs, z_vars, m)); } -// TODO(user): Remove and merge with model loading. -void TryToAddCutGenerators(const ConstraintProto& ct, int linearization_level, - Model* m, LinearRelaxation* relaxation) { - switch (ct.constraint_case()) { - case ConstraintProto::ConstraintCase::kCircuit: { - if (linearization_level > 1) { - AddCircuitCutGenerator(ct, m, relaxation); - } - break; - } - case ConstraintProto::ConstraintCase::kRoutes: { - if (linearization_level > 1) { - AddRoutesCutGenerator(ct, m, relaxation); - } - break; - } - case ConstraintProto::ConstraintCase::kIntProd: { - AddIntProdCutGenerator(ct, linearization_level, m, relaxation); - break; - } - case ConstraintProto::ConstraintCase::kAllDiff: { - if (linearization_level > 1) { - AddAllDiffCutGenerator(ct, m, relaxation); - } - break; - } - case ConstraintProto::ConstraintCase::kCumulative: { - if (linearization_level > 1) { - AddCumulativeCutGenerator(ct, m, relaxation); - } - break; - } - case ConstraintProto::ConstraintCase::kNoOverlap: { - if (linearization_level > 1) { - AddNoOverlapCutGenerator(ct, m, relaxation); - } - break; - } - case ConstraintProto::ConstraintCase::kNoOverlap2D: { - if (linearization_level > 1) { - AddNoOverlap2dCutGenerator(ct, m, relaxation); - } - break; - } - case ConstraintProto::ConstraintCase::kLinMax: { - if (linearization_level > 1) { - if (LinMaxContainsOnlyOneVarInExpressions(ct)) { - AddMaxAffineCutGenerator(ct, m, relaxation); - } else { - AddLinMaxCutGenerator(ct, m, relaxation); - } - } - break; - } - default: { - } - } -} - // If we have an exactly one between literals l_i, and each l_i => var == // value_i, then we can add a strong linear relaxation: var = sum l_i * value_i. // @@ -1364,7 +1343,6 @@ void ComputeLinearRelaxation(const CpModelProto& model_proto, for (const auto& ct : model_proto.constraints()) { TryToLinearizeConstraint(model_proto, ct, linearization_level, m, relaxation); - TryToAddCutGenerators(ct, linearization_level, m, relaxation); } // Linearize the encoding of variable that are fully encoded. diff --git a/ortools/sat/linear_relaxation.h b/ortools/sat/linear_relaxation.h index 836e90e64f..4c683468cf 100644 --- a/ortools/sat/linear_relaxation.h +++ b/ortools/sat/linear_relaxation.h @@ -75,6 +75,18 @@ void AppendPartialGreaterThanEncodingRelaxation(IntegerVariable var, std::vector CreateAlternativeLiteralsWithView( int num_literals, Model* model, LinearRelaxation* relaxation); +void AppendBoolOrRelaxation(const ConstraintProto& ct, Model* model, + LinearRelaxation* relaxation); + +void AppendBoolAndRelaxation(const ConstraintProto& ct, Model* model, + LinearRelaxation* relaxation); + +void AppendAtMostOneRelaxation(const ConstraintProto& ct, Model* model, + LinearRelaxation* relaxation); + +void AppendExactlyOneRelaxation(const ConstraintProto& ct, Model* model, + LinearRelaxation* relaxation); + // Adds linearization of int max constraints. Returns a vector of z vars such // that: z_vars[l] == 1 <=> target = exprs[l]. // @@ -96,23 +108,16 @@ std::vector CreateAlternativeLiteralsWithView( // TODO(user): Support linear expression as target. void AppendLinMaxRelaxationPart1(const ConstraintProto& ct, Model* model, LinearRelaxation* relaxation); + void AppendLinMaxRelaxationPart2( IntegerVariable target, const std::vector& alternative_literals, const std::vector& exprs, Model* model, LinearRelaxation* relaxation); -void AppendBoolOrRelaxation(const ConstraintProto& ct, Model* model, - LinearRelaxation* relaxation); - -void AppendBoolAndRelaxation(const ConstraintProto& ct, Model* model, - LinearRelaxation* relaxation); - -void AppendAtMostOneRelaxation(const ConstraintProto& ct, Model* model, +// Note: This only works if all affine expressions share the same variable. +void AppendMaxAffineRelaxation(const ConstraintProto& ct, Model* model, LinearRelaxation* relaxation); -void AppendExactlyOneRelaxation(const ConstraintProto& ct, Model* model, - LinearRelaxation* relaxation); - // Appends linear constraints to the relaxation. This also handles the // relaxation of linear constraints with enforcement literals. // A linear constraint lb <= ax <= ub with enforcement literals {ei} is relaxed @@ -132,9 +137,6 @@ void AppendCircuitRelaxation(const ConstraintProto& ct, Model* model, void AppendRoutesRelaxation(const ConstraintProto& ct, Model* model, LinearRelaxation* relaxation); -void AppendIntervalRelaxation(const ConstraintProto& ct, Model* model, - LinearRelaxation* relaxation); - // Adds linearization of no overlap constraints. // It adds an energetic equation linking the duration of all potential tasks to // the actual span of the no overlap constraint. @@ -149,25 +151,22 @@ void AppendCumulativeRelaxation(const CpModelProto& model_proto, const ConstraintProto& ct, Model* model, LinearRelaxation* relaxation); -// Adds linearization of different types of constraints. -void TryToLinearizeConstraint(const CpModelProto& model_proto, - const ConstraintProto& ct, - int linearization_level, Model* model, - LinearRelaxation* relaxation); - // Cut generators. +void AddIntProdCutGenerator(const ConstraintProto& ct, int linearization_level, + Model* m, LinearRelaxation* relaxation); + +void AddAllDiffCutGenerator(const ConstraintProto& ct, Model* m, + LinearRelaxation* relaxation); + +void AddLinMaxCutGenerator(const ConstraintProto& ct, Model* m, + LinearRelaxation* relaxation); + void AddCircuitCutGenerator(const ConstraintProto& ct, Model* m, LinearRelaxation* relaxation); void AddRoutesCutGenerator(const ConstraintProto& ct, Model* m, LinearRelaxation* relaxation); -void AddIntProdCutGenerator(const ConstraintProto& ct, Model* m, - LinearRelaxation* relaxation); - -void AddAllDiffCutGenerator(const ConstraintProto& ct, Model* m, - LinearRelaxation* relaxation); - void AddCumulativeCutGenerator(const ConstraintProto& ct, Model* m, LinearRelaxation* relaxation); @@ -177,18 +176,13 @@ void AddNoOverlapCutGenerator(const ConstraintProto& ct, Model* m, void AddNoOverlap2dCutGenerator(const ConstraintProto& ct, Model* m, LinearRelaxation* relaxation); -void AddLinMaxCutGenerator(const ConstraintProto& ct, Model* m, - LinearRelaxation* relaxation); +// Adds linearization of different types of constraints. +void TryToLinearizeConstraint(const CpModelProto& model_proto, + const ConstraintProto& ct, + int linearization_level, Model* model, + LinearRelaxation* relaxation); -// Note: This only work if all affine expressions share the same variable. -void AppendMaxAffineRelaxation(const ConstraintProto& ct, Model* model, - LinearRelaxation* relaxation); - -// Scan the model and add cut generators. -void TryToAddCutGenerators(const ConstraintProto& ct, int linearization_level, - Model* m, LinearRelaxation* relaxation); - -// Builds the linear relaxaton of a CpModelProto and stores it in the +// Builds the linear relaxation of a CpModelProto and stores it in the // LinearRelaxation container. void ComputeLinearRelaxation(const CpModelProto& model_proto, int linearization_level, Model* m,