From 08f556c52018bb684ddac845ca82ca124f4df385 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Thu, 22 Sep 2016 13:55:16 +0200 Subject: [PATCH] more work on sat; initial connection to the flatzinc interpreter --- makefiles/Makefile.cpp.mk | 5 + makefiles/Makefile.gen.mk | 27 +++- src/flatzinc/constraints.cc | 24 ++-- src/flatzinc/fz.cc | 11 +- src/flatzinc/presolve.cc | 6 +- src/flatzinc/sat_fz_solver.cc | 228 +++++++++++++++++++++++++++++++ src/flatzinc/sat_fz_solver.h | 28 ++++ src/graph/util.h | 184 ++++++++++++++++++++----- src/sat/boolean_problem.cc | 3 +- src/sat/disjunctive.cc | 152 ++++++++++----------- src/sat/disjunctive.h | 39 ++++-- src/sat/integer.cc | 247 ++++++++++++++++++++++++++++++---- src/sat/integer.h | 95 ++++++++++++- src/sat/integer_expr.cc | 8 +- src/sat/intervals.h | 3 - src/sat/precedences.cc | 2 +- src/sat/sat_base.h | 2 +- src/sat/sat_solver.cc | 2 +- src/sat/sat_solver.h | 39 +++++- src/sat/table.cc | 242 +++++++++++++++++++++++++++++++++ src/sat/table.h | 48 +++++++ src/util/rev.h | 129 ++++++++++++++++++ 22 files changed, 1328 insertions(+), 196 deletions(-) create mode 100644 src/flatzinc/sat_fz_solver.cc create mode 100644 src/flatzinc/sat_fz_solver.h create mode 100644 src/sat/table.cc create mode 100644 src/sat/table.h create mode 100644 src/util/rev.h diff --git a/makefiles/Makefile.cpp.mk b/makefiles/Makefile.cpp.mk index 7cb87e18e6..b636f50b68 100755 --- a/makefiles/Makefile.cpp.mk +++ b/makefiles/Makefile.cpp.mk @@ -96,6 +96,7 @@ FLATZINC_DEPS = \ $(SRC_DIR)/flatzinc/presolve.h \ $(SRC_DIR)/flatzinc/reporting.h \ $(SRC_DIR)/flatzinc/sat_constraint.h \ + $(SRC_DIR)/flatzinc/sat_fz_solver.h \ $(SRC_DIR)/flatzinc/solver_data.h \ $(SRC_DIR)/flatzinc/solver.h \ $(SRC_DIR)/flatzinc/solver_util.h \ @@ -191,6 +192,7 @@ FLATZINC_OBJS=\ $(OBJ_DIR)/flatzinc/presolve.$O \ $(OBJ_DIR)/flatzinc/reporting.$O \ $(OBJ_DIR)/flatzinc/sat_constraint.$O \ + $(OBJ_DIR)/flatzinc/sat_fz_solver.$O \ $(OBJ_DIR)/flatzinc/solver.$O \ $(OBJ_DIR)/flatzinc/solver_data.$O \ $(OBJ_DIR)/flatzinc/solver_util.$O @@ -234,6 +236,9 @@ $(OBJ_DIR)/flatzinc/reporting.$O: $(SRC_DIR)/flatzinc/reporting.cc $(FLATZINC_DE $(OBJ_DIR)/flatzinc/sat_constraint.$O: $(SRC_DIR)/flatzinc/sat_constraint.cc $(FLATZINC_DEPS) $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sflatzinc$Ssat_constraint.cc $(OBJ_OUT)$(OBJ_DIR)$Sflatzinc$Ssat_constraint.$O +$(OBJ_DIR)/flatzinc/sat_fz_solver.$O: $(SRC_DIR)/flatzinc/sat_fz_solver.cc $(FLATZINC_DEPS) + $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sflatzinc$Ssat_fz_solver.cc $(OBJ_OUT)$(OBJ_DIR)$Sflatzinc$Ssat_fz_solver.$O + $(OBJ_DIR)/flatzinc/solver.$O: $(SRC_DIR)/flatzinc/solver.cc $(FLATZINC_DEPS) $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sflatzinc$Ssolver.cc $(OBJ_OUT)$(OBJ_DIR)$Sflatzinc$Ssolver.$O diff --git a/makefiles/Makefile.gen.mk b/makefiles/Makefile.gen.mk index 7fe463448e..d137c1a39c 100644 --- a/makefiles/Makefile.gen.mk +++ b/makefiles/Makefile.gen.mk @@ -128,6 +128,9 @@ $(SRC_DIR)/base/status.h: \ $(SRC_DIR)/base/statusor.h: \ $(SRC_DIR)/base/status.h +$(SRC_DIR)/base/stringpiece_utils.h: \ + $(SRC_DIR)/base/stringpiece.h + $(SRC_DIR)/base/stringprintf.h: \ $(SRC_DIR)/base/stringpiece.h @@ -150,9 +153,6 @@ $(SRC_DIR)/base/sysinfo.h: \ $(SRC_DIR)/base/thorough_hash.h: \ $(SRC_DIR)/base/integral_types.h -$(SRC_DIR)/base/threadpool.h: \ - $(SRC_DIR)/base/callback.h - $(SRC_DIR)/base/timer.h: \ $(SRC_DIR)/base/basictypes.h \ $(SRC_DIR)/base/logging.h \ @@ -350,6 +350,10 @@ $(SRC_DIR)/util/range_query_function.h: \ $(SRC_DIR)/util/rational_approximation.h: \ $(SRC_DIR)/base/integral_types.h +$(SRC_DIR)/util/rev.h: \ + $(SRC_DIR)/base/logging.h \ + $(SRC_DIR)/base/map_util.h + $(SRC_DIR)/util/running_stat.h: \ $(SRC_DIR)/base/logging.h \ $(SRC_DIR)/base/macros.h @@ -1154,8 +1158,8 @@ $(SRC_DIR)/graph/shortestpaths.h: \ $(SRC_DIR)/graph/util.h: \ $(SRC_DIR)/graph/graph.h \ - $(SRC_DIR)/base/hash.h \ $(SRC_DIR)/base/join.h \ + $(SRC_DIR)/base/map_util.h \ $(SRC_DIR)/base/murmur.h \ $(SRC_DIR)/base/numbers.h \ $(SRC_DIR)/base/split.h \ @@ -1444,6 +1448,7 @@ SAT_LIB_OBJS = \ $(OBJ_DIR)/sat/sat_solver.$O \ $(OBJ_DIR)/sat/simplification.$O \ $(OBJ_DIR)/sat/symmetry.$O \ + $(OBJ_DIR)/sat/table.$O \ $(OBJ_DIR)/sat/util.$O \ $(OBJ_DIR)/sat/boolean_problem.pb.$O \ $(OBJ_DIR)/sat/sat_parameters.pb.$O @@ -1498,9 +1503,11 @@ $(SRC_DIR)/sat/integer.h: \ $(SRC_DIR)/sat/sat_solver.h \ $(SRC_DIR)/base/int_type.h \ $(SRC_DIR)/base/join.h \ + $(SRC_DIR)/base/map_util.h \ $(SRC_DIR)/base/port.h \ $(SRC_DIR)/util/bitset.h \ $(SRC_DIR)/util/iterators.h \ + $(SRC_DIR)/util/rev.h \ $(SRC_DIR)/util/saturated_arithmetic.h $(SRC_DIR)/sat/intervals.h: \ @@ -1584,6 +1591,10 @@ $(SRC_DIR)/sat/symmetry.h: \ $(SRC_DIR)/util/stats.h \ $(SRC_DIR)/algorithms/sparse_permutation.h +$(SRC_DIR)/sat/table.h: \ + $(SRC_DIR)/sat/integer.h \ + $(SRC_DIR)/sat/model.h + $(SRC_DIR)/sat/util.h: \ $(GEN_DIR)/sat/sat_parameters.pb.h \ $(SRC_DIR)/base/random.h @@ -1708,6 +1719,13 @@ $(OBJ_DIR)/sat/symmetry.$O: \ $(SRC_DIR)/sat/symmetry.h $(CCC) $(CFLAGS) -c $(SRC_DIR)/sat/symmetry.cc $(OBJ_OUT)$(OBJ_DIR)$Ssat$Ssymmetry.$O +$(OBJ_DIR)/sat/table.$O: \ + $(SRC_DIR)/sat/table.cc \ + $(SRC_DIR)/sat/table.h \ + $(SRC_DIR)/base/map_util.h \ + $(SRC_DIR)/base/stl_util.h + $(CCC) $(CFLAGS) -c $(SRC_DIR)/sat/table.cc $(OBJ_OUT)$(OBJ_DIR)$Ssat$Stable.$O + $(OBJ_DIR)/sat/util.$O: \ $(SRC_DIR)/sat/util.cc \ $(SRC_DIR)/sat/util.h @@ -3038,3 +3056,4 @@ $(GEN_DIR)/constraint_solver/solver_parameters.pb.h: $(GEN_DIR)/constraint_solve $(OBJ_DIR)/constraint_solver/solver_parameters.pb.$O: $(GEN_DIR)/constraint_solver/solver_parameters.pb.cc $(CCC) $(CFLAGS) -c $(GEN_DIR)/constraint_solver/solver_parameters.pb.cc $(OBJ_OUT)$(OBJ_DIR)$Sconstraint_solver$Ssolver_parameters.pb.$O + diff --git a/src/flatzinc/constraints.cc b/src/flatzinc/constraints.cc index 02c236d44d..524c6282b3 100644 --- a/src/flatzinc/constraints.cc +++ b/src/flatzinc/constraints.cc @@ -1596,6 +1596,18 @@ void ParseShortIntLin(fz::SolverData* data, fz::Constraint* ct, IntExpr** left, *left = nullptr; *right = nullptr; + if (fzvars.empty() && size != 0) { + // We have a constant array. + CHECK_EQ(ct->arguments[1].values.size(), size); + int64 result = 0; + for (int i = 0; i < size; ++i) { + result += coefficients[i] * ct->arguments[1].values[i]; + } + *left = solver->MakeIntConst(result); + *right = solver->MakeIntConst(rhs); + return; + } + switch (size) { case 0: { *left = solver->MakeIntConst(0); @@ -1727,7 +1739,8 @@ bool AreAllVariablesBoolean(fz::SolverData* data, fz::Constraint* ct) { bool ExtractLinAsShort(fz::SolverData* data, fz::Constraint* ct) { const int size = ct->arguments[0].values.size(); if (ct->arguments[1].variables.empty()) { - return false; + // Constant linear scalprods will be treated correctly by ParseShortLin. + return true; } switch (size) { case 0: @@ -2086,7 +2099,6 @@ void ExtractIntLinLeReif(fz::SolverData* data, fz::Constraint* ct) { void ExtractIntLinNe(fz::SolverData* data, fz::Constraint* ct) { Solver* const solver = data->solver(); - const int size = ct->arguments[0].values.size(); if (ExtractLinAsShort(data, ct)) { IntExpr* left = nullptr; IntExpr* right = nullptr; @@ -2097,12 +2109,8 @@ void ExtractIntLinNe(fz::SolverData* data, fz::Constraint* ct) { std::vector coeffs; int64 rhs = 0; ParseLongIntLin(data, ct, &vars, &coeffs, &rhs); - if (AreAllBooleans(vars) && AreAllOnes(coeffs)) { - PostBooleanSumInRange(data->Sat(), solver, vars, rhs, size); - } else { - AddConstraint(solver, ct, solver->MakeNonEquality( - solver->MakeScalProd(vars, coeffs), rhs)); - } + AddConstraint(solver, ct, solver->MakeNonEquality( + solver->MakeScalProd(vars, coeffs), rhs)); } } diff --git a/src/flatzinc/fz.cc b/src/flatzinc/fz.cc index 660a99cfe1..4e0f9d12b9 100644 --- a/src/flatzinc/fz.cc +++ b/src/flatzinc/fz.cc @@ -33,6 +33,7 @@ #include "flatzinc/parser.h" #include "flatzinc/presolve.h" #include "flatzinc/reporting.h" +#include "flatzinc/sat_fz_solver.h" #include "flatzinc/solver.h" #include "flatzinc/solver_util.h" @@ -62,6 +63,8 @@ DEFINE_bool( "Increase verbosity of the impact based search when used in free search."); DEFINE_bool(verbose_mt, false, "Verbose Multi-Thread."); +DEFINE_bool(use_fz_sat, false, "Use the SAT/CP solver."); + DECLARE_bool(log_prefix); DECLARE_bool(use_sat); @@ -300,6 +303,12 @@ int main(int argc, char** argv) { operations_research::fz::Model model = operations_research::fz::ParseFlatzincModel(input, !FLAGS_read_from_stdin); - operations_research::fz::Solve(model); + + if (FLAGS_use_fz_sat) { + operations_research::sat::SolveWithSat( + model, operations_research::fz::SingleThreadParameters()); + } else { + operations_research::fz::Solve(model); + } return EXIT_SUCCESS; } diff --git a/src/flatzinc/presolve.cc b/src/flatzinc/presolve.cc index 4b12e8ae38..7429cafeca 100644 --- a/src/flatzinc/presolve.cc +++ b/src/flatzinc/presolve.cc @@ -1649,10 +1649,10 @@ bool Presolver::SimplifyIntNeReif(Constraint* ct, std::string* log) { ContainsKey(int_eq_reif_map_[ct->arguments[0].Var()], ct->arguments[1].Var())) { log->append("merge constraint with opposite constraint"); - IntegerVariable* const opposite = + IntegerVariable* const opposite_boolvar = int_eq_reif_map_[ct->arguments[0].Var()][ct->arguments[1].Var()]; - ct->arguments[0] = Argument::IntVarRef(opposite); - ct->arguments[1] = Argument::IntVarRef(ct->arguments[1].Var()); + ct->arguments[0] = Argument::IntVarRef(opposite_boolvar); + ct->arguments[1] = Argument::IntVarRef(ct->arguments[2].Var()); ct->RemoveArg(2); ct->type = "bool_not"; return true; diff --git a/src/flatzinc/sat_fz_solver.cc b/src/flatzinc/sat_fz_solver.cc new file mode 100644 index 0000000000..ad4a2758df --- /dev/null +++ b/src/flatzinc/sat_fz_solver.cc @@ -0,0 +1,228 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "flatzinc/sat_fz_solver.h" + +#include "base/map_util.h" +#include "flatzinc/logging.h" +#include "sat/disjunctive.h" +#include "sat/integer.h" +#include "sat/integer_expr.h" +#include "sat/intervals.h" +#include "sat/model.h" +#include "sat/sat_solver.h" +#include "sat/table.h" + +namespace operations_research { +namespace sat { + +// TODO(user): deal with constant variables. +// In a first version we could create constant IntegerVariable in the solver. +IntegerVariable LookupVar( + const hash_map& var_map, + const fz::Argument& argument) { + CHECK_EQ(argument.type, fz::Argument::INT_VAR_REF); + return FindOrDie(var_map, argument.variables[0]); +} + +// TODO(user): deal with constant variables. +// In a first version we could create constant IntegerVariable in the solver. +std::vector LookupVars( + const hash_map& var_map, + const fz::Argument& argument) { + CHECK_EQ(argument.type, fz::Argument::INT_VAR_REF_ARRAY); + std::vector result; + for (fz::IntegerVariable* var : argument.variables) { + result.push_back(FindOrDie(var_map, var)); + } + return result; +} + +void ExtractIntMin( + const fz::Constraint& ct, + const hash_map& var_map, + Model* sat_model) { + const IntegerVariable a = LookupVar(var_map, ct.arguments[0]); + const IntegerVariable b = LookupVar(var_map, ct.arguments[1]); + const IntegerVariable c = LookupVar(var_map, ct.arguments[2]); + sat_model->Add(IsEqualToMinOf(c, {a, b})); +} + +void ExtractIntAbs( + const fz::Constraint& ct, + const hash_map& var_map, + Model* sat_model) { + const IntegerVariable v = LookupVar(var_map, ct.arguments[0]); + const IntegerVariable abs = LookupVar(var_map, ct.arguments[1]); + sat_model->Add(IsEqualToMaxOf(abs, {v, NegationOf(v)})); +} + +void ExtractRegular( + const fz::Constraint& ct, + const hash_map& var_map, + Model* sat_model) { + const std::vector vars = LookupVars(var_map, ct.arguments[0]); + const int64 num_states = ct.arguments[1].Value(); + const int64 num_values = ct.arguments[2].Value(); + + const std::vector& next = ct.arguments[3].values; + std::vector> transitions; + int count = 0; + for (int i = 1; i <= num_states; ++i) { + for (int j = 1; j <= num_values; ++j) { + transitions.push_back({i, j, next[count++]}); + } + } + + const int64 initial_state = ct.arguments[4].Value(); + + std::vector final_states; + switch (ct.arguments[5].type) { + case fz::Argument::INT_VALUE: { + final_states.push_back(ct.arguments[5].values[0]); + break; + } + case fz::Argument::INT_INTERVAL: { + for (int v = ct.arguments[5].values[0]; v <= ct.arguments[5].values[1]; + ++v) { + final_states.push_back(v); + } + break; + } + case fz::Argument::INT_LIST: { + final_states = ct.arguments[5].values; + break; + } + default: { LOG(FATAL) << "Wrong constraint " << ct.DebugString(); } + } + + sat_model->Add( + TransitionConstraint(vars, transitions, initial_state, final_states)); +} + +// The format is fixed in the flatzinc specification. +std::string SolutionString( + const Model& sat_model, + const hash_map& var_map, + const fz::SolutionOutputSpecs& output) { + if (output.variable != nullptr) { + const int64 value = + sat_model.Get(Value(FindOrDie(var_map, output.variable))); + if (output.display_as_boolean) { + return StringPrintf("%s = %s;", output.name.c_str(), + value == 1 ? "true" : "false"); + } else { + return StringPrintf("%s = %" GG_LL_FORMAT "d;", output.name.c_str(), + value); + } + } else { + const int bound_size = output.bounds.size(); + std::string result = + StringPrintf("%s = array%dd(", output.name.c_str(), bound_size); + for (int i = 0; i < bound_size; ++i) { + if (output.bounds[i].max_value != 0) { + result.append(StringPrintf("%" GG_LL_FORMAT "d..%" GG_LL_FORMAT "d, ", + output.bounds[i].min_value, + output.bounds[i].max_value)); + } else { + result.append("{},"); + } + } + result.append("["); + for (int i = 0; i < output.flat_variables.size(); ++i) { + const int64 value = + sat_model.Get(Value(FindOrDie(var_map, output.flat_variables[i]))); + if (output.display_as_boolean) { + result.append(StringPrintf(value ? "true" : "false")); + } else { + result.append(StringPrintf("%" GG_LL_FORMAT "d", value)); + } + if (i != output.flat_variables.size() - 1) { + result.append(", "); + } + } + result.append("]);"); + return result; + } + return ""; +} + +void SolveWithSat(const fz::Model& model, const fz::FlatzincParameters& p) { + Model sat_model; + + // Correspondance between a fz::IntegerVariable and a sat::IntegerVariable. + hash_map var_map; + + // Extract all the variables. + FZLOG << "Extracting " << model.variables().size() << " variables. " + << FZENDL; + for (fz::IntegerVariable* var : model.variables()) { + if (!var->active) continue; + var_map[var] = + sat_model.Add(NewIntegerVariable(var->domain.Min(), var->domain.Max())); + + // TODO(user): encode if it is a list of value? This way constraint using + // it will get the intersection with the proper interval. + } + + // Extract all the constraints. + FZLOG << "Extracting " << model.constraints().size() << " constraints. " + << FZENDL; + std::set unsupported_types; + for (fz::Constraint* ct : model.constraints()) { + if (ct != nullptr && ct->active) { + if (ct->type == "int_min") { + ExtractIntMin(*ct, var_map, &sat_model); + } else if (ct->type == "int_abs") { + ExtractIntAbs(*ct, var_map, &sat_model); + } else if (ct->type == "regular") { + ExtractRegular(*ct, var_map, &sat_model); + } else { + unsupported_types.insert(ct->type); + } + } + } + if (!unsupported_types.empty()) { + FZLOG << "There is unsuported constraints types in this model: " << FZENDL; + for (const std::string& type : unsupported_types) { + FZLOG << " - " << type; + } + return; + } + + // For now assume a decision problem! + // + // TODO(user): deal with other kind of search (optim, all solutions, ...). + // + // TODO(user): Encode IntegerVariable that are not fixed at the end of the + // search. + CHECK(model.objective() == nullptr); + CHECK_EQ(1, p.num_solutions); + FZLOG << "Solving..." << FZENDL; + const SatSolver::Status status = sat_model.GetOrCreate()->Solve(); + FZLOG << "Status: " << status << FZENDL; + + // Output! + std::string solution_string; + for (const fz::SolutionOutputSpecs& output : model.output()) { + solution_string.append(SolutionString(sat_model, var_map, output)); + solution_string.append("\n"); + } + + // Print the solution. + // The "----------" is needed by minizinc. + std::cout << solution_string << "----------" << std::endl; +} + +} // namespace sat +} // namespace operations_research diff --git a/src/flatzinc/sat_fz_solver.h b/src/flatzinc/sat_fz_solver.h new file mode 100644 index 0000000000..a8ca1ce234 --- /dev/null +++ b/src/flatzinc/sat_fz_solver.h @@ -0,0 +1,28 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_FLATZINC_SAT_FZ_SOLVER_H_ +#define OR_TOOLS_FLATZINC_SAT_FZ_SOLVER_H_ + +#include "flatzinc/model.h" +#include "flatzinc/solver.h" + +namespace operations_research { +namespace sat { + +void SolveWithSat(const fz::Model& model, const fz::FlatzincParameters& p); + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_FLATZINC_SAT_FZ_SOLVER_H_ diff --git a/src/graph/util.h b/src/graph/util.h index 67f11570b1..d959bddbbb 100644 --- a/src/graph/util.h +++ b/src/graph/util.h @@ -21,10 +21,12 @@ #include #include -#include "base/numbers.h" #include "base/hash.h" +#include "base/join.h" +#include "base/numbers.h" #include "base/split.h" #include "base/join.h" +#include "base/map_util.h" #include "base/murmur.h" #include "graph/graph.h" #include "util/filelineiter.h" @@ -41,22 +43,68 @@ bool GraphIsSymmetric(const Graph& graph); // Creates a remapped copy of graph "graph", where node i becomes node // new_node_index[i]. The caller takes ownership of the returned graph. -// "new_node_index" must be a valid permutation of [0..num_nodes-1], or -// else the return StatusOr will be in an error state. +// "new_node_index" must be a valid permutation of [0..num_nodes-1] or the +// behavior is undefined (it may die). +// Note that you can call IsValidPermutation() to check it yourself. template -util::StatusOr RemapGraph(const Graph& graph, +std::unique_ptr RemapGraph(const Graph& graph, const std::vector& new_node_index); -// Returns a std::string representation of a graph: one arc per line. Eg.: -// "1->2\n3->3" for a graph with 4 nodes and 2 arcs (1->2) and (3->3). -// Arcs are sorted by their tail, then by the order of OutgoingArcs(). +// Returns true iff the given vector is a permutation of [0..size()-1]. +bool IsValidPermutation(const std::vector& v); + +// Returns a std::string representation of a graph. +enum GraphToStringFormat { + // One arc per line, eg. "3->1". + PRINT_GRAPH_ARCS, + + // One space-separated adjacency list per line, eg. "5 1 3 1". + // Nodes with no outgoing arc get an empty line. + PRINT_GRAPH_ADJACENCY_LISTS, + + // Ditto, but the adjacency lists are sorted. + PRINT_GRAPH_ADJACENCY_LISTS_SORTED, +}; template -std::string GraphToString(const Graph& graph); +std::string GraphToString(const Graph& graph, GraphToStringFormat format); // Returns a copy of "graph", without self-arcs and duplicate arcs. template std::unique_ptr RemoveSelfArcsAndDuplicateArcs(const Graph& graph); +// Given an arc path, changes it to a sub-path with the same source and +// destination but without any cycle. Nothing happen if the path was already +// without cycle. +// +// The graph class should support Tail(arc) and Head(arc). They should both +// return an integer representing the corresponding tail/head of the passed arc. +// +// TODO(user): In some cases, there is more than one possible solution. We could +// take some arc costs and return the cheapest path instead. Or return the +// shortest path in term of number of arcs. +template +void RemoveCyclesFromPath(const Graph& graph, std::vector* arc_path); + +// Returns true iff the given path contains a cycle. +template +bool PathHasCycle(const Graph& graph, const std::vector& arc_path); + +// Returns a vector representing a mapping from arcs to arcs such that each arc +// is mapped to another arc with its (tail, head) flipped, if such an arc +// exists (otherwise it is mapped to -1). +// If the graph is symmetric, the returned mapping is bijective and reflexive, +// i.e. out[out[arc]] = arc for all "arc", where "out" is the returned vector. +// If "die_if_not_symmetric" is true, this function CHECKs() that the graph +// is symmetric. +// +// Self-arcs are always mapped to themselves. +// +// Note that since graphs may have multi-arcs, the mapping isn't necessarily +// unique, hence the function name. +template +std::vector ComputeOnePossibleReverseArcMapping(const Graph& graph, + bool die_if_not_symmetric); + // Read a graph file in the simple ".g" format: the file should be a text file // containing only space-separated integers, whose first line is: // [ @@ -140,29 +188,11 @@ bool GraphIsSymmetric(const Graph& graph) { } template -util::StatusOr RemapGraph(const Graph& old_graph, +std::unique_ptr RemapGraph(const Graph& old_graph, const std::vector& new_node_index) { + DCHECK(IsValidPermutation(new_node_index)) << "Invalid permutation"; const int num_nodes = old_graph.num_nodes(); - // Quickly verify that "new_node_index" is a valid permutation. - bool ok = new_node_index.size() == num_nodes; - if (ok) { - std::vector tmp_node_mask(old_graph.num_nodes(), false); - for (const int i : new_node_index) { - if (i < 0 || i >= num_nodes || tmp_node_mask[i]) { - ok = false; - break; - } - tmp_node_mask[i] = true; - } - } - if (!ok) { - return util::Status( - util::error::INVALID_ARGUMENT, - StrCat( - "new_node_index is not a valid permutation of [0..num_nodes-1]," - " with num_nodes = ", - num_nodes)); - } + CHECK_EQ(new_node_index.size(), num_nodes); std::unique_ptr new_graph(new Graph(num_nodes, old_graph.num_arcs())); typedef typename Graph::NodeIndex NodeIndex; typedef typename Graph::ArcIndex ArcIndex; @@ -173,16 +203,29 @@ util::StatusOr RemapGraph(const Graph& old_graph, } } new_graph->Build(); - return new_graph.release(); + return new_graph; } template -std::string GraphToString(const Graph& graph) { +std::string GraphToString(const Graph& graph, GraphToStringFormat format) { std::string out; + std::vector adj; for (const typename Graph::NodeIndex node : graph.AllNodes()) { - for (const typename Graph::ArcIndex arc : graph.OutgoingArcs(node)) { - if (!out.empty()) out += '\n'; - StrAppend(&out, node, "->", graph.Head(arc)); + if (format == PRINT_GRAPH_ARCS) { + for (const typename Graph::ArcIndex arc : graph.OutgoingArcs(node)) { + if (!out.empty()) out += '\n'; + StrAppend(&out, node, "->", graph.Head(arc)); + } + } else { // PRINT_GRAPH_ADJACENCY_LISTS[_SORTED] + adj.clear(); + for (const typename Graph::ArcIndex arc : graph.OutgoingArcs(node)) { + adj.push_back(graph.Head(arc)); + } + if (format == PRINT_GRAPH_ADJACENCY_LISTS_SORTED) { + std::sort(adj.begin(), adj.end()); + } + if (node != 0) out += '\n'; + StrAppend(&out, node, ": ", strings::Join(adj, " ")); } } return out; @@ -206,6 +249,79 @@ std::unique_ptr RemoveSelfArcsAndDuplicateArcs(const Graph& graph) { return g; } +template +void RemoveCyclesFromPath(const Graph& graph, std::vector* arc_path) { + if (arc_path->empty()) return; + + // This maps each node to the latest arc in the given path that leaves it. + std::map last_arc_leaving_node; + for (const int arc : *arc_path) last_arc_leaving_node[graph.Tail(arc)] = arc; + + // Special case for the destination. + // Note that this requires that -1 is not a valid arc of Graph. + last_arc_leaving_node[graph.Head(arc_path->back())] = -1; + + // Reconstruct the path by starting at the source and then following the + // "next" arcs. We override the given arc_path at the same time. + int node = graph.Tail(arc_path->front()); + int new_size = 0; + while (new_size < arc_path->size()) { // To prevent cycle on bad input. + const int arc = FindOrDie(last_arc_leaving_node, node); + if (arc == -1) break; + (*arc_path)[new_size++] = arc; + node = graph.Head(arc); + } + arc_path->resize(new_size); +} + +template +bool PathHasCycle(const Graph& graph, const std::vector& arc_path) { + if (arc_path.empty()) return false; + std::set seen; + seen.insert(graph.Tail(arc_path.front())); + for (const int arc : arc_path) { + if (!InsertIfNotPresent(&seen, graph.Head(arc))) return true; + } + return false; +} + +template +std::vector ComputeOnePossibleReverseArcMapping(const Graph& graph, + bool die_if_not_symmetric) { + std::vector reverse_arc(graph.num_arcs(), -1); + hash_multimap, /*arc index*/ int> arc_map; + for (int arc = 0; arc < graph.num_arcs(); ++arc) { + const int tail = graph.Tail(arc); + const int head = graph.Head(arc); + if (tail == head) { + // Special case: directly map any self-arc to itself. + reverse_arc[arc] = arc; + continue; + } + // Lookup for the reverse arc of the current one... + auto it = arc_map.find({head, tail}); + if (it != arc_map.end()) { + // Found a reverse arc! Store the mapping and remove the + // reverse arc from the map. + reverse_arc[arc] = it->second; + reverse_arc[it->second] = arc; + arc_map.erase(it); + } else { + // Reverse arc not in the map. Add the current arc to the map. + arc_map.insert({{tail, head}, arc}); + } + } + // Algorithm check, for debugging. + DCHECK_EQ(std::count(reverse_arc.begin(), reverse_arc.end(), -1), + arc_map.size()); + if (die_if_not_symmetric) { + CHECK_EQ(arc_map.size(), 0) + << "The graph is not symmetric: " << arc_map.size() << " of " + << graph.num_arcs() << " arcs did not have a reverse."; + } + return reverse_arc; +} + template util::StatusOr ReadGraphFile( const std::string& filename, bool directed, diff --git a/src/sat/boolean_problem.cc b/src/sat/boolean_problem.cc index 56fa5acdb8..c842353c28 100644 --- a/src/sat/boolean_problem.cc +++ b/src/sat/boolean_problem.cc @@ -612,8 +612,7 @@ void FindLinearBooleanProblemSymmetries( for (int node = 0; node < graph->num_nodes(); ++node) { new_node_index[node] = next_index_by_class[equivalence_classes[node]]++; } - std::unique_ptr remapped_graph( - RemapGraph(*graph, new_node_index).ValueOrDie()); + std::unique_ptr remapped_graph = RemapGraph(*graph, new_node_index); const util::Status status = WriteGraphToFile( *remapped_graph, FLAGS_debug_dump_symmetry_graph_to_file, /*directed=*/false, class_size); diff --git a/src/sat/disjunctive.cc b/src/sat/disjunctive.cc index 692e625a9f..17aa2fff45 100644 --- a/src/sat/disjunctive.cc +++ b/src/sat/disjunctive.cc @@ -213,6 +213,9 @@ bool DisjunctiveConstraint::Propagate(Trail* trail) { // Loop until we reach the fixed-point. It should be unique (see Petr Villim // PhD). + // + // TODO(user): Some of these passes are idempotent, so there is no need to + // call them if their input didn't change! Improve that. while (true) { const int64 old_timestamp = integer_trail_->num_enqueues(); @@ -292,44 +295,57 @@ void DisjunctiveConstraint::AddMaxEndReason(int t, IntegerValue upper_bound) { IntegerLiteral::LowerOrEqual(end_vars_[t], upper_bound)); } -bool DisjunctiveConstraint::CheckIntervalForConflict(int t, Trail* trail) { - // TODO(user): instead of this, we could propagate MinEnd()/MaxStart() and - // let the code in integer.h detect empty domains. - if (CapAdd(MinStart(t), MinDuration(t)) > MaxEnd(t)) { - integer_reason_.clear(); - integer_reason_.push_back( - integer_trail_->LowerBoundAsLiteral(start_vars_[t])); - integer_reason_.push_back( - integer_trail_->UpperBoundAsLiteral(end_vars_[t])); - if (duration_vars_[t] != kNoIntegerVariable) { - integer_reason_.push_back( - integer_trail_->LowerBoundAsLiteral(duration_vars_[t])); - } +void DisjunctiveConstraint::AddMaxStartReason(int t, IntegerValue upper_bound) { + integer_reason_.push_back( + IntegerLiteral::LowerOrEqual(start_vars_[t], upper_bound)); +} - if (!task_is_currently_present_[t]) { - // We can propagate reason_for_presence_[t] to false. - // Note that it could have been already propagated in which case we do - // nothing. - if (trail->Assignment().LiteralIsFalse( - Literal(reason_for_presence_[t]))) { - return true; - } - DCHECK_NE(reason_for_presence_[t], kNoLiteralIndex); - literal_reason_.clear(); - integer_trail_->EnqueueLiteral(Literal(reason_for_presence_[t]).Negated(), - literal_reason_, integer_reason_, trail); - return true; - } else { - // Conflict. - std::vector* conflict = trail->MutableConflict(); - conflict->clear(); - if (reason_for_presence_[t] != kNoLiteralIndex) { - conflict->push_back(Literal(reason_for_presence_[t]).Negated()); - } - integer_trail_->MergeReasonInto(integer_reason_, conflict); +bool DisjunctiveConstraint::IncreaseMinStart(int t, + IntegerValue new_min_start) { + if (!integer_trail_->Enqueue( + IntegerLiteral::GreaterOrEqual(start_vars_[t], new_min_start), + literal_reason_, integer_reason_)) { + return false; + } + + // We propagate right away the new min-end lower-bound we have. + const IntegerValue min_end_lb = new_min_start + MinDuration(t); + if (MinEnd(t) < min_end_lb) { + integer_reason_.clear(); + literal_reason_.clear(); + AddMinStartReason(t, new_min_start); + AddMinDurationReason(t); + if (!integer_trail_->Enqueue( + IntegerLiteral::GreaterOrEqual(end_vars_[t], min_end_lb), + literal_reason_, integer_reason_)) { return false; } } + + return true; +} + +bool DisjunctiveConstraint::DecreaseMaxEnd(int t, IntegerValue new_max_end) { + if (!integer_trail_->Enqueue( + IntegerLiteral::LowerOrEqual(end_vars_[t], new_max_end), + literal_reason_, integer_reason_)) { + return false; + } + + // We propagate right away the new max-start upper-bound we have. + const IntegerValue max_start_ub = new_max_end - MinDuration(t); + if (MaxStart(t) > max_start_ub) { + integer_reason_.clear(); + literal_reason_.clear(); + AddMaxEndReason(t, new_max_end); + AddMinDurationReason(t); + if (!integer_trail_->Enqueue( + IntegerLiteral::LowerOrEqual(start_vars_[t], max_start_ub), + literal_reason_, integer_reason_)) { + return false; + } + } + return true; } @@ -441,7 +457,7 @@ bool DisjunctiveConstraint::OverloadCheckingPass(IntegerTrail* integer_trail, DCHECK_NE(reason_for_presence_[t], kNoLiteralIndex); integer_trail->EnqueueLiteral( Literal(reason_for_presence_[t]).Negated(), literal_reason_, - integer_reason_, trail); + integer_reason_); } } @@ -509,7 +525,7 @@ bool DisjunctiveConstraint::DetectablePrecedencePass( integer_reason_.clear(); // We need: - // - MaxStart(ct) < MinEnd(t) for the detectable precedence + // - MaxStart(ct) < MinEnd(t) for the detectable precedence. // - MinStart(ct) > window_start for the min_end_of_critical_tasks reason. const IntegerValue window_start = sorted_tasks[critical_index].min_start; for (int i = critical_index; i < sorted_tasks.size(); ++i) { @@ -520,10 +536,7 @@ bool DisjunctiveConstraint::DetectablePrecedencePass( DCHECK_LT(MaxStart(ct), min_end); AddPresenceAndDurationReason(ct); AddMinStartReason(ct, window_start); - - // TODO(user): Add the reason on MaxStart() instead, it will be more - // correct for task with non-fixed duration. - AddMaxEndReason(ct, CapAdd(min_end, MinDuration(ct) - 1)); + AddMaxStartReason(ct, min_end - 1); } // Add the reason for t (we don't need the max-end or presence reason). @@ -532,13 +545,7 @@ bool DisjunctiveConstraint::DetectablePrecedencePass( // This augment the min-start of t and subsequently it can augment the // next min_end_of_critical_tasks, but our deduction is still valid. - if (!integer_trail->Enqueue( - IntegerLiteral::GreaterOrEqual(start_vars_[t], - min_end_of_critical_tasks), - literal_reason_, integer_reason_)) { - return false; - } - if (!CheckIntervalForConflict(t, trail)) return false; + if (!IncreaseMinStart(t, min_end_of_critical_tasks)) return false; // We need to reorder t inside task_set_. task_set_.NotifyEntryIsNowLastIfPresent({t, MinStart(t), MinDuration(t)}); @@ -587,16 +594,13 @@ bool DisjunctiveConstraint::PrecedencePass(IntegerTrail* integer_trail, Literal(reason_for_beeing_before_[ct]).Negated()); } } + + // TODO(user): If var is actually a min-start of an interval, we + // could push the min-end and check the interval consistency right away. if (!integer_trail->Enqueue(IntegerLiteral::GreaterOrEqual(var, min_end), literal_reason_, integer_reason_)) { return false; } - - // TODO(user): for non-optional intervals, we could check right away that - // the domain of var is non-empty. Or for a var that correspond to one - // of the interval bounds, we could detect infeasibility early with - // CheckIntervalForConflict(). Not doing should be fine, it just postpone - // a bit the conflict detection. } } return true; @@ -632,8 +636,8 @@ bool DisjunctiveConstraint::NotLastPass(IntegerTrail* integer_trail, // [(critical tasks) // (duration_t)] // - // So we can deduce that the max-end of t is smaller that the largest - // max-start of the critical tasks. + // So we can deduce that the max-end of t is smaller than or equal to the + // largest max-start of the critical tasks. // // Note that this works as well when task_is_currently_present_[t] is false. int critical_index = 0; @@ -641,49 +645,37 @@ bool DisjunctiveConstraint::NotLastPass(IntegerTrail* integer_trail, task_set_.ComputeMinEnd(/*task_to_ignore=*/t, &critical_index); if (min_end_of_critical_tasks <= MaxStart(t)) continue; - // Find the largest max-start of the critical tasks (excluding t). This - // will be a valid new max-end for t. - IntegerValue new_max_end = kMinIntegerValue; - int task_responsible_for_new_max_end = -1; + // Find the largest max-start of the critical tasks (excluding t). The + // max-end for t need to be smaller than or equal to this. + IntegerValue largest_ct_max_start = kMinIntegerValue; const std::vector& sorted_tasks = task_set_.SortedTasks(); for (int i = critical_index; i < sorted_tasks.size(); ++i) { const int ct = sorted_tasks[i].task; if (t == ct) continue; const IntegerValue max_start = MaxStart(ct); - if (max_start > new_max_end) { - new_max_end = max_start; - task_responsible_for_new_max_end = ct; + if (max_start > largest_ct_max_start) { + largest_ct_max_start = max_start; } } - if (max_end > new_max_end) { + if (max_end > largest_ct_max_start) { literal_reason_.clear(); integer_reason_.clear(); - // We don't need the max-end reason of the critical tasks except the - // one for the task responsible for new_max_end. const IntegerValue window_start = sorted_tasks[critical_index].min_start; for (int i = critical_index; i < sorted_tasks.size(); ++i) { const int ct = sorted_tasks[i].task; if (ct == t) continue; AddPresenceAndDurationReason(ct); AddMinStartReason(ct, window_start); - if (ct == task_responsible_for_new_max_end) { - AddMaxEndReason(ct, MaxEnd(ct)); - } + AddMaxStartReason(ct, largest_ct_max_start); } - // Add the reason for t (we don't need the min-start or presence reason). - AddMinDurationReason(t); - AddMaxEndReason(t, CapAdd(min_end_of_critical_tasks, MinDuration(t) - 1)); + // Add the reason for t, we only need the max-start. + AddMaxStartReason(t, min_end_of_critical_tasks - 1); // Enqueue the new max-end for t. // Note that changing it will not influence the rest of the loop. - if (!integer_trail->Enqueue( - IntegerLiteral::LowerOrEqual(end_vars_[t], new_max_end), - literal_reason_, integer_reason_)) { - return false; - } - if (!CheckIntervalForConflict(t, trail)) return false; + if (!DecreaseMaxEnd(t, largest_ct_max_start)) return false; } } return true; @@ -859,13 +851,9 @@ bool DisjunctiveConstraint::EdgeFindingPass(IntegerTrail* integer_trail, // TODO(user): propagate the precedence Boolean here too? I think it // will be more powerful. Even if eventually all these precedence will // become detectable (see Petr Villim PhD). - if (!integer_trail->Enqueue( - IntegerLiteral::GreaterOrEqual(start_vars_[gray_task], - min_end_of_critical_tasks), - literal_reason_, integer_reason_)) { + if (!IncreaseMinStart(gray_task, min_end_of_critical_tasks)) { return false; } - if (!CheckIntervalForConflict(gray_task, trail)) return false; } // Remove the gray_task from sorted_tasks_. diff --git a/src/sat/disjunctive.h b/src/sat/disjunctive.h index 3d7cf6a7dc..383b548176 100644 --- a/src/sat/disjunctive.h +++ b/src/sat/disjunctive.h @@ -24,19 +24,19 @@ namespace operations_research { namespace sat { -// Enforces a disjunctive (or no overlap) constraints on the given interval +// Enforces a disjunctive (or no overlap) constraint on the given interval // variables. std::function Disjunctive(const std::vector& vars); -// Same as Disjunctive() but also creates a Boolean variables for all the +// Same as Disjunctive() but also creates a Boolean variable for all the // possible precedences of the form (task i is before task j). std::function DisjunctiveWithBooleanPrecedences( const std::vector& vars); // Helper class to compute the min-end of a set of tasks given their min-start -// and min-duration. In Petr Vilim PhD "Global Constraints in Scheduling", this -// corresponds to his Theta-tree except that we use a O(n) implementation for -// most of the function here, not a O(log(n)) one. +// and min-duration. In Petr Vilim's PhD "Global Constraints in Scheduling", +// this corresponds to his Theta-tree except that we use a O(n) implementation +// for most of the function here, not a O(log(n)) one. class TaskSet { public: TaskSet() : optimized_restart_(0) {} @@ -75,9 +75,10 @@ class TaskSet { // // [Bunch of tasks] ... [Bunch of tasks] ... [critical tasks]. // - // We call "critical tasks" the last group. These tasks will be the sole - // responsible for the min-end of the whole set. The returned critical_index - // will be the index of the first critical task in SortedTasks(). + // We call "critical tasks" the last group. These tasks will be solely + // responsible for for the min-end of the whole set. The returned + // critical_index will be the index of the first critical task in + // SortedTasks(). // // A reason for the min end is: // - The min-duration of all the critical tasks. @@ -132,6 +133,9 @@ class DisjunctiveConstraint : public PropagatorInterface { // [(min-duration) ... (min-duration)] // ^ ^ ^ ^ // min-start min-end max-start max-end + // + // Note that for tasks with variable durations, we don't necessarily have + // min-duration between the the min-XXX and max-XXX value. IntegerValue MinDuration(int t) const { return duration_vars_[t] == kNoIntegerVariable ? fixed_durations_[t] @@ -140,11 +144,15 @@ class DisjunctiveConstraint : public PropagatorInterface { IntegerValue MinStart(int t) const { return integer_trail_->LowerBound(start_vars_[t]); } + IntegerValue MaxStart(int t) const { + return integer_trail_->UpperBound(start_vars_[t]); + } + IntegerValue MinEnd(int t) const { + return integer_trail_->LowerBound(end_vars_[t]); + } IntegerValue MaxEnd(int t) const { return integer_trail_->UpperBound(end_vars_[t]); } - IntegerValue MaxStart(int t) const { return MaxEnd(t) - MinDuration(t); } - IntegerValue MinEnd(int t) const { return MinStart(t) + MinDuration(t); } // Helper functions to compute the reason of a propagation. // Append to literal_reason_ and integer_reason_ the corresponding reason. @@ -152,10 +160,15 @@ class DisjunctiveConstraint : public PropagatorInterface { void AddMinDurationReason(int t); void AddMinStartReason(int t, IntegerValue lower_bound); void AddMaxEndReason(int t, IntegerValue upper_bound); + void AddMaxStartReason(int t, IntegerValue upper_bound); - // Checks that the interval [min_start_t, max_end_t] is larger than - // min_duration_t. Returns false and report an conflict otherwise. - bool CheckIntervalForConflict(int t, Trail* trail); + // Enqueues new bounds of an interval. The reasons (literal_reason_ and + // integer_reason_) must already be filled. Note that we automatically push + // min-end and max-start accordingly, so we maintain the invariants: + // - min-end >= min-start + min-duration + // - max-start <= max-end + min-duration + bool IncreaseMinStart(int t, IntegerValue new_min_start); + bool DecreaseMaxEnd(int t, IntegerValue new_max_end); // All these passes use the algorithms described in Petr Vilim PhD "Global // Constraints in Scheduling". Except that we don't use the O(log(n)) balanced diff --git a/src/sat/integer.cc b/src/sat/integer.cc index 1dc48c2d0f..e744fd1b44 100644 --- a/src/sat/integer.cc +++ b/src/sat/integer.cc @@ -18,6 +18,65 @@ namespace operations_research { namespace sat { +void IntegerEncoder::FullyEncodeVariable(IntegerVariable i_var, + std::vector values) { + CHECK_EQ(0, sat_solver_->CurrentDecisionLevel()); + CHECK(!values.empty()); // UNSAT problem. We don't deal with that here. + + STLSortAndRemoveDuplicates(&values); + + // TODO(user): This case is annoying, not sure yet how to best fix the + // variable. There is certainly no need to create a Boolean variable, but + // one needs to talk to IntegerTrail to fix the variable and we don't want + // the encoder to depend on this. So for now we fail here and it is up to + // the caller to deal with this case. + CHECK_NE(values.size(), 1); + + // If the variable has already been fully encoded, for now we check that + // the sets of value is the same. + // + // TODO(user): Take the intersection, and handle that case in the constraints + // creation functions. + if (ContainsKey(full_encoding_index_, i_var)) { + const std::vector& encoding = FullDomainEncoding(i_var); + CHECK_EQ(values.size(), encoding.size()); + for (int i = 0; i < values.size(); ++i) { + CHECK_EQ(values[i], encoding[i].value); + } + return; + } + + std::vector encoding; + if (values.size() == 2) { + const BooleanVariable var = sat_solver_->NewBooleanVariable(); + encoding.push_back({values[0], Literal(var, true)}); + encoding.push_back({values[1], Literal(var, false)}); + } else { + std::vector cst; + for (const IntegerValue value : values) { + const BooleanVariable var = sat_solver_->NewBooleanVariable(); + encoding.push_back({value, Literal(var, true)}); + cst.push_back(LiteralWithCoeff(Literal(var, true), Coefficient(1))); + } + CHECK(sat_solver_->AddLinearConstraint(true, sat::Coefficient(1), true, + sat::Coefficient(1), &cst)); + } + + full_encoding_index_[i_var] = full_encoding_.size(); + full_encoding_.push_back(encoding); // copy because we need it below. + + // Deal with NegationOf(i_var). + // + // TODO(user): This seems a bit wasted, but it does simplify the code at a + // somehow small cost. + std::reverse(encoding.begin(), encoding.end()); + for (auto& entry : encoding) { + entry.value = -entry.value; // Reverse the value. + } + full_encoding_index_[NegationOf(i_var)] = full_encoding_.size(); + full_encoding_.push_back(std::move(encoding)); +} + void IntegerEncoder::AddImplications(IntegerLiteral i_lit, Literal literal) { if (i_lit.var >= encoding_by_var_.size()) { encoding_by_var_.resize(i_lit.var + 1); @@ -89,21 +148,94 @@ bool IntegerTrail::Propagate(Trail* trail) { CHECK_EQ(trail->CurrentDecisionLevel(), integer_decision_levels_.size()); } + // Value encoder. + // + // TODO(user): There is no need to maintain the bounds of such variable if + // they are never used in any constraint! + // + // Algorithm: + // 1/ See if new variables are fully encoded and initialize them. + // 2/ In the loop below, each time a "min" variable was assigned to false, + // update the associated variable bounds, and change the watched "min". + // This step is is O(num variables at false between the old and new min). + // + // The data structure are reversible. + watched_min_.SetLevel(trail->CurrentDecisionLevel()); + current_min_.SetLevel(trail->CurrentDecisionLevel()); + if (encoder_->GetFullyEncodedVariables().size() != num_encoded_variables_) { + num_encoded_variables_ = encoder_->GetFullyEncodedVariables().size(); + + // for now this is only supported at level zero. Otherwise we need to + // inspect the trail to properly compute all the min. + // + // TODO(user): Don't rescan all the variables from scratch, we could only + // scan the new ones. But then we need a mecanism to detect the new ones. + CHECK_EQ(trail->CurrentDecisionLevel(), 0); + for (const auto& entry : encoder_->GetFullyEncodedVariables()) { + IntegerVariable var = entry.first; + const auto& encoding = encoder_->FullDomainEncoding(var); + for (int i = 0; i < encoding.size(); ++i) { + if (!trail_->Assignment().LiteralIsFalse(encoding[i].literal)) { + watched_min_.Set(encoding[i].literal.NegatedIndex(), {var, i}); + current_min_.Set(var, i); + + // No reason because we are at level zero. + if (!Enqueue(IntegerLiteral::GreaterOrEqual(var, encoding[i].value), + {}, {})) { + return false; + } + break; + } + } + } + } + // Process all the "associated" literals and Enqueue() the corresponding // bounds. while (propagation_trail_index_ < trail->Index()) { const Literal literal = (*trail)[propagation_trail_index_++]; - const IntegerLiteral i_lit = encoder_->GetIntegerLiteral(literal); - if (i_lit.var < 0) continue; - // The reason is simply the associated literal. - if (!Enqueue(i_lit, {literal.Negated()}, {})) return false; + // Bound encoder. + const IntegerLiteral i_lit = encoder_->GetIntegerLiteral(literal); + if (i_lit.var >= 0) { + // The reason is simply the associated literal. + if (!Enqueue(i_lit, {literal.Negated()}, {})) return false; + } + + // Value encoder. + if (watched_min_.ContainsKey(literal.Index())) { + // A watched min value just became false. + const auto pair = watched_min_.FindOrDie(literal.Index()); + const IntegerVariable var = pair.first; + const int min = pair.second; + const auto& encoding = encoder_->FullDomainEncoding(var); + std::vector literal_reason = {literal.Negated()}; + for (int i = min + 1; i < encoding.size(); ++i) { + if (!trail_->Assignment().LiteralIsFalse(encoding[i].literal)) { + watched_min_.EraseOrDie(literal.Index()); + watched_min_.Set(encoding[i].literal.NegatedIndex(), {var, i}); + current_min_.Set(var, i); + + // Note that we also need the fact that all smaller value are false + // for the propagation. We use the current lower bound for that. + if (!Enqueue(IntegerLiteral::GreaterOrEqual(var, encoding[i].value), + literal_reason, {LowerBoundAsLiteral(var)})) { + return false; + } + break; + } else { + literal_reason.push_back(encoding[i].literal); + } + } + } } return true; } void IntegerTrail::Untrail(const Trail& trail, int literal_trail_index) { + watched_min_.SetLevel(trail.CurrentDecisionLevel()); + current_min_.SetLevel(trail.CurrentDecisionLevel()); propagation_trail_index_ = std::min(propagation_trail_index_, literal_trail_index); @@ -171,6 +303,28 @@ int IntegerTrail::FindLowestTrailIndexThatExplainBound( return prev_trail_index; } +bool IntegerTrail::EnqueueAssociatedLiteral( + Literal literal, IntegerLiteral i_lit, + const std::vector& literals_reason, + const std::vector& bounds_reason) { + if (!trail_->Assignment().VariableIsAssigned(literal.Variable())) { + // The reason is simply i_lit and will be expanded lazily when needed. + std::vector* unused; + std::vector* integer_reason; + EnqueueLiteral(literal, &unused, &integer_reason); + integer_reason->push_back(i_lit); + return true; + } + if (trail_->Assignment().LiteralIsFalse(literal)) { + std::vector* conflict = trail_->MutableConflict(); + *conflict = literals_reason; + conflict->push_back(literal); + MergeReasonInto(bounds_reason, conflict); + return false; + } + return true; +} + bool IntegerTrail::Enqueue(IntegerLiteral i_lit, const std::vector& literals_reason, const std::vector& bounds_reason) { @@ -178,8 +332,58 @@ bool IntegerTrail::Enqueue(IntegerLiteral i_lit, if (i_lit.bound <= vars_[i_lit.var].current_bound) return true; ++num_enqueues_; - // Check if the integer variable has an empty domain. + // Deal with fully encoded variable. We want to do that first because this may + // make the IntegerLiteral bound stronger. const IntegerVariable var(i_lit.var); + if (current_min_.ContainsKey(var)) { + // Recover the current min, and propagate to false all the values that + // are in [min, i_lit.value). All these literals have the same reason, so + // we use the "same reason as" mecanism. + const int min_index = current_min_.FindOrDie(var); + const auto& encoding = encoder_->FullDomainEncoding(var); + if (i_lit.bound > encoding[min_index].value) { + const Literal negated_min = encoding[min_index].literal.Negated(); + if (!EnqueueAssociatedLiteral(negated_min, i_lit, literals_reason, + bounds_reason)) { + return false; + } + + int i = min_index + 1; + for (; i < encoding.size(); ++i) { + if (i_lit.bound <= encoding[i].value) break; + const Literal literal = encoding[i].literal.Negated(); + if (!trail_->Assignment().VariableIsAssigned(literal.Variable())) { + trail_->EnqueueWithSameReasonAs(literal, negated_min.Variable()); + } else if (trail_->Assignment().LiteralIsFalse(literal)) { + // Conflict. + std::vector* conflict = trail_->MutableConflict(); + *conflict = literals_reason; + conflict->push_back(literal); + MergeReasonInto(bounds_reason, conflict); + return false; + } + } + + if (i == encoding.size()) { + // Conflict: no possible values left. + std::vector* conflict = trail_->MutableConflict(); + *conflict = literals_reason; + MergeReasonInto(bounds_reason, conflict); + return false; + } else { + // We have a new min. + watched_min_.EraseOrDie(encoding[min_index].literal.NegatedIndex()); + watched_min_.Set(encoding[i].literal.NegatedIndex(), {var, i}); + current_min_.Set(var, i); + + // Adjust the bound of i_lit ! + CHECK_GE(encoding[i].value, i_lit.bound); + i_lit.bound = encoding[i].value.value(); + } + } + } + + // Check if the integer variable has an empty domain. if (i_lit.bound > UpperBound(var)) { if (!IsOptional(var) || trail_->Assignment().LiteralIsFalse(Literal(is_empty_literals_[var]))) { @@ -199,8 +403,7 @@ bool IntegerTrail::Enqueue(IntegerLiteral i_lit, if (!trail_->Assignment().LiteralIsTrue(is_empty)) { std::vector* literal_reason_ptr; std::vector* integer_reason_ptr; - EnqueueLiteral(is_empty, &literal_reason_ptr, &integer_reason_ptr, - trail_); + EnqueueLiteral(is_empty, &literal_reason_ptr, &integer_reason_ptr); *literal_reason_ptr = literals_reason; *integer_reason_ptr = bounds_reason; integer_reason_ptr->push_back(UpperBoundAsLiteral(var)); @@ -217,18 +420,8 @@ bool IntegerTrail::Enqueue(IntegerLiteral i_lit, const LiteralIndex literal_index = encoder_->SearchForLiteralAtOrBefore(i_lit); if (literal_index != kNoLiteralIndex) { - const Literal literal(literal_index); - if (!trail_->Assignment().VariableIsAssigned(literal.Variable())) { - std::vector* literal_reason; - std::vector* integer_reason; - EnqueueLiteral(literal, &literal_reason, &integer_reason, trail_); - integer_reason->push_back(i_lit); - } else if (trail_->Assignment().LiteralIsFalse(literal)) { - // Conflict. - std::vector* conflict = trail_->MutableConflict(); - *conflict = literals_reason; - conflict->push_back(literal); - MergeReasonInto(bounds_reason, conflict); + if (!EnqueueAssociatedLiteral(Literal(literal_index), i_lit, + literals_reason, bounds_reason)) { return false; } } @@ -385,9 +578,8 @@ ClauseRef IntegerTrail::Reason(const Trail& trail, int trail_index) const { void IntegerTrail::EnqueueLiteral(Literal literal, std::vector** literal_reason, - std::vector** integer_reason, - Trail* trail) { - const int trail_index = trail->Index(); + std::vector** integer_reason) { + const int trail_index = trail_->Index(); if (trail_index >= literal_reasons_.size()) { literal_reasons_.resize(trail_index + 1); integer_reasons_.resize(trail_index + 1); @@ -400,16 +592,15 @@ void IntegerTrail::EnqueueLiteral(Literal literal, if (integer_reason != nullptr) { *integer_reason = &integer_reasons_[trail_index]; } - trail->Enqueue(literal, propagator_id_); + trail_->Enqueue(literal, propagator_id_); } -void IntegerTrail::EnqueueLiteral(Literal literal, - const std::vector& literal_reason, - const std::vector& integer_reason, - Trail* trail) { +void IntegerTrail::EnqueueLiteral( + Literal literal, const std::vector& literal_reason, + const std::vector& integer_reason) { std::vector* literal_reason_ptr; std::vector* integer_reason_ptr; - EnqueueLiteral(literal, &literal_reason_ptr, &integer_reason_ptr, trail); + EnqueueLiteral(literal, &literal_reason_ptr, &integer_reason_ptr); *literal_reason_ptr = literal_reason; *integer_reason_ptr = integer_reason; } diff --git a/src/sat/integer.h b/src/sat/integer.h index c4b0492d62..24ef03857c 100644 --- a/src/sat/integer.h +++ b/src/sat/integer.h @@ -19,11 +19,13 @@ #include "base/port.h" #include "base/join.h" #include "base/int_type.h" +#include "base/map_util.h" #include "sat/model.h" #include "sat/sat_base.h" #include "sat/sat_solver.h" #include "util/bitset.h" #include "util/iterators.h" +#include "util/rev.h" #include "util/saturated_arithmetic.h" namespace operations_research { @@ -146,8 +148,10 @@ inline std::ostream& operator<<(std::ostream& os, IntegerLiteral i_lit) { // these variables activity and so on. These variables can also be propagated // directly by the learned clauses. // -// TODO(user): Add support for creating literals encoding x == v? for now this -// is not used though. +// This class also support a non-lazy full domain encoding which will create one +// literal per possible value in the domain. See FullyEncodeVariable(). This is +// meant to be called by constraints that directly work on the variable values +// like a table constraint or an all-diff constraint. // // TODO(user): We could also lazily create precedences Booleans between two // arbitrary IntegerVariable. This is better done in the PrecedencesPropagator @@ -168,6 +172,58 @@ class IntegerEncoder { return encoder; } + // This has 3 effects: + // 1/ It restricts the given variable to only take values amongst the given + // ones. + // 2/ It creates one Boolean variable per value that convey the fact that the + // var is equal to this value iff the Boolean is true. If there is only + // 2 values, then just one Boolean variable is created. For more than two + // values, a constraint is also added to enforce that exactly one Boolean + // variable is true. + // 3/ The encoding for NegationOf(var) is automatically created too. It reuses + // the same Boolean variable as the encoding of var. + // + // Calling this more than once is an error (Checked). + // TODO(user): we could instead only keep the intersection and fix the now + // impossible values to zero. + // + // Note(user): There is currently no relation here between + // FullyEncodeVariable() and CreateAssociatedLiteral(). However the + // IntegerTrail class will automatically link the two representations and do + // the right thing. + // + // Note(user): Calling this with just one value will cause a CHECK fail. One + // need to fix the IntegerVariable inside the IntegerTrail instead of calling + // this. + // + // TODO(user): It is currently only possible to call that at the decision + // level zero. This is Checked. + void FullyEncodeVariable(IntegerVariable var, std::vector values); + + // Gets the full encoding of a variable on which FullyEncodeVariable() has + // been called. The returned elements are always sorted by increasing + // IntegerValue. Once created, the encoding never changes, but some Boolean + // variable may become fixed. + struct ValueLiteralPair { + ValueLiteralPair(IntegerValue v, Literal l) : value(v), literal(l) {} + bool operator==(const ValueLiteralPair& o) const { + return value == o.value && literal == o.literal; + } + IntegerValue value; + Literal literal; + }; + const std::vector& FullDomainEncoding( + IntegerVariable var) const { + return full_encoding_[FindOrDie(full_encoding_index_, var)]; + } + + // Returns the set of variable encoded as the keys in a map. The map values + // only have an internal meaning. The set of encoded variables is returned + // with this "weird" api for efficiency. + const hash_map& GetFullyEncodedVariables() const { + return full_encoding_index_; + } + // Creates a new Boolean variable 'var' such that // - if true, then the IntegerLiteral is true. // - if false, then the negated IntegerLiteral is true. @@ -218,14 +274,18 @@ class IntegerEncoder { // doesn't correspond to an IntegerLiteral. ITIVector reverse_encoding_; + // Full domain encoding. The map contains the index in full_encoding_ of + // the fully encoded variable. Each entry in full_encoding_ is sorted by + // IntegerValue and contains the encoding of one IntegerVariable. + hash_map full_encoding_index_; + std::vector> full_encoding_; + DISALLOW_COPY_AND_ASSIGN(IntegerEncoder); }; // This class maintains a set of integer variables with their current bounds. // Bounds can be propagated from an external "source" and this class helps // to maintain the reason for each propagation. -// -// TODO(user): Add support for a lazy encoding of the integer variable in SAT. class IntegerTrail : public Propagator { public: IntegerTrail(IntegerEncoder* encoder, Trail* trail) @@ -323,10 +383,9 @@ class IntegerTrail : public Propagator { // assignment. They are only valid just after this is called. The full literal // reason will be computed lazily when it becomes needed. void EnqueueLiteral(Literal literal, std::vector** literal_reason, - std::vector** integer_reason, Trail* trail); + std::vector** integer_reason); void EnqueueLiteral(Literal literal, const std::vector& literal_reason, - const std::vector& integer_reason, - Trail* trail); + const std::vector& integer_reason); // Returns the reason (as set of Literal currently false) for a given integer // literal. Note that the bound must be less restrictive than the current @@ -355,6 +414,12 @@ class IntegerTrail : public Propagator { } private: + // Helper used by Enqueue() to propagate one of the literal associated to + // the given i_lit and maintained by encoder_. + bool EnqueueAssociatedLiteral(Literal literal, IntegerLiteral i_lit, + const std::vector& literals_reason, + const std::vector& bounds_reason); + // Returns a lower bound on the given var that will always be valid. IntegerValue LevelZeroBound(int var) const { // The level zero bounds are stored at the begining of the trail and they @@ -410,6 +475,13 @@ class IntegerTrail : public Propagator { // The "is_empty" literal of the optional variables or kNoLiteralIndex. ITIVector is_empty_literals_; + // Data used to support the propagation of fully encoded variable. We keep + // for each variable the index in encoder_.GetDomainEncoding() of the first + // literal that is not assigned to false, and call this the "min". + int64 num_encoded_variables_ = 0; + RevMap>> watched_min_; + RevMap> current_min_; + // Temporary data used by MergeReasonInto(). mutable std::vector tmp_queue_; mutable std::vector tmp_trail_indices_; @@ -593,6 +665,15 @@ inline std::function UpperBound(IntegerVariable v) { }; } +// This checks that the variable is fixed. +inline std::function Value(IntegerVariable v) { + return [=](const Model& model) { + const IntegerTrail* trail = model.Get(); + CHECK_EQ(trail->LowerBound(v), trail->UpperBound(v)); + return trail->LowerBound(v).value(); + }; +} + inline std::function GreaterOrEqual(IntegerVariable v, int64 lb) { return [=](Model* model) { if (!model->GetOrCreate()->Enqueue( diff --git a/src/sat/integer_expr.cc b/src/sat/integer_expr.cc index 618b3b4d72..3eea18a264 100644 --- a/src/sat/integer_expr.cc +++ b/src/sat/integer_expr.cc @@ -206,12 +206,12 @@ bool IsOneOfPropagator::Propagate(Trail* trail) { std::vector* literal_reason; std::vector* integer_reason; if (current_min > values_[i]) { - integer_trail_->EnqueueLiteral( - selectors_[i].Negated(), &literal_reason, &integer_reason, trail); + integer_trail_->EnqueueLiteral(selectors_[i].Negated(), + &literal_reason, &integer_reason); integer_reason->push_back(integer_trail_->LowerBoundAsLiteral(var_)); } else if (current_max < values_[i]) { - integer_trail_->EnqueueLiteral( - selectors_[i].Negated(), &literal_reason, &integer_reason, trail); + integer_trail_->EnqueueLiteral(selectors_[i].Negated(), + &literal_reason, &integer_reason); integer_reason->push_back(integer_trail_->UpperBoundAsLiteral(var_)); } } diff --git a/src/sat/intervals.h b/src/sat/intervals.h index fcafc579a6..c88aee4738 100644 --- a/src/sat/intervals.h +++ b/src/sat/intervals.h @@ -201,9 +201,6 @@ inline std::function EndAtEnd(IntervalVariable i1, }; } -// TODO(user): Add a propagator on the interval duration depending -// on the set of alternatives that are currently not executable. -// // This requires that all the alternatives are optional tasks. inline std::function IntervalWithAlternatives( IntervalVariable master, const std::vector& members) { diff --git a/src/sat/precedences.cc b/src/sat/precedences.cc index e5b4c11c97..8a60a04c90 100644 --- a/src/sat/precedences.cc +++ b/src/sat/precedences.cc @@ -305,7 +305,7 @@ void PrecedencesPropagator::PropagateOptionalArcs(Trail* trail) { std::vector* literal_reason; std::vector* integer_reason; integer_trail_->EnqueueLiteral(is_present.Negated(), &literal_reason, - &integer_reason, trail); + &integer_reason); integer_reason->push_back( integer_trail_->LowerBoundAsLiteral(arc.tail_var)); integer_reason->push_back( diff --git a/src/sat/sat_base.h b/src/sat/sat_base.h index bae846310b..441f8702f2 100644 --- a/src/sat/sat_base.h +++ b/src/sat/sat_base.h @@ -269,7 +269,7 @@ class Trail { } // Specific Enqueue() version for the search decision. - void EnqueueSeachDecision(Literal true_literal) { + void EnqueueSearchDecision(Literal true_literal) { Enqueue(true_literal, AssignmentType::kSearchDecision); } diff --git a/src/sat/sat_solver.cc b/src/sat/sat_solver.cc index 143cae7ead..8047a0d72e 100644 --- a/src/sat/sat_solver.cc +++ b/src/sat/sat_solver.cc @@ -1617,7 +1617,7 @@ void SatSolver::EnqueueNewDecision(Literal literal) { decisions_[current_decision_level_] = Decision(trail_->Index(), literal); ++current_decision_level_; trail_->SetDecisionLevel(current_decision_level_); - trail_->EnqueueSeachDecision(literal); + trail_->EnqueueSearchDecision(literal); } Literal SatSolver::NextBranch() { diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 34365896c7..32a75b9f5e 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -322,15 +322,17 @@ class SatSolver { const std::vector& NewlyAddedBinaryClauses(); void ClearNewlyAddedBinaryClauses(); - // Various getters of the current solver state. struct Decision { Decision() : trail_index(-1) {} Decision(int i, Literal l) : trail_index(i), literal(l) {} int trail_index; Literal literal; }; - int CurrentDecisionLevel() const { return current_decision_level_; } + + // Note that the Decisions() vector is always of size NumVariables(), and that + // only the first CurrentDecisionLevel() entries have a meaning. const std::vector& Decisions() const { return decisions_; } + int CurrentDecisionLevel() const { return current_decision_level_; } const Trail& LiteralTrail() const { return *trail_; } const VariablesAssignment& Assignment() const { return trail_->Assignment(); } @@ -372,6 +374,10 @@ class SatSolver { binary_implication_graph_.AddBinaryClause(a, b); } + // Performs propagation of the recently enqueued elements. + // Mainly visible for testing. + bool Propagate(); + private: // Calls Propagate() and returns true if no conflict occured. Otherwise, // learns the conflict, backtracks, enqueues the consequence of the learned @@ -486,8 +492,7 @@ class SatSolver { // True and Enqueue() this change. void EnqueueNewDecision(Literal literal); - // Performs propagation of the recently enqueued elements. - bool Propagate(); + // Returns true if everything has been propagated. bool PropagationIsDone() const; // Update the propagators_ list with the relevant propagators. @@ -943,6 +948,32 @@ inline std::function ClauseConstraint( }; } +// The a => b constraint. +inline std::function Implication(Literal a, Literal b) { + return [=](Model* model) { + model->GetOrCreate()->AddBinaryClause(a.Negated(), b); + }; +} + +// This can be used to enumerate all the solutions. After each SAT call to +// Solve(), calling this will reset the solver and exclude the current solution +// so that the next call to Solve() will give a new solution or UNSAT is there +// is no more new solutions. +inline std::function ExcludeCurrentSolutionAndBacktrack() { + return [=](Model* model) { + SatSolver* sat_solver = model->GetOrCreate(); + + // Note that we only exclude the current decisions, which is an efficient + // way to not get the same SAT assignment. + std::vector exlude_solution; + for (int i = 0; i < sat_solver->CurrentDecisionLevel(); ++i) { + exlude_solution.push_back(sat_solver->Decisions()[i].literal.Negated()); + } + sat_solver->Backtrack(0); + model->Add(ClauseConstraint(exlude_solution)); + }; +} + inline std::function NewSatParameters(std::string params) { return [=](Model* model) { sat::SatParameters parameters; diff --git a/src/sat/table.cc b/src/sat/table.cc new file mode 100644 index 0000000000..7c1111a07e --- /dev/null +++ b/src/sat/table.cc @@ -0,0 +1,242 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "sat/table.h" + +#include "base/map_util.h" +#include "base/stl_util.h" + +namespace operations_research { +namespace sat { + +namespace { + +// Transpose the given "matrix" and transform the value to IntegerValue. +std::vector> Transpose(const std::vector> tuples) { + CHECK(!tuples.empty()); + const int n = tuples.size(); + const int m = tuples[0].size(); + std::vector> transpose(m, std::vector(n)); + for (int i = 0; i < n; ++i) { + CHECK_EQ(m, tuples[i].size()); + for (int j = 0; j < m; ++j) { + transpose[j][i] = tuples[i][j]; + } + } + return transpose; +} + +// Converts the vector representation returned by FullDomainEncoding() to a map. +hash_map GetEncoding(IntegerVariable var, Model* model) { + hash_map encoding; + for (const auto& entry : + model->GetOrCreate()->FullDomainEncoding(var)) { + encoding[entry.value] = entry.literal; + } + return encoding; +} + +// Add the implications and clauses to link one column of a table to the Literal +// controling if the lines are possible or not. The column has the given values, +// and the Literal of the column variable can be retreived using the encoding +// map. +void ProcessOneColumn(const std::vector& line_literals, + const std::vector& values, + const hash_map& encoding, + Model* model) { + CHECK_EQ(line_literals.size(), values.size()); + hash_map> value_to_list_of_line_literals; + + // If a value is false (i.e not possible), then the tuple with this value + // is false too (i.e not possible). + for (int i = 0; i < values.size(); ++i) { + const IntegerValue v = values[i]; + value_to_list_of_line_literals[v].push_back(line_literals[i]); + model->Add(Implication(FindOrDie(encoding, v).Negated(), + line_literals[i].Negated())); + } + + // If all the tuples containing a value are false, then this value must be + // false too. + for (const auto& entry : value_to_list_of_line_literals) { + std::vector clause = entry.second; + clause.push_back(FindOrDie(encoding, entry.first).Negated()); + model->Add(ClauseConstraint(clause)); + } +} + +} // namespace + +std::function TableConstraint( + const std::vector& vars, const std::vector>& tuples) { + return [=](Model* model) { + // Create one Boolean variable per tuple to indicate if it can still be + // selected or not. Note that we don't enforce exactly one tuple to be + // selected because these variables are just used by this constraint, so + // only the information "can't be selected" is important. + // + // TODO(user): If a value in one column is unique, we don't need to create a + // new BooleanVariable corresponding to this line since we can use the one + // corresponding to this value in that column. + std::vector tuple_literals; + for (int i = 0; i < tuples.size(); ++i) { + tuple_literals.push_back(Literal(model->Add(NewBooleanVariable()), true)); + } + + // Fully encode the variables using all the values appearing in the tuples. + IntegerEncoder* encoder = model->GetOrCreate(); + hash_map encoding; + const std::vector>& tr_tuples = Transpose(tuples); + for (int i = 0; i < vars.size(); ++i) { + encoder->FullyEncodeVariable(vars[i], tr_tuples[i]); + encoding = GetEncoding(vars[i], model); + ProcessOneColumn(tuple_literals, tr_tuples[i], encoding, model); + } + }; +} + +std::function TransitionConstraint( + const std::vector& vars, const std::vector>& automata, + int64 initial_state, const std::vector& final_states) { + return [=](Model* model) { + IntegerEncoder* encoder = model->GetOrCreate(); + const int n = vars.size(); + CHECK_GT(n, 0) << "No variables in TransitionConstraint()."; + + // Test precondition. + { + std::set> unique_transition_checker; + for (const std::vector& transition : automata) { + CHECK_EQ(transition.size(), 3); + const std::pair p{transition[0], transition[1]}; + CHECK(!ContainsKey(unique_transition_checker, p)) + << "Duplicate outgoing transitions with value " << transition[1] + << " from state " << transition[0] << "."; + unique_transition_checker.insert(p); + } + } + + // Compute the set of reachable state at each time point. + std::vector> reachable_states(n + 1); + reachable_states[0].insert(initial_state); + reachable_states[n] = {final_states.begin(), final_states.end()}; + + // Forward. + for (int time = 0; time + 1 < n; ++time) { + for (const std::vector& transition : automata) { + if (!ContainsKey(reachable_states[time], transition[0])) continue; + reachable_states[time + 1].insert(transition[2]); + } + } + + // Backward. + for (int time = n - 1; time > 0; --time) { + std::set new_set; + for (const std::vector& transition : automata) { + if (!ContainsKey(reachable_states[time], transition[0])) continue; + if (!ContainsKey(reachable_states[time + 1], transition[2])) continue; + new_set.insert(transition[0]); + } + reachable_states[time].swap(new_set); + } + + // We will model at each time step the current automata state using Boolean + // variables. We will have n+1 time step. At time zero, we start in the + // initial state, and at time n we should be in one of the final states. We + // don't need to create Booleans at at time when there is just one possible + // state (like at time zero). + hash_map encoding; + hash_map in_encoding; + hash_map out_encoding; + for (int time = 0; time < n; ++time) { + // All these vector have the same size. We will use them to enforce a + // local table constraint representing one step of the automata at the + // given time. + std::vector tuple_literals; + std::vector in_states; + std::vector transition_values; + std::vector out_states; + for (const std::vector& transition : automata) { + if (!ContainsKey(reachable_states[time], transition[0])) continue; + if (!ContainsKey(reachable_states[time + 1], transition[2])) continue; + + // TODO(user): if this transition correspond to just one in-state or + // one-out state or one variable value, we could reuse the corresponding + // Boolean variable instead of creating a new one! + tuple_literals.push_back( + Literal(model->Add(NewBooleanVariable()), true)); + in_states.push_back(IntegerValue(transition[0])); + + transition_values.push_back(IntegerValue(transition[1])); + out_states.push_back(IntegerValue(transition[2])); + } + + // Fully instantiate vars[time]. + { + std::vector s = transition_values; + STLSortAndRemoveDuplicates(&s); + + encoding.clear(); + if (s.size() > 1) { + std::vector values(s.begin(), s.end()); + encoder->FullyEncodeVariable(vars[time], values); + encoding = GetEncoding(vars[time], model); + } else { + // Fix vars[time] to its unique possible value. + CHECK_EQ(s.size(), 1); + const int64 unique_value = s.begin()->value(); + model->Add(LowerOrEqual(vars[time], unique_value)); + model->Add(GreaterOrEqual(vars[time], unique_value)); + } + } + + // For each possible out states, create one Boolean variable. + // + // TODO(user): enforce an at most one constraint? it is not really needed + // though, so I am not sure it will improve or hurt the performance. To + // investigate on real problems. + { + std::vector s = out_states; + STLSortAndRemoveDuplicates(&s); + + out_encoding.clear(); + if (s.size() == 2) { + const BooleanVariable var = model->Add(NewBooleanVariable()); + out_encoding[s.front()] = Literal(var, true); + out_encoding[s.back()] = Literal(var, false); + } else if (s.size() > 1) { + // Enforce at most one constraint? + for (const IntegerValue state : s) { + out_encoding[state] = + Literal(model->Add(NewBooleanVariable()), true); + } + } + } + + // Now we link everything together. + if (in_encoding.size() > 1) { + ProcessOneColumn(tuple_literals, in_states, in_encoding, model); + } + if (encoding.size() > 1) { + ProcessOneColumn(tuple_literals, transition_values, encoding, model); + } + if (out_encoding.size() > 1) { + ProcessOneColumn(tuple_literals, out_states, out_encoding, model); + } + in_encoding = out_encoding; + } + }; +} + +} // namespace sat +} // namespace operations_research diff --git a/src/sat/table.h b/src/sat/table.h new file mode 100644 index 0000000000..1866400493 --- /dev/null +++ b/src/sat/table.h @@ -0,0 +1,48 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_TABLE_H_ +#define OR_TOOLS_SAT_TABLE_H_ + +#include "sat/integer.h" +#include "sat/model.h" + +namespace operations_research { +namespace sat { + +// Enforces that the given tuple of variables is equal to one of the given +// tuples. All the tuples must have the same size as var.size(), this is +// Checked. +std::function TableConstraint( + const std::vector& vars, const std::vector>& tuples); + +// Given an automata defined by a set of 3-tuples: +// (state, transition_with_value_as_label, next_state) +// this accepts the sequences of vars.size() variables that are recognized by +// this automata. That is: +// - We start from the initial state. +// - For each variable, we move along the transition labeled by this variable +// value. Moreover, the variable must take a value that correspond to a +// feasible transition. +// - We only accept sequences that ends in one of the final states. +// +// We CHECK that there is only one possible transition for a state/value pair. +// See the test for some examples. +std::function TransitionConstraint( + const std::vector& vars, const std::vector>& automata, + int64 initial_state, const std::vector& final_states); + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_TABLE_H_ diff --git a/src/util/rev.h b/src/util/rev.h new file mode 100644 index 0000000000..1e1c0dabdb --- /dev/null +++ b/src/util/rev.h @@ -0,0 +1,129 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Reversible (i.e Backtrackable) classes, used to simplify coding propagators. +#ifndef OR_TOOLS_UTIL_REV_H_ +#define OR_TOOLS_UTIL_REV_H_ + +#include + +#include "base/logging.h" +#include "base/map_util.h" + +namespace operations_research { + +// Like a normal map but support backtrackable operations. +// +// This works on any class "Map" that supports: begin(), end(), find(), erase(), +// insert(), key_type, value_type, mapped_type and const_iterator. +template +class RevMap { + public: + typedef typename Map::key_type key_type; + typedef typename Map::mapped_type mapped_type; + typedef typename Map::value_type value_type; + typedef typename Map::const_iterator const_iterator; + + // Backtracking support: changes the current "level" (always non-negative). + // + // Initially the class starts at level zero. Increasing the level works in + // O(level diff) and saves the state of the current old level. Decreasing the + // level restores the state to what it was at this level and all higher levels + // are forgotten. Everything done at level zero cannot be backtracked over. + void SetLevel(int level); + int Level() const { return first_op_index_of_next_level_.size(); } + + bool ContainsKey(key_type key) const { return operations_research::ContainsKey(map_, key); } + const mapped_type& FindOrDie(key_type key) const { + return operations_research::FindOrDie(map_, key); + } + + void EraseOrDie(key_type key); + void Set(key_type key, mapped_type value); // Adds or overwrites. + + // Wrapper to the underlying const map functions. + int size() const { return map_.size(); } + bool empty() const { return map_.empty(); } + const_iterator find(const key_type& k) const { return map_.find(k); } + const_iterator begin() const { return map_.begin(); } + const_iterator end() const { return map_.end(); } + + private: + Map map_; + + // The operation that needs to be performed to reverse one modification: + // - If is_deletion is true, then we need to delete the entry with given key. + // - Otherwise we need to add back (or overwrite) the saved entry. + struct UndoOperation { + bool is_deletion; + key_type key; + mapped_type value; + }; + + // TODO(user): We could merge the operations with the same key from the same + // level. Investigate and implement if this is worth the effort for our use + // case. + std::vector operations_; + std::vector first_op_index_of_next_level_; +}; + +template +void RevMap::SetLevel(int level) { + DCHECK_GE(level, 0); + if (level < Level()) { + const int backtrack_level = first_op_index_of_next_level_[level]; + first_op_index_of_next_level_.resize(level); // Shrinks. + while (operations_.size() > backtrack_level) { + const UndoOperation& to_undo = operations_.back(); + if (to_undo.is_deletion) { + map_.erase(to_undo.key); + } else { + map_.insert({to_undo.key, to_undo.value}).first->second = to_undo.value; + } + operations_.pop_back(); + } + return; + } + + // This is ok even if level == Level(). + first_op_index_of_next_level_.resize(level, operations_.size()); // Grows. +} + +template +void RevMap::EraseOrDie(key_type key) { + const auto iter = map_.find(key); + if (iter == map_.end()) LOG(FATAL) << "key not present: '" << key << "'."; + if (Level() > 0) { + operations_.push_back({false, key, iter->second}); + } + map_.erase(iter); +} + +template +void RevMap::Set(key_type key, mapped_type value) { + auto insertion_result = map_.insert({key, value}); + if (Level() > 0) { + if (insertion_result.second) { + // It is an insertion. Undo = delete. + operations_.push_back({true, key}); + } else { + // It is a modification. Undo = change back to old value. + operations_.push_back({false, key, insertion_result.first->second}); + } + } + insertion_result.first->second = value; +} + +} // namespace operations_research + +#endif // OR_TOOLS_UTIL_REV_H_