diff --git a/makefiles/Makefile.gen.mk b/makefiles/Makefile.gen.mk index 274fcacc65..2ec43c8f61 100644 --- a/makefiles/Makefile.gen.mk +++ b/makefiles/Makefile.gen.mk @@ -23,6 +23,7 @@ BASE_DEPS = \ $(SRC_DIR)/ortools/base/map_util.h \ $(SRC_DIR)/ortools/base/mathutil.h \ $(SRC_DIR)/ortools/base/murmur.h \ + $(SRC_DIR)/ortools/base/protobuf_util.h \ $(SRC_DIR)/ortools/base/protoutil.h \ $(SRC_DIR)/ortools/base/ptr_util.h \ $(SRC_DIR)/ortools/base/python-swig.h \ @@ -31,6 +32,7 @@ BASE_DEPS = \ $(SRC_DIR)/ortools/base/small_map.h \ $(SRC_DIR)/ortools/base/small_ordered_set.h \ $(SRC_DIR)/ortools/base/status.h \ + $(SRC_DIR)/ortools/base/status_macros.h \ $(SRC_DIR)/ortools/base/statusor.h \ $(SRC_DIR)/ortools/base/stl_util.h \ $(SRC_DIR)/ortools/base/sysinfo.h \ @@ -173,7 +175,8 @@ objs/util/cached_log.$O: ortools/util/cached_log.cc \ objs/util/file_util.$O: ortools/util/file_util.cc ortools/util/file_util.h \ ortools/base/file.h ortools/base/integral_types.h ortools/base/logging.h \ - ortools/base/macros.h ortools/base/status.h ortools/base/recordio.h | $(OBJ_DIR)/util + ortools/base/macros.h ortools/base/status.h ortools/base/recordio.h \ + ortools/base/statusor.h ortools/base/status_macros.h | $(OBJ_DIR)/util $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Sutil$Sfile_util.cc $(OBJ_OUT)$(OBJ_DIR)$Sutil$Sfile_util.$O objs/util/fp_utils.$O: ortools/util/fp_utils.cc ortools/util/fp_utils.h \ @@ -472,9 +475,7 @@ objs/lp_data/matrix_utils.$O: ortools/lp_data/matrix_utils.cc \ $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Slp_data$Smatrix_utils.cc $(OBJ_OUT)$(OBJ_DIR)$Slp_data$Smatrix_utils.$O objs/lp_data/model_reader.$O: ortools/lp_data/model_reader.cc \ - ortools/lp_data/model_reader.h \ - ortools/gen/ortools/linear_solver/linear_solver.pb.h \ - ortools/gen/ortools/util/optional_boolean.pb.h ortools/lp_data/lp_data.h \ + ortools/lp_data/model_reader.h ortools/lp_data/lp_data.h \ ortools/base/hash.h ortools/base/basictypes.h \ ortools/base/integral_types.h ortools/base/logging.h \ ortools/base/macros.h ortools/base/int_type.h \ @@ -485,25 +486,29 @@ objs/lp_data/model_reader.$O: ortools/lp_data/model_reader.cc \ ortools/util/return_macros.h ortools/lp_data/sparse_column.h \ ortools/lp_data/sparse_vector.h ortools/graph/iterators.h \ ortools/util/fp_utils.h ortools/base/file.h ortools/base/status.h \ - ortools/lp_data/mps_reader.h ortools/base/commandlineflags.h \ - ortools/base/map_util.h ortools/lp_data/proto_utils.h \ - ortools/util/file_util.h ortools/base/recordio.h | $(OBJ_DIR)/lp_data + ortools/gen/ortools/linear_solver/linear_solver.pb.h \ + ortools/gen/ortools/util/optional_boolean.pb.h \ + ortools/lp_data/proto_utils.h ortools/util/file_util.h \ + ortools/base/recordio.h ortools/base/statusor.h | $(OBJ_DIR)/lp_data $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Slp_data$Smodel_reader.cc $(OBJ_OUT)$(OBJ_DIR)$Slp_data$Smodel_reader.$O objs/lp_data/mps_reader.$O: ortools/lp_data/mps_reader.cc \ - ortools/lp_data/mps_reader.h ortools/base/commandlineflags.h \ - ortools/base/hash.h ortools/base/basictypes.h \ - ortools/base/integral_types.h ortools/base/logging.h \ - ortools/base/macros.h ortools/base/int_type.h \ + ortools/lp_data/mps_reader.h ortools/base/protobuf_util.h \ + ortools/base/logging.h ortools/base/integral_types.h \ + ortools/base/macros.h ortools/base/canonical_errors.h \ + ortools/base/status.h ortools/base/commandlineflags.h \ + ortools/base/filelineiter.h ortools/base/file.h ortools/base/hash.h \ + ortools/base/basictypes.h ortools/base/int_type.h \ ortools/base/int_type_indexed_vector.h ortools/base/map_util.h \ - ortools/lp_data/lp_data.h ortools/gen/ortools/glop/parameters.pb.h \ - ortools/lp_data/lp_types.h ortools/util/bitset.h \ - ortools/lp_data/sparse.h ortools/lp_data/permutation.h \ - ortools/base/random.h ortools/util/return_macros.h \ - ortools/lp_data/sparse_column.h ortools/lp_data/sparse_vector.h \ - ortools/graph/iterators.h ortools/util/fp_utils.h ortools/base/file.h \ - ortools/base/status.h ortools/base/filelineiter.h \ - ortools/lp_data/lp_print_utils.h | $(OBJ_DIR)/lp_data + ortools/base/status_macros.h ortools/base/statusor.h \ + ortools/gen/ortools/linear_solver/linear_solver.pb.h \ + ortools/gen/ortools/util/optional_boolean.pb.h ortools/lp_data/lp_data.h \ + ortools/gen/ortools/glop/parameters.pb.h ortools/lp_data/lp_types.h \ + ortools/util/bitset.h ortools/lp_data/sparse.h \ + ortools/lp_data/permutation.h ortools/base/random.h \ + ortools/util/return_macros.h ortools/lp_data/sparse_column.h \ + ortools/lp_data/sparse_vector.h ortools/graph/iterators.h \ + ortools/util/fp_utils.h | $(OBJ_DIR)/lp_data $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Slp_data$Smps_reader.cc $(OBJ_OUT)$(OBJ_DIR)$Slp_data$Smps_reader.$O objs/lp_data/proto_utils.$O: ortools/lp_data/proto_utils.cc \ @@ -673,7 +678,8 @@ objs/glop/lp_solver.$O: ortools/glop/lp_solver.cc ortools/glop/lp_solver.h \ ortools/lp_data/matrix_scaler.h ortools/lp_data/proto_utils.h \ ortools/gen/ortools/linear_solver/linear_solver.pb.h \ ortools/gen/ortools/util/optional_boolean.pb.h ortools/util/file_util.h \ - ortools/base/file.h ortools/base/status.h ortools/base/recordio.h | $(OBJ_DIR)/glop + ortools/base/file.h ortools/base/status.h ortools/base/recordio.h \ + ortools/base/statusor.h | $(OBJ_DIR)/glop $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Sglop$Slp_solver.cc $(OBJ_OUT)$(OBJ_DIR)$Sglop$Slp_solver.$O objs/glop/lu_factorization.$O: ortools/glop/lu_factorization.cc \ @@ -2617,6 +2623,7 @@ LP_DEPS = \ $(SRC_DIR)/ortools/linear_solver/linear_expr.h \ $(SRC_DIR)/ortools/linear_solver/linear_solver.h \ $(SRC_DIR)/ortools/linear_solver/model_exporter.h \ + $(SRC_DIR)/ortools/linear_solver/model_exporter_swig_helper.h \ $(SRC_DIR)/ortools/linear_solver/model_validator.h \ $(GEN_DIR)/ortools/linear_solver/linear_solver.pb.h @@ -2696,9 +2703,9 @@ objs/linear_solver/cplex_interface.$O: \ $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Slinear_solver$Scplex_interface.cc $(OBJ_OUT)$(OBJ_DIR)$Slinear_solver$Scplex_interface.$O objs/linear_solver/glop_interface.$O: \ - ortools/linear_solver/glop_interface.cc ortools/base/integral_types.h \ - ortools/base/logging.h ortools/base/macros.h ortools/base/hash.h \ - ortools/base/basictypes.h ortools/glop/lp_solver.h \ + ortools/linear_solver/glop_interface.cc ortools/base/hash.h \ + ortools/base/basictypes.h ortools/base/integral_types.h \ + ortools/base/logging.h ortools/base/macros.h ortools/glop/lp_solver.h \ ortools/gen/ortools/glop/parameters.pb.h ortools/glop/preprocessor.h \ ortools/glop/revised_simplex.h ortools/glop/basis_representation.h \ ortools/glop/lu_factorization.h ortools/glop/markowitz.h \ @@ -2767,21 +2774,24 @@ objs/linear_solver/linear_solver.$O: \ ortools/linear_solver/linear_expr.h \ ortools/gen/ortools/linear_solver/linear_solver.pb.h \ ortools/gen/ortools/util/optional_boolean.pb.h \ - ortools/port/proto_utils.h ortools/port/file.h \ - ortools/base/accurate_sum.h ortools/base/canonical_errors.h \ - ortools/base/map_util.h ortools/base/stl_util.h \ - ortools/linear_solver/model_exporter.h ortools/base/hash.h \ - ortools/linear_solver/model_validator.h ortools/util/fp_utils.h | $(OBJ_DIR)/linear_solver + ortools/port/proto_utils.h ortools/base/accurate_sum.h \ + ortools/base/canonical_errors.h ortools/base/map_util.h \ + ortools/base/status_macros.h ortools/base/statusor.h \ + ortools/base/stl_util.h ortools/linear_solver/model_exporter.h \ + ortools/base/hash.h ortools/linear_solver/model_validator.h \ + ortools/port/file.h ortools/util/fp_utils.h | $(OBJ_DIR)/linear_solver $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Slinear_solver$Slinear_solver.cc $(OBJ_OUT)$(OBJ_DIR)$Slinear_solver$Slinear_solver.$O objs/linear_solver/model_exporter.$O: \ ortools/linear_solver/model_exporter.cc \ ortools/linear_solver/model_exporter.h ortools/base/hash.h \ ortools/base/basictypes.h ortools/base/integral_types.h \ - ortools/base/logging.h ortools/base/macros.h \ - ortools/base/commandlineflags.h ortools/base/map_util.h \ + ortools/base/logging.h ortools/base/macros.h ortools/base/statusor.h \ + ortools/base/status.h \ ortools/gen/ortools/linear_solver/linear_solver.pb.h \ - ortools/gen/ortools/util/optional_boolean.pb.h ortools/util/fp_utils.h | $(OBJ_DIR)/linear_solver + ortools/gen/ortools/util/optional_boolean.pb.h \ + ortools/base/canonical_errors.h ortools/base/commandlineflags.h \ + ortools/base/map_util.h ortools/util/fp_utils.h | $(OBJ_DIR)/linear_solver $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Slinear_solver$Smodel_exporter.cc $(OBJ_OUT)$(OBJ_DIR)$Slinear_solver$Smodel_exporter.$O objs/linear_solver/model_validator.$O: \ diff --git a/ortools/sat/table.cc b/ortools/sat/table.cc index 13ca5fa444..b8a26bff1b 100644 --- a/ortools/sat/table.cc +++ b/ortools/sat/table.cc @@ -83,9 +83,147 @@ void ProcessOneColumn( } } +void AddSizeTwoTable( + absl::Span vars, + const std::vector>& tuples, + const std::vector>& values_per_var, + Model* model) { + const int n = vars.size(); + IntegerTrail* const integer_trail = model->GetOrCreate(); + + std::vector> encodings(n); + for (int i = 0; i < n; ++i) { + const std::vector reached_values(values_per_var[i].begin(), + values_per_var[i].end()); + integer_trail->UpdateInitialDomain(vars[i], + Domain::FromValues(reached_values)); + if (values_per_var.size() > 1) { + model->Add(FullyEncodeVariable(vars[i])); + encodings[i] = GetEncoding(vars[i], model); + } + } + + // One variable is fixed. Propagation is complete. + if (values_per_var[0].size() == 1 || values_per_var[1].size() == 1) return; + + absl::flat_hash_map> + left_to_right; + absl::flat_hash_map> + right_to_left; + + for (const auto& tuple : tuples) { + const IntegerValue left_value(tuple[0]); + const IntegerValue right_value(tuple[1]); + if (!encodings[0].contains(left_value) || + !encodings[1].contains(right_value)) { + continue; + } + + Literal left = gtl::FindOrDie(encodings[0], left_value); + Literal right = gtl::FindOrDie(encodings[1], right_value); + left_to_right[left.Index()].insert(right.Index()); + right_to_left[right.Index()].insert(left.Index()); + } + + int implications = 0; + int clause_added = 0; + int large_clause_added = 0; + std::vector clauses; + auto add_support = + [model, &clause_added, &large_clause_added, &implications, &clauses]( + LiteralIndex lit, const absl::flat_hash_set& supports, + int max_support_size) { + if (supports.size() == max_support_size) return; + if (supports.size() == 1) { + model->Add(Implication(Literal(lit), Literal(*supports.begin()))); + implications++; + } else { + clauses.clear(); + for (const LiteralIndex index : supports) { + clauses.push_back(Literal(index)); + } + clauses.push_back(Literal(lit).Negated()); + model->Add(ClauseConstraint(clauses)); + clause_added++; + if (supports.size() > max_support_size / 2) { + large_clause_added++; + } + } + }; + + for (const auto& it : left_to_right) { + add_support(it.first, it.second, values_per_var[1].size()); + } + for (const auto& it : right_to_left) { + add_support(it.first, it.second, values_per_var[0].size()); + } + VLOG(2) << "Table: 2 variables, " << tuples.size() << " tuples encoded using " + << clause_added << " clauses, " << large_clause_added + << " large clauses, " << implications << " implications"; +} + +void ExplorePrefixes(const std::vector>& tuples, + const std::vector>& var_domains, + absl::Span vars, Model* model) { + auto explore_prefix_span = [&](int start, int end) { + // Compute the maximum number of such prefix tuples. + int64 max_num_prefix_tuples = 1; + for (int i = start; i <= end; ++i) { + max_num_prefix_tuples = + CapProd(max_num_prefix_tuples, var_domains[i].size()); + } + + // Abort early. + if (max_num_prefix_tuples > 2 * tuples.size()) return; + + absl::flat_hash_set> prefixes; + for (const std::vector& tuple : tuples) { + prefixes.insert(absl::MakeSpan(&tuple[start], end - start + 1)); + if (prefixes.size() == max_num_prefix_tuples) return; + } + const int num_prefix_tuples = prefixes.size(); + + std::vector> negated_tuples; + + int created = 0; + if (num_prefix_tuples < max_num_prefix_tuples && + max_num_prefix_tuples < num_prefix_tuples * 2) { + std::vector tmp_tuple; + for (int i = 0; i < max_num_prefix_tuples; ++i) { + tmp_tuple.clear(); + int index = i; + for (int j = start; j <= end; ++j) { + tmp_tuple.push_back(var_domains[j][index % var_domains[j].size()]); + index /= var_domains[j].size(); + } + if (!prefixes.contains(tmp_tuple)) { + negated_tuples.push_back(tmp_tuple); + created++; + } + } + AddNegatedTableConstraint(vars.subspan(start, end - start + 1), + negated_tuples, model); + VLOG(2) << " created = " << created << " for " << start << " .. " << end; + } + }; + + for (int end = 1; end < var_domains.size(); ++end) { + explore_prefix_span(0, end); + } + for (int start = 1; start + 1 < var_domains.size(); ++start) { + explore_prefix_span(start, start + 1); + } + for (int start = 1; start + 2 < var_domains.size(); ++start) { + explore_prefix_span(start, start + 2); + } + for (int start = 1; start + 3 < var_domains.size(); ++start) { + explore_prefix_span(start, start + 3); + } +} + } // namespace -void CompressTuples(const std::vector& domain_sizes, int64 any_value, +void CompressTuples(absl::Span domain_sizes, int64 any_value, std::vector>* tuples) { if (tuples->empty()) return; @@ -129,7 +267,7 @@ void CompressTuples(const std::vector& domain_sizes, int64 any_value, // For every column col, and every value val of that column, // the decomposition uses clauses corresponding to the equivalence: // (\/_{row | tuples[row][col] = val} tuple_literals[row]) <=> (vars[col] = val) -void AddTableConstraint(const std::vector& vars, +void AddTableConstraint(absl::Span vars, std::vector> tuples, Model* model) { const int n = vars.size(); IntegerTrail* integer_trail = model->GetOrCreate(); @@ -167,14 +305,9 @@ void AddTableConstraint(const std::vector& vars, // Detect the case when the first n-1 columns are all different. // This encodes the implication table (tuple of size n - 1) implies value. - absl::flat_hash_set> prefixes; - // We cannot use absl::Span() as the tuples will be compressed after this - // step. - std::vector prefix(n); + absl::flat_hash_set> prefixes; for (const std::vector& tuple : tuples) { - prefix = tuple; - prefix.pop_back(); - prefixes.insert(prefix); + prefixes.insert(absl::MakeSpan(tuple.data(), n - 1)); } const int num_prefix_tuples = prefixes.size(); // Compute the maximum number of such prefix tuples. @@ -183,9 +316,40 @@ void AddTableConstraint(const std::vector& vars, max_num_prefix_tuples = CapProd(max_num_prefix_tuples, values_per_var[i].size()); } + if (n == 2) { + AddSizeTwoTable(vars, tuples, values_per_var, model); + return; + } + + std::vector> var_domains(n); + for (int j = 0; j < n; ++j) { + var_domains[j].assign(values_per_var[j].begin(), values_per_var[j].end()); + std::sort(var_domains[j].begin(), var_domains[j].end()); + } + if (vars.size() > 2) { + ExplorePrefixes(tuples, var_domains, vars, model); + } + // Detect if prefix tuples are all different. const bool prefixes_are_all_different = num_prefix_tuples == num_valid_tuples; + // The variable domains have been computed. Fully encode variables. + // Note that in some corner cases (like duplicate vars), as we call + // UpdateInitialDomain(), the domain of other variable could become more + // restricted that values_per_var. For now, we do not try to reach a fixed + // point here. + std::vector> encodings(n); + for (int i = 0; i < n; ++i) { + const std::vector reached_values(values_per_var[i].begin(), + values_per_var[i].end()); + integer_trail->UpdateInitialDomain(vars[i], + Domain::FromValues(reached_values)); + if (values_per_var.size() > 1) { + model->Add(FullyEncodeVariable(vars[i])); + encodings[i] = GetEncoding(vars[i], model); + } + } + // Compress tuples. const int64 any_value = kint64min; std::vector domain_sizes; @@ -218,23 +382,6 @@ void AddTableConstraint(const std::vector& vars, VLOG(2) << message; } - // The variable domains have been computed. Fully encode variables. - // Note that in some corner cases (like duplicate vars), as we call - // UpdateInitialDomain(), the domain of other variable could become more - // restricted that values_per_var. For now, we do not try to reach a fixed - // point here. - std::vector> encodings(n); - for (int i = 0; i < n; ++i) { - const std::vector reached_values(values_per_var[i].begin(), - values_per_var[i].end()); - integer_trail->UpdateInitialDomain(vars[i], - Domain::FromValues(reached_values)); - if (values_per_var.size() > 1) { - model->Add(FullyEncodeVariable(vars[i])); - encodings[i] = GetEncoding(vars[i], model); - } - } - if (tuples.size() == 1) { // Nothing more to do. return; @@ -245,9 +392,9 @@ void AddTableConstraint(const std::vector& vars, // selected because these variables are just used by this constraint, so // only the information "can't be selected" is important. // - // TODO(user): If a value in one column is unique, we don't need to create - // a new BooleanVariable corresponding to this line since we can use the one - // corresponding to this value in that column. + // TODO(user): If a value in one column is unique, we don't need to + // create a new BooleanVariable corresponding to this line since we can use + // the one corresponding to this value in that column. // // Note that if there is just one tuple, there is no need to create such // variables since they are not used. @@ -263,7 +410,6 @@ void AddTableConstraint(const std::vector& vars, model->Add(ClauseConstraint(tuple_literals)); } - // Fully encode the variables using all the values appearing in the tuples. std::vector active_tuple_literals; std::vector active_values; std::vector any_tuple_literals; @@ -324,45 +470,9 @@ void AddTableConstraint(const std::vector& vars, model->Add(ClauseConstraint(clause)); } } - - // This is optional, as it will not propagate more. - // It seems to give better explanation though. - if (prefixes_are_all_different && num_prefix_tuples < max_num_prefix_tuples && - max_num_prefix_tuples <= 2 * num_prefix_tuples) { - // If we have a table of 'unique prefix' => value tuples. - // This table will likely not be negated, as the density of tuples will be - // less than 1 / size of the domain of the last variable. - // Still, just on the prefix part, it can be close to complete. - // For each missing prefix, we can add their negation as a valid clause. - // For this, we complement the prefix tuples, and add a negative table - // constraint on these. - std::vector> var_domains(n - 1); - for (int j = 0; j + 1 < n; ++j) { - var_domains[j].assign(values_per_var[j].begin(), values_per_var[j].end()); - std::sort(var_domains[j].begin(), var_domains[j].end()); - } - std::vector> negated_tuples; - std::vector tmp_tuple; - for (int i = 0; i < max_num_prefix_tuples; ++i) { - tmp_tuple.assign(n - 1, 0); - int index = i; - for (int j = 0; j + 1 < n; ++j) { - tmp_tuple[j] = var_domains[j][index % var_domains[j].size()]; - index /= var_domains[j].size(); - } - if (!prefixes.contains(tmp_tuple)) { - negated_tuples.push_back(tmp_tuple); - } - } - std::vector prefix_vars = vars; - prefix_vars.pop_back(); - VLOG(2) << " - add negated table with " << negated_tuples.size() - << " tuples"; - AddNegatedTableConstraint(prefix_vars, negated_tuples, model); - } } -void AddNegatedTableConstraint(const std::vector& vars, +void AddNegatedTableConstraint(absl::Span vars, std::vector> tuples, Model* model) { const int n = vars.size(); diff --git a/ortools/sat/table.h b/ortools/sat/table.h index 5aa943a70b..4250f77650 100644 --- a/ortools/sat/table.h +++ b/ortools/sat/table.h @@ -17,6 +17,7 @@ #include #include +#include "absl/types/span.h" #include "ortools/base/integral_types.h" #include "ortools/sat/integer.h" #include "ortools/sat/model.h" @@ -28,14 +29,14 @@ namespace sat { // Enforces that the given tuple of variables is equal to one of the given // tuples. All the tuples must have the same size as var.size(), this is // Checked. -void AddTableConstraint(const std::vector& vars, +void AddTableConstraint(absl::Span vars, std::vector> tuples, Model* model); // Enforces that none of the given tuple appear. // // TODO(user): we could propagate more than what we currently do which is simply // adding one clause per tuples. -void AddNegatedTableConstraint(const std::vector& vars, +void AddNegatedTableConstraint(absl::Span vars, std::vector> tuples, Model* model); @@ -54,7 +55,7 @@ std::function LiteralTableConstraint( // regexps. // // This method is exposed for testing purposes. -void CompressTuples(const std::vector& domain_sizes, int64 any_value, +void CompressTuples(absl::Span domain_sizes, int64 any_value, std::vector>* tuples); // Given an automaton defined by a set of 3-tuples: