[CP-SAT] improve presolve; new random-restart strategy; improve integer encoding

This commit is contained in:
Laurent Perron
2019-02-14 16:44:57 +01:00
parent 828d7df7f3
commit 8042642540
5 changed files with 165 additions and 135 deletions

View File

@@ -50,6 +50,7 @@ namespace sat {
int PresolveContext::NewIntVar(const Domain& domain) {
IntegerVariableProto* const var = working_model->add_variables();
FillDomainInProto(domain, var);
InitializeNewDomains();
return working_model->variables_size() - 1;
}
@@ -61,6 +62,7 @@ int PresolveContext::GetOrCreateConstantVar(int64 cst) {
IntegerVariableProto* const var_proto = working_model->add_variables();
var_proto->add_domain(cst);
var_proto->add_domain(cst);
InitializeNewDomains();
}
return constant_to_ref[cst];
}
@@ -342,6 +344,9 @@ int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64 value) {
// TODO(user,user): use affine relation here.
const int var = PositiveRef(ref);
const int64 s_value = RefIsPositive(ref) ? value : -value;
if (!domains[var].Contains(s_value)) {
return GetOrCreateConstantVar(0);
}
std::pair<int, int64> key{var, s_value};
if (encoding.contains(key)) return encoding[key];
@@ -355,12 +360,12 @@ int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64 value) {
if (domains[var].Size() == 2) {
const int64 var_min = MinOf(var);
const int64 var_max = MaxOf(var);
if (var_min == 0 && var_max == 1) {
encoding[std::make_pair(var, 0)] = NegatedRef(var);
encoding[std::make_pair(var, 1)] = var;
} else {
const int literal = NewBoolVar();
InitializeNewDomains();
encoding[std::make_pair(var, var_min)] = NegatedRef(literal);
encoding[std::make_pair(var, var_max)] = literal;
@@ -374,6 +379,7 @@ int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64 value) {
lin->add_domain(var_min);
StoreAffineRelation(*ct, var, literal, var_max - var_min, var_min);
}
return gtl::FindOrDieNoPrint(encoding, key);
}
@@ -382,7 +388,6 @@ int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64 value) {
AddImplyInDomain(NegatedRef(literal), var, Domain(s_value).Complement());
encoding[key] = literal;
InitializeNewDomains();
return literal;
}
@@ -1584,48 +1589,50 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) {
int num_vars = 0;
bool all_constants = true;
absl::flat_hash_set<int64> constant_set;
bool all_included_in_target_domain = true;
bool reduced_index_domain = false;
if (context->IntersectDomainWith(index_ref,
Domain(0, ct->element().vars_size() - 1))) {
reduced_index_domain = true;
}
// Filter possible index values. Accumulate variable domains to build
// a possible target domain.
Domain infered_domain;
const Domain initial_index_domain = context->DomainOf(index_ref);
const Domain target_domain = context->DomainOf(target_ref);
for (const ClosedInterval interval : initial_index_domain) {
for (int value = interval.start; value <= interval.end; ++value) {
CHECK_GE(value, 0);
CHECK_LT(value, ct->element().vars_size());
const int ref = ct->element().vars(value);
const Domain domain = context->DomainOf(ref);
if (domain.IntersectionWith(target_domain).IsEmpty()) {
context->IntersectDomainWith(index_ref, Domain(value).Complement());
reduced_index_domain = true;
} else {
++num_vars;
if (domain.Min() == domain.Max()) {
constant_set.insert(domain.Min());
{
bool reduced_index_domain = false;
if (context->IntersectDomainWith(
index_ref, Domain(0, ct->element().vars_size() - 1))) {
reduced_index_domain = true;
}
// Filter possible index values. Accumulate variable domains to build
// a possible target domain.
Domain infered_domain;
const Domain initial_index_domain = context->DomainOf(index_ref);
Domain target_domain = context->DomainOf(target_ref);
for (const ClosedInterval interval : initial_index_domain) {
for (int value = interval.start; value <= interval.end; ++value) {
CHECK_GE(value, 0);
CHECK_LT(value, ct->element().vars_size());
const int ref = ct->element().vars(value);
const Domain domain = context->DomainOf(ref);
if (domain.IntersectionWith(target_domain).IsEmpty()) {
context->IntersectDomainWith(index_ref, Domain(value).Complement());
reduced_index_domain = true;
} else {
all_constants = false;
++num_vars;
if (domain.Min() == domain.Max()) {
constant_set.insert(domain.Min());
} else {
all_constants = false;
}
if (!domain.IsIncludedIn(target_domain)) {
all_included_in_target_domain = false;
}
infered_domain = infered_domain.UnionWith(domain);
}
if (!domain.IsIncludedIn(target_domain)) {
all_included_in_target_domain = false;
}
infered_domain = infered_domain.UnionWith(domain);
}
}
}
if (reduced_index_domain) {
context->UpdateRuleStats("element: reduced index domain");
}
if (context->IntersectDomainWith(target_ref, infered_domain)) {
if (context->DomainOf(target_ref).IsEmpty()) return true;
context->UpdateRuleStats("element: reduced target domain");
if (reduced_index_domain) {
context->UpdateRuleStats("element: reduced index domain");
}
if (context->IntersectDomainWith(target_ref, infered_domain)) {
if (context->DomainOf(target_ref).IsEmpty()) return true;
context->UpdateRuleStats("element: reduced target domain");
}
}
// If the index is fixed, this is a equality constraint.
@@ -1759,7 +1766,6 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) {
}
ct->mutable_element()->set_target(r_target.representative);
if (changed_values) {
context->InitializeNewDomains();
context->UpdateRuleStats("element: unscaled values from affine target");
}
if (index_domain.Size() > valid_index_values.size()) {
@@ -1790,10 +1796,31 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) {
}
}
context->UpdateRuleStats("element: expand fixed target element");
context->InitializeNewDomains();
return RemoveConstraint(ct, context);
}
if (target_ref == index_ref) {
// Filter impossible index values.
Domain index_domain = context->DomainOf(index_ref);
std::vector<int64> possible_indices;
for (const ClosedInterval& interval : index_domain) {
for (int64 value = interval.start; value <= interval.end; ++value) {
const int ref = ct->element().vars(value);
if (context->DomainContains(ref, value)) {
possible_indices.push_back(value);
}
}
}
if (possible_indices.size() < index_domain.Size()) {
context->IntersectDomainWith(index_ref,
Domain::FromValues(possible_indices));
}
context->UpdateRuleStats(
"element: reduce index domain when target equals index");
return true;
}
if (all_constants) {
const Domain index_domain = context->DomainOf(index_ref);
absl::flat_hash_map<int, std::vector<int>> supports;
@@ -1810,10 +1837,8 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) {
const int64 value = context->MinOf(ct->element().vars(v));
const int index_lit =
context->GetOrCreateVarValueEncoding(index_ref, v);
CHECK(context->DomainContains(target_ref, value))
<< "target " << context->DomainOf(target_ref)
<< ", value = " << value;
CHECK(context->DomainContains(target_ref, value));
const int target_lit =
context->GetOrCreateVarValueEncoding(target_ref, value);
context->AddImplication(index_lit, target_lit);
@@ -1844,7 +1869,6 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) {
}
context->UpdateRuleStats("element: expand fixed array element");
context->InitializeNewDomains();
return RemoveConstraint(ct, context);
}
@@ -2006,7 +2030,9 @@ bool PresolveTable(ConstraintProto* ct, PresolveContext* context) {
bool PresolveAllDiff(ConstraintProto* ct, PresolveContext* context) {
if (HasEnforcementLiteral(*ct)) return false;
const int size = ct->all_diff().vars_size();
AllDifferentConstraintProto& all_diff = *ct->mutable_all_diff();
const int size = all_diff.vars_size();
if (size == 0) {
context->UpdateRuleStats("all_diff: empty constraint");
return RemoveConstraint(ct, context);
@@ -2016,16 +2042,46 @@ bool PresolveAllDiff(ConstraintProto* ct, PresolveContext* context) {
return RemoveConstraint(ct, context);
}
bool contains_fixed_variable = false;
absl::flat_hash_set<int> fixed_variables;
for (int i = 0; i < size; ++i) {
if (context->IsFixed(ct->all_diff().vars(i))) {
contains_fixed_variable = true;
break;
if (!context->IsFixed(all_diff.vars(i))) continue;
fixed_variables.insert(i);
const int64 value = context->MinOf(all_diff.vars(i));
bool propagated = false;
for (int j = 0; j < size; ++j) {
if (i == j) continue;
if (context->DomainContains(all_diff.vars(j), value)) {
context->IntersectDomainWith(all_diff.vars(j),
Domain(value).Complement());
if (context->is_unsat) return true;
propagated = true;
}
}
if (propagated) {
context->UpdateRuleStats("all_diff: propagated fixed variables");
}
}
if (contains_fixed_variable) {
context->UpdateRuleStats("TODO all_diff: fixed variables");
if (!fixed_variables.empty()) {
std::vector<int> new_variables;
for (int i = 0; i < all_diff.vars_size(); ++i) {
// We cannot check the domain here, as it may have been fixed by the
// propagation loop above. In that case, it will be picked up by this
// presolve rule in the next iteration.
if (!gtl::ContainsKey(fixed_variables, i)) {
new_variables.push_back(all_diff.vars(i));
}
}
CHECK_EQ(all_diff.vars_size(),
new_variables.size() + fixed_variables.size());
all_diff.mutable_vars()->Clear();
for (const int var : new_variables) {
all_diff.add_vars(var);
}
context->UpdateRuleStats("all_diff: removed fixed variables");
return true;
}
return false;
}
@@ -2325,8 +2381,9 @@ bool PresolveCircuit(ConstraintProto* ct, PresolveContext* context) {
bool PresolveAutomaton(ConstraintProto* ct, PresolveContext* context) {
if (HasEnforcementLiteral(*ct)) return false;
AutomatonConstraintProto& proto = *ct->mutable_automaton();
if (proto.vars_size() == 0 || proto.transition_label_size() == 0)
if (proto.vars_size() == 0 || proto.transition_label_size() == 0) {
return false;
}
bool all_affine = true;
std::vector<AffineRelation::Relation> affine_relations;
@@ -3459,13 +3516,8 @@ void TryToSimplifyDomains(PresolveContext* context) {
if (domain.Size() == 2 && domain.NumIntervals() == 1 && domain.Min() != 0) {
// Shifted Boolean variable.
const int new_var_index = context->working_model->variables_size();
IntegerVariableProto* const var_proto =
context->working_model->add_variables();
var_proto->add_domain(0);
var_proto->add_domain(1);
const int new_var_index = context->NewBoolVar();
const int64 offset = domain.Min();
context->InitializeNewDomains();
ConstraintProto* const ct = context->working_model->add_constraints();
LinearConstraintProto* const lin = ct->mutable_linear();

View File

@@ -37,72 +37,23 @@ void IntegerEncoder::FullyEncodeVariable(IntegerVariable var) {
CHECK(!VariableIsFullyEncoded(var));
CHECK_EQ(0, sat_solver_->CurrentDecisionLevel());
CHECK(!(*domains_)[var].IsEmpty()); // UNSAT. We don't deal with that here.
CHECK_LT((*domains_)[var].Size(), 100000)
<< "Domain too large for full encoding.";
std::vector<IntegerValue> values;
// TODO(user): Maybe we can optimize the literal creation order and their
// polarity as our default SAT heuristics initially depends on this.
for (const ClosedInterval interval : (*domains_)[var]) {
for (IntegerValue v(interval.start); v <= interval.end; ++v) {
values.push_back(v);
CHECK_LT(values.size(), 100000) << "Domain too large for full encoding.";
GetOrCreateLiteralAssociatedToEquality(var, v);
}
}
std::vector<Literal> literals;
if (values.size() == 1) {
literals.push_back(GetTrueLiteral());
} else if (values.size() == 2) {
literals.push_back(GetOrCreateAssociatedLiteral(
IntegerLiteral::LowerOrEqual(var, values[0])));
literals.push_back(literals.back().Negated());
} else {
for (int i = 0; i < values.size(); ++i) {
const std::pair<IntegerVariable, IntegerValue> key{var, values[i]};
if (gtl::ContainsKey(equality_to_associated_literal_, key)) {
literals.push_back(equality_to_associated_literal_[key]);
} else {
literals.push_back(Literal(sat_solver_->NewBooleanVariable(), true));
}
}
}
// Create the associated literal (<= and >=) in order (best for the
// implications between them). Note that we only create literals like this for
// value inside the domain. This is nice since these will be the only kind of
// literal pushed by Enqueue() (we look at the domain there).
for (int i = 0; i + 1 < literals.size(); ++i) {
const IntegerLiteral i_lit = IntegerLiteral::LowerOrEqual(var, values[i]);
const IntegerLiteral i_lit_negated =
IntegerLiteral::GreaterOrEqual(var, values[i + 1]);
if (i == 0) {
// Special case for the start.
HalfAssociateGivenLiteral(i_lit, literals[0]);
HalfAssociateGivenLiteral(i_lit_negated, literals[0].Negated());
} else if (i + 2 == literals.size()) {
// Special case for the end.
HalfAssociateGivenLiteral(i_lit, literals.back().Negated());
HalfAssociateGivenLiteral(i_lit_negated, literals.back());
} else {
// Normal case.
if (!LiteralIsAssociated(i_lit) || !LiteralIsAssociated(i_lit_negated)) {
const BooleanVariable b = sat_solver_->NewBooleanVariable();
HalfAssociateGivenLiteral(i_lit, Literal(b, true));
HalfAssociateGivenLiteral(i_lit_negated, Literal(b, false));
}
}
}
// Now that all literals are created, wire them together using
// (X == v) <=> (X >= v) and (X <= v).
//
// TODO(user): this is currently in O(n^2) which is potentially bad even if
// we do it only once per variable.
for (int i = 0; i < literals.size(); ++i) {
AssociateToIntegerEqualValue(literals[i], var, values[i]);
}
// Mark var and Negation(var) as fully encoded.
const int required_size = std::max(var, NegationOf(var)).value() + 1;
if (required_size > is_fully_encoded_.size()) {
is_fully_encoded_.resize(required_size, false);
{
const int required_size = std::max(var, NegationOf(var)).value() + 1;
if (required_size > is_fully_encoded_.size()) {
is_fully_encoded_.resize(required_size, false);
}
}
is_fully_encoded_[var] = true;
is_fully_encoded_[NegationOf(var)] = true;

View File

@@ -336,6 +336,10 @@ class IntegerEncoder {
void AddAllImplicationsBetweenAssociatedLiterals();
// Returns the IntegerLiterals that were associated with the given Literal.
//
// Note that more than one IntegerLiterals (possibly on different variables)
// may have been associated to the same literal. We also returns ">= value"
// and "<= value" if lit was associated to "== value".
const InlinedIntegerLiteralVector& GetIntegerLiterals(Literal lit) const {
if (lit.Index() >= reverse_encoding_.size()) {
return empty_integer_literal_vector_;

View File

@@ -137,21 +137,33 @@ std::function<LiteralIndex()> PseudoCost(Model* model) {
};
}
std::function<LiteralIndex()> RandomizeOnRestartSatSolverHeuristic(
Model* model) {
std::function<LiteralIndex()> RandomizeOnRestartHeuristic(Model* model) {
SatSolver* sat_solver = model->GetOrCreate<SatSolver>();
Trail* trail = model->GetOrCreate<Trail>();
SatDecisionPolicy* decision_policy = model->GetOrCreate<SatDecisionPolicy>();
return [sat_solver, trail, decision_policy, model] {
if (sat_solver->CurrentDecisionLevel() == 0) {
RandomizeDecisionHeuristic(model->GetOrCreate<ModelRandomGenerator>(),
model->GetOrCreate<SatParameters>());
decision_policy->ResetDecisionHeuristic();
}
const bool all_assigned = trail->Index() == sat_solver->NumVariables();
return all_assigned ? kNoLiteralIndex
: decision_policy->NextBranch().Index();
};
// The duplication increase the probability of the first heuristics. This is
// wanted because when we randomize the sat parameters, we have more than one
// heuristic for choosing the phase of the decision.
//
// TODO(user): Add other policy and perform more experiments.
std::function<LiteralIndex()> sat_policy = SatSolverHeuristic(model);
std::vector<std::function<LiteralIndex()>> policies{
sat_policy, sat_policy, ExploitLpSolution(sat_policy, model),
ExploitLpSolution(SequentialSearch({PseudoCost(model), sat_policy}),
model)};
int policy_index = 0;
return
[sat_solver, decision_policy, policies, policy_index, model]() mutable {
if (sat_solver->CurrentDecisionLevel() == 0) {
RandomizeDecisionHeuristic(model->GetOrCreate<ModelRandomGenerator>(),
model->GetOrCreate<SatParameters>());
decision_policy->ResetDecisionHeuristic();
std::uniform_int_distribution<int> dist(0, policies.size() - 1);
policy_index = dist(*(model->GetOrCreate<ModelRandomGenerator>()));
}
return policies[policy_index]();
};
}
std::function<LiteralIndex()> FollowHint(
@@ -343,7 +355,7 @@ SatSolver::Status SolveIntegerProblemWithLazyEncoding(
std::function<LiteralIndex()> search;
if (parameters.randomize_search()) {
search = SequentialSearch(
{RandomizeOnRestartSatSolverHeuristic(model), next_decision});
{RandomizeOnRestartHeuristic(model), next_decision});
} else {
search = SequentialSearch({SatSolverHeuristic(model), next_decision});
}
@@ -406,8 +418,7 @@ SatSolver::Status SolveIntegerProblemWithLazyEncoding(
model);
}
case SatParameters::PSEUDO_COST_SEARCH: {
std::function<LiteralIndex()> search;
search = SequentialSearch(
std::function<LiteralIndex()> search = SequentialSearch(
{PseudoCost(model), SatSolverHeuristic(model), next_decision});
if (parameters.exploit_integer_lp_solution() ||
parameters.exploit_all_lp_solution()) {
@@ -416,6 +427,13 @@ SatSolver::Status SolveIntegerProblemWithLazyEncoding(
return SolveProblemWithPortfolioSearch(
{search}, {SatSolverRestartPolicy(model)}, model);
}
case SatParameters::PORTFOLIO_WITH_QUICK_RESTART_SEARCH: {
std::function<LiteralIndex()> search =
SequentialSearch({RandomizeOnRestartHeuristic(model), next_decision});
return SolveProblemWithPortfolioSearch(
{search},
{RestartEveryKFailures(10, model->GetOrCreate<SatSolver>())}, model);
}
}
return SatSolver::LIMIT_REACHED;
}

View File

@@ -13,11 +13,11 @@
syntax = "proto2";
package operations_research.sat;
option java_package = "com.google.ortools.sat";
option java_multiple_files = true;
package operations_research.sat;
// Contains the definitions for all the sat algorithm parameters and their
// default values.
//
@@ -571,6 +571,11 @@ message SatParameters {
// If used, the solver uses the pseudo costs for branching.
PSEUDO_COST_SEARCH = 4;
// Mainly exposed here for testing. This quickly tries a lot of randomized
// heuristics with a low conflict limit. It usually provides a good first
// solution.
PORTFOLIO_WITH_QUICK_RESTART_SEARCH = 5;
}
optional SearchBranching search_branching = 82 [default = AUTOMATIC_SEARCH];