From 80426425408ddf9b9cc852bf7b9231658b9726a0 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Thu, 14 Feb 2019 16:44:57 +0100 Subject: [PATCH] [CP-SAT] improve presolve; new random-restart strategy; improve integer encoding --- ortools/sat/cp_model_presolve.cc | 168 ++++++++++++++++++++----------- ortools/sat/integer.cc | 69 ++----------- ortools/sat/integer.h | 4 + ortools/sat/integer_search.cc | 50 ++++++--- ortools/sat/sat_parameters.proto | 9 +- 5 files changed, 165 insertions(+), 135 deletions(-) diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index 8e27f13b06..4842737c7f 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -50,6 +50,7 @@ namespace sat { int PresolveContext::NewIntVar(const Domain& domain) { IntegerVariableProto* const var = working_model->add_variables(); FillDomainInProto(domain, var); + InitializeNewDomains(); return working_model->variables_size() - 1; } @@ -61,6 +62,7 @@ int PresolveContext::GetOrCreateConstantVar(int64 cst) { IntegerVariableProto* const var_proto = working_model->add_variables(); var_proto->add_domain(cst); var_proto->add_domain(cst); + InitializeNewDomains(); } return constant_to_ref[cst]; } @@ -342,6 +344,9 @@ int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64 value) { // TODO(user,user): use affine relation here. const int var = PositiveRef(ref); const int64 s_value = RefIsPositive(ref) ? value : -value; + if (!domains[var].Contains(s_value)) { + return GetOrCreateConstantVar(0); + } std::pair key{var, s_value}; if (encoding.contains(key)) return encoding[key]; @@ -355,12 +360,12 @@ int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64 value) { if (domains[var].Size() == 2) { const int64 var_min = MinOf(var); const int64 var_max = MaxOf(var); + if (var_min == 0 && var_max == 1) { encoding[std::make_pair(var, 0)] = NegatedRef(var); encoding[std::make_pair(var, 1)] = var; } else { const int literal = NewBoolVar(); - InitializeNewDomains(); encoding[std::make_pair(var, var_min)] = NegatedRef(literal); encoding[std::make_pair(var, var_max)] = literal; @@ -374,6 +379,7 @@ int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64 value) { lin->add_domain(var_min); StoreAffineRelation(*ct, var, literal, var_max - var_min, var_min); } + return gtl::FindOrDieNoPrint(encoding, key); } @@ -382,7 +388,6 @@ int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64 value) { AddImplyInDomain(NegatedRef(literal), var, Domain(s_value).Complement()); encoding[key] = literal; - InitializeNewDomains(); return literal; } @@ -1584,48 +1589,50 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { int num_vars = 0; bool all_constants = true; absl::flat_hash_set constant_set; - bool all_included_in_target_domain = true; - bool reduced_index_domain = false; - if (context->IntersectDomainWith(index_ref, - Domain(0, ct->element().vars_size() - 1))) { - reduced_index_domain = true; - } - // Filter possible index values. Accumulate variable domains to build - // a possible target domain. - Domain infered_domain; - const Domain initial_index_domain = context->DomainOf(index_ref); - const Domain target_domain = context->DomainOf(target_ref); - for (const ClosedInterval interval : initial_index_domain) { - for (int value = interval.start; value <= interval.end; ++value) { - CHECK_GE(value, 0); - CHECK_LT(value, ct->element().vars_size()); - const int ref = ct->element().vars(value); - const Domain domain = context->DomainOf(ref); - if (domain.IntersectionWith(target_domain).IsEmpty()) { - context->IntersectDomainWith(index_ref, Domain(value).Complement()); - reduced_index_domain = true; - } else { - ++num_vars; - if (domain.Min() == domain.Max()) { - constant_set.insert(domain.Min()); + { + bool reduced_index_domain = false; + if (context->IntersectDomainWith( + index_ref, Domain(0, ct->element().vars_size() - 1))) { + reduced_index_domain = true; + } + + // Filter possible index values. Accumulate variable domains to build + // a possible target domain. + Domain infered_domain; + const Domain initial_index_domain = context->DomainOf(index_ref); + Domain target_domain = context->DomainOf(target_ref); + for (const ClosedInterval interval : initial_index_domain) { + for (int value = interval.start; value <= interval.end; ++value) { + CHECK_GE(value, 0); + CHECK_LT(value, ct->element().vars_size()); + const int ref = ct->element().vars(value); + const Domain domain = context->DomainOf(ref); + if (domain.IntersectionWith(target_domain).IsEmpty()) { + context->IntersectDomainWith(index_ref, Domain(value).Complement()); + reduced_index_domain = true; } else { - all_constants = false; + ++num_vars; + if (domain.Min() == domain.Max()) { + constant_set.insert(domain.Min()); + } else { + all_constants = false; + } + if (!domain.IsIncludedIn(target_domain)) { + all_included_in_target_domain = false; + } + infered_domain = infered_domain.UnionWith(domain); } - if (!domain.IsIncludedIn(target_domain)) { - all_included_in_target_domain = false; - } - infered_domain = infered_domain.UnionWith(domain); } } - } - if (reduced_index_domain) { - context->UpdateRuleStats("element: reduced index domain"); - } - if (context->IntersectDomainWith(target_ref, infered_domain)) { - if (context->DomainOf(target_ref).IsEmpty()) return true; - context->UpdateRuleStats("element: reduced target domain"); + if (reduced_index_domain) { + context->UpdateRuleStats("element: reduced index domain"); + } + if (context->IntersectDomainWith(target_ref, infered_domain)) { + if (context->DomainOf(target_ref).IsEmpty()) return true; + context->UpdateRuleStats("element: reduced target domain"); + } } // If the index is fixed, this is a equality constraint. @@ -1759,7 +1766,6 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { } ct->mutable_element()->set_target(r_target.representative); if (changed_values) { - context->InitializeNewDomains(); context->UpdateRuleStats("element: unscaled values from affine target"); } if (index_domain.Size() > valid_index_values.size()) { @@ -1790,10 +1796,31 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { } } context->UpdateRuleStats("element: expand fixed target element"); - context->InitializeNewDomains(); return RemoveConstraint(ct, context); } + if (target_ref == index_ref) { + // Filter impossible index values. + Domain index_domain = context->DomainOf(index_ref); + std::vector possible_indices; + for (const ClosedInterval& interval : index_domain) { + for (int64 value = interval.start; value <= interval.end; ++value) { + const int ref = ct->element().vars(value); + if (context->DomainContains(ref, value)) { + possible_indices.push_back(value); + } + } + } + if (possible_indices.size() < index_domain.Size()) { + context->IntersectDomainWith(index_ref, + Domain::FromValues(possible_indices)); + } + context->UpdateRuleStats( + "element: reduce index domain when target equals index"); + + return true; + } + if (all_constants) { const Domain index_domain = context->DomainOf(index_ref); absl::flat_hash_map> supports; @@ -1810,10 +1837,8 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { const int64 value = context->MinOf(ct->element().vars(v)); const int index_lit = context->GetOrCreateVarValueEncoding(index_ref, v); - CHECK(context->DomainContains(target_ref, value)) - << "target " << context->DomainOf(target_ref) - << ", value = " << value; + CHECK(context->DomainContains(target_ref, value)); const int target_lit = context->GetOrCreateVarValueEncoding(target_ref, value); context->AddImplication(index_lit, target_lit); @@ -1844,7 +1869,6 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { } context->UpdateRuleStats("element: expand fixed array element"); - context->InitializeNewDomains(); return RemoveConstraint(ct, context); } @@ -2006,7 +2030,9 @@ bool PresolveTable(ConstraintProto* ct, PresolveContext* context) { bool PresolveAllDiff(ConstraintProto* ct, PresolveContext* context) { if (HasEnforcementLiteral(*ct)) return false; - const int size = ct->all_diff().vars_size(); + AllDifferentConstraintProto& all_diff = *ct->mutable_all_diff(); + + const int size = all_diff.vars_size(); if (size == 0) { context->UpdateRuleStats("all_diff: empty constraint"); return RemoveConstraint(ct, context); @@ -2016,16 +2042,46 @@ bool PresolveAllDiff(ConstraintProto* ct, PresolveContext* context) { return RemoveConstraint(ct, context); } - bool contains_fixed_variable = false; + absl::flat_hash_set fixed_variables; for (int i = 0; i < size; ++i) { - if (context->IsFixed(ct->all_diff().vars(i))) { - contains_fixed_variable = true; - break; + if (!context->IsFixed(all_diff.vars(i))) continue; + fixed_variables.insert(i); + const int64 value = context->MinOf(all_diff.vars(i)); + bool propagated = false; + for (int j = 0; j < size; ++j) { + if (i == j) continue; + if (context->DomainContains(all_diff.vars(j), value)) { + context->IntersectDomainWith(all_diff.vars(j), + Domain(value).Complement()); + if (context->is_unsat) return true; + propagated = true; + } + } + if (propagated) { + context->UpdateRuleStats("all_diff: propagated fixed variables"); } } - if (contains_fixed_variable) { - context->UpdateRuleStats("TODO all_diff: fixed variables"); + + if (!fixed_variables.empty()) { + std::vector new_variables; + for (int i = 0; i < all_diff.vars_size(); ++i) { + // We cannot check the domain here, as it may have been fixed by the + // propagation loop above. In that case, it will be picked up by this + // presolve rule in the next iteration. + if (!gtl::ContainsKey(fixed_variables, i)) { + new_variables.push_back(all_diff.vars(i)); + } + } + CHECK_EQ(all_diff.vars_size(), + new_variables.size() + fixed_variables.size()); + all_diff.mutable_vars()->Clear(); + for (const int var : new_variables) { + all_diff.add_vars(var); + } + context->UpdateRuleStats("all_diff: removed fixed variables"); + return true; } + return false; } @@ -2325,8 +2381,9 @@ bool PresolveCircuit(ConstraintProto* ct, PresolveContext* context) { bool PresolveAutomaton(ConstraintProto* ct, PresolveContext* context) { if (HasEnforcementLiteral(*ct)) return false; AutomatonConstraintProto& proto = *ct->mutable_automaton(); - if (proto.vars_size() == 0 || proto.transition_label_size() == 0) + if (proto.vars_size() == 0 || proto.transition_label_size() == 0) { return false; + } bool all_affine = true; std::vector affine_relations; @@ -3459,13 +3516,8 @@ void TryToSimplifyDomains(PresolveContext* context) { if (domain.Size() == 2 && domain.NumIntervals() == 1 && domain.Min() != 0) { // Shifted Boolean variable. - const int new_var_index = context->working_model->variables_size(); - IntegerVariableProto* const var_proto = - context->working_model->add_variables(); - var_proto->add_domain(0); - var_proto->add_domain(1); + const int new_var_index = context->NewBoolVar(); const int64 offset = domain.Min(); - context->InitializeNewDomains(); ConstraintProto* const ct = context->working_model->add_constraints(); LinearConstraintProto* const lin = ct->mutable_linear(); diff --git a/ortools/sat/integer.cc b/ortools/sat/integer.cc index 6623fac0fe..0f66b45a11 100644 --- a/ortools/sat/integer.cc +++ b/ortools/sat/integer.cc @@ -37,72 +37,23 @@ void IntegerEncoder::FullyEncodeVariable(IntegerVariable var) { CHECK(!VariableIsFullyEncoded(var)); CHECK_EQ(0, sat_solver_->CurrentDecisionLevel()); CHECK(!(*domains_)[var].IsEmpty()); // UNSAT. We don't deal with that here. + CHECK_LT((*domains_)[var].Size(), 100000) + << "Domain too large for full encoding."; - std::vector values; + // TODO(user): Maybe we can optimize the literal creation order and their + // polarity as our default SAT heuristics initially depends on this. for (const ClosedInterval interval : (*domains_)[var]) { for (IntegerValue v(interval.start); v <= interval.end; ++v) { - values.push_back(v); - CHECK_LT(values.size(), 100000) << "Domain too large for full encoding."; + GetOrCreateLiteralAssociatedToEquality(var, v); } } - std::vector literals; - if (values.size() == 1) { - literals.push_back(GetTrueLiteral()); - } else if (values.size() == 2) { - literals.push_back(GetOrCreateAssociatedLiteral( - IntegerLiteral::LowerOrEqual(var, values[0]))); - literals.push_back(literals.back().Negated()); - } else { - for (int i = 0; i < values.size(); ++i) { - const std::pair key{var, values[i]}; - if (gtl::ContainsKey(equality_to_associated_literal_, key)) { - literals.push_back(equality_to_associated_literal_[key]); - } else { - literals.push_back(Literal(sat_solver_->NewBooleanVariable(), true)); - } - } - } - - // Create the associated literal (<= and >=) in order (best for the - // implications between them). Note that we only create literals like this for - // value inside the domain. This is nice since these will be the only kind of - // literal pushed by Enqueue() (we look at the domain there). - for (int i = 0; i + 1 < literals.size(); ++i) { - const IntegerLiteral i_lit = IntegerLiteral::LowerOrEqual(var, values[i]); - const IntegerLiteral i_lit_negated = - IntegerLiteral::GreaterOrEqual(var, values[i + 1]); - if (i == 0) { - // Special case for the start. - HalfAssociateGivenLiteral(i_lit, literals[0]); - HalfAssociateGivenLiteral(i_lit_negated, literals[0].Negated()); - } else if (i + 2 == literals.size()) { - // Special case for the end. - HalfAssociateGivenLiteral(i_lit, literals.back().Negated()); - HalfAssociateGivenLiteral(i_lit_negated, literals.back()); - } else { - // Normal case. - if (!LiteralIsAssociated(i_lit) || !LiteralIsAssociated(i_lit_negated)) { - const BooleanVariable b = sat_solver_->NewBooleanVariable(); - HalfAssociateGivenLiteral(i_lit, Literal(b, true)); - HalfAssociateGivenLiteral(i_lit_negated, Literal(b, false)); - } - } - } - - // Now that all literals are created, wire them together using - // (X == v) <=> (X >= v) and (X <= v). - // - // TODO(user): this is currently in O(n^2) which is potentially bad even if - // we do it only once per variable. - for (int i = 0; i < literals.size(); ++i) { - AssociateToIntegerEqualValue(literals[i], var, values[i]); - } - // Mark var and Negation(var) as fully encoded. - const int required_size = std::max(var, NegationOf(var)).value() + 1; - if (required_size > is_fully_encoded_.size()) { - is_fully_encoded_.resize(required_size, false); + { + const int required_size = std::max(var, NegationOf(var)).value() + 1; + if (required_size > is_fully_encoded_.size()) { + is_fully_encoded_.resize(required_size, false); + } } is_fully_encoded_[var] = true; is_fully_encoded_[NegationOf(var)] = true; diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index 3084074076..b80ec8c0f0 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -336,6 +336,10 @@ class IntegerEncoder { void AddAllImplicationsBetweenAssociatedLiterals(); // Returns the IntegerLiterals that were associated with the given Literal. + // + // Note that more than one IntegerLiterals (possibly on different variables) + // may have been associated to the same literal. We also returns ">= value" + // and "<= value" if lit was associated to "== value". const InlinedIntegerLiteralVector& GetIntegerLiterals(Literal lit) const { if (lit.Index() >= reverse_encoding_.size()) { return empty_integer_literal_vector_; diff --git a/ortools/sat/integer_search.cc b/ortools/sat/integer_search.cc index 80584ae73b..d932bc0c41 100644 --- a/ortools/sat/integer_search.cc +++ b/ortools/sat/integer_search.cc @@ -137,21 +137,33 @@ std::function PseudoCost(Model* model) { }; } -std::function RandomizeOnRestartSatSolverHeuristic( - Model* model) { +std::function RandomizeOnRestartHeuristic(Model* model) { SatSolver* sat_solver = model->GetOrCreate(); - Trail* trail = model->GetOrCreate(); SatDecisionPolicy* decision_policy = model->GetOrCreate(); - return [sat_solver, trail, decision_policy, model] { - if (sat_solver->CurrentDecisionLevel() == 0) { - RandomizeDecisionHeuristic(model->GetOrCreate(), - model->GetOrCreate()); - decision_policy->ResetDecisionHeuristic(); - } - const bool all_assigned = trail->Index() == sat_solver->NumVariables(); - return all_assigned ? kNoLiteralIndex - : decision_policy->NextBranch().Index(); - }; + + // The duplication increase the probability of the first heuristics. This is + // wanted because when we randomize the sat parameters, we have more than one + // heuristic for choosing the phase of the decision. + // + // TODO(user): Add other policy and perform more experiments. + std::function sat_policy = SatSolverHeuristic(model); + std::vector> policies{ + sat_policy, sat_policy, ExploitLpSolution(sat_policy, model), + ExploitLpSolution(SequentialSearch({PseudoCost(model), sat_policy}), + model)}; + + int policy_index = 0; + return + [sat_solver, decision_policy, policies, policy_index, model]() mutable { + if (sat_solver->CurrentDecisionLevel() == 0) { + RandomizeDecisionHeuristic(model->GetOrCreate(), + model->GetOrCreate()); + decision_policy->ResetDecisionHeuristic(); + std::uniform_int_distribution dist(0, policies.size() - 1); + policy_index = dist(*(model->GetOrCreate())); + } + return policies[policy_index](); + }; } std::function FollowHint( @@ -343,7 +355,7 @@ SatSolver::Status SolveIntegerProblemWithLazyEncoding( std::function search; if (parameters.randomize_search()) { search = SequentialSearch( - {RandomizeOnRestartSatSolverHeuristic(model), next_decision}); + {RandomizeOnRestartHeuristic(model), next_decision}); } else { search = SequentialSearch({SatSolverHeuristic(model), next_decision}); } @@ -406,8 +418,7 @@ SatSolver::Status SolveIntegerProblemWithLazyEncoding( model); } case SatParameters::PSEUDO_COST_SEARCH: { - std::function search; - search = SequentialSearch( + std::function search = SequentialSearch( {PseudoCost(model), SatSolverHeuristic(model), next_decision}); if (parameters.exploit_integer_lp_solution() || parameters.exploit_all_lp_solution()) { @@ -416,6 +427,13 @@ SatSolver::Status SolveIntegerProblemWithLazyEncoding( return SolveProblemWithPortfolioSearch( {search}, {SatSolverRestartPolicy(model)}, model); } + case SatParameters::PORTFOLIO_WITH_QUICK_RESTART_SEARCH: { + std::function search = + SequentialSearch({RandomizeOnRestartHeuristic(model), next_decision}); + return SolveProblemWithPortfolioSearch( + {search}, + {RestartEveryKFailures(10, model->GetOrCreate())}, model); + } } return SatSolver::LIMIT_REACHED; } diff --git a/ortools/sat/sat_parameters.proto b/ortools/sat/sat_parameters.proto index b6d2bae41f..a9d6598303 100644 --- a/ortools/sat/sat_parameters.proto +++ b/ortools/sat/sat_parameters.proto @@ -13,11 +13,11 @@ syntax = "proto2"; +package operations_research.sat; + option java_package = "com.google.ortools.sat"; option java_multiple_files = true; -package operations_research.sat; - // Contains the definitions for all the sat algorithm parameters and their // default values. // @@ -571,6 +571,11 @@ message SatParameters { // If used, the solver uses the pseudo costs for branching. PSEUDO_COST_SEARCH = 4; + + // Mainly exposed here for testing. This quickly tries a lot of randomized + // heuristics with a low conflict limit. It usually provides a good first + // solution. + PORTFOLIO_WITH_QUICK_RESTART_SEARCH = 5; } optional SearchBranching search_branching = 82 [default = AUTOMATIC_SEARCH];