diff --git a/ortools/sat/integer_expr.cc b/ortools/sat/integer_expr.cc index 8942f8e835..bbeb69a581 100644 --- a/ortools/sat/integer_expr.cc +++ b/ortools/sat/integer_expr.cc @@ -77,16 +77,11 @@ void IntegerSumLE::FillIntegerReason() { } std::pair IntegerSumLE::ConditionalLb( - IntegerVariable bool_view, IntegerVariable target_var) const { - if (integer_trail_->LowerBound(bool_view) != 0 && - integer_trail_->UpperBound(bool_view) != 1) { - return {kMinIntegerValue, kMinIntegerValue}; - } - + IntegerLiteral integer_literal, IntegerVariable target_var) const { // Recall that all our coefficient are positive. - bool bool_view_present = false; - bool bool_view_present_positively = false; - IntegerValue view_coeff; + bool literal_var_present = false; + bool literal_var_present_positively = false; + IntegerValue var_coeff; bool target_var_present_negatively = false; IntegerValue target_coeff; @@ -104,21 +99,34 @@ std::pair IntegerSumLE::ConditionalLb( const IntegerValue lb = integer_trail_->LowerBound(var); implied_lb += coeff * lb; - if (PositiveVariable(var) == PositiveVariable(bool_view)) { - view_coeff = coeff; - bool_view_present = true; - bool_view_present_positively = (var == bool_view); + if (PositiveVariable(var) == PositiveVariable(integer_literal.var)) { + var_coeff = coeff; + literal_var_present = true; + literal_var_present_positively = (var == integer_literal.var); } } - if (!bool_view_present || !target_var_present_negatively) { + if (!literal_var_present || !target_var_present_negatively) { return {kMinIntegerValue, kMinIntegerValue}; } - if (bool_view_present_positively) { + // A literal means var >= bound. + if (literal_var_present_positively) { + // We have var_coeff * var in the expression, the literal is var >= bound. + // When it is false, it is not relevant as implied_lb used var >= lb. + // When it is true, the diff is bound - lb. + const IntegerValue diff = std::max( + IntegerValue(0), integer_literal.bound - + integer_trail_->LowerBound(integer_literal.var)); return {CeilRatio(implied_lb, target_coeff), - CeilRatio(implied_lb + view_coeff, target_coeff)}; + CeilRatio(implied_lb + var_coeff * diff, target_coeff)}; } else { - return {CeilRatio(implied_lb + view_coeff, target_coeff), + // We have var_coeff * -var in the expression, the literal is var >= bound. + // When it is true, it is not relevant as implied_lb used -var >= -ub. + // And when it is false it means var < bound, so -var >= -bound + 1 + const IntegerValue diff = std::max( + IntegerValue(0), integer_trail_->UpperBound(integer_literal.var) - + integer_literal.bound + 1); + return {CeilRatio(implied_lb + var_coeff * diff, target_coeff), CeilRatio(implied_lb, target_coeff)}; } } diff --git a/ortools/sat/integer_expr.h b/ortools/sat/integer_expr.h index 602c73f3e0..222ed86a32 100644 --- a/ortools/sat/integer_expr.h +++ b/ortools/sat/integer_expr.h @@ -72,10 +72,11 @@ class IntegerSumLE : public PropagatorInterface { bool PropagateAtLevelZero(); // This is a pretty usage specific function. Returns the implied lower bound - // on var if the bool_view take the value 0 or 1. If the variables do not - // appear both in the linear inequality, this returns two kMinIntegerValue. + // on target_var if the given integer literal is false (resp. true). If the + // variables do not appear both in the linear inequality, this returns two + // kMinIntegerValue. std::pair ConditionalLb( - IntegerVariable bool_view, IntegerVariable target_var) const; + IntegerLiteral integer_literal, IntegerVariable target_var) const; private: // Fills integer_reason_ with all the current lower_bounds. The real diff --git a/ortools/sat/lb_tree_search.cc b/ortools/sat/lb_tree_search.cc index 51f90ab8a0..b355762743 100644 --- a/ortools/sat/lb_tree_search.cc +++ b/ortools/sat/lb_tree_search.cc @@ -416,42 +416,41 @@ SatSolver::Status LbTreeSearch::Search( // and also allow for a more incremental LP solving since we do less back // and forth. // - // TODO(user): The code to recover that is a bit convoluted. + // TODO(user): The code to recover that is a bit convoluted. Alternatively + // Maybe we should do a "fast" propagation without the LP in each branch. + // That will work as long as we keep these optimal LP constraints around + // and propagate them. + // // TODO(user): Incorporate this in the heuristic so we choose more Boolean // inside these LP explanations? if (lp_constraint_ != nullptr) { - IntegerSumLE* last_rc = lp_constraint_->LatestOptimalConstraintOrNull(); - if (last_rc != nullptr) { - const IntegerVariable pos_view = - integer_encoder_->GetLiteralView(Literal(decision)); - if (pos_view != kNoIntegerVariable) { - const std::pair bounds = - last_rc->ConditionalLb(pos_view, objective_var_); - Node& node = nodes_[n]; - if (bounds.first > node.false_objective) { - ++num_rc_detected_; - node.UpdateFalseObjective(bounds.first); - } - if (bounds.second > node.true_objective) { - ++num_rc_detected_; - node.UpdateTrueObjective(bounds.second); - } - } + // Note that this return literal EQUIVALENT to the decision, not just + // implied by it. We need that for correctness. + int num_tests = 0; + for (const IntegerLiteral integer_literal : + integer_encoder_->GetIntegerLiterals(Literal(decision))) { + if (integer_trail_->IsCurrentlyIgnored(integer_literal.var)) continue; - const IntegerVariable neg_view = - integer_encoder_->GetLiteralView(Literal(decision).Negated()); - if (neg_view != kNoIntegerVariable) { - const std::pair bounds = - last_rc->ConditionalLb(neg_view, objective_var_); - Node& node = nodes_[n]; - if (bounds.first > node.true_objective) { - ++num_rc_detected_; - node.UpdateTrueObjective(bounds.second); - } - if (bounds.second > node.false_objective) { - ++num_rc_detected_; - node.UpdateFalseObjective(bounds.second); - } + // To avoid bad corner case. Not sure it ever triggers. + if (++num_tests > 10) break; + + // TODO(user): we could consider earlier constraint instead of just + // looking at the last one, but experiments didn't really show a big + // gain. + const auto& cts = lp_constraint_->OptimalConstraints(); + if (cts.empty()) continue; + + const std::unique_ptr& rc = cts.back(); + const std::pair bounds = + rc->ConditionalLb(integer_literal, objective_var_); + Node& node = nodes_[n]; + if (bounds.first > node.false_objective) { + ++num_rc_detected_; + node.UpdateFalseObjective(bounds.first); + } + if (bounds.second > node.true_objective) { + ++num_rc_detected_; + node.UpdateTrueObjective(bounds.second); } } } diff --git a/ortools/sat/linear_programming_constraint.h b/ortools/sat/linear_programming_constraint.h index 3ba2a345e2..467625dee7 100644 --- a/ortools/sat/linear_programming_constraint.h +++ b/ortools/sat/linear_programming_constraint.h @@ -231,6 +231,10 @@ class LinearProgrammingConstraint : public PropagatorInterface, return optimal_constraints_.back().get(); } + const std::vector>& OptimalConstraints() const { + return optimal_constraints_; + } + private: // Helper methods for branching. Returns true if branching on the given // variable helps with more propagation or finds a conflict.