diff --git a/ortools/sat/table.cc b/ortools/sat/table.cc index 1e3bf58993..646b164cad 100644 --- a/ortools/sat/table.cc +++ b/ortools/sat/table.cc @@ -32,22 +32,6 @@ namespace sat { namespace { -// Transposes the given "matrix". -std::vector> Transpose( - const std::vector>& tuples) { - CHECK(!tuples.empty()); - const int n = tuples.size(); - const int m = tuples[0].size(); - std::vector> transpose(m, std::vector(n)); - for (int i = 0; i < n; ++i) { - CHECK_EQ(m, tuples[i].size()); - for (int j = 0; j < m; ++j) { - transpose[j][i] = tuples[i][j]; - } - } - return transpose; -} - // Converts the vector representation returned by FullDomainEncoding() to a map. absl::flat_hash_map GetEncoding(IntegerVariable var, Model* model) { @@ -79,7 +63,8 @@ void FilterValues(IntegerVariable var, Model* model, void ProcessOneColumn( const std::vector& line_literals, const std::vector& values, - const absl::flat_hash_map& encoding, Model* model) { + const absl::flat_hash_map& encoding, + const std::vector& tuples_with_any, Model* model) { CHECK_EQ(line_literals.size(), values.size()); absl::flat_hash_map> value_to_list_of_line_literals; @@ -88,6 +73,7 @@ void ProcessOneColumn( // is false too (i.e not possible). for (int i = 0; i < values.size(); ++i) { const IntegerValue v = values[i]; + if (!gtl::ContainsKey(encoding, v)) { model->Add(ClauseConstraint({line_literals[i].Negated()})); } else { @@ -101,11 +87,61 @@ void ProcessOneColumn( // false too. for (const auto& entry : value_to_list_of_line_literals) { std::vector clause = entry.second; + for (const Literal any_tuple_literal : tuples_with_any) { + clause.push_back(any_tuple_literal); + } + clause.push_back(gtl::FindOrDie(encoding, entry.first).Negated()); model->Add(ClauseConstraint(clause)); } } +void CompressTuples(const std::vector& domain_sizes, + std::vector>* tuples, int64* any_value) { + *any_value = -kint64min; // Check not conflicting. + if (tuples->empty()) return; + const int initial_num_tuples = tuples->size(); + const int num_vars = (*tuples)[0].size(); + + const auto remove_tuple = [tuples](int pos) { + if (pos < tuples->size() - 1) { + (*tuples)[pos] = tuples->back(); + } + tuples->pop_back(); + }; + + std::vector to_remove; + for (int i = 0; i < num_vars; ++i) { + const int domain_size = domain_sizes[i]; + if (domain_size == 1) continue; + absl::flat_hash_map, std::vector> + masked_tuples_to_indices; + for (int t = 0; t < tuples->size(); ++t) { + std::vector masked_copy = (*tuples)[t]; + if (masked_copy[i] == *any_value) continue; + masked_copy[i] = *any_value; + masked_tuples_to_indices[masked_copy].push_back(t); + } + to_remove.clear(); + for (const auto& it : masked_tuples_to_indices) { + if (it.second.size() == domain_size) { + (*tuples)[it.second.front()] = it.first; + for (int j = 1; j < it.second.size(); j++) { + to_remove.push_back(it.second[j]); + } + } + } + std::sort(to_remove.begin(), to_remove.end(), std::greater()); + for (const int t : to_remove) { + remove_tuple(t); + } + } + if (initial_num_tuples != tuples->size()) { + VLOG(1) << "Compressed tuples from " << initial_num_tuples << " to " + << tuples->size(); + } +} + } // namespace // Makes a static decomposition of a table constraint into clauses. @@ -155,6 +191,14 @@ std::function TableConstraint( // Remove duplicates if any. gtl::STLSortAndRemoveDuplicates(&new_tuples); + // Compress tuples. + int64 any_value = kint64min; + std::vector domain_sizes; + for (int i = 0; i < n; ++i) { + domain_sizes.push_back(values_per_var[i].size()); + } + CompressTuples(domain_sizes, &new_tuples, &any_value); + // Create one Boolean variable per tuple to indicate if it can still be // selected or not. Note that we don't enforce exactly one tuple to be // selected because these variables are just used by this constraint, so @@ -175,24 +219,49 @@ std::function TableConstraint( for (int i = 0; i < new_tuples.size(); ++i) { tuple_literals.emplace_back(model->Add(NewBooleanVariable()), true); } + model->Add(ClauseConstraint(tuple_literals)); } // Fully encode the variables using all the values appearing in the tuples. IntegerTrail* integer_trail = model->GetOrCreate(); - const std::vector> tr_tuples = Transpose(new_tuples); + std::vector active_tuple_literals; + std::vector active_values; + std::vector any_tuple_literals; for (int i = 0; i < n; ++i) { - const int64 first = tr_tuples[i].front(); - if (std::all_of(tr_tuples[i].begin(), tr_tuples[i].end(), - [first](int64 v) { return v == first; })) { + bool has_any = false; + bool all_equals = true; + active_tuple_literals.clear(); + active_values.clear(); + any_tuple_literals.clear(); + const int64 first = new_tuples[0][i]; + + for (int j = 0; j < tuple_literals.size(); ++j) { + const int64 v = new_tuples[j][i]; + + if (v != first) { + all_equals = false; + } + + if (v == any_value) { + has_any = true; + any_tuple_literals.push_back(tuple_literals[j]); + } else { + active_tuple_literals.push_back(tuple_literals[j]); + active_values.push_back(IntegerValue(v)); + } + } + + if (all_equals && !has_any) { model->Add(Equality(vars[i], first)); - } else { + } else if (!active_tuple_literals.empty()) { + const std::vector reached_values(values_per_var[i].begin(), + values_per_var[i].end()); integer_trail->UpdateInitialDomain(vars[i], - Domain::FromValues(tr_tuples[i])); + Domain::FromValues(reached_values)); model->Add(FullyEncodeVariable(vars[i])); - ProcessOneColumn( - tuple_literals, - std::vector(tr_tuples[i].begin(), tr_tuples[i].end()), - GetEncoding(vars[i], model), model); + ProcessOneColumn(active_tuple_literals, active_values, + GetEncoding(vars[i], model), any_tuple_literals, + model); } } }; @@ -234,11 +303,24 @@ std::function NegatedTableConstraintWithoutFullEncoding( const int n = vars.size(); IntegerEncoder* encoder = model->GetOrCreate(); std::vector clause; - for (const std::vector& tuple : tuples) { + + std::vector domain_sizes; + for (int i = 0; i < n; ++i) { + const int64 lb = model->Get(LowerBound(vars[i])); + const int64 ub = model->Get(UpperBound(vars[i])); + domain_sizes.push_back(CapAdd(1, CapSub(ub, lb))); + } + + std::vector> new_tuples = tuples; + int64 any_value = kint64min; + CompressTuples(domain_sizes, &new_tuples, &any_value); + + for (const std::vector& tuple : new_tuples) { clause.clear(); bool add = true; for (int i = 0; i < n; ++i) { const int64 value = tuple[i]; + if (value == any_value) continue; const int64 lb = model->Get(LowerBound(vars[i])); const int64 ub = model->Get(UpperBound(vars[i])); // TODO(user): test the full initial domain instead of just checking @@ -460,13 +542,14 @@ std::function TransitionConstraint( // because it is already implicitely encoded since we have exactly one // transition value. if (!in_encoding.empty()) { - ProcessOneColumn(tuple_literals, in_states, in_encoding, model); + ProcessOneColumn(tuple_literals, in_states, in_encoding, {}, model); } if (!encoding.empty()) { - ProcessOneColumn(tuple_literals, transition_values, encoding, model); + ProcessOneColumn(tuple_literals, transition_values, encoding, {}, + model); } if (!out_encoding.empty()) { - ProcessOneColumn(tuple_literals, out_states, out_encoding, model); + ProcessOneColumn(tuple_literals, out_states, out_encoding, {}, model); } in_encoding = out_encoding; }