split and rewrite CP-SAT search
This commit is contained in:
@@ -20,11 +20,13 @@
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "ortools/base/int_type.h"
|
||||
#include "ortools/base/logging.h"
|
||||
#include "ortools/base/map_util.h"
|
||||
#include "ortools/base/stl_util.h"
|
||||
#include "ortools/sat/sat_base.h"
|
||||
#include "ortools/sat/sat_solver.h"
|
||||
#include "ortools/util/sorted_interval_list.h"
|
||||
|
||||
@@ -94,16 +96,38 @@ void ProcessOneColumn(
|
||||
}
|
||||
}
|
||||
|
||||
void AddRegularPositiveTable(
|
||||
void AddPositiveTable(
|
||||
const std::vector<IntegerVariable>& vars,
|
||||
const std::vector<std::vector<int64>>& tuples,
|
||||
const std::vector<absl::flat_hash_set<int64>> values_per_var,
|
||||
int64 any_value, Model* model) {
|
||||
IntegerTrail* const integer_trail = model->GetOrCreate<IntegerTrail>();
|
||||
int64 any_value, bool prefix_mode, Model* model) {
|
||||
const int n = vars.size();
|
||||
|
||||
// Domains have be propagated. Fully encode variables.
|
||||
std::vector<absl::flat_hash_map<IntegerValue, Literal>> encodings(n);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (values_per_var.size() > 1) {
|
||||
model->Add(FullyEncodeVariable(vars[i]));
|
||||
encodings[i] = GetEncoding(vars[i], model);
|
||||
}
|
||||
}
|
||||
|
||||
// Create one Boolean variable per tuple to indicate if it can still be
|
||||
// selected or not. Note that we don't enforce exactly one tuple to be
|
||||
// selected because these variables are just used by this constraint, so
|
||||
// only the information "can't be selected" is important.
|
||||
//
|
||||
// TODO(user): If a value in one column is unique, we don't need to create
|
||||
// a new BooleanVariable corresponding to this line since we can use the one
|
||||
// corresponding to this value in that column.
|
||||
//
|
||||
// Note that if there is just one tuple, there is no need to create such
|
||||
// variables since they are not used.
|
||||
std::vector<Literal> tuple_literals;
|
||||
tuple_literals.reserve(tuples.size());
|
||||
if (tuples.size() == 2) {
|
||||
if (tuples.size() == 1) {
|
||||
tuple_literals.push_back(Literal(kTrueLiteralIndex));
|
||||
} else if (tuples.size() == 2) {
|
||||
tuple_literals.emplace_back(model->Add(NewBooleanVariable()), true);
|
||||
tuple_literals.emplace_back(tuple_literals[0].Negated());
|
||||
} else if (tuples.size() > 2) {
|
||||
@@ -118,19 +142,14 @@ void AddRegularPositiveTable(
|
||||
std::vector<IntegerValue> active_values;
|
||||
std::vector<Literal> any_tuple_literals;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (values_per_var[i].size() == 1) continue;
|
||||
|
||||
active_tuple_literals.clear();
|
||||
active_values.clear();
|
||||
any_tuple_literals.clear();
|
||||
const int64 first = tuples[0][i];
|
||||
bool all_equals = true;
|
||||
|
||||
for (int j = 0; j < tuple_literals.size(); ++j) {
|
||||
const int64 v = tuples[j][i];
|
||||
|
||||
if (v != first) {
|
||||
all_equals = false;
|
||||
}
|
||||
|
||||
if (v == any_value) {
|
||||
any_tuple_literals.push_back(tuple_literals[j]);
|
||||
} else {
|
||||
@@ -139,51 +158,43 @@ void AddRegularPositiveTable(
|
||||
}
|
||||
}
|
||||
|
||||
if (all_equals && any_tuple_literals.empty() && first != any_value) {
|
||||
model->Add(Equality(vars[i], first));
|
||||
} else if (!active_tuple_literals.empty()) {
|
||||
const std::vector<int64> reached_values(values_per_var[i].begin(),
|
||||
values_per_var[i].end());
|
||||
integer_trail->UpdateInitialDomain(vars[i],
|
||||
Domain::FromValues(reached_values));
|
||||
model->Add(FullyEncodeVariable(vars[i]));
|
||||
if (!active_tuple_literals.empty()) {
|
||||
ProcessOneColumn(active_tuple_literals, active_values,
|
||||
GetEncoding(vars[i], model), any_tuple_literals, model);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AddFullyPrefixedPositiveTable(
|
||||
const std::vector<IntegerVariable>& vars,
|
||||
const std::vector<std::vector<int64>>& tuples,
|
||||
const std::vector<absl::flat_hash_set<int64>> values_per_var,
|
||||
int64 any_value, Model* model) {
|
||||
const int n = vars.size();
|
||||
std::vector<absl::flat_hash_map<IntegerValue, Literal>> encodings(n);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
model->Add(FullyEncodeVariable(vars[i]));
|
||||
encodings[i] = GetEncoding(vars[i], model);
|
||||
}
|
||||
if (prefix_mode) {
|
||||
// For each tuple, we add a clause prefix => last value.
|
||||
std::vector<Literal> clause;
|
||||
for (int j = 0; j < tuples.size(); ++j) {
|
||||
clause.clear();
|
||||
bool tuple_is_valid = true;
|
||||
for (int i = 0; i + 1 < n; ++i) {
|
||||
// Ignore fixed variables.
|
||||
if (values_per_var[i].size() == 1) continue;
|
||||
|
||||
std::vector<Literal> clause;
|
||||
for (int j = 0; j < tuples.size(); ++j) {
|
||||
clause.clear();
|
||||
bool tuple_is_valid = true;
|
||||
for (int i = 0; i + 1 < n; ++i) {
|
||||
const int64 v = tuples[j][i];
|
||||
if (v == any_value) continue;
|
||||
if (!encodings[i].contains(IntegerValue(v))) {
|
||||
tuple_is_valid = false;
|
||||
break;
|
||||
const int64 v = tuples[j][i];
|
||||
// Ignored 'any' created during compression.
|
||||
if (v == any_value) continue;
|
||||
|
||||
// Check the validity of the tuple.
|
||||
const IntegerValue value(v);
|
||||
if (!encodings[i].contains(value)) {
|
||||
tuple_is_valid = false;
|
||||
break;
|
||||
}
|
||||
clause.push_back(gtl::FindOrDie(encodings[i], value).Negated());
|
||||
}
|
||||
|
||||
// Add the target of the implication.
|
||||
const IntegerValue target_value = IntegerValue(tuples[j][n - 1]);
|
||||
if (tuple_is_valid && encodings[n - 1].contains(target_value)) {
|
||||
const Literal target_literal =
|
||||
gtl::FindOrDie(encodings[n - 1], target_value);
|
||||
clause.push_back(target_literal);
|
||||
model->Add(ClauseConstraint(clause));
|
||||
}
|
||||
clause.push_back(gtl::FindOrDie(encodings[i], IntegerValue(v)).Negated());
|
||||
}
|
||||
const IntegerValue target_value = IntegerValue(tuples[j][n - 1]);
|
||||
if (tuple_is_valid && encodings[n - 1].contains(target_value)) {
|
||||
const Literal target_literal =
|
||||
gtl::FindOrDie(encodings[n - 1], target_value);
|
||||
clause.push_back(target_literal);
|
||||
model->Add(ClauseConstraint(clause));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -197,7 +208,6 @@ void CompressTuples(const std::vector<int64>& domain_sizes, int64 any_value,
|
||||
// Remove duplicates if any.
|
||||
gtl::STLSortAndRemoveDuplicates(tuples);
|
||||
|
||||
const int initial_num_tuples = tuples->size();
|
||||
const int num_vars = (*tuples)[0].size();
|
||||
|
||||
std::vector<int> to_remove;
|
||||
@@ -228,10 +238,6 @@ void CompressTuples(const std::vector<int64>& domain_sizes, int64 any_value,
|
||||
tuples->pop_back();
|
||||
}
|
||||
}
|
||||
if (initial_num_tuples != tuples->size()) {
|
||||
VLOG(1) << "Compressed tuples from " << initial_num_tuples << " to "
|
||||
<< tuples->size();
|
||||
}
|
||||
}
|
||||
|
||||
// Makes a static decomposition of a table constraint into clauses.
|
||||
@@ -243,8 +249,10 @@ void AddTableConstraint(const std::vector<IntegerVariable>& vars,
|
||||
std::vector<std::vector<int64>> tuples, Model* model) {
|
||||
const int n = vars.size();
|
||||
IntegerTrail* integer_trail = model->GetOrCreate<IntegerTrail>();
|
||||
const int num_original_tuples = tuples.size();
|
||||
|
||||
// Compute the set of possible values for each variable (from the table).
|
||||
// Remove invalid tuples along the way.
|
||||
std::vector<absl::flat_hash_set<int64>> values_per_var(n);
|
||||
int index = 0;
|
||||
while (index < tuples.size()) {
|
||||
@@ -267,12 +275,15 @@ void AddTableConstraint(const std::vector<IntegerVariable>& vars,
|
||||
index++;
|
||||
}
|
||||
}
|
||||
const int num_valid_tuples = tuples.size();
|
||||
|
||||
if (tuples.empty()) {
|
||||
model->GetOrCreate<SatSolver>()->NotifyThatModelIsUnsat();
|
||||
return;
|
||||
}
|
||||
|
||||
// Detect the case when the first n-1 columns are all different.
|
||||
// This encodes the implication table (tuple of size n - 1) implies value.
|
||||
absl::flat_hash_set<std::vector<int64>> prefixes;
|
||||
std::vector<int64> prefix(n);
|
||||
for (const std::vector<int64>& tuple : tuples) {
|
||||
@@ -280,19 +291,14 @@ void AddTableConstraint(const std::vector<IntegerVariable>& vars,
|
||||
prefix.pop_back();
|
||||
prefixes.insert(prefix);
|
||||
}
|
||||
double prefix_space_size = 1.0;
|
||||
const int num_prefix_tuples = prefixes.size();
|
||||
// Compute the maximum number of such prefix tuples.
|
||||
double max_num_prefix_tuples = 1.0;
|
||||
for (int i = 0; i + 1 < n; ++i) {
|
||||
prefix_space_size *= values_per_var[i].size();
|
||||
}
|
||||
const bool prefix_is_key = prefixes.size() == tuples.size();
|
||||
const bool prefix_is_covering_domain = prefixes.size() == prefix_space_size;
|
||||
|
||||
if (prefix_is_key && !prefix_is_covering_domain) {
|
||||
VLOG(2) << "tuples = " << prefixes.size()
|
||||
<< " space = " << prefix_space_size << " | " << n;
|
||||
} else if (prefix_is_key && prefix_is_covering_domain) {
|
||||
VLOG(2) << "prefix = " << prefixes.size() << " | " << n;
|
||||
max_num_prefix_tuples *= values_per_var[i].size();
|
||||
}
|
||||
// Detect if prefix tuples are all different.
|
||||
const bool prefix_are_all_different = num_prefix_tuples == num_valid_tuples;
|
||||
|
||||
// Compress tuples.
|
||||
const int64 any_value = kint64min;
|
||||
@@ -301,23 +307,64 @@ void AddTableConstraint(const std::vector<IntegerVariable>& vars,
|
||||
domain_sizes.push_back(values_per_var[i].size());
|
||||
}
|
||||
CompressTuples(domain_sizes, any_value, &tuples);
|
||||
const int num_compressed_tuples = tuples.size();
|
||||
|
||||
// Create one Boolean variable per tuple to indicate if it can still be
|
||||
// selected or not. Note that we don't enforce exactly one tuple to be
|
||||
// selected because these variables are just used by this constraint, so
|
||||
// only the information "can't be selected" is important.
|
||||
//
|
||||
// TODO(user): If a value in one column is unique, we don't need to create
|
||||
// a new BooleanVariable corresponding to this line since we can use the one
|
||||
// corresponding to this value in that column.
|
||||
//
|
||||
// Note that if there is just one tuple, there is no need to create such
|
||||
// variables since they are not used.
|
||||
if (prefix_is_key && prefix_is_covering_domain) {
|
||||
AddFullyPrefixedPositiveTable(vars, tuples, values_per_var, any_value,
|
||||
model);
|
||||
} else {
|
||||
AddRegularPositiveTable(vars, tuples, values_per_var, any_value, model);
|
||||
if (VLOG_IS_ON(2)) {
|
||||
std::string message = absl::StrCat(
|
||||
"Table: ", n, " variables, original tuples = ", num_original_tuples);
|
||||
if (num_valid_tuples != num_original_tuples) {
|
||||
absl::StrAppend(&message, ", valid tuples = ", num_valid_tuples);
|
||||
}
|
||||
if (prefix_are_all_different) {
|
||||
if (num_prefix_tuples < max_num_prefix_tuples) {
|
||||
absl::StrAppend(&message, ", partial prefix = ", num_prefix_tuples, "/",
|
||||
max_num_prefix_tuples);
|
||||
} else {
|
||||
absl::StrAppend(&message, ", full prefix = true");
|
||||
}
|
||||
} else {
|
||||
absl::StrAppend(&message, ", num prefix tuples = ", prefixes.size());
|
||||
}
|
||||
if (num_compressed_tuples != num_valid_tuples) {
|
||||
absl::StrAppend(&message,
|
||||
", compressed tuples = ", num_compressed_tuples);
|
||||
}
|
||||
LOG(INFO) << message;
|
||||
}
|
||||
AddPositiveTable(vars, tuples, values_per_var, any_value,
|
||||
prefix_are_all_different, model);
|
||||
|
||||
if (prefix_are_all_different && num_prefix_tuples < max_num_prefix_tuples) {
|
||||
// If we have a table of 'unique prefix' => value tuples.
|
||||
// This table will likely not be negated, as the density of tuples will be
|
||||
// less than 1 / size of the domain of the last variable.
|
||||
// Still, just on the prefix part, it can be close to complete.
|
||||
// For each missing prefix, we can add their negation as a valid clause.
|
||||
// For this, we negate the prefix tuples, and add a negative table
|
||||
// constraint on these.
|
||||
std::vector<std::vector<int64>> var_domains(n - 1);
|
||||
for (int j = 0; j + 1 < n; ++j) {
|
||||
var_domains[j].assign(values_per_var[j].begin(), values_per_var[j].end());
|
||||
std::sort(var_domains[j].begin(), var_domains[j].end());
|
||||
}
|
||||
std::vector<std::vector<int64>> negated_tuples;
|
||||
std::vector<int64> tmp_tuple;
|
||||
for (int i = 0; i < max_num_prefix_tuples; ++i) {
|
||||
tmp_tuple.assign(n - 1, 0);
|
||||
int index = i;
|
||||
for (int j = 0; j + 1 < n; ++j) {
|
||||
tmp_tuple[j] = var_domains[j][index % var_domains[j].size()];
|
||||
index /= var_domains[j].size();
|
||||
}
|
||||
if (!gtl::ContainsKey(prefixes, tmp_tuple)) {
|
||||
negated_tuples.push_back(tmp_tuple);
|
||||
}
|
||||
}
|
||||
std::vector<IntegerVariable> prefix_vars = vars;
|
||||
prefix_vars.pop_back();
|
||||
VLOG(2) << " . add negated table with " << negated_tuples.size()
|
||||
<< " tuples";
|
||||
AddNegatedTableConstraint(prefix_vars, negated_tuples, model);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user