From fa2473affe655a1a0fc20c27863b028322529159 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Sun, 14 Apr 2024 10:58:12 +0200 Subject: [PATCH] [CP-SAT] improve presolve for affine_max, improve work sharing; improved linear code --- ortools/sat/BUILD.bazel | 2 + ortools/sat/cp_model_presolve.cc | 335 ++++++++++++++++++++++----- ortools/sat/cp_model_presolve.h | 6 +- ortools/sat/cp_model_utils.cc | 1 + ortools/sat/cp_model_utils.h | 17 ++ ortools/sat/diffn_util.cc | 2 +- ortools/sat/diffn_util.h | 2 +- ortools/sat/linear_propagation.cc | 49 +++- ortools/sat/linear_propagation.h | 6 + ortools/sat/linear_relaxation.cc | 21 +- ortools/sat/parameters_validation.cc | 2 + ortools/sat/python/cp_model.py | 4 +- ortools/sat/sat_parameters.proto | 3 +- ortools/sat/work_assignment.cc | 36 ++- ortools/sat/work_assignment.h | 6 +- 15 files changed, 385 insertions(+), 107 deletions(-) diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index 3db7378c9f..4b11f8bd0a 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -570,6 +570,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/hash", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/numeric:int128", @@ -1143,6 +1144,7 @@ cc_library( ":sat_base", ":sat_solver", ":synchronization", + ":util", "//ortools/base:stl_util", "//ortools/base:strong_vector", "//ortools/util:bitset", diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index 1d29544102..03d7366330 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -762,27 +762,185 @@ bool CpModelPresolver::DivideLinMaxByGcd(int c, ConstraintProto* ct) { return true; } -bool CpModelPresolver::PresolveLinMax(ConstraintProto* ct) { - if (context_->ModelIsUnsat()) return false; - if (HasEnforcementLiteral(*ct)) return false; - const LinearExpressionProto& target = ct->lin_max().target(); +namespace { - // x = max(x, xi...) => forall i, x >= xi. - for (const LinearExpressionProto& expr : ct->lin_max().exprs()) { - if (LinearExpressionProtosAreEqual(expr, target)) { - for (const LinearExpressionProto& e : ct->lin_max().exprs()) { - if (LinearExpressionProtosAreEqual(e, target)) continue; - LinearConstraintProto* prec = - context_->working_model->add_constraints()->mutable_linear(); - prec->add_domain(0); - prec->add_domain(std::numeric_limits::max()); - AddLinearExpressionToLinearConstraint(target, 1, prec); - AddLinearExpressionToLinearConstraint(e, -1, prec); - } - context_->UpdateRuleStats("lin_max: x = max(x, ...)"); - return RemoveConstraint(ct); +int64_t EvaluateSingleVariableExpression(const LinearExpressionProto& expr, + int var, int64_t value) { + int64_t result = expr.offset(); + for (int i = 0; i < expr.vars().size(); ++i) { + CHECK_EQ(expr.vars(i), var); + result += expr.coeffs(i) * value; + } + return result; +} + +template +int GetFirstVar(ExpressionList exprs) { + for (const LinearExpressionProto& expr : exprs) { + for (const int var : expr.vars()) { + DCHECK(RefIsPositive(var)); + return var; } } + return -1; +} + +bool IsAffineIntAbs(const ConstraintProto& ct) { + if (ct.constraint_case() != ConstraintProto::kLinMax || + ct.lin_max().exprs_size() != 2 || ct.lin_max().target().vars_size() > 1 || + ct.lin_max().exprs(0).vars_size() != 1 || + ct.lin_max().exprs(1).vars_size() != 1) { + return false; + } + + const LinearArgumentProto& lin_max = ct.lin_max(); + if (lin_max.exprs(0).offset() != -lin_max.exprs(1).offset()) return false; + if (PositiveRef(lin_max.exprs(0).vars(0)) != + PositiveRef(lin_max.exprs(1).vars(0))) { + return false; + } + + const int64_t left_coeff = RefIsPositive(lin_max.exprs(0).vars(0)) + ? lin_max.exprs(0).coeffs(0) + : -lin_max.exprs(0).coeffs(0); + const int64_t right_coeff = RefIsPositive(lin_max.exprs(1).vars(0)) + ? lin_max.exprs(1).coeffs(0) + : -lin_max.exprs(1).coeffs(0); + return left_coeff == -right_coeff; +} + +} // namespace + +bool CpModelPresolver::PropagateAndReduceAffineMax(ConstraintProto* ct) { + // Get the unique variable appearing in the expressions. + const int unique_var = GetFirstVar(ct->lin_max().exprs()); + + const auto& lin_max = ct->lin_max(); + const int num_exprs = lin_max.exprs_size(); + const auto& target = lin_max.target(); + std::vector num_wins(num_exprs, 0); + std::vector reachable_target_values; + std::vector valid_variable_values; + std::vector tmp_values(num_exprs); + + const bool target_has_same_unique_var = + target.vars_size() == 1 && target.vars(0) == unique_var; + + CHECK_LE(context_->DomainOf(unique_var).Size(), 1000); + + for (const int64_t value : context_->DomainOf(unique_var).Values()) { + int64_t current_max = std::numeric_limits::min(); + + // Fill tmp_values and compute current_max; + for (int i = 0; i < num_exprs; ++i) { + const int64_t v = + EvaluateSingleVariableExpression(lin_max.exprs(i), unique_var, value); + current_max = std::max(current_max, v); + tmp_values[i] = v; + } + + // Check if any expr produced a value compatible with the target. + if (!context_->DomainContains(target, current_max)) continue; + + // Special case: affine(x) == max(exprs(x)). We can check if the affine() + // and the max(exprs) are compatible. + if (target_has_same_unique_var && + EvaluateSingleVariableExpression(target, unique_var, value) != + current_max) { + continue; + } + + valid_variable_values.push_back(value); + reachable_target_values.push_back(current_max); + for (int i = 0; i < num_exprs; ++i) { + DCHECK_LE(tmp_values[i], current_max); + if (tmp_values[i] == current_max) { + num_wins[i]++; + } + } + } + + if (reachable_target_values.empty() || valid_variable_values.empty()) { + context_->UpdateRuleStats("lin_max: infeasible affine_max constraint"); + return MarkConstraintAsFalse(ct); + } + + { + bool reduced = false; + if (!context_->IntersectDomainWith( + target, Domain::FromValues(reachable_target_values), &reduced)) { + return true; + } + if (reduced) { + context_->UpdateRuleStats("lin_max: affine_max target domain reduced"); + } + } + + { + bool reduced = false; + if (!context_->IntersectDomainWith( + unique_var, Domain::FromValues(valid_variable_values), &reduced)) { + return true; + } + if (reduced) { + context_->UpdateRuleStats( + "lin_max: unique affine_max var domain reduced"); + } + } + + // If one expression always wins, even tied, we can eliminate all the others. + for (int i = 0; i < num_exprs; ++i) { + if (num_wins[i] == valid_variable_values.size()) { + const LinearExpressionProto winner_expr = lin_max.exprs(i); + ct->mutable_lin_max()->clear_exprs(); + *ct->mutable_lin_max()->add_exprs() = winner_expr; + break; + } + } + + bool changed = false; + if (ct->lin_max().exprs_size() > 1) { + int new_size = 0; + for (int i = 0; i < num_exprs; ++i) { + if (num_wins[i] == 0) continue; + *ct->mutable_lin_max()->mutable_exprs(new_size) = ct->lin_max().exprs(i); + new_size++; + } + if (new_size < ct->lin_max().exprs_size()) { + context_->UpdateRuleStats("lin_max: removed affine_max exprs"); + google::protobuf::util::Truncate(ct->mutable_lin_max()->mutable_exprs(), + new_size); + changed = true; + } + } + + if (context_->IsFixed(target)) { + context_->UpdateRuleStats("lin_max: fixed affine_max target"); + return RemoveConstraint(ct); + } + + if (target_has_same_unique_var) { + context_->UpdateRuleStats("lin_max: target_affine(x) = max(affine_i(x))"); + return RemoveConstraint(ct); + } + + // Remove the affine_max constraint if the target is removable and if domains + // have been propagated without loss. For now, we known that there is no loss + // if the target is a single ref. Since all the expression are affine, in this + // case we are fine. + if (ExpressionContainsSingleRef(target) && + context_->VariableIsUniqueAndRemovable(target.vars(0))) { + context_->MarkVariableAsRemoved(target.vars(0)); + *context_->mapping_model->add_constraints() = *ct; + context_->UpdateRuleStats("lin_max: unused affine_max target"); + return RemoveConstraint(ct); + } + + return changed; +} + +bool CpModelPresolver::PropagateAndReduceLinMax(ConstraintProto* ct) { + const LinearExpressionProto& target = ct->lin_max().target(); // Compute the infered min/max of the target. // Update target domain (if it is not a complex expression). @@ -819,7 +977,6 @@ bool CpModelPresolver::PresolveLinMax(ConstraintProto* ct) { // Filter the expressions which are smaller than target_min. const int64_t target_min = context_->MinOf(target); - const int64_t target_max = context_->MaxOf(target); bool changed = false; { // If one expression is >= target_min, @@ -864,6 +1021,59 @@ bool CpModelPresolver::PresolveLinMax(ConstraintProto* ct) { } } + return changed; +} + +bool CpModelPresolver::PresolveLinMax(ConstraintProto* ct) { + if (context_->ModelIsUnsat()) return false; + if (HasEnforcementLiteral(*ct)) return false; + const LinearExpressionProto& target = ct->lin_max().target(); + + // x = max(x, xi...) => forall i, x >= xi. + for (const LinearExpressionProto& expr : ct->lin_max().exprs()) { + if (LinearExpressionProtosAreEqual(expr, target)) { + for (const LinearExpressionProto& e : ct->lin_max().exprs()) { + if (LinearExpressionProtosAreEqual(e, target)) continue; + LinearConstraintProto* prec = + context_->working_model->add_constraints()->mutable_linear(); + prec->add_domain(0); + prec->add_domain(std::numeric_limits::max()); + AddLinearExpressionToLinearConstraint(target, 1, prec); + AddLinearExpressionToLinearConstraint(e, -1, prec); + } + context_->UpdateRuleStats("lin_max: x = max(x, ...)"); + return RemoveConstraint(ct); + } + } + + const bool is_one_var_affine_max = + ExpressionsContainsOnlyOneVar(ct->lin_max().exprs()) && + ct->lin_max().target().vars_size() <= 1; + bool unique_var_is_small_enough = false; + const bool is_int_abs = IsAffineIntAbs(*ct); + + if (is_one_var_affine_max) { + const int unique_var = GetFirstVar(ct->lin_max().exprs()); + unique_var_is_small_enough = context_->DomainOf(unique_var).Size() <= 1000; + } + + // This is a test.12y + + bool changed; + if (is_one_var_affine_max && unique_var_is_small_enough) { + changed = PropagateAndReduceAffineMax(ct); + } else if (is_int_abs) { + changed = PropagateAndReduceIntAbs(ct); + } else { + changed = PropagateAndReduceLinMax(ct); + } + + if (context_->ModelIsUnsat()) return false; + if (ct->constraint_case() != ConstraintProto::kLinMax) { + // The constraint was removed by the propagate helpers. + return changed; + } + if (ct->lin_max().exprs().empty()) { context_->UpdateRuleStats("lin_max: no exprs"); return MarkConstraintAsFalse(ct); @@ -895,6 +1105,8 @@ bool CpModelPresolver::PresolveLinMax(ConstraintProto* ct) { // Cut everything above the max if possible. // If one of the linear expression has many term and is above the max, we // abort early since none of the other rule can be applied. + const int64_t target_min = context_->MinOf(target); + const int64_t target_max = context_->MaxOf(target); { bool abort = false; for (const LinearExpressionProto& expr : ct->lin_max().exprs()) { @@ -1133,8 +1345,9 @@ bool CpModelPresolver::PresolveLinMaxWhenAllBoolean(ConstraintProto* ct) { return RemoveConstraint(ct); } -// This presolve expect that the constraint only contains affine expressions. -bool CpModelPresolver::PresolveIntAbs(ConstraintProto* ct) { +// This presolve expect that the constraint only contains 1-var affine +// expressions. +bool CpModelPresolver::PropagateAndReduceIntAbs(ConstraintProto* ct) { CHECK_EQ(ct->enforcement_literal_size(), 0); if (context_->ModelIsUnsat()) return false; const LinearExpressionProto& target_expr = ct->lin_max().target(); @@ -1153,11 +1366,11 @@ bool CpModelPresolver::PresolveIntAbs(ConstraintProto* ct) { return false; } if (expr_domain.IsFixed()) { - context_->UpdateRuleStats("int_abs: fixed expression"); + context_->UpdateRuleStats("lin_max: fixed expression in int_abs"); return RemoveConstraint(ct); } if (target_domain_modified) { - context_->UpdateRuleStats("int_abs: propagate domain from x to abs(x)"); + context_->UpdateRuleStats("lin_max: propagate domain from x to abs(x)"); } } @@ -1176,17 +1389,17 @@ bool CpModelPresolver::PresolveIntAbs(ConstraintProto* ct) { // This is the only reason why we don't support fully generic linear // expression. if (context_->IsFixed(target_expr)) { - context_->UpdateRuleStats("int_abs: fixed target"); + context_->UpdateRuleStats("lin_max: fixed abs target"); return RemoveConstraint(ct); } if (expr_domain_modified) { - context_->UpdateRuleStats("int_abs: propagate domain from abs(x) to x"); + context_->UpdateRuleStats("lin_max: propagate domain from abs(x) to x"); } } // Convert to equality if the sign of expr is fixed. if (context_->MinOf(expr) >= 0) { - context_->UpdateRuleStats("int_abs: converted to equality"); + context_->UpdateRuleStats("lin_max: converted abs to equality"); ConstraintProto* new_ct = context_->working_model->add_constraints(); new_ct->set_name(ct->name()); auto* arg = new_ct->mutable_linear(); @@ -1200,7 +1413,7 @@ bool CpModelPresolver::PresolveIntAbs(ConstraintProto* ct) { } if (context_->MaxOf(expr) <= 0) { - context_->UpdateRuleStats("int_abs: converted to equality"); + context_->UpdateRuleStats("lin_max: converted abs to equality"); ConstraintProto* new_ct = context_->working_model->add_constraints(); new_ct->set_name(ct->name()); auto* arg = new_ct->mutable_linear(); @@ -1221,7 +1434,7 @@ bool CpModelPresolver::PresolveIntAbs(ConstraintProto* ct) { context_->VariableIsUniqueAndRemovable(target_expr.vars(0))) { context_->MarkVariableAsRemoved(target_expr.vars(0)); *context_->mapping_model->add_constraints() = *ct; - context_->UpdateRuleStats("int_abs: unused target"); + context_->UpdateRuleStats("lin_max: unused abs target"); return RemoveConstraint(ct); } @@ -1514,7 +1727,7 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { return RemoveConstraint(ct); } -bool CpModelPresolver::PresolveIntDiv(ConstraintProto* ct) { +bool CpModelPresolver::PresolveIntDiv(int c, ConstraintProto* ct) { if (context_->ModelIsUnsat()) return false; const LinearExpressionProto target = ct->int_div().target(); @@ -1535,6 +1748,34 @@ bool CpModelPresolver::PresolveIntDiv(ConstraintProto* ct) { return RemoveConstraint(ct); } + // Sometimes we have only a single variable appearing in the whole constraint. + // If the domain is small enough, we can just restrict the domain and remove + // the constraint. + if (ct->enforcement_literal().empty() && + context_->ConstraintToVars(c).size() == 1) { + const int var = context_->ConstraintToVars(c)[0]; + if (context_->DomainOf(var).Size() >= 100) { + context_->UpdateRuleStats( + "TODO int_div: single variable with large domain"); + } else { + std::vector possible_values; + for (const int64_t v : context_->DomainOf(var).Values()) { + const int64_t target_v = + EvaluateSingleVariableExpression(target, var, v); + const int64_t expr_v = EvaluateSingleVariableExpression(expr, var, v); + const int64_t div_v = EvaluateSingleVariableExpression(div, var, v); + if (div_v == 0) continue; + if (target_v == expr_v / div_v) { + possible_values.push_back(v); + } + } + (void)context_->IntersectDomainWith(var, + Domain::FromValues(possible_values)); + context_->UpdateRuleStats("int_div: single variable"); + return RemoveConstraint(ct); + } + } + // For now, we only presolve the case where the divisor is constant. if (!context_->IsFixed(div)) return false; @@ -7796,34 +8037,6 @@ void CpModelPresolver::TransformIntoMaxCliques() { } } -namespace { - -bool IsAffineIntAbs(const ConstraintProto& ct) { - if (ct.constraint_case() != ConstraintProto::kLinMax || - ct.lin_max().exprs_size() != 2 || ct.lin_max().target().vars_size() > 1 || - ct.lin_max().exprs(0).vars_size() != 1 || - ct.lin_max().exprs(1).vars_size() != 1) { - return false; - } - - const LinearArgumentProto& lin_max = ct.lin_max(); - if (lin_max.exprs(0).offset() != -lin_max.exprs(1).offset()) return false; - if (PositiveRef(lin_max.exprs(0).vars(0)) != - PositiveRef(lin_max.exprs(1).vars(0))) { - return false; - } - - const int64_t left_coeff = RefIsPositive(lin_max.exprs(0).vars(0)) - ? lin_max.exprs(0).coeffs(0) - : -lin_max.exprs(0).coeffs(0); - const int64_t right_coeff = RefIsPositive(lin_max.exprs(1).vars(0)) - ? lin_max.exprs(1).coeffs(0) - : -lin_max.exprs(1).coeffs(0); - return left_coeff == -right_coeff; -} - -} // namespace - bool CpModelPresolver::PresolveOneConstraint(int c) { if (context_->ModelIsUnsat()) return false; ConstraintProto* ct = context_->working_model->mutable_constraints(c); @@ -7855,11 +8068,7 @@ bool CpModelPresolver::PresolveOneConstraint(int c) { context_->UpdateConstraintVariableUsage(c); } if (!DivideLinMaxByGcd(c, ct)) return false; - if (IsAffineIntAbs(*ct)) { - return PresolveIntAbs(ct); - } else { - return PresolveLinMax(ct); - } + return PresolveLinMax(ct); case ConstraintProto::kIntProd: if (CanonicalizeLinearArgument(*ct, ct->mutable_int_prod())) { context_->UpdateConstraintVariableUsage(c); @@ -7869,7 +8078,7 @@ bool CpModelPresolver::PresolveOneConstraint(int c) { if (CanonicalizeLinearArgument(*ct, ct->mutable_int_div())) { context_->UpdateConstraintVariableUsage(c); } - return PresolveIntDiv(ct); + return PresolveIntDiv(c, ct); case ConstraintProto::kIntMod: if (CanonicalizeLinearArgument(*ct, ct->mutable_int_mod())) { context_->UpdateConstraintVariableUsage(c); diff --git a/ortools/sat/cp_model_presolve.h b/ortools/sat/cp_model_presolve.h index d8d1741cbd..b5b5f817ea 100644 --- a/ortools/sat/cp_model_presolve.h +++ b/ortools/sat/cp_model_presolve.h @@ -113,8 +113,7 @@ class CpModelPresolver { bool PresolveAllDiff(ConstraintProto* ct); bool PresolveAutomaton(ConstraintProto* ct); bool PresolveElement(ConstraintProto* ct); - bool PresolveIntAbs(ConstraintProto* ct); - bool PresolveIntDiv(ConstraintProto* ct); + bool PresolveIntDiv(int c, ConstraintProto* ct); bool PresolveIntMod(int c, ConstraintProto* ct); bool PresolveIntProd(ConstraintProto* ct); bool PresolveInterval(int c, ConstraintProto* ct); @@ -122,6 +121,9 @@ class CpModelPresolver { bool DivideLinMaxByGcd(int c, ConstraintProto* ct); bool PresolveLinMax(ConstraintProto* ct); bool PresolveLinMaxWhenAllBoolean(ConstraintProto* ct); + bool PropagateAndReduceAffineMax(ConstraintProto* ct); + bool PropagateAndReduceIntAbs(ConstraintProto* ct); + bool PropagateAndReduceLinMax(ConstraintProto* ct); bool PresolveTable(ConstraintProto* ct); void DetectDuplicateIntervals( int c, google::protobuf::RepeatedField* intervals); diff --git a/ortools/sat/cp_model_utils.cc b/ortools/sat/cp_model_utils.cc index fa58d26611..3fa2df9aba 100644 --- a/ortools/sat/cp_model_utils.cc +++ b/ortools/sat/cp_model_utils.cc @@ -27,6 +27,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" #include "ortools/base/stl_util.h" diff --git a/ortools/sat/cp_model_utils.h b/ortools/sat/cp_model_utils.h index 2542b5c190..61411d644a 100644 --- a/ortools/sat/cp_model_utils.h +++ b/ortools/sat/cp_model_utils.h @@ -224,6 +224,23 @@ bool LinearExpressionProtosAreEqual(const LinearExpressionProto& a, const LinearExpressionProto& b, int64_t b_scaling = 1); +// Returns true if there exactly one variable appearing in all the expressions. +template +bool ExpressionsContainsOnlyOneVar(const ExpressionList& exprs) { + int unique_var = -1; + for (const LinearExpressionProto& expr : exprs) { + for (const int var : expr.vars()) { + CHECK(RefIsPositive(var)); + if (unique_var == -1) { + unique_var = var; + } else if (var != unique_var) { + return false; + } + } + } + return unique_var != -1; +} + // Default seed for fingerprints. constexpr uint64_t kDefaultFingerprintSeed = 0xa5b85c5e198ed849; diff --git a/ortools/sat/diffn_util.cc b/ortools/sat/diffn_util.cc index 6a44c7d7f6..fb48a1fd01 100644 --- a/ortools/sat/diffn_util.cc +++ b/ortools/sat/diffn_util.cc @@ -571,7 +571,7 @@ void AppendPairwiseRestriction(const ItemForPairwiseRestriction& item1, } // namespace void AppendPairwiseRestrictions( - const std::vector& items, + absl::Span items, std::vector* result) { for (int i1 = 0; i1 + 1 < items.size(); ++i1) { for (int i2 = i1 + 1; i2 < items.size(); ++i2) { diff --git a/ortools/sat/diffn_util.h b/ortools/sat/diffn_util.h index 4cd568edab..bd7bfef3f4 100644 --- a/ortools/sat/diffn_util.h +++ b/ortools/sat/diffn_util.h @@ -248,7 +248,7 @@ struct PairwiseRestriction { // Find pair of items that are either in conflict or could have their range // shrinked to avoid conflict. void AppendPairwiseRestrictions( - const std::vector& items, + absl::Span items, std::vector* result); // Same as above, but test `items` against `other_items` and append the diff --git a/ortools/sat/linear_propagation.cc b/ortools/sat/linear_propagation.cc index 79b3f9bcb7..1b3f699be9 100644 --- a/ortools/sat/linear_propagation.cc +++ b/ortools/sat/linear_propagation.cc @@ -39,6 +39,7 @@ #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_solver.h" #include "ortools/sat/synchronization.h" +#include "ortools/sat/util.h" #include "ortools/util/bitset.h" #include "ortools/util/strong_integers.h" #include "ortools/util/time_limit.h" @@ -477,6 +478,7 @@ LinearPropagator::LinearPropagator(Model* model) rev_integer_value_repository_( model->GetOrCreate()), precedences_(model->GetOrCreate()), + random_(model->GetOrCreate()), shared_stats_(model->GetOrCreate()), watcher_id_(watcher_->Register(this)) { // Note that we need this class always in sync. @@ -760,7 +762,7 @@ void LinearPropagator::CanonicalizeConstraint(int id) { } // TODO(user): template everything for the case info.all_coeffs_are_one ? -bool LinearPropagator::PropagateOneConstraint(int id) { +std::pair LinearPropagator::AnalyzeConstraint(int id) { // This is here for development purpose, it is a bit too slow to check by // default though, even VLOG_IS_ON(1) so we disable it. if (/* DISABLES CODE */ (false)) { @@ -803,7 +805,7 @@ bool LinearPropagator::PropagateOneConstraint(int id) { unenforced_constraints_.push_back(id); } ++num_ignored_; - return true; + return {0, 0}; } // Compute the slack and max_variations_ of each variables. @@ -866,16 +868,45 @@ bool LinearPropagator::PropagateOneConstraint(int id) { } } } + + // What we call slack here is the "room" between the implied_lb and the rhs. + // Note that we use slack in other context in this file too. const IntegerValue slack = info.rev_rhs - implied_lb; // Negative slack means the constraint is false. - if (max_variation <= slack) return true; + // Note that if max_variation > slack, we are sure to propagate something + // except if the constraint is enforced and the slack is non-negative. + if (slack < 0 || max_variation <= slack) return {slack, 0}; + if (enf_status == EnforcementStatus::IS_ENFORCED) { + // Swap the variable(s) that will be pushed at the beginning. + int num_to_push = 0; + const auto coeffs = GetCoeffs(info); + for (int i = 0; i < info.rev_size; ++i) { + if (max_variations[i] <= slack) continue; + std::swap(vars[i], vars[num_to_push]); + std::swap(coeffs[i], coeffs[num_to_push]); + ++num_to_push; + } + return {slack, num_to_push}; + } + return {slack, 0}; +} + +bool LinearPropagator::PropagateOneConstraint(int id) { + // The slack is const after this. + const auto [slack, num_to_push] = AnalyzeConstraint(id); + if (slack >= 0 && num_to_push == 0) return true; + + // We are sure to propagate something at this stage. id_propagated_something_[id] = true; + const ConstraintInfo& info = infos_[id]; + const auto vars = GetVariables(info); + const auto coeffs = GetCoeffs(info); + if (slack < 0) { // Fill integer reason. integer_reason_.clear(); reason_coeffs_.clear(); - const auto coeffs = GetCoeffs(info); for (int i = 0; i < info.initial_size; ++i) { const IntegerVariable var = vars[i]; if (!integer_trail_->VariableLowerBoundIsFromLevelZero(var)) { @@ -893,17 +924,13 @@ bool LinearPropagator::PropagateOneConstraint(int id) { } // We can only propagate more if all the enforcement literals are true. - if (info.enf_status != static_cast(EnforcementStatus::IS_ENFORCED)) { - return true; - } + // But this should have been checked by SkipConstraint(). + CHECK_EQ(info.enf_status, static_cast(EnforcementStatus::IS_ENFORCED)); // The lower bound of all the variables except one can be used to update the // upper bound of the last one. int num_pushed = 0; - const auto coeffs = GetCoeffs(info); - for (int i = 0; i < info.rev_size; ++i) { - if (max_variations[i] <= slack) continue; - + for (int i = 0; i < num_to_push; ++i) { // TODO(user): If the new ub fall into an hole of the variable, we can // actually relax the reason more by computing a better slack. ++num_pushes_; diff --git a/ortools/sat/linear_propagation.h b/ortools/sat/linear_propagation.h index 8fccd247c9..582b14b807 100644 --- a/ortools/sat/linear_propagation.h +++ b/ortools/sat/linear_propagation.h @@ -233,6 +233,11 @@ class LinearPropagator : public PropagatorInterface, ReversibleInterface { ABSL_MUST_USE_RESULT bool ReportConflictingCycle(); ABSL_MUST_USE_RESULT bool DisassembleSubtree(int root_id, int num_pushed); + // Returns (slack, num_to_push) of the given constraint. + // If slack < 0 we have a conflict or might push the enforcement. + // If slack >= 0 the first num_to_push variables can be pushed. + std::pair AnalyzeConstraint(int id); + void ClearPropagatedBy(); void CanonicalizeConstraint(int id); void AddToQueueIfNeeded(int id); @@ -250,6 +255,7 @@ class LinearPropagator : public PropagatorInterface, ReversibleInterface { RevIntRepository* rev_int_repository_; RevIntegerValueRepository* rev_integer_value_repository_; PrecedenceRelations* precedences_; + ModelRandomGenerator* random_; SharedStatistics* shared_stats_ = nullptr; const int watcher_id_; diff --git a/ortools/sat/linear_relaxation.cc b/ortools/sat/linear_relaxation.cc index 6404bdca28..c1bb145c51 100644 --- a/ortools/sat/linear_relaxation.cc +++ b/ortools/sat/linear_relaxation.cc @@ -97,22 +97,6 @@ std::pair GetMinAndMaxNotEncoded( return {min, max}; } -bool LinMaxContainsOnlyOneVarInExpressions(const ConstraintProto& ct) { - CHECK_EQ(ct.constraint_case(), ConstraintProto::ConstraintCase::kLinMax); - int current_var = -1; - for (const LinearExpressionProto& expr : ct.lin_max().exprs()) { - if (expr.vars().empty()) continue; - if (expr.vars().size() > 1) return false; - const int var = PositiveRef(expr.vars(0)); - if (current_var == -1) { - current_var = var; - } else if (var != current_var) { - return false; - } - } - return true; -} - // Collect all the affines expressions in a LinMax constraint. // It checks that these are indeed affine expressions, and that they all share // the same variable. @@ -121,7 +105,7 @@ bool LinMaxContainsOnlyOneVarInExpressions(const ConstraintProto& ct) { void CollectAffineExpressionWithSingleVariable( const ConstraintProto& ct, CpModelMapping* mapping, IntegerVariable* var, std::vector>* affines) { - DCHECK(LinMaxContainsOnlyOneVarInExpressions(ct)); + DCHECK(ExpressionsContainsOnlyOneVar(ct.lin_max().exprs())); CHECK_EQ(ct.constraint_case(), ConstraintProto::ConstraintCase::kLinMax); *var = kNoIntegerVariable; affines->clear(); @@ -1381,7 +1365,8 @@ void TryToLinearizeConstraint(const CpModelProto& /*model_proto*/, } case ConstraintProto::ConstraintCase::kLinMax: { AppendLinMaxRelaxationPart1(ct, model, relaxation); - const bool is_affine_max = LinMaxContainsOnlyOneVarInExpressions(ct); + const bool is_affine_max = + ExpressionsContainsOnlyOneVar(ct.lin_max().exprs()); if (is_affine_max) { AppendMaxAffineRelaxation(ct, model, relaxation); } diff --git a/ortools/sat/parameters_validation.cc b/ortools/sat/parameters_validation.cc index b091cf80fc..4feb6c34c7 100644 --- a/ortools/sat/parameters_validation.cc +++ b/ortools/sat/parameters_validation.cc @@ -105,6 +105,8 @@ std::string ValidateParameters(const SatParameters& params) { TEST_IN_RANGE(min_num_lns_workers, 0, kMaxReasonableParallelism); TEST_IN_RANGE(shared_tree_num_workers, 0, kMaxReasonableParallelism); TEST_IN_RANGE(interleave_batch_size, 0, kMaxReasonableParallelism); + TEST_IN_RANGE(shared_tree_open_leaves_per_worker, 1, + kMaxReasonableParallelism); // TODO(user): Consider using annotations directly in the proto for these // validation. It is however not open sourced. diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index d8ed4ec3dd..f09baffbce 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -324,7 +324,9 @@ class LinearExpr: if num_elements == 0: return offset elif num_elements == 1: - return IntVar(model, proto.vars[0], None) * proto.coeffs[0] + offset + return ( + IntVar(model, proto.vars[0], None) * proto.coeffs[0] + offset + ) # pytype: disable=bad-return-type else: variables = [] coeffs = [] diff --git a/ortools/sat/sat_parameters.proto b/ortools/sat/sat_parameters.proto index 02ab2e509d..22f8e6f0ff 100644 --- a/ortools/sat/sat_parameters.proto +++ b/ortools/sat/sat_parameters.proto @@ -1048,7 +1048,8 @@ message SatParameters { // bounds are equal. This rule allows twice as many workers to work in the // preferred subtree as non-preferred. SPLIT_STRATEGY_DISCREPANCY = 1; - // Only split nodes with an objective lb equal to the global lb. + // Only split nodes with an objective lb equal to the global lb. If there is + // no objective, this is equivalent to SPLIT_STRATEGY_FIRST_PROPOSAL. SPLIT_STRATEGY_OBJECTIVE_LB = 2; // Attempt to keep the shared tree balanced. SPLIT_STRATEGY_BALANCED_TREE = 3; diff --git a/ortools/sat/work_assignment.cc b/ortools/sat/work_assignment.cc index 3bcd070b85..6fbc96182f 100644 --- a/ortools/sat/work_assignment.cc +++ b/ortools/sat/work_assignment.cc @@ -42,6 +42,7 @@ #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/sat_solver.h" #include "ortools/sat/synchronization.h" +#include "ortools/sat/util.h" #include "ortools/util/strong_integers.h" #include "ortools/util/time_limit.h" @@ -214,8 +215,17 @@ int SharedTreeManager::NumNodes() const { int SharedTreeManager::SplitsToGeneratePerWorker() const { absl::MutexLock mutex_lock(&mu_); - return std::min(num_splits_wanted_ / 2 + 1, - max_nodes_ - static_cast(nodes_.size())); + const int max_additional_nodes = max_nodes_ - static_cast(nodes_.size()); + const int total_splits_wanted = + std::min(num_splits_wanted_, + // Each split generates 2 nodes, so divide by 2, rounding up. + CeilOfRatio(max_additional_nodes, 2)); + // We want workers to propose too many splits as we expect to reject some, + // and it's more efficient to generate several splits on the same worker + // restart so we don't want to divide by num_workers_. + // But we also don't want more than half the splits to come from a single + // restart on a single worker so we divide by 2. + return CeilOfRatio(total_splits_wanted, 2); } bool SharedTreeManager::SyncTree(ProtoTrail& path) { @@ -246,7 +256,7 @@ bool SharedTreeManager::SyncTree(ProtoTrail& path) { } // Restart after processing updates - we might learn a new objective bound. if (++num_syncs_since_restart_ / num_workers_ > kSyncsPerWorkerPerRestart && - (num_restarts_ < kNumInitialRestarts || nodes_.size() >= max_nodes_)) { + num_restarts_ < kNumInitialRestarts) { RestartLockHeld(); path.Clear(); return false; @@ -278,6 +288,8 @@ void SharedTreeManager::ProposeSplit(ProtoTrail& path, ProtoLiteral decision) { VLOG(2) << "Enough splits for now"; return; } + const int num_desired_leaves = + params_.shared_tree_open_leaves_per_worker() * num_workers_; if (params_.shared_tree_split_strategy() == SatParameters::SPLIT_STRATEGY_DISCREPANCY || params_.shared_tree_split_strategy() == @@ -293,7 +305,7 @@ void SharedTreeManager::ProposeSplit(ProtoTrail& path, ProtoLiteral decision) { // TODO(user): Need to write up the shape this creates. // This rule will allow twice as many leaves in the preferred subtree. if (discrepancy + path.MaxLevel() > - MaxAllowedDiscrepancyPlusDepth(num_workers_)) { + MaxAllowedDiscrepancyPlusDepth(num_desired_leaves)) { VLOG(2) << "Too high discrepancy to accept split"; return; } @@ -307,7 +319,7 @@ void SharedTreeManager::ProposeSplit(ProtoTrail& path, ProtoLiteral decision) { } } else if (params_.shared_tree_split_strategy() == SatParameters::SPLIT_STRATEGY_BALANCED_TREE) { - if (path.MaxLevel() + 1 > log2(num_workers_)) { + if (path.MaxLevel() + 1 > log2(num_desired_leaves)) { VLOG(2) << "Tree too unbalanced to accept split"; return; } @@ -375,12 +387,15 @@ SharedTreeManager::Node* SharedTreeManager::MakeSubtree(Node* parent, } void SharedTreeManager::ProcessNodeChanges() { + int num_newly_closed = 0; while (!to_close_.empty()) { Node* node = to_close_.back(); CHECK_NE(node, nullptr); to_close_.pop_back(); // Iterate over open parents while each sibling is closed. while (node != nullptr && !node->closed) { + ++num_newly_closed; + ++num_closed_nodes_; node->closed = true; node->objective_lb = kMaxIntegerValue; // If we are closing a leaf, try to maintain the same number of leaves; @@ -404,6 +419,13 @@ void SharedTreeManager::ProcessNodeChanges() { to_update_.push_back(node->parent); } } + if (num_newly_closed > 0) { + shared_response_manager_->LogMessageWithThrottling( + "Tree", absl::StrCat("nodes:", nodes_.size(), "/", max_nodes_, + " closed:", num_closed_nodes_, + " unassigned:", unassigned_leaves_.size(), + " restarts:", num_restarts_)); + } bool root_updated = false; while (!to_update_.empty()) { Node* node = to_update_.back(); @@ -491,6 +513,7 @@ void SharedTreeManager::RestartLockHeld() { unassigned_leaves_.clear(); num_splits_wanted_ = num_workers_ * params_.shared_tree_open_leaves_per_worker() - 1; + num_closed_nodes_ = 0; num_restarts_ += 1; num_syncs_since_restart_ = 0; } @@ -663,9 +686,8 @@ void SharedTreeWorker::MaybeProposeSplit() { CHECK_EQ(assigned_tree_literals_.size(), assigned_tree_.MaxLevel()); manager_->ProposeSplit(assigned_tree_, *encoded); if (assigned_tree_.MaxLevel() > assigned_tree_literals_.size()) { - assigned_tree_literals_.push_back(split_decision); --splits_wanted_; - CHECK_EQ(assigned_tree_literals_.size(), assigned_tree_.MaxLevel()); + assigned_tree_literals_.push_back(split_decision); } CHECK_EQ(assigned_tree_literals_.size(), assigned_tree_.MaxLevel()); } diff --git a/ortools/sat/work_assignment.h b/ortools/sat/work_assignment.h index 3e39f0c94a..c11712dd5a 100644 --- a/ortools/sat/work_assignment.h +++ b/ortools/sat/work_assignment.h @@ -209,10 +209,11 @@ class SharedTreeManager { // How many splits we should generate now to keep the desired number of // leaves. - int num_splits_wanted_; + int num_splits_wanted_ ABSL_GUARDED_BY(mu_); // We limit the total nodes generated per restart to cap the RAM usage and - // communication overhead. If we exceed this, we will restart the shared tree. + // communication overhead. If we exceed this, workers become portfolio + // workers when no unassigned leaves are available. const int max_nodes_; int num_leaves_assigned_ ABSL_GUARDED_BY(mu_) = 0; @@ -223,6 +224,7 @@ class SharedTreeManager { int64_t num_restarts_ ABSL_GUARDED_BY(mu_) = 0; int64_t num_syncs_since_restart_ ABSL_GUARDED_BY(mu_) = 0; + int num_closed_nodes_ ABSL_GUARDED_BY(mu_) = 0; }; class SharedTreeWorker {