split and rewrite CP-SAT search

This commit is contained in:
Laurent Perron
2019-04-16 09:25:34 -07:00
parent 639a4a8c9c
commit 0d443c3569
7 changed files with 325 additions and 246 deletions

View File

@@ -20,11 +20,13 @@
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "ortools/base/int_type.h"
#include "ortools/base/logging.h"
#include "ortools/base/map_util.h"
#include "ortools/base/stl_util.h"
#include "ortools/sat/sat_base.h"
#include "ortools/sat/sat_solver.h"
#include "ortools/util/sorted_interval_list.h"
@@ -94,16 +96,38 @@ void ProcessOneColumn(
}
}
void AddRegularPositiveTable(
void AddPositiveTable(
const std::vector<IntegerVariable>& vars,
const std::vector<std::vector<int64>>& tuples,
const std::vector<absl::flat_hash_set<int64>> values_per_var,
int64 any_value, Model* model) {
IntegerTrail* const integer_trail = model->GetOrCreate<IntegerTrail>();
int64 any_value, bool prefix_mode, Model* model) {
const int n = vars.size();
// Domains have be propagated. Fully encode variables.
std::vector<absl::flat_hash_map<IntegerValue, Literal>> encodings(n);
for (int i = 0; i < n; ++i) {
if (values_per_var.size() > 1) {
model->Add(FullyEncodeVariable(vars[i]));
encodings[i] = GetEncoding(vars[i], model);
}
}
// Create one Boolean variable per tuple to indicate if it can still be
// selected or not. Note that we don't enforce exactly one tuple to be
// selected because these variables are just used by this constraint, so
// only the information "can't be selected" is important.
//
// TODO(user): If a value in one column is unique, we don't need to create
// a new BooleanVariable corresponding to this line since we can use the one
// corresponding to this value in that column.
//
// Note that if there is just one tuple, there is no need to create such
// variables since they are not used.
std::vector<Literal> tuple_literals;
tuple_literals.reserve(tuples.size());
if (tuples.size() == 2) {
if (tuples.size() == 1) {
tuple_literals.push_back(Literal(kTrueLiteralIndex));
} else if (tuples.size() == 2) {
tuple_literals.emplace_back(model->Add(NewBooleanVariable()), true);
tuple_literals.emplace_back(tuple_literals[0].Negated());
} else if (tuples.size() > 2) {
@@ -118,19 +142,14 @@ void AddRegularPositiveTable(
std::vector<IntegerValue> active_values;
std::vector<Literal> any_tuple_literals;
for (int i = 0; i < n; ++i) {
if (values_per_var[i].size() == 1) continue;
active_tuple_literals.clear();
active_values.clear();
any_tuple_literals.clear();
const int64 first = tuples[0][i];
bool all_equals = true;
for (int j = 0; j < tuple_literals.size(); ++j) {
const int64 v = tuples[j][i];
if (v != first) {
all_equals = false;
}
if (v == any_value) {
any_tuple_literals.push_back(tuple_literals[j]);
} else {
@@ -139,51 +158,43 @@ void AddRegularPositiveTable(
}
}
if (all_equals && any_tuple_literals.empty() && first != any_value) {
model->Add(Equality(vars[i], first));
} else if (!active_tuple_literals.empty()) {
const std::vector<int64> reached_values(values_per_var[i].begin(),
values_per_var[i].end());
integer_trail->UpdateInitialDomain(vars[i],
Domain::FromValues(reached_values));
model->Add(FullyEncodeVariable(vars[i]));
if (!active_tuple_literals.empty()) {
ProcessOneColumn(active_tuple_literals, active_values,
GetEncoding(vars[i], model), any_tuple_literals, model);
}
}
}
void AddFullyPrefixedPositiveTable(
const std::vector<IntegerVariable>& vars,
const std::vector<std::vector<int64>>& tuples,
const std::vector<absl::flat_hash_set<int64>> values_per_var,
int64 any_value, Model* model) {
const int n = vars.size();
std::vector<absl::flat_hash_map<IntegerValue, Literal>> encodings(n);
for (int i = 0; i < n; ++i) {
model->Add(FullyEncodeVariable(vars[i]));
encodings[i] = GetEncoding(vars[i], model);
}
if (prefix_mode) {
// For each tuple, we add a clause prefix => last value.
std::vector<Literal> clause;
for (int j = 0; j < tuples.size(); ++j) {
clause.clear();
bool tuple_is_valid = true;
for (int i = 0; i + 1 < n; ++i) {
// Ignore fixed variables.
if (values_per_var[i].size() == 1) continue;
std::vector<Literal> clause;
for (int j = 0; j < tuples.size(); ++j) {
clause.clear();
bool tuple_is_valid = true;
for (int i = 0; i + 1 < n; ++i) {
const int64 v = tuples[j][i];
if (v == any_value) continue;
if (!encodings[i].contains(IntegerValue(v))) {
tuple_is_valid = false;
break;
const int64 v = tuples[j][i];
// Ignored 'any' created during compression.
if (v == any_value) continue;
// Check the validity of the tuple.
const IntegerValue value(v);
if (!encodings[i].contains(value)) {
tuple_is_valid = false;
break;
}
clause.push_back(gtl::FindOrDie(encodings[i], value).Negated());
}
// Add the target of the implication.
const IntegerValue target_value = IntegerValue(tuples[j][n - 1]);
if (tuple_is_valid && encodings[n - 1].contains(target_value)) {
const Literal target_literal =
gtl::FindOrDie(encodings[n - 1], target_value);
clause.push_back(target_literal);
model->Add(ClauseConstraint(clause));
}
clause.push_back(gtl::FindOrDie(encodings[i], IntegerValue(v)).Negated());
}
const IntegerValue target_value = IntegerValue(tuples[j][n - 1]);
if (tuple_is_valid && encodings[n - 1].contains(target_value)) {
const Literal target_literal =
gtl::FindOrDie(encodings[n - 1], target_value);
clause.push_back(target_literal);
model->Add(ClauseConstraint(clause));
}
}
}
@@ -197,7 +208,6 @@ void CompressTuples(const std::vector<int64>& domain_sizes, int64 any_value,
// Remove duplicates if any.
gtl::STLSortAndRemoveDuplicates(tuples);
const int initial_num_tuples = tuples->size();
const int num_vars = (*tuples)[0].size();
std::vector<int> to_remove;
@@ -228,10 +238,6 @@ void CompressTuples(const std::vector<int64>& domain_sizes, int64 any_value,
tuples->pop_back();
}
}
if (initial_num_tuples != tuples->size()) {
VLOG(1) << "Compressed tuples from " << initial_num_tuples << " to "
<< tuples->size();
}
}
// Makes a static decomposition of a table constraint into clauses.
@@ -243,8 +249,10 @@ void AddTableConstraint(const std::vector<IntegerVariable>& vars,
std::vector<std::vector<int64>> tuples, Model* model) {
const int n = vars.size();
IntegerTrail* integer_trail = model->GetOrCreate<IntegerTrail>();
const int num_original_tuples = tuples.size();
// Compute the set of possible values for each variable (from the table).
// Remove invalid tuples along the way.
std::vector<absl::flat_hash_set<int64>> values_per_var(n);
int index = 0;
while (index < tuples.size()) {
@@ -267,12 +275,15 @@ void AddTableConstraint(const std::vector<IntegerVariable>& vars,
index++;
}
}
const int num_valid_tuples = tuples.size();
if (tuples.empty()) {
model->GetOrCreate<SatSolver>()->NotifyThatModelIsUnsat();
return;
}
// Detect the case when the first n-1 columns are all different.
// This encodes the implication table (tuple of size n - 1) implies value.
absl::flat_hash_set<std::vector<int64>> prefixes;
std::vector<int64> prefix(n);
for (const std::vector<int64>& tuple : tuples) {
@@ -280,19 +291,14 @@ void AddTableConstraint(const std::vector<IntegerVariable>& vars,
prefix.pop_back();
prefixes.insert(prefix);
}
double prefix_space_size = 1.0;
const int num_prefix_tuples = prefixes.size();
// Compute the maximum number of such prefix tuples.
double max_num_prefix_tuples = 1.0;
for (int i = 0; i + 1 < n; ++i) {
prefix_space_size *= values_per_var[i].size();
}
const bool prefix_is_key = prefixes.size() == tuples.size();
const bool prefix_is_covering_domain = prefixes.size() == prefix_space_size;
if (prefix_is_key && !prefix_is_covering_domain) {
VLOG(2) << "tuples = " << prefixes.size()
<< " space = " << prefix_space_size << " | " << n;
} else if (prefix_is_key && prefix_is_covering_domain) {
VLOG(2) << "prefix = " << prefixes.size() << " | " << n;
max_num_prefix_tuples *= values_per_var[i].size();
}
// Detect if prefix tuples are all different.
const bool prefix_are_all_different = num_prefix_tuples == num_valid_tuples;
// Compress tuples.
const int64 any_value = kint64min;
@@ -301,23 +307,64 @@ void AddTableConstraint(const std::vector<IntegerVariable>& vars,
domain_sizes.push_back(values_per_var[i].size());
}
CompressTuples(domain_sizes, any_value, &tuples);
const int num_compressed_tuples = tuples.size();
// Create one Boolean variable per tuple to indicate if it can still be
// selected or not. Note that we don't enforce exactly one tuple to be
// selected because these variables are just used by this constraint, so
// only the information "can't be selected" is important.
//
// TODO(user): If a value in one column is unique, we don't need to create
// a new BooleanVariable corresponding to this line since we can use the one
// corresponding to this value in that column.
//
// Note that if there is just one tuple, there is no need to create such
// variables since they are not used.
if (prefix_is_key && prefix_is_covering_domain) {
AddFullyPrefixedPositiveTable(vars, tuples, values_per_var, any_value,
model);
} else {
AddRegularPositiveTable(vars, tuples, values_per_var, any_value, model);
if (VLOG_IS_ON(2)) {
std::string message = absl::StrCat(
"Table: ", n, " variables, original tuples = ", num_original_tuples);
if (num_valid_tuples != num_original_tuples) {
absl::StrAppend(&message, ", valid tuples = ", num_valid_tuples);
}
if (prefix_are_all_different) {
if (num_prefix_tuples < max_num_prefix_tuples) {
absl::StrAppend(&message, ", partial prefix = ", num_prefix_tuples, "/",
max_num_prefix_tuples);
} else {
absl::StrAppend(&message, ", full prefix = true");
}
} else {
absl::StrAppend(&message, ", num prefix tuples = ", prefixes.size());
}
if (num_compressed_tuples != num_valid_tuples) {
absl::StrAppend(&message,
", compressed tuples = ", num_compressed_tuples);
}
LOG(INFO) << message;
}
AddPositiveTable(vars, tuples, values_per_var, any_value,
prefix_are_all_different, model);
if (prefix_are_all_different && num_prefix_tuples < max_num_prefix_tuples) {
// If we have a table of 'unique prefix' => value tuples.
// This table will likely not be negated, as the density of tuples will be
// less than 1 / size of the domain of the last variable.
// Still, just on the prefix part, it can be close to complete.
// For each missing prefix, we can add their negation as a valid clause.
// For this, we negate the prefix tuples, and add a negative table
// constraint on these.
std::vector<std::vector<int64>> var_domains(n - 1);
for (int j = 0; j + 1 < n; ++j) {
var_domains[j].assign(values_per_var[j].begin(), values_per_var[j].end());
std::sort(var_domains[j].begin(), var_domains[j].end());
}
std::vector<std::vector<int64>> negated_tuples;
std::vector<int64> tmp_tuple;
for (int i = 0; i < max_num_prefix_tuples; ++i) {
tmp_tuple.assign(n - 1, 0);
int index = i;
for (int j = 0; j + 1 < n; ++j) {
tmp_tuple[j] = var_domains[j][index % var_domains[j].size()];
index /= var_domains[j].size();
}
if (!gtl::ContainsKey(prefixes, tmp_tuple)) {
negated_tuples.push_back(tmp_tuple);
}
}
std::vector<IntegerVariable> prefix_vars = vars;
prefix_vars.pop_back();
VLOG(2) << " . add negated table with " << negated_tuples.size()
<< " tuples";
AddNegatedTableConstraint(prefix_vars, negated_tuples, model);
}
}