compress tables in CP-SAT
This commit is contained in:
@@ -32,22 +32,6 @@ namespace sat {
|
||||
|
||||
namespace {
|
||||
|
||||
// Transposes the given "matrix".
|
||||
std::vector<std::vector<int64>> Transpose(
|
||||
const std::vector<std::vector<int64>>& tuples) {
|
||||
CHECK(!tuples.empty());
|
||||
const int n = tuples.size();
|
||||
const int m = tuples[0].size();
|
||||
std::vector<std::vector<int64>> transpose(m, std::vector<int64>(n));
|
||||
for (int i = 0; i < n; ++i) {
|
||||
CHECK_EQ(m, tuples[i].size());
|
||||
for (int j = 0; j < m; ++j) {
|
||||
transpose[j][i] = tuples[i][j];
|
||||
}
|
||||
}
|
||||
return transpose;
|
||||
}
|
||||
|
||||
// Converts the vector representation returned by FullDomainEncoding() to a map.
|
||||
absl::flat_hash_map<IntegerValue, Literal> GetEncoding(IntegerVariable var,
|
||||
Model* model) {
|
||||
@@ -79,7 +63,8 @@ void FilterValues(IntegerVariable var, Model* model,
|
||||
void ProcessOneColumn(
|
||||
const std::vector<Literal>& line_literals,
|
||||
const std::vector<IntegerValue>& values,
|
||||
const absl::flat_hash_map<IntegerValue, Literal>& encoding, Model* model) {
|
||||
const absl::flat_hash_map<IntegerValue, Literal>& encoding,
|
||||
const std::vector<Literal>& tuples_with_any, Model* model) {
|
||||
CHECK_EQ(line_literals.size(), values.size());
|
||||
absl::flat_hash_map<IntegerValue, std::vector<Literal>>
|
||||
value_to_list_of_line_literals;
|
||||
@@ -88,6 +73,7 @@ void ProcessOneColumn(
|
||||
// is false too (i.e not possible).
|
||||
for (int i = 0; i < values.size(); ++i) {
|
||||
const IntegerValue v = values[i];
|
||||
|
||||
if (!gtl::ContainsKey(encoding, v)) {
|
||||
model->Add(ClauseConstraint({line_literals[i].Negated()}));
|
||||
} else {
|
||||
@@ -101,11 +87,61 @@ void ProcessOneColumn(
|
||||
// false too.
|
||||
for (const auto& entry : value_to_list_of_line_literals) {
|
||||
std::vector<Literal> clause = entry.second;
|
||||
for (const Literal any_tuple_literal : tuples_with_any) {
|
||||
clause.push_back(any_tuple_literal);
|
||||
}
|
||||
|
||||
clause.push_back(gtl::FindOrDie(encoding, entry.first).Negated());
|
||||
model->Add(ClauseConstraint(clause));
|
||||
}
|
||||
}
|
||||
|
||||
void CompressTuples(const std::vector<int64>& domain_sizes,
|
||||
std::vector<std::vector<int64>>* tuples, int64* any_value) {
|
||||
*any_value = -kint64min; // Check not conflicting.
|
||||
if (tuples->empty()) return;
|
||||
const int initial_num_tuples = tuples->size();
|
||||
const int num_vars = (*tuples)[0].size();
|
||||
|
||||
const auto remove_tuple = [tuples](int pos) {
|
||||
if (pos < tuples->size() - 1) {
|
||||
(*tuples)[pos] = tuples->back();
|
||||
}
|
||||
tuples->pop_back();
|
||||
};
|
||||
|
||||
std::vector<int> to_remove;
|
||||
for (int i = 0; i < num_vars; ++i) {
|
||||
const int domain_size = domain_sizes[i];
|
||||
if (domain_size == 1) continue;
|
||||
absl::flat_hash_map<const std::vector<int64>, std::vector<int>>
|
||||
masked_tuples_to_indices;
|
||||
for (int t = 0; t < tuples->size(); ++t) {
|
||||
std::vector<int64> masked_copy = (*tuples)[t];
|
||||
if (masked_copy[i] == *any_value) continue;
|
||||
masked_copy[i] = *any_value;
|
||||
masked_tuples_to_indices[masked_copy].push_back(t);
|
||||
}
|
||||
to_remove.clear();
|
||||
for (const auto& it : masked_tuples_to_indices) {
|
||||
if (it.second.size() == domain_size) {
|
||||
(*tuples)[it.second.front()] = it.first;
|
||||
for (int j = 1; j < it.second.size(); j++) {
|
||||
to_remove.push_back(it.second[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::sort(to_remove.begin(), to_remove.end(), std::greater<int>());
|
||||
for (const int t : to_remove) {
|
||||
remove_tuple(t);
|
||||
}
|
||||
}
|
||||
if (initial_num_tuples != tuples->size()) {
|
||||
VLOG(1) << "Compressed tuples from " << initial_num_tuples << " to "
|
||||
<< tuples->size();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Makes a static decomposition of a table constraint into clauses.
|
||||
@@ -155,6 +191,14 @@ std::function<void(Model*)> TableConstraint(
|
||||
// Remove duplicates if any.
|
||||
gtl::STLSortAndRemoveDuplicates(&new_tuples);
|
||||
|
||||
// Compress tuples.
|
||||
int64 any_value = kint64min;
|
||||
std::vector<int64> domain_sizes;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
domain_sizes.push_back(values_per_var[i].size());
|
||||
}
|
||||
CompressTuples(domain_sizes, &new_tuples, &any_value);
|
||||
|
||||
// Create one Boolean variable per tuple to indicate if it can still be
|
||||
// selected or not. Note that we don't enforce exactly one tuple to be
|
||||
// selected because these variables are just used by this constraint, so
|
||||
@@ -175,24 +219,49 @@ std::function<void(Model*)> TableConstraint(
|
||||
for (int i = 0; i < new_tuples.size(); ++i) {
|
||||
tuple_literals.emplace_back(model->Add(NewBooleanVariable()), true);
|
||||
}
|
||||
model->Add(ClauseConstraint(tuple_literals));
|
||||
}
|
||||
|
||||
// Fully encode the variables using all the values appearing in the tuples.
|
||||
IntegerTrail* integer_trail = model->GetOrCreate<IntegerTrail>();
|
||||
const std::vector<std::vector<int64>> tr_tuples = Transpose(new_tuples);
|
||||
std::vector<Literal> active_tuple_literals;
|
||||
std::vector<IntegerValue> active_values;
|
||||
std::vector<Literal> any_tuple_literals;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
const int64 first = tr_tuples[i].front();
|
||||
if (std::all_of(tr_tuples[i].begin(), tr_tuples[i].end(),
|
||||
[first](int64 v) { return v == first; })) {
|
||||
bool has_any = false;
|
||||
bool all_equals = true;
|
||||
active_tuple_literals.clear();
|
||||
active_values.clear();
|
||||
any_tuple_literals.clear();
|
||||
const int64 first = new_tuples[0][i];
|
||||
|
||||
for (int j = 0; j < tuple_literals.size(); ++j) {
|
||||
const int64 v = new_tuples[j][i];
|
||||
|
||||
if (v != first) {
|
||||
all_equals = false;
|
||||
}
|
||||
|
||||
if (v == any_value) {
|
||||
has_any = true;
|
||||
any_tuple_literals.push_back(tuple_literals[j]);
|
||||
} else {
|
||||
active_tuple_literals.push_back(tuple_literals[j]);
|
||||
active_values.push_back(IntegerValue(v));
|
||||
}
|
||||
}
|
||||
|
||||
if (all_equals && !has_any) {
|
||||
model->Add(Equality(vars[i], first));
|
||||
} else {
|
||||
} else if (!active_tuple_literals.empty()) {
|
||||
const std::vector<int64> reached_values(values_per_var[i].begin(),
|
||||
values_per_var[i].end());
|
||||
integer_trail->UpdateInitialDomain(vars[i],
|
||||
Domain::FromValues(tr_tuples[i]));
|
||||
Domain::FromValues(reached_values));
|
||||
model->Add(FullyEncodeVariable(vars[i]));
|
||||
ProcessOneColumn(
|
||||
tuple_literals,
|
||||
std::vector<IntegerValue>(tr_tuples[i].begin(), tr_tuples[i].end()),
|
||||
GetEncoding(vars[i], model), model);
|
||||
ProcessOneColumn(active_tuple_literals, active_values,
|
||||
GetEncoding(vars[i], model), any_tuple_literals,
|
||||
model);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -234,11 +303,24 @@ std::function<void(Model*)> NegatedTableConstraintWithoutFullEncoding(
|
||||
const int n = vars.size();
|
||||
IntegerEncoder* encoder = model->GetOrCreate<IntegerEncoder>();
|
||||
std::vector<Literal> clause;
|
||||
for (const std::vector<int64>& tuple : tuples) {
|
||||
|
||||
std::vector<int64> domain_sizes;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
const int64 lb = model->Get(LowerBound(vars[i]));
|
||||
const int64 ub = model->Get(UpperBound(vars[i]));
|
||||
domain_sizes.push_back(CapAdd(1, CapSub(ub, lb)));
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64>> new_tuples = tuples;
|
||||
int64 any_value = kint64min;
|
||||
CompressTuples(domain_sizes, &new_tuples, &any_value);
|
||||
|
||||
for (const std::vector<int64>& tuple : new_tuples) {
|
||||
clause.clear();
|
||||
bool add = true;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
const int64 value = tuple[i];
|
||||
if (value == any_value) continue;
|
||||
const int64 lb = model->Get(LowerBound(vars[i]));
|
||||
const int64 ub = model->Get(UpperBound(vars[i]));
|
||||
// TODO(user): test the full initial domain instead of just checking
|
||||
@@ -460,13 +542,14 @@ std::function<void(Model*)> TransitionConstraint(
|
||||
// because it is already implicitely encoded since we have exactly one
|
||||
// transition value.
|
||||
if (!in_encoding.empty()) {
|
||||
ProcessOneColumn(tuple_literals, in_states, in_encoding, model);
|
||||
ProcessOneColumn(tuple_literals, in_states, in_encoding, {}, model);
|
||||
}
|
||||
if (!encoding.empty()) {
|
||||
ProcessOneColumn(tuple_literals, transition_values, encoding, model);
|
||||
ProcessOneColumn(tuple_literals, transition_values, encoding, {},
|
||||
model);
|
||||
}
|
||||
if (!out_encoding.empty()) {
|
||||
ProcessOneColumn(tuple_literals, out_states, out_encoding, model);
|
||||
ProcessOneColumn(tuple_literals, out_states, out_encoding, {}, model);
|
||||
}
|
||||
in_encoding = out_encoding;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user