diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index bd6a42c395..d77e0e1a9a 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -1090,6 +1090,7 @@ cc_library( "//ortools/graph", "//ortools/graph:topologicalsorter", "//ortools/util:bitset", + "//ortools/util:logging", "//ortools/util:strong_integers", "//ortools/util:time_limit", "@com_google_absl//absl/cleanup", diff --git a/ortools/sat/cp_constraints.cc b/ortools/sat/cp_constraints.cc index 76256e8999..79f037a3be 100644 --- a/ortools/sat/cp_constraints.cc +++ b/ortools/sat/cp_constraints.cc @@ -121,6 +121,10 @@ bool GreaterThanAtLeastOneOfPropagator::Propagate() { literal_reason_.push_back(l.Negated()); } for (int i = 0; i < exprs_.size(); ++i) { + // If the level zero bounds is good enough, no reason needed. + if (integer_trail_->LevelZeroLowerBound(exprs_[i]) >= target_min) { + continue; + } if (trail_->Assignment().LiteralIsFalse(selectors_[i])) { literal_reason_.push_back(selectors_[i]); } else { @@ -139,7 +143,11 @@ void GreaterThanAtLeastOneOfPropagator::RegisterWith( const int id = watcher->Register(this); for (const Literal l : selectors_) watcher->WatchLiteral(l.Negated(), id); for (const Literal l : enforcements_) watcher->WatchLiteral(l, id); - for (const AffineExpression e : exprs_) watcher->WatchLowerBound(e, id); + for (const AffineExpression e : exprs_) { + if (!e.IsConstant()) { + watcher->WatchLowerBound(e, id); + } + } } } // namespace sat diff --git a/ortools/sat/cp_model.cc b/ortools/sat/cp_model.cc index 02d53e4f2d..5c9da884f8 100644 --- a/ortools/sat/cp_model.cc +++ b/ortools/sat/cp_model.cc @@ -497,7 +497,7 @@ Constraint Constraint::WithName(absl::string_view name) { return *this; } -const std::string& Constraint::Name() const { return proto_->name(); } +absl::string_view Constraint::Name() const { return proto_->name(); } Constraint Constraint::OnlyEnforceIf(absl::Span literals) { for (const BoolVar& var : literals) { diff --git a/ortools/sat/cp_model.h b/ortools/sat/cp_model.h index 29c6cb1d39..70123343a0 100644 --- a/ortools/sat/cp_model.h +++ b/ortools/sat/cp_model.h @@ -556,7 +556,7 @@ class Constraint { Constraint WithName(absl::string_view name); /// Returns the name of the constraint (or the empty string if not set). - const std::string& Name() const; + absl::string_view Name() const; /// Returns the underlying protobuf object (useful for testing). const ConstraintProto& Proto() const { return *proto_; } diff --git a/ortools/sat/cp_model_expand.cc b/ortools/sat/cp_model_expand.cc index bd640620ec..731f136ebf 100644 --- a/ortools/sat/cp_model_expand.cc +++ b/ortools/sat/cp_model_expand.cc @@ -460,6 +460,44 @@ void ExpandInverse(ConstraintProto* ct, PresolveContext* context) { context->UpdateRuleStats("inverse: expanded"); } +void ExpandLinMaxWithTwoTerms(ConstraintProto* ct, PresolveContext* context) { + CHECK_EQ(ct->lin_max().exprs().size(), 2); + + // We will create 4 constraints for target = max(a, b). + // First. + // - target >= a. + // - target >= b. + for (const LinearExpressionProto& expr : ct->lin_max().exprs()) { + LinearConstraintProto* lin = + context->working_model->add_constraints()->mutable_linear(); + lin->add_domain(0); + lin->add_domain(std::numeric_limits::max()); + AddLinearExpressionToLinearConstraint(ct->lin_max().target(), 1, lin); + AddLinearExpressionToLinearConstraint(expr, -1, lin); + } + + // And then, a new boolean b, and + // - b => target == a + // - not(b) => target == b + const int new_bool = context->NewBoolVar(); + bool first_loop = true; + for (const LinearExpressionProto& expr : ct->lin_max().exprs()) { + ConstraintProto* new_ct = context->working_model->add_constraints(); + new_ct->add_enforcement_literal(first_loop ? new_bool + : NegatedRef(new_bool)); + first_loop = false; + + LinearConstraintProto* lin = new_ct->mutable_linear(); + lin->add_domain(0); + lin->add_domain(0); + AddLinearExpressionToLinearConstraint(ct->lin_max().target(), 1, lin); + AddLinearExpressionToLinearConstraint(expr, -1, lin); + } + + ct->Clear(); + context->UpdateRuleStats("lin_max: expanded lin_max with two terms"); +} + // A[V] == V means for all i, V == i => A_i == i void ExpandElementWithTargetEqualIndex(ConstraintProto* ct, PresolveContext* context) { @@ -2227,6 +2265,12 @@ void ExpandCpModel(PresolveContext* context) { ExpandPositiveTable(ct, context); } break; + case ConstraintProto::kLinMax: + if (context->params().expand_binary_lin_max() && + ct->lin_max().exprs().size() == 2) { + ExpandLinMaxWithTwoTerms(ct, context); + } + break; case ConstraintProto::kAllDiff: has_all_diffs = true; skip = true; diff --git a/ortools/sat/cp_model_loader.cc b/ortools/sat/cp_model_loader.cc index 32b72dd37f..09c9b13d7b 100644 --- a/ortools/sat/cp_model_loader.cc +++ b/ortools/sat/cp_model_loader.cc @@ -1253,6 +1253,27 @@ void LoadLinearConstraint(const ConstraintProto& ct, Model* m) { max_sum += std::max(term_a, term_b); } + // Load conditional precedences. + const SatParameters& params = *m->GetOrCreate(); + if (params.auto_detect_greater_than_at_least_one_of() && + ct.enforcement_literal().size() == 1 && vars.size() <= 2) { + // To avoid overflow in the code below, we tighten the bounds. + int64_t rhs_min = ct.linear().domain(0); + int64_t rhs_max = ct.linear().domain(ct.linear().domain().size() - 1); + rhs_min = std::max(rhs_min, min_sum.value()); + rhs_max = std::min(rhs_max, max_sum.value()); + + auto* detector = m->GetOrCreate(); + const Literal lit = mapping->Literal(ct.enforcement_literal(0)); + const Domain domain = ReadDomainFromProto(ct.linear()); + if (vars.size() == 1) { + detector->Add(lit, {vars[0], coeffs[0]}, {}, rhs_min, rhs_max); + } else if (vars.size() == 2) { + detector->Add(lit, {vars[0], coeffs[0]}, {vars[1], coeffs[1]}, rhs_min, + rhs_max); + } + } + // Load precedences. if (!HasEnforcementLiteral(ct)) { auto* precedences = m->GetOrCreate(); @@ -1311,7 +1332,6 @@ void LoadLinearConstraint(const ConstraintProto& ct, Model* m) { } } - const SatParameters& params = *m->GetOrCreate(); const IntegerValue domain_size_limit( params.max_domain_size_when_encoding_eq_neq_constraints()); if (ct.linear().vars_size() == 2 && !integer_trail->IsFixed(vars[0]) && diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index 6d2a3d0989..f1d5a4b003 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -1637,9 +1637,8 @@ void LoadCpModel(const CpModelProto& model_proto, Model* model) { // Note that we do that before we finish loading the problem (objective and // LP relaxation), because propagation will be faster at this point and it // should be enough for the purpose of this auto-detection. - if (model->Mutable() != nullptr && - parameters.auto_detect_greater_than_at_least_one_of()) { - model->Mutable() + if (parameters.auto_detect_greater_than_at_least_one_of()) { + model->GetOrCreate() ->AddGreaterThanAtLeastOneOfConstraints(model); if (!sat_solver->FinishPropagation()) return unsat(); } diff --git a/ortools/sat/drat_checker.cc b/ortools/sat/drat_checker.cc index caab5ef664..d1bd09736d 100644 --- a/ortools/sat/drat_checker.cc +++ b/ortools/sat/drat_checker.cc @@ -604,7 +604,7 @@ bool AddInferedAndDeletedClauses(const std::string& file_path, } bool PrintClauses(const std::string& file_path, SatFormat format, - const std::vector>& clauses, + absl::Span> clauses, int num_variables) { std::ofstream output_stream(file_path, std::ofstream::out); if (format == DIMACS) { diff --git a/ortools/sat/drat_checker.h b/ortools/sat/drat_checker.h index 11a5b11095..9069285ab4 100644 --- a/ortools/sat/drat_checker.h +++ b/ortools/sat/drat_checker.h @@ -336,7 +336,7 @@ enum SatFormat { // Prints the given clauses in the file at the given path, using the given file // format. Returns true iff the file was successfully written. bool PrintClauses(const std::string& file_path, SatFormat format, - const std::vector>& clauses, + absl::Span> clauses, int num_variables); } // namespace sat diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index a4629893bd..42f0968268 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -300,9 +300,9 @@ struct AffineExpression { AffineExpression(IntegerVariable v) // NOLINT(runtime/explicit) : var(v), coeff(1) {} AffineExpression(IntegerVariable v, IntegerValue c) - : var(c > 0 ? v : NegationOf(v)), coeff(IntTypeAbs(c)) {} + : var(c >= 0 ? v : NegationOf(v)), coeff(IntTypeAbs(c)) {} AffineExpression(IntegerVariable v, IntegerValue c, IntegerValue cst) - : var(c > 0 ? v : NegationOf(v)), coeff(IntTypeAbs(c)), constant(cst) {} + : var(c >= 0 ? v : NegationOf(v)), coeff(IntTypeAbs(c)), constant(cst) {} // Returns the integer literal corresponding to expression >= value or // expression <= value. diff --git a/ortools/sat/lb_tree_search.cc b/ortools/sat/lb_tree_search.cc index 7d7e69db45..59a855454c 100644 --- a/ortools/sat/lb_tree_search.cc +++ b/ortools/sat/lb_tree_search.cc @@ -47,7 +47,8 @@ namespace operations_research { namespace sat { LbTreeSearch::LbTreeSearch(Model* model) - : time_limit_(model->GetOrCreate()), + : name_(model->Name().empty() ? "lb_tree_search" : model->Name()), + time_limit_(model->GetOrCreate()), random_(model->GetOrCreate()), sat_solver_(model->GetOrCreate()), integer_encoder_(model->GetOrCreate()), @@ -298,8 +299,8 @@ SatSolver::Status LbTreeSearch::Search( const IntegerValue bound = nodes_[current_branch_[0]].MinObjective(); if (bound > current_objective_lb_) { shared_response_->UpdateInnerObjectiveBounds( - absl::StrCat("lb_tree_search (", SmallProgressString(), ") "), - bound, integer_trail_->LevelZeroUpperBound(objective_var_)); + absl::StrCat(name_, " (", SmallProgressString(), ") "), bound, + integer_trail_->LevelZeroUpperBound(objective_var_)); current_objective_lb_ = bound; if (VLOG_IS_ON(3)) DebugDisplayTree(current_branch_[0]); } diff --git a/ortools/sat/lb_tree_search.h b/ortools/sat/lb_tree_search.h index 8a72301aeb..60d7779b5b 100644 --- a/ortools/sat/lb_tree_search.h +++ b/ortools/sat/lb_tree_search.h @@ -138,6 +138,7 @@ class LbTreeSearch { std::string SmallProgressString() const; // Model singleton class used here. + const std::string name_; TimeLimit* time_limit_; ModelRandomGenerator* random_; SatSolver* sat_solver_; diff --git a/ortools/sat/linear_constraint.h b/ortools/sat/linear_constraint.h index f385e2d670..8b343343cb 100644 --- a/ortools/sat/linear_constraint.h +++ b/ortools/sat/linear_constraint.h @@ -15,6 +15,7 @@ #define OR_TOOLS_SAT_LINEAR_CONSTRAINT_H_ #include +#include #include #include #include diff --git a/ortools/sat/linear_programming_constraint.h b/ortools/sat/linear_programming_constraint.h index e3dd67785a..125d3ca856 100644 --- a/ortools/sat/linear_programming_constraint.h +++ b/ortools/sat/linear_programming_constraint.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include diff --git a/ortools/sat/linear_propagation.cc b/ortools/sat/linear_propagation.cc index 87c5fc7f33..007b87a154 100644 --- a/ortools/sat/linear_propagation.cc +++ b/ortools/sat/linear_propagation.cc @@ -549,13 +549,15 @@ bool LinearPropagator::Propagate() { // propagator might have pushed the same variable further. // // Empty FIFO queue. + bool result = true; + num_terms_for_dtime_update_ = 0; const int saved_index = trail_->Index(); while (!propagation_queue_.empty()) { const int id = propagation_queue_.Pop(); in_queue_[id] = false; if (!PropagateOneConstraint(id)) { - modified_vars_.ClearAndResize(integer_trail_->NumIntegerVariables()); - return false; + result = false; + break; } if (trail_->Index() > saved_index) { @@ -565,8 +567,10 @@ bool LinearPropagator::Propagate() { } // Clean-up modified_vars_ to do as little as possible on the next call. + time_limit_->AdvanceDeterministicTime( + static_cast(num_terms_for_dtime_update_) * 1e-9); modified_vars_.ClearAndResize(integer_trail_->NumIntegerVariables()); - return true; + return result; } // Adds a new constraint to the propagator. @@ -671,7 +675,11 @@ bool LinearPropagator::AddConstraint( } // Propagate this new constraint. - return PropagateOneConstraint(id); + num_terms_for_dtime_update_ = 0; + const bool result = PropagateOneConstraint(id); + time_limit_->AdvanceDeterministicTime( + static_cast(num_terms_for_dtime_update_) * 1e-9); + return result; } absl::Span LinearPropagator::GetCoeffs( @@ -689,8 +697,8 @@ absl::Span LinearPropagator::GetVariables( void LinearPropagator::CanonicalizeConstraint(int id) { const ConstraintInfo& info = infos_[id]; - auto coeffs = GetCoeffs(info); - auto vars = GetVariables(info); + const auto coeffs = GetCoeffs(info); + const auto vars = GetVariables(info); for (int i = 0; i < vars.size(); ++i) { if (coeffs[i] < 0) { coeffs[i] = -coeffs[i]; @@ -732,34 +740,60 @@ bool LinearPropagator::PropagateOneConstraint(int id) { // Compute the slack and max_variations_ of each variables. // We also filter out fixed variables in a reversible way. IntegerValue implied_lb(0); - auto vars = GetVariables(info); - auto coeffs = GetCoeffs(info); + const auto vars = GetVariables(info); IntegerValue max_variation(0); bool first_change = true; - time_limit_->AdvanceDeterministicTime(static_cast(info.rev_size) * - 1e-9); - for (int i = 0; i < info.rev_size;) { - const IntegerVariable var = vars[i]; - const IntegerValue coeff = coeffs[i]; - const IntegerValue lb = integer_trail_->LowerBound(var); - const IntegerValue ub = integer_trail_->UpperBound(var); - if (lb == ub) { - if (first_change) { - // Note that we can save at most one state per fixed var. Also at - // level zero we don't save anything. - rev_int_repository_->SaveState(&info.rev_size); - rev_integer_value_repository_->SaveState(&info.rev_rhs); - first_change = false; + num_terms_for_dtime_update_ += info.rev_size; + IntegerValue* max_variations = max_variations_.data(); + if (info.all_coeffs_are_one) { + // TODO(user): Avoid duplication? + for (int i = 0; i < info.rev_size;) { + const IntegerVariable var = vars[i]; + const IntegerValue lb = integer_trail_->LowerBound(var); + const IntegerValue ub = integer_trail_->UpperBound(var); + if (lb == ub) { + if (first_change) { + // Note that we can save at most one state per fixed var. Also at + // level zero we don't save anything. + rev_int_repository_->SaveState(&info.rev_size); + rev_integer_value_repository_->SaveState(&info.rev_rhs); + first_change = false; + } + info.rev_size--; + std::swap(vars[i], vars[info.rev_size]); + info.rev_rhs -= lb; + } else { + implied_lb += lb; + max_variations[i] = (ub - lb); + max_variation = std::max(max_variation, max_variations[i]); + ++i; + } + } + } else { + const auto coeffs = GetCoeffs(info); + for (int i = 0; i < info.rev_size;) { + const IntegerVariable var = vars[i]; + const IntegerValue coeff = coeffs[i]; + const IntegerValue lb = integer_trail_->LowerBound(var); + const IntegerValue ub = integer_trail_->UpperBound(var); + if (lb == ub) { + if (first_change) { + // Note that we can save at most one state per fixed var. Also at + // level zero we don't save anything. + rev_int_repository_->SaveState(&info.rev_size); + rev_integer_value_repository_->SaveState(&info.rev_rhs); + first_change = false; + } + info.rev_size--; + std::swap(vars[i], vars[info.rev_size]); + std::swap(coeffs[i], coeffs[info.rev_size]); + info.rev_rhs -= coeff * lb; + } else { + implied_lb += coeff * lb; + max_variations[i] = (ub - lb) * coeff; + max_variation = std::max(max_variation, max_variations[i]); + ++i; } - info.rev_size--; - std::swap(vars[i], vars[info.rev_size]); - std::swap(coeffs[i], coeffs[info.rev_size]); - info.rev_rhs -= coeff * lb; - } else { - implied_lb += coeff * lb; - max_variations_[i] = (ub - lb) * coeff; - max_variation = std::max(max_variation, max_variations_[i]); - ++i; } } const IntegerValue slack = info.rev_rhs - implied_lb; @@ -770,6 +804,7 @@ bool LinearPropagator::PropagateOneConstraint(int id) { // Fill integer reason. integer_reason_.clear(); reason_coeffs_.clear(); + const auto coeffs = GetCoeffs(info); for (int i = 0; i < info.initial_size; ++i) { const IntegerVariable var = vars[i]; if (!integer_trail_->VariableLowerBoundIsFromLevelZero(var)) { @@ -794,8 +829,9 @@ bool LinearPropagator::PropagateOneConstraint(int id) { // The lower bound of all the variables except one can be used to update the // upper bound of the last one. int num_pushed = 0; + const auto coeffs = GetCoeffs(info); for (int i = 0; i < info.rev_size; ++i) { - if (max_variations_[i] <= slack) continue; + if (max_variations[i] <= slack) continue; // TODO(user): If the new ub fall into an hole of the variable, we can // actually relax the reason more by computing a better slack. @@ -817,8 +853,8 @@ bool LinearPropagator::PropagateOneConstraint(int id) { literal_reason); reason_coeffs_.clear(); - auto coeffs = GetCoeffs(info); - auto vars = GetVariables(info); + const auto coeffs = GetCoeffs(info); + const auto vars = GetVariables(info); for (int i = 0; i < info.initial_size; ++i) { const IntegerVariable var = vars[i]; if (PositiveVariable(var) == PositiveVariable(i_lit.var)) { @@ -873,8 +909,8 @@ bool LinearPropagator::PropagateOneConstraint(int id) { std::string LinearPropagator::ConstraintDebugString(int id) { std::string result; const ConstraintInfo& info = infos_[id]; - auto coeffs = GetCoeffs(info); - auto vars = GetVariables(info); + const auto coeffs = GetCoeffs(info); + const auto vars = GetVariables(info); IntegerValue implied_lb(0); IntegerValue rhs_correction(0); for (int i = 0; i < info.initial_size; ++i) { @@ -908,8 +944,8 @@ bool LinearPropagator::ReportConflictingCycle() { const ConstraintInfo& info = infos_[id]; enforcement_propagator_->AddEnforcementReason(info.enf_id, &literal_reason_); - auto coeffs = GetCoeffs(info); - auto vars = GetVariables(info); + const auto coeffs = GetCoeffs(info); + const auto vars = GetVariables(info); IntegerValue rhs_correction(0); for (int i = 0; i < info.initial_size; ++i) { if (i >= info.rev_size) { @@ -1033,7 +1069,7 @@ bool LinearPropagator::DisassembleSubtree(int root_id, int num_pushed) { disassemble_branch_.clear(); { const ConstraintInfo& info = infos_[root_id]; - auto vars = GetVariables(info); + const auto vars = GetVariables(info); for (int i = 0; i < num_pushed; ++i) { disassemble_queue_.push_back({root_id, NegationOf(vars[i])}); } @@ -1041,6 +1077,7 @@ bool LinearPropagator::DisassembleSubtree(int root_id, int num_pushed) { // Note that all var should be unique since there is only one propagated_by_ // for each one. And each time we explore an id, we disassemble the tree. + absl::Span id_to_count = absl::MakeSpan(id_to_propagation_count_); while (!disassemble_queue_.empty()) { const auto [prev_id, var] = disassemble_queue_.back(); if (!disassemble_branch_.empty() && @@ -1081,16 +1118,16 @@ bool LinearPropagator::DisassembleSubtree(int root_id, int num_pushed) { // variation in slack might be big enough to push a variable twice and // thus push a lower coeff. const ConstraintInfo& info = infos_[id]; - auto coeffs = GetCoeffs(info); - auto vars = GetVariables(info); + const auto coeffs = GetCoeffs(info); + const auto vars = GetVariables(info); IntegerValue root_coeff(0); IntegerValue var_coeff(0); for (int i = 0; i < info.initial_size; ++i) { if (vars[i] == var) var_coeff = coeffs[i]; if (vars[i] == NegationOf(root_var)) root_coeff = coeffs[i]; } - CHECK_NE(root_coeff, 0); - CHECK_NE(var_coeff, 0); + DCHECK_NE(root_coeff, 0); + DCHECK_NE(var_coeff, 0); if (var_coeff >= root_coeff) { return ReportConflictingCycle(); } else { @@ -1099,15 +1136,15 @@ bool LinearPropagator::DisassembleSubtree(int root_id, int num_pushed) { } } - if (id_to_propagation_count_[id] == 0) continue; // Didn't push. + if (id_to_count[id] == 0) continue; // Didn't push. disassemble_to_reorder_.Set(id); // The constraint pushed some variable. Identify which ones will be pushed // further. Disassemble the whole info since we are about to propagate // this constraint again. Any pushed variable must be before the rev_size. const ConstraintInfo& info = infos_[id]; - auto coeffs = GetCoeffs(info); - auto vars = GetVariables(info); + const auto coeffs = GetCoeffs(info); + const auto vars = GetVariables(info); IntegerValue var_coeff(0); disassemble_candidates_.clear(); ++num_explored_in_disassemble_; @@ -1124,7 +1161,7 @@ bool LinearPropagator::DisassembleSubtree(int root_id, int num_pushed) { // We will propagate var again later, so clear all this for now. propagated_by_[next_var] = -1; - id_to_propagation_count_[id]--; + id_to_count[id]--; } } for (const auto [next_var, coeff] : disassemble_candidates_) { @@ -1152,7 +1189,7 @@ bool LinearPropagator::DisassembleSubtree(int root_id, int num_pushed) { tmp_to_reorder_.push_back(id); } - // TODO(user): Reordering can be sloe since require sort and can touch many + // TODO(user): Reordering can be slow since require sort and can touch many // entries. Investigate alternatives. We could probably optimize this a bit // more. if (tmp_to_reorder_.empty()) return true; diff --git a/ortools/sat/linear_propagation.h b/ortools/sat/linear_propagation.h index 17d7aa6e08..0163a8acd1 100644 --- a/ortools/sat/linear_propagation.h +++ b/ortools/sat/linear_propagation.h @@ -254,9 +254,6 @@ class LinearPropagator : public PropagatorInterface, ReversibleInterface { std::deque infos_; // Buffer of the constraints data. - // - // TODO(user): A lot of constrains have all their coeffs at one, we could - // exploit this. std::vector variables_buffer_; std::vector coeffs_buffer_; std::vector buffer_of_ones_; @@ -315,6 +312,9 @@ class LinearPropagator : public PropagatorInterface, ReversibleInterface { SparseBitset id_scanned_at_least_once_; int64_t num_extra_scans_ = 0; + // This is used to update the deterministic time. + int64_t num_terms_for_dtime_update_ = 0; + // Stats. int64_t num_pushes_ = 0; int64_t num_enforcement_pushes_ = 0; diff --git a/ortools/sat/optimization.cc b/ortools/sat/optimization.cc index 72f0eaace0..19efe38d67 100644 --- a/ortools/sat/optimization.cc +++ b/ortools/sat/optimization.cc @@ -666,9 +666,8 @@ bool CoreBasedOptimizer::CoverOptimization() { } SatSolver::Status CoreBasedOptimizer::OptimizeWithSatEncoding( - const std::vector& literals, - const std::vector& vars, - const std::vector& coefficients, Coefficient offset) { + absl::Span literals, absl::Span vars, + absl::Span coefficients, Coefficient offset) { // Create one initial nodes per variables with cost. // TODO(user): We could create EncodingNode out of IntegerVariable. // diff --git a/ortools/sat/optimization.h b/ortools/sat/optimization.h index 5820783906..03657da0c1 100644 --- a/ortools/sat/optimization.h +++ b/ortools/sat/optimization.h @@ -114,9 +114,9 @@ class CoreBasedOptimizer { // - Support resuming for interleaved search. // - Implement all core heurisitics. SatSolver::Status OptimizeWithSatEncoding( - const std::vector& literals, - const std::vector& vars, - const std::vector& coefficients, Coefficient offset); + absl::Span literals, + absl::Span vars, + absl::Span coefficients, Coefficient offset); private: CoreBasedOptimizer(const CoreBasedOptimizer&) = delete; diff --git a/ortools/sat/precedences.cc b/ortools/sat/precedences.cc index 291cf2fff4..50e01cf5ad 100644 --- a/ortools/sat/precedences.cc +++ b/ortools/sat/precedences.cc @@ -42,6 +42,7 @@ #include "ortools/sat/sat_solver.h" #include "ortools/sat/synchronization.h" #include "ortools/util/bitset.h" +#include "ortools/util/logging.h" #include "ortools/util/strong_integers.h" #include "ortools/util/time_limit.h" @@ -1021,112 +1022,174 @@ bool PrecedencesPropagator::BellmanFordTarjan(Trail* trail) { return true; } -int PrecedencesPropagator::AddGreaterThanAtLeastOneOfConstraintsFromClause( - const absl::Span clause, Model* model) { +void GreaterThanAtLeastOneOfDetector::Add(Literal lit, LinearTerm a, + LinearTerm b, IntegerValue lhs, + IntegerValue rhs) { + Relation r; + r.enforcement = lit; + r.a = a; + r.b = b; + r.lhs = lhs; + r.rhs = rhs; + + // We shall only consider positive variable here. + if (r.a.var != kNoIntegerVariable && !VariableIsPositive(r.a.var)) { + r.a.var = NegationOf(r.a.var); + r.a.coeff = -r.a.coeff; + } + if (r.b.var != kNoIntegerVariable && !VariableIsPositive(r.b.var)) { + r.b.var = NegationOf(r.b.var); + r.b.coeff = -r.b.coeff; + } + + const int index = relations_.size(); + relations_.push_back(std::move(r)); + + if (lit.Index() >= lit_to_relations_.size()) { + lit_to_relations_.resize(lit.Index() + 1); + } + lit_to_relations_[lit.Index()].push_back(index); +} + +bool GreaterThanAtLeastOneOfDetector::AddRelationFromIndices( + IntegerVariable var, absl::Span clause, + absl::Span indices, Model* model) { + std::vector exprs; + std::vector selectors; + absl::flat_hash_set used; + auto* integer_trail = model->GetOrCreate(); + + const IntegerValue var_lb = integer_trail->LevelZeroLowerBound(var); + for (const int index : indices) { + Relation r = relations_[index]; + if (r.a.var != PositiveVariable(var)) std::swap(r.a, r.b); + CHECK_EQ(r.a.var, PositiveVariable(var)); + + if ((r.a.coeff == 1) == VariableIsPositive(var)) { + // a + b >= lhs + if (r.lhs <= kMinIntegerValue) continue; + exprs.push_back(AffineExpression(r.b.var, -r.b.coeff, r.lhs)); + } else { + // -a + b <= rhs. + if (r.rhs >= kMaxIntegerValue) continue; + exprs.push_back(AffineExpression(r.b.var, r.b.coeff, -r.rhs)); + } + + // Ignore this entry if it is always true. + if (var_lb >= integer_trail->LevelZeroUpperBound(exprs.back())) { + exprs.pop_back(); + continue; + } + + // Note that duplicate selector are supported. + selectors.push_back(r.enforcement); + used.insert(r.enforcement); + } + + // The enforcement of the new constraint are simply the literal not used + // above. + std::vector enforcements; + for (const Literal l : clause) { + if (!used.contains(l.Index())) { + enforcements.push_back(l.Negated()); + } + } + + // No point adding a constraint if there is not at least two different + // literals in selectors. + if (used.size() <= 1) return false; + + // Add the constraint. + GreaterThanAtLeastOneOfPropagator* constraint = + new GreaterThanAtLeastOneOfPropagator(var, exprs, selectors, enforcements, + model); + constraint->RegisterWith(model->GetOrCreate()); + model->TakeOwnership(constraint); + return true; +} + +int GreaterThanAtLeastOneOfDetector:: + AddGreaterThanAtLeastOneOfConstraintsFromClause( + const absl::Span clause, Model* model) { CHECK_EQ(model->GetOrCreate()->CurrentDecisionLevel(), 0); if (clause.size() < 2) return 0; - // Collect all arcs impacted by this clause. - std::vector infos; + // Collect all relations impacted by this clause. + std::vector> infos; for (const Literal l : clause) { - if (l.Index() >= literal_to_new_impacted_arcs_.size()) continue; - for (const ArcIndex arc_index : literal_to_new_impacted_arcs_[l.Index()]) { - const ArcInfo& arc = arcs_[arc_index]; - if (arc.presence_literals.size() != 1) continue; - - // TODO(user): Support variable offset. - if (arc.offset_var != kNoIntegerVariable) continue; - infos.push_back(arc); + if (l.Index() >= lit_to_relations_.size()) continue; + for (const int index : lit_to_relations_[l.Index()]) { + const Relation& r = relations_[index]; + if (r.a.var != kNoIntegerVariable && IntTypeAbs(r.a.coeff) == 1) { + infos.push_back({r.a.var, index}); + } + if (r.b.var != kNoIntegerVariable && IntTypeAbs(r.b.coeff) == 1) { + infos.push_back({r.b.var, index}); + } } } if (infos.size() <= 1) return 0; - // Stable sort by head_var so that for a same head_var, the entry are sorted - // by Literal as they appear in clause. - std::stable_sort(infos.begin(), infos.end(), - [](const ArcInfo& a, const ArcInfo& b) { - return a.head_var < b.head_var; - }); + // Stable sort to regroup by var. + std::stable_sort(infos.begin(), infos.end()); - // We process ArcInfo with the same head_var toghether. + // We process the info with same variable together. int num_added_constraints = 0; - auto* solver = model->GetOrCreate(); + std::vector indices; for (int i = 0; i < infos.size();) { const int start = i; - const IntegerVariable head_var = infos[start].head_var; - for (i++; i < infos.size() && infos[i].head_var == head_var; ++i) { - } - const absl::Span arcs(&infos[start], i - start); + const IntegerVariable var = infos[start].first; - // Skip single arcs since it will already be fully propagated. - if (arcs.size() < 2) continue; - - // Heuristic. Look for full or almost full clauses. We could add - // GreaterThanAtLeastOneOf() with more enforcement literals. TODO(user): - // experiments. - if (arcs.size() + 1 < clause.size()) continue; - - std::vector vars; - std::vector offsets; - std::vector selectors; - std::vector enforcements; - - int j = 0; - for (const Literal l : clause) { - bool added = false; - for (; j < arcs.size() && l == arcs[j].presence_literals.front(); ++j) { - added = true; - vars.push_back(arcs[j].tail_var); - offsets.push_back(arcs[j].offset); - - // Note that duplicate selector are supported. - // - // TODO(user): If we support variable offset, we should regroup the arcs - // into one (tail + offset <= head) though, instead of having too - // identical entries. - selectors.push_back(l); - } - if (!added) { - enforcements.push_back(l.Negated()); - } + indices.clear(); + for (; i < infos.size() && infos[i].first == var; ++i) { + indices.push_back(infos[i].second); } - // No point adding a constraint if there is not at least two different - // literals in selectors. - if (enforcements.size() + 1 == clause.size()) continue; + // Skip single relations, we are not interested in these. + if (indices.size() < 2) continue; - ++num_added_constraints; - model->Add(GreaterThanAtLeastOneOf(head_var, vars, offsets, selectors, - enforcements)); - if (!solver->FinishPropagation()) return num_added_constraints; + // Heuristic. Look for full or almost full clauses. + // + // TODO(user): We could add GreaterThanAtLeastOneOf() with more enforcement + // literals. Experiment. + if (indices.size() + 1 < clause.size()) continue; + + if (AddRelationFromIndices(var, clause, indices, model)) { + ++num_added_constraints; + } + if (AddRelationFromIndices(NegationOf(var), clause, indices, model)) { + ++num_added_constraints; + } } return num_added_constraints; } -int PrecedencesPropagator:: +int GreaterThanAtLeastOneOfDetector:: AddGreaterThanAtLeastOneOfConstraintsWithClauseAutoDetection(Model* model) { auto* time_limit = model->GetOrCreate(); auto* solver = model->GetOrCreate(); - // Fill the set of incoming conditional arcs for each variables. - absl::StrongVector> incoming_arcs_; - for (ArcIndex arc_index(0); arc_index < arcs_.size(); ++arc_index) { - const ArcInfo& arc = arcs_[arc_index]; - - // Only keep arc that have a fixed offset and a single presence_literals. - if (arc.offset_var != kNoIntegerVariable) continue; - if (arc.tail_var == arc.head_var) continue; - if (arc.presence_literals.size() != 1) continue; - - if (arc.head_var >= incoming_arcs_.size()) { - incoming_arcs_.resize(arc.head_var.value() + 1); + // Fill the set of interesting relations for each variables. + absl::StrongVector> var_to_relations; + for (int index = 0; index < relations_.size(); ++index) { + const Relation& r = relations_[index]; + if (r.a.var != kNoIntegerVariable && IntTypeAbs(r.a.coeff) == 1) { + if (r.a.var >= var_to_relations.size()) { + var_to_relations.resize(r.a.var + 1); + } + var_to_relations[r.a.var].push_back(index); + } + if (r.b.var != kNoIntegerVariable && IntTypeAbs(r.b.coeff) == 1) { + if (r.b.var >= var_to_relations.size()) { + var_to_relations.resize(r.b.var + 1); + } + var_to_relations[r.b.var].push_back(index); } - incoming_arcs_[arc.head_var].push_back(arc_index); } int num_added_constraints = 0; - for (IntegerVariable target(0); target < incoming_arcs_.size(); ++target) { - if (incoming_arcs_[target].size() <= 1) continue; + for (IntegerVariable target(0); target < var_to_relations.size(); ++target) { + if (var_to_relations[target].size() <= 1) continue; if (time_limit->LimitReached()) return num_added_constraints; // Detect set of incoming arcs for which at least one must be present. @@ -1135,55 +1198,56 @@ int PrecedencesPropagator:: solver->Backtrack(0); if (solver->ModelIsUnsat()) return num_added_constraints; std::vector clause; - for (const ArcIndex arc_index : incoming_arcs_[target]) { - const Literal literal = arcs_[arc_index].presence_literals.front(); + for (const int index : var_to_relations[target]) { + const Literal literal = relations_[index].enforcement; if (solver->Assignment().LiteralIsFalse(literal)) continue; const SatSolver::Status status = solver->EnqueueDecisionAndBacktrackOnConflict(literal.Negated()); if (status == SatSolver::INFEASIBLE) return num_added_constraints; if (status == SatSolver::ASSUMPTIONS_UNSAT) { + // We need to invert it, since a clause is not all false. clause = solver->GetLastIncompatibleDecisions(); + for (Literal& ref : clause) ref = ref.Negated(); break; } } solver->Backtrack(0); + if (clause.size() <= 1) continue; - if (clause.size() > 1) { - // Extract the set of arc for which at least one must be present. - const absl::btree_set clause_set(clause.begin(), clause.end()); - std::vector arcs_in_clause; - for (const ArcIndex arc_index : incoming_arcs_[target]) { - const Literal literal(arcs_[arc_index].presence_literals.front()); - if (clause_set.contains(literal.Negated())) { - arcs_in_clause.push_back(arc_index); - } + // Recover the indices corresponding to this clause. + const absl::btree_set clause_set(clause.begin(), clause.end()); + + std::vector indices; + for (const int index : var_to_relations[target]) { + const Literal literal = relations_[index].enforcement; + if (clause_set.contains(literal)) { + indices.push_back(index); } + } - VLOG(2) << arcs_in_clause.size() << "/" << incoming_arcs_[target].size(); - + // Try both direction. + if (AddRelationFromIndices(target, clause, indices, model)) { + ++num_added_constraints; + } + if (AddRelationFromIndices(NegationOf(target), clause, indices, model)) { ++num_added_constraints; - std::vector vars; - std::vector offsets; - std::vector selectors; - for (const ArcIndex a : arcs_in_clause) { - vars.push_back(arcs_[a].tail_var); - offsets.push_back(arcs_[a].offset); - selectors.push_back(Literal(arcs_[a].presence_literals.front())); - } - model->Add(GreaterThanAtLeastOneOf(target, vars, offsets, selectors, {})); - if (!solver->FinishPropagation()) return num_added_constraints; } } + solver->Backtrack(0); return num_added_constraints; } -int PrecedencesPropagator::AddGreaterThanAtLeastOneOfConstraints(Model* model) { - VLOG(1) << "Detecting GreaterThanAtLeastOneOf() constraints..."; +int GreaterThanAtLeastOneOfDetector::AddGreaterThanAtLeastOneOfConstraints( + Model* model, bool auto_detect_clauses) { auto* time_limit = model->GetOrCreate(); auto* solver = model->GetOrCreate(); auto* clauses = model->GetOrCreate(); + auto* logger = model->GetOrCreate(); + int num_added_constraints = 0; + SOLVER_LOG(logger, "[Precedences] num_relations=", relations_.size(), + " num_clauses=", clauses->AllClausesInCreationOrder().size()); // We have two possible approaches. For now, we prefer the first one except if // there is too many clauses in the problem. @@ -1191,7 +1255,8 @@ int PrecedencesPropagator::AddGreaterThanAtLeastOneOfConstraints(Model* model) { // TODO(user): Do more extensive experiment. Remove the second approach as // it is more time consuming? or identify when it make sense. Note that the // first approach also allows to use "incomplete" at least one between arcs. - if (clauses->AllClausesInCreationOrder().size() < 1e6) { + if (!auto_detect_clauses && + clauses->AllClausesInCreationOrder().size() < 1e6) { // TODO(user): This does not take into account clause of size 2 since they // are stored in the BinaryImplicationGraph instead. Some ideas specific // to size 2: @@ -1229,10 +1294,14 @@ int PrecedencesPropagator::AddGreaterThanAtLeastOneOfConstraints(Model* model) { } if (num_added_constraints > 0) { - SOLVER_LOG(model->GetOrCreate(), "[Precedences] Added ", - num_added_constraints, + SOLVER_LOG(logger, "[Precedences] Added ", num_added_constraints, " GreaterThanAtLeastOneOf() constraints."); } + + // Release the memory, it is not longer needed. + gtl::STLClearObject(&relations_); + gtl::STLClearObject(&lit_to_relations_); + return num_added_constraints; } diff --git a/ortools/sat/precedences.h b/ortools/sat/precedences.h index 6b329fd9df..41245d5395 100644 --- a/ortools/sat/precedences.h +++ b/ortools/sat/precedences.h @@ -234,16 +234,6 @@ class PrecedencesPropagator : public SatPropagator, PropagatorInterface { void ComputePartialPrecedences(const std::vector& vars, std::vector* output); - // Advanced usage. To be called once all the constraints have been added to - // the model. This will loop over all "node" in this class, and if one of its - // optional incoming arcs must be chosen, it will add a corresponding - // GreaterThanAtLeastOneOfConstraint(). Returns the number of added - // constraint. - // - // TODO(user): This can be quite slow, add some kind of deterministic limit - // so that we can use it all the time. - int AddGreaterThanAtLeastOneOfConstraints(Model* model); - // If known, return an offset such that we have a + offset <= b. // Note that this only cover the case where this was conditionned by a single // literal. @@ -262,19 +252,6 @@ class PrecedencesPropagator : public SatPropagator, PropagatorInterface { DEFINE_STRONG_INDEX_TYPE(ArcIndex); DEFINE_STRONG_INDEX_TYPE(OptionalArcIndex); - // Given an existing clause, sees if it can be used to add "greater than at - // least one of" type of constraints. Returns the number of such constraint - // added. - int AddGreaterThanAtLeastOneOfConstraintsFromClause( - absl::Span clause, Model* model); - - // Another approach for AddGreaterThanAtLeastOneOfConstraints(), this one - // might be a bit slow as it relies on the propagation engine to detect - // clauses between incoming arcs presence literals. - // Returns the number of added constraints. - int AddGreaterThanAtLeastOneOfConstraintsWithClauseAutoDetection( - Model* model); - // Information about an individual arc. struct ArcInfo { IntegerVariable tail_var; @@ -432,6 +409,66 @@ class PrecedencesPropagator : public SatPropagator, PropagatorInterface { int64_t num_enforcement_pushes_ = 0; }; +// Similar to AffineExpression, but with a zero constant. +// If coeff is zero, then this is always zero and var is ignored. +struct LinearTerm { + IntegerVariable var = kNoIntegerVariable; + IntegerValue coeff = IntegerValue(0); +}; + +// This collect all enforced linear of size 2 or 1 and detect if at least one of +// a subset touching the same variable must be true. When this is the case +// we add a new propagator to propagate that fact. +// +// TODO(user): Shall we do that on the main thread before the workers are +// spawned? note that the probing version need the model to be loaded though. +class GreaterThanAtLeastOneOfDetector { + public: + // Adds a relation lit => a + b \in [lhs, rhs]. + void Add(Literal lit, LinearTerm a, LinearTerm b, IntegerValue lhs, + IntegerValue rhs); + + // Advanced usage. To be called once all the constraints have been added to + // the model. This will detect GreaterThanAtLeastOneOfConstraint(). + // Returns the number of added constraint. + // + // TODO(user): This can be quite slow, add some kind of deterministic limit + // so that we can use it all the time. + int AddGreaterThanAtLeastOneOfConstraints(Model* model, + bool auto_detect_clauses = false); + + private: + // Given an existing clause, sees if it can be used to add "greater than at + // least one of" type of constraints. Returns the number of such constraint + // added. + int AddGreaterThanAtLeastOneOfConstraintsFromClause( + absl::Span clause, Model* model); + + // Another approach for AddGreaterThanAtLeastOneOfConstraints(), this one + // might be a bit slow as it relies on the propagation engine to detect + // clauses between incoming arcs presence literals. + // Returns the number of added constraints. + int AddGreaterThanAtLeastOneOfConstraintsWithClauseAutoDetection( + Model* model); + + // Once we identified a clause and relevant indices, this build the + // constraint. Returns true if we actually add it. + bool AddRelationFromIndices(IntegerVariable var, + absl::Span clause, + absl::Span indices, Model* model); + + struct Relation { + Literal enforcement; + LinearTerm a; + LinearTerm b; + IntegerValue lhs; + IntegerValue rhs; + }; + + std::vector relations_; + absl::StrongVector> lit_to_relations_; +}; + // ============================================================================= // Implementation of the small API functions below. // ============================================================================= diff --git a/ortools/sat/probing.cc b/ortools/sat/probing.cc index c661ae2ef4..d4d24ef56a 100644 --- a/ortools/sat/probing.cc +++ b/ortools/sat/probing.cc @@ -300,7 +300,7 @@ bool Prober::ProbeBooleanVariables( } bool Prober::ProbeDnf(absl::string_view name, - const std::vector>& dnf) { + absl::Span> dnf) { if (dnf.size() <= 1) return true; // Reset the solver in case it was already used. diff --git a/ortools/sat/probing.h b/ortools/sat/probing.h index 56def373d2..801951896b 100644 --- a/ortools/sat/probing.h +++ b/ortools/sat/probing.h @@ -87,7 +87,7 @@ class Prober { // the conjunction must be true, we might be able to fix literal or improve // integer bounds if all conjunction propagate the same thing. bool ProbeDnf(absl::string_view name, - const std::vector>& dnf); + absl::Span> dnf); // Statistics. // They are reset each time ProbleBooleanVariables() is called. diff --git a/ortools/sat/python/cp_model.py b/ortools/sat/python/cp_model.py index 23c2ffcba0..76b1e1e9bf 100644 --- a/ortools/sat/python/cp_model.py +++ b/ortools/sat/python/cp_model.py @@ -66,6 +66,7 @@ from typing import ( ) import warnings +import numpy as np import pandas as pd from ortools.sat import cp_model_pb2 @@ -119,16 +120,48 @@ PARTIAL_FIXED_SEARCH = sat_parameters_pb2.SatParameters.PARTIAL_FIXED_SEARCH RANDOMIZED_SEARCH = sat_parameters_pb2.SatParameters.RANDOMIZED_SEARCH # Type aliases -# We need to add int to numbers.Integral -IntegralT = Union[numbers.Integral, int] -# We need to add int and float, otherwise type checkers complain. -NumberT = Union[int, numbers.Number, float] +IntegralT = Union[int, np.int8, np.uint8, np.int32, np.uint32, np.int64_t, np.uint64] +IntegralTypes = ( + int, + np.int8, + np.uint8, + np.int32, + np.uint32, + np.int64_t, + np.uint64_t, +) +NumberT = Union[ + int, + float, + np.int8, + np.uint8, + np.int32, + np.uint32, + np.int64_t, + np.uint64_t, + np.double, +] +NumberTypes = ( + int, + float, + np.int8, + np.uint8, + np.int32, + np.uint32, + np.int64_t, + np.uint64_t, + np.double, +) + LiteralT = Union["IntVar", "_NotBooleanVariable", IntegralT, bool] BoolVarT = Union["IntVar", "_NotBooleanVariable"] VariableT = Union["IntVar", IntegralT] -LinearExprT = Union["LinearExpr", IntegralT] + +# We need to add 'IntVar' for pytype. +LinearExprT = Union["LinearExpr", "IntVar", IntegralT] ObjLinearExprT = Union["LinearExpr", NumberT] BoundedLinearExprT = Union["BoundedLinearExpression", bool] + ArcT = Tuple[IntegralT, IntegralT, LiteralT] _IndexOrSeries = Union[pd.Index, pd.Series] @@ -311,14 +344,16 @@ class LinearExpr: else: return _WeightedSum(variables, coeffs, offset) - def get_integer_var_value_map(self) -> Tuple[Dict["IntVar", IntegralT], int]: + def get_integer_var_value_map(self) -> Tuple[Dict["IntVar", int], int]: """Scans the expression, and returns (var_coef_map, constant).""" - coeffs = collections.defaultdict(int) + coeffs: Dict["IntVar", int] = collections.defaultdict(int) constant = 0 - to_process: List[Tuple[LinearExprT, IntegralT]] = [(self, 1)] + to_process: List[Tuple[LinearExprT, int]] = [(self, 1)] while to_process: # Flatten to avoid recursion. + expr: LinearExprT + coeff: int expr, coeff = to_process.pop() - if isinstance(expr, numbers.Integral): + if isinstance(expr, IntegralTypes): constant += coeff * int(expr) elif isinstance(expr, _ProductCst): to_process.append((expr.expression(), coeff * expr.coefficient())) @@ -347,14 +382,14 @@ class LinearExpr: self, ) -> Tuple[Dict["IntVar", float], float, bool]: """Scans the expression. Returns (var_coef_map, constant, is_integer).""" - coeffs = {} - constant = 0 - to_process: List[Tuple[LinearExprT, Union[IntegralT, float]]] = [(self, 1)] + coeffs: Dict["IntVar", Union[int, float]] = {} + constant: Union[int, float] = 0 + to_process: List[Tuple[LinearExprT, Union[int, float]]] = [(self, 1)] while to_process: # Flatten to avoid recursion. expr, coeff = to_process.pop() - if isinstance(expr, numbers.Integral): # Keep integrality. + if isinstance(expr, IntegralTypes): # Keep integrality. constant += coeff * int(expr) - elif isinstance(expr, numbers.Number): + elif isinstance(expr, NumberTypes): constant += coeff * float(expr) elif isinstance(expr, _ProductCst): to_process.append((expr.expression(), coeff * expr.coefficient())) @@ -382,10 +417,10 @@ class LinearExpr: coeffs[expr.negated()] = -coeff else: raise TypeError("Unrecognized linear expression: " + str(expr)) - is_integer = isinstance(constant, numbers.Integral) + is_integer = isinstance(constant, IntegralTypes) if is_integer: for coeff in coeffs.values(): - if not isinstance(coeff, numbers.Integral): + if not isinstance(coeff, IntegralTypes): is_integer = False break return coeffs, constant, is_integer @@ -421,9 +456,7 @@ class LinearExpr: ... def __radd__(self, arg): - if cmh.is_zero(arg): - return self - return _Sum(self, arg) + return self.__add__(arg) @overload def __sub__(self, arg: "LinearExpr") -> "LinearExpr": @@ -436,7 +469,7 @@ class LinearExpr: def __sub__(self, arg): if cmh.is_zero(arg): return self - if isinstance(arg, numbers.Number): + if isinstance(arg, NumberTypes): arg = cmh.assert_is_a_number(arg) return _Sum(self, -arg) else: @@ -478,12 +511,7 @@ class LinearExpr: ... def __rmul__(self, arg): - arg = cmh.assert_is_a_number(arg) - if cmh.is_one(arg): - return self - elif cmh.is_zero(arg): - return 0 - return _ProductCst(self, arg) + return self.__mul__(arg) def __div__(self, _) -> NoReturn: raise NotImplementedError( @@ -545,31 +573,33 @@ class LinearExpr: "Evaluating a LinearExpr instance as a Boolean is not implemented." ) - def __eq__(self, arg: LinearExprT) -> BoundedLinearExprT: + def __eq__(self, arg: LinearExprT) -> BoundedLinearExprT: # type: ignore[override] if arg is None: return False - if isinstance(arg, numbers.Integral): + if isinstance(arg, IntegralTypes): arg = cmh.assert_is_int64(arg) return BoundedLinearExpression(self, [arg, arg]) - else: + elif isinstance(arg, LinearExpr): return BoundedLinearExpression(self - arg, [0, 0]) + else: + return False - def __ge__(self, arg: LinearExprT) -> BoundedLinearExprT: - if isinstance(arg, numbers.Integral): + def __ge__(self, arg: LinearExprT) -> "BoundedLinearExpression": + if isinstance(arg, IntegralTypes): arg = cmh.assert_is_int64(arg) return BoundedLinearExpression(self, [arg, INT_MAX]) else: return BoundedLinearExpression(self - arg, [0, INT_MAX]) - def __le__(self, arg: LinearExprT) -> BoundedLinearExprT: - if isinstance(arg, numbers.Integral): + def __le__(self, arg: LinearExprT) -> "BoundedLinearExpression": + if isinstance(arg, IntegralTypes): arg = cmh.assert_is_int64(arg) return BoundedLinearExpression(self, [INT_MIN, arg]) else: return BoundedLinearExpression(self - arg, [INT_MIN, 0]) - def __lt__(self, arg: LinearExprT) -> BoundedLinearExprT: - if isinstance(arg, numbers.Integral): + def __lt__(self, arg: LinearExprT) -> "BoundedLinearExpression": + if isinstance(arg, IntegralTypes): arg = cmh.assert_is_int64(arg) if arg == INT_MIN: raise ArithmeticError("< INT_MIN is not supported") @@ -577,8 +607,8 @@ class LinearExpr: else: return BoundedLinearExpression(self - arg, [INT_MIN, -1]) - def __gt__(self, arg: LinearExprT) -> BoundedLinearExprT: - if isinstance(arg, numbers.Integral): + def __gt__(self, arg: LinearExprT) -> "BoundedLinearExpression": + if isinstance(arg, IntegralTypes): arg = cmh.assert_is_int64(arg) if arg == INT_MAX: raise ArithmeticError("> INT_MAX is not supported") @@ -586,10 +616,10 @@ class LinearExpr: else: return BoundedLinearExpression(self - arg, [1, INT_MAX]) - def __ne__(self, arg: LinearExprT) -> BoundedLinearExprT: + def __ne__(self, arg: LinearExprT) -> BoundedLinearExprT: # type: ignore[override] if arg is None: return True - if isinstance(arg, numbers.Integral): + if isinstance(arg, IntegralTypes): arg = cmh.assert_is_int64(arg) if arg == INT_MAX: return BoundedLinearExpression(self, [INT_MIN, INT_MAX - 1]) @@ -599,8 +629,10 @@ class LinearExpr: return BoundedLinearExpression( self, [INT_MIN, arg - 1, arg + 1, INT_MAX] ) - else: + elif isinstance(arg, LinearExpr): return BoundedLinearExpression(self - arg, [INT_MIN, -1, 1, INT_MAX]) + else: + return True # Compatibility with pre PEP8 # pylint: disable=invalid-name @@ -663,7 +695,7 @@ class _Sum(LinearExpr): def __init__(self, left, right): for x in [left, right]: - if not isinstance(x, (numbers.Number, LinearExpr)): + if not isinstance(x, (NumberTypes, LinearExpr)): raise TypeError("not an linear expression: " + str(x)) self.__left = left self.__right = right @@ -716,7 +748,7 @@ class _SumArray(LinearExpr): self.__expressions = [] self.__constant = constant for x in expressions: - if isinstance(x, numbers.Number): + if isinstance(x, NumberTypes): if cmh.is_zero(x): continue x = cmh.assert_is_a_number(x) @@ -762,7 +794,7 @@ class _WeightedSum(LinearExpr): c = cmh.assert_is_a_number(c) if cmh.is_zero(c): continue - if isinstance(e, numbers.Number): + if isinstance(e, NumberTypes): e = cmh.assert_is_a_number(e) self.__constant += e * c elif isinstance(e, LinearExpr): @@ -829,7 +861,7 @@ class IntVar(LinearExpr): def __init__( self, model: cp_model_pb2.CpModelProto, - domain: Union[int, Domain], + domain: Union[int, sorted_interval_list.Domain], name: Optional[str], ) -> None: """See CpModel.new_int_var below.""" @@ -841,13 +873,15 @@ class IntVar(LinearExpr): # model is a CpModelProto, domain is a Domain, and name is a string. # case 2: # model is a CpModelProto, domain is an index (int), and name is None. - if isinstance(domain, numbers.Integral) and name is None: + if isinstance(domain, IntegralTypes) and name is None: self.__index: int = int(domain) self.__var: cp_model_pb2.IntegerVariableProto = model.variables[domain] else: self.__index: int = len(model.variables) self.__var: cp_model_pb2.IntegerVariableProto = model.variables.add() - self.__var.domain.extend(cast(Domain, domain).flattened_intervals()) + self.__var.domain.extend( + cast(sorted_interval_list.Domain, domain).flattened_intervals() + ) self.__var.name = name @property @@ -1090,12 +1124,12 @@ class Constraint: """ for lit in expand_generator_or_tuple(boolvar): if (cmh.is_boolean(lit) and lit) or ( - isinstance(lit, numbers.Integral) and lit == 1 + isinstance(lit, IntegralTypes) and lit == 1 ): # Always true. Do nothing. pass elif (cmh.is_boolean(lit) and not lit) or ( - isinstance(lit, numbers.Integral) and lit == 0 + isinstance(lit, IntegralTypes) and lit == 0 ): self.__constraint.enforcement_literal.append( self.__cp_model.new_constant(0).index @@ -1275,7 +1309,7 @@ def object_is_a_true_literal(literal: LiteralT) -> bool: if isinstance(literal, _NotBooleanVariable): proto = literal.negated().proto return len(proto.domain) == 2 and proto.domain[0] == 0 and proto.domain[1] == 0 - if isinstance(literal, numbers.Integral): + if isinstance(literal, IntegralTypes): return int(literal) == 1 return False @@ -1288,7 +1322,7 @@ def object_is_a_false_literal(literal: LiteralT) -> bool: if isinstance(literal, _NotBooleanVariable): proto = literal.negated().proto return len(proto.domain) == 2 and proto.domain[0] == 1 and proto.domain[1] == 1 - if isinstance(literal, numbers.Integral): + if isinstance(literal, IntegralTypes): return int(literal) == 0 return False @@ -1304,7 +1338,7 @@ class CpModel: def __init__(self) -> None: self.__model: cp_model_pb2.CpModelProto = cp_model_pb2.CpModelProto() - self.__constant_map = {} + self.__constant_map: Dict[IntegralT, int] = {} # Naming. @property @@ -1337,9 +1371,11 @@ class CpModel: a variable whose domain is [lb, ub]. """ - return IntVar(self.__model, Domain(lb, ub), name) + return IntVar(self.__model, sorted_interval_list.Domain(lb, ub), name) - def new_int_var_from_domain(self, domain: Domain, name: str) -> IntVar: + def new_int_var_from_domain( + self, domain: sorted_interval_list.Domain, name: str + ) -> IntVar: """Create an integer variable from a domain. A domain is a set of integers specified by a collection of intervals. @@ -1357,7 +1393,7 @@ class CpModel: def new_bool_var(self, name: str) -> IntVar: """Creates a 0-1 variable with the given name.""" - return IntVar(self.__model, Domain(0, 1), name) + return IntVar(self.__model, sorted_interval_list.Domain(0, 1), name) def new_constant(self, value: IntegralT) -> IntVar: """Declares a constant integer.""" @@ -1397,8 +1433,8 @@ class CpModel: if not name.isidentifier(): raise ValueError("name={} is not a valid identifier".format(name)) if ( - isinstance(lower_bounds, numbers.Integral) - and isinstance(upper_bounds, numbers.Integral) + isinstance(lower_bounds, IntegralTypes) + and isinstance(upper_bounds, IntegralTypes) and lower_bounds > upper_bounds ): raise ValueError( @@ -1419,7 +1455,9 @@ class CpModel: IntVar( model=self.__model, name=f"{name}[{i}]", - domain=Domain(lower_bounds[i], upper_bounds[i]), + domain=sorted_interval_list.Domain( + lower_bounds[i], upper_bounds[i] + ), ) for i in index ], @@ -1453,10 +1491,12 @@ class CpModel: self, linear_expr: LinearExprT, lb: IntegralT, ub: IntegralT ) -> Constraint: """Adds the constraint: `lb <= linear_expr <= ub`.""" - return self.add_linear_expression_in_domain(linear_expr, Domain(lb, ub)) + return self.add_linear_expression_in_domain( + linear_expr, sorted_interval_list.Domain(lb, ub) + ) def add_linear_expression_in_domain( - self, linear_expr: LinearExprT, domain: Domain + self, linear_expr: LinearExprT, domain: sorted_interval_list.Domain ) -> Constraint: """Adds the constraint: `linear_expr` in `domain`.""" if isinstance(linear_expr, LinearExpr): @@ -1476,7 +1516,7 @@ class CpModel: ] ) return ct - if isinstance(linear_expr, numbers.Integral): + if isinstance(linear_expr, IntegralTypes): if not domain.contains(int(linear_expr)): return self.add_bool_or([]) # Evaluate to false. else: @@ -1489,7 +1529,15 @@ class CpModel: + ")" ) - def add(self, ct: Union[BoundedLinearExpression, bool]) -> Constraint: + @overload + def add(self, ct: BoundedLinearExpression) -> Constraint: + ... + + @overload + def add(self, ct: Union[bool, np.bool_]) -> Constraint: + ... + + def add(self, ct): """Adds a `BoundedLinearExpression` to the model. Args: @@ -1500,7 +1548,8 @@ class CpModel: """ if isinstance(ct, BoundedLinearExpression): return self.add_linear_expression_in_domain( - ct.expression(), Domain.from_flat_intervals(ct.bounds()) + ct.expression(), + sorted_interval_list.Domain.from_flat_intervals(ct.bounds()), ) if ct and cmh.is_boolean(ct): return self.add_bool_or([True]) @@ -1554,8 +1603,9 @@ class CpModel: if not variables: raise ValueError("add_element expects a non-empty variables array") - if isinstance(index, numbers.Integral): - return self.add(list(variables)[int(index)] == target) + if isinstance(index, IntegralTypes): + variable: VariableT = list(variables)[int(index)] + return self.add(variable == target) ct = Constraint(self) model_ct = self.__model.constraints[ct.index] @@ -2725,7 +2775,7 @@ class CpModel: and arg.coefficient() == -1 ): return -arg.expression().index - 1 - if isinstance(arg, numbers.Integral): + if isinstance(arg, IntegralTypes): arg = cmh.assert_is_int64(arg) return self.get_or_make_index_from_constant(arg) raise TypeError("NotSupported: model.get_or_make_index(" + str(arg) + ")") @@ -2738,7 +2788,7 @@ class CpModel: if isinstance(arg, _NotBooleanVariable): self.assert_is_boolean_variable(arg.negated()) return arg.index - if isinstance(arg, numbers.Integral): + if isinstance(arg, IntegralTypes): arg = cmh.assert_is_zero_or_one(arg) return self.get_or_make_index_from_constant(arg) if cmh.is_boolean(arg): @@ -2774,7 +2824,7 @@ class CpModel: cp_model_pb2.LinearExpressionProto() ) mult = -1 if negate else 1 - if isinstance(linear_expr, numbers.Integral): + if isinstance(linear_expr, IntegralTypes): result.offset = int(linear_expr) * mult return result @@ -2826,7 +2876,7 @@ class CpModel: for v, c in coeffs_map.items(): self.__model.floating_point_objective.coeffs.append(c) self.__model.floating_point_objective.vars.append(v.index) - elif isinstance(obj, numbers.Integral): + elif isinstance(obj, IntegralTypes): self.__model.objective.offset = int(obj) self.__model.objective.scaling_factor = 1 else: @@ -3024,7 +3074,7 @@ def expand_generator_or_tuple(args): if hasattr(args, "__len__"): # Tuple if len(args) != 1: return args - if isinstance(args[0], (numbers.Number, LinearExpr)): + if isinstance(args[0], (NumberTypes, LinearExpr)): return args # Generator return args[0] @@ -3034,7 +3084,7 @@ def evaluate_linear_expr( expression: LinearExprT, solution: cp_model_pb2.CpSolverResponse ) -> int: """Evaluate a linear expression against a solution.""" - if isinstance(expression, numbers.Integral): + if isinstance(expression, IntegralTypes): return int(expression) if not isinstance(expression, LinearExpr): raise TypeError("Cannot interpret %s as a linear expression." % expression) @@ -3043,7 +3093,7 @@ def evaluate_linear_expr( to_process = [(expression, 1)] while to_process: expr, coeff = to_process.pop() - if isinstance(expr, numbers.Integral): + if isinstance(expr, IntegralTypes): value += int(expr) * coeff elif isinstance(expr, _ProductCst): to_process.append((expr.expression(), coeff * expr.coefficient())) @@ -3072,7 +3122,7 @@ def evaluate_boolean_expression( literal: LiteralT, solution: cp_model_pb2.CpSolverResponse ) -> bool: """Evaluate a boolean expression against a solution.""" - if isinstance(literal, numbers.Integral): + if isinstance(literal, IntegralTypes): return bool(literal) elif isinstance(literal, IntVar) or isinstance(literal, _NotBooleanVariable): index: int = cast(Union[IntVar, _NotBooleanVariable], literal).index @@ -3411,7 +3461,7 @@ class CpSolverSolutionCallback(swig_helper.SolutionCallback): """ if not self.has_response(): raise RuntimeError("solve() has not been called.") - if isinstance(lit, numbers.Integral): + if isinstance(lit, IntegralTypes): return bool(lit) if isinstance(lit, IntVar) or isinstance(lit, _NotBooleanVariable): return self.SolutionBooleanValue( @@ -3441,7 +3491,7 @@ class CpSolverSolutionCallback(swig_helper.SolutionCallback): to_process = [(expression, 1)] while to_process: expr, coeff = to_process.pop() - if isinstance(expr, numbers.Integral): + if isinstance(expr, IntegralTypes): value += int(expr) * coeff elif isinstance(expr, _ProductCst): to_process.append((expr.expression(), coeff * expr.coefficient())) @@ -3681,7 +3731,7 @@ def _convert_to_integral_series_and_validate_index( TypeError: If the type of `value_or_series` is not recognized. ValueError: If the index does not match. """ - if isinstance(value_or_series, numbers.Integral): + if isinstance(value_or_series, IntegralTypes): result = pd.Series(data=value_or_series, index=index) elif isinstance(value_or_series, pd.Series): if value_or_series.index.equals(index): @@ -3709,7 +3759,7 @@ def _convert_to_linear_expr_series_and_validate_index( TypeError: If the type of `value_or_series` is not recognized. ValueError: If the index does not match. """ - if isinstance(value_or_series, numbers.Integral): + if isinstance(value_or_series, IntegralTypes): result = pd.Series(data=value_or_series, index=index) elif isinstance(value_or_series, pd.Series): if value_or_series.index.equals(index): @@ -3737,7 +3787,7 @@ def _convert_to_literal_series_and_validate_index( TypeError: If the type of `value_or_series` is not recognized. ValueError: If the index does not match. """ - if isinstance(value_or_series, numbers.Integral): + if isinstance(value_or_series, IntegralTypes): result = pd.Series(data=value_or_series, index=index) elif isinstance(value_or_series, pd.Series): if value_or_series.index.equals(index): diff --git a/ortools/sat/python/cp_model_helper.py b/ortools/sat/python/cp_model_helper.py index 662e1cbd38..c3bdde0044 100644 --- a/ortools/sat/python/cp_model_helper.py +++ b/ortools/sat/python/cp_model_helper.py @@ -13,8 +13,8 @@ """helpers methods for the cp_model module.""" -import numbers from typing import Any, Union +import numbers import numpy as np @@ -37,7 +37,7 @@ def is_zero(x: Any) -> bool: """Checks if the x is 0 or 0.0.""" if isinstance(x, numbers.Integral): return int(x) == 0 - if isinstance(x, numbers.Number): + if isinstance(x, numbers.Real): return float(x) == 0.0 return False @@ -46,7 +46,7 @@ def is_one(x: Any) -> bool: """Checks if x is 1 or 1.0.""" if isinstance(x, numbers.Integral): return int(x) == 1 - if isinstance(x, numbers.Number): + if isinstance(x, numbers.Real): return float(x) == 1.0 return False @@ -55,7 +55,7 @@ def is_minus_one(x: Any) -> bool: """Checks if x is -1 or -1.0 .""" if isinstance(x, numbers.Integral): return int(x) == -1 - if isinstance(x, numbers.Number): + if isinstance(x, numbers.Real): return float(x) == -1.0 return False @@ -89,7 +89,7 @@ def assert_is_a_number(x: Any) -> Union[int, float]: """Asserts that x is a number and returns it casted to an int or a float.""" if isinstance(x, numbers.Integral): return int(x) - if isinstance(x, numbers.Number): + if isinstance(x, numbers.Real): return float(x) raise TypeError("Not a number: %s" % x) diff --git a/ortools/sat/rins.cc b/ortools/sat/rins.cc index e9576ee1b2..edff8f5048 100644 --- a/ortools/sat/rins.cc +++ b/ortools/sat/rins.cc @@ -104,8 +104,8 @@ struct VarWeight { bool operator<(const VarWeight& o) const { return weight < o.weight; } }; -void FillRinsNeighborhood(const std::vector& solution, - const std::vector& relaxation_values, +void FillRinsNeighborhood(absl::Span solution, + absl::Span relaxation_values, double difficulty, absl::BitGenRef random, ReducedDomainNeighborhood& reduced_domains) { std::vector var_lp_gap_pairs; diff --git a/ortools/sat/routing_cuts.cc b/ortools/sat/routing_cuts.cc index 4c1497c0eb..237d04f725 100644 --- a/ortools/sat/routing_cuts.cc +++ b/ortools/sat/routing_cuts.cc @@ -89,7 +89,7 @@ class OutgoingCutHelper { // Given a subset of nodes, it is easy to identify the best subset A of edge // to consider. bool TryBlossomSubsetCut(std::string name, - const std::vector& symmetrized_edges, + absl::Span symmetrized_edges, absl::Span subset); private: @@ -271,7 +271,7 @@ bool OutgoingCutHelper::TrySubsetCut(std::string name, } bool OutgoingCutHelper::TryBlossomSubsetCut( - std::string name, const std::vector& symmetrized_edges, + std::string name, absl::Span symmetrized_edges, absl::Span subset) { DCHECK_GE(subset.size(), 1); DCHECK_LT(subset.size(), num_nodes_); @@ -715,7 +715,7 @@ namespace { // Returns for each literal its integer view, or the view of its negation. std::vector GetAssociatedVariables( - const std::vector& literals, Model* model) { + absl::Span literals, Model* model) { auto* encoder = model->GetOrCreate(); std::vector result; for (const Literal l : literals) { @@ -792,8 +792,8 @@ CutGenerator CreateCVRPCutGenerator(int num_nodes, std::vector tails, // This is really similar to SeparateSubtourInequalities, see the reference // there. void SeparateFlowInequalities( - int num_nodes, const std::vector& tails, const std::vector& heads, - const std::vector& arc_capacities, + int num_nodes, absl::Span tails, absl::Span heads, + absl::Span arc_capacities, std::function& in_subset, IntegerValue* min_incoming_flow, IntegerValue* min_outgoing_flow)> diff --git a/ortools/sat/samples/bin_packing_sat.py b/ortools/sat/samples/bin_packing_sat.py index 680f1949aa..f30e0f9382 100644 --- a/ortools/sat/samples/bin_packing_sat.py +++ b/ortools/sat/samples/bin_packing_sat.py @@ -121,8 +121,8 @@ def main() -> None: for b in active_bins: print(f"Bin {b}") - items_in_bin = x_values.xs(b, level="bin").loc[lambda x: x].index - for item in items_in_bin: + items_in_active_bin = x_values.xs(b, level="bin").loc[lambda x: x].index + for item in items_in_active_bin: print(f" Item {item} - weight {items.loc[item].weight}") print(f" Packed items weight: {items.loc[items_in_bin].sum().to_string()}") print() diff --git a/ortools/sat/samples/schedule_requests_sat.py b/ortools/sat/samples/schedule_requests_sat.py index f87789c64c..4518a75eda 100644 --- a/ortools/sat/samples/schedule_requests_sat.py +++ b/ortools/sat/samples/schedule_requests_sat.py @@ -15,6 +15,8 @@ # [START program] """Nurse scheduling problem with shift requests.""" # [START import] +from typing import Union + from ortools.sat.python import cp_model # [END import] @@ -80,7 +82,7 @@ def main() -> None: else: max_shifts_per_nurse = min_shifts_per_nurse + 1 for n in all_nurses: - num_shifts_worked = 0 + num_shifts_worked: Union[cp_model.LinearExpr, int] = 0 for d in all_days: for s in all_shifts: num_shifts_worked += shifts[(n, d, s)] diff --git a/ortools/sat/sat_parameters.proto b/ortools/sat/sat_parameters.proto index 81dda8ca2e..46bff6984e 100644 --- a/ortools/sat/sat_parameters.proto +++ b/ortools/sat/sat_parameters.proto @@ -23,7 +23,7 @@ option csharp_namespace = "Google.OrTools.Sat"; // Contains the definitions for all the sat algorithm parameters and their // default values. // -// NEXT TAG: 280 +// NEXT TAG: 281 message SatParameters { // In some context, like in a portfolio of search, it makes sense to name a // given parameters set for logging purpose. @@ -464,6 +464,13 @@ message SatParameters { // possible precedences between event and encoding the constraint. optional bool expand_reservoir_constraints = 182 [default = true]; + // If true, replace target = max(x, y) by linear constraint with the + // introduction of a new boolean b such that b => target == x and not(b) => + // target == y. + // + // This is mainly for experimenting compared to a custom lin_max propagator. + optional bool expand_binary_lin_max = 280 [default = false]; + // If true, it disable all constraint expansion. // This should only be used to test the presolve of expanded constraints. optional bool disable_constraint_expansion = 181 [default = false];