more work on sat; initial connection to the flatzinc interpreter

This commit is contained in:
Laurent Perron
2016-09-22 13:55:16 +02:00
parent cfafaf6d6e
commit 08f556c520
22 changed files with 1328 additions and 196 deletions

View File

@@ -96,6 +96,7 @@ FLATZINC_DEPS = \
$(SRC_DIR)/flatzinc/presolve.h \
$(SRC_DIR)/flatzinc/reporting.h \
$(SRC_DIR)/flatzinc/sat_constraint.h \
$(SRC_DIR)/flatzinc/sat_fz_solver.h \
$(SRC_DIR)/flatzinc/solver_data.h \
$(SRC_DIR)/flatzinc/solver.h \
$(SRC_DIR)/flatzinc/solver_util.h \
@@ -191,6 +192,7 @@ FLATZINC_OBJS=\
$(OBJ_DIR)/flatzinc/presolve.$O \
$(OBJ_DIR)/flatzinc/reporting.$O \
$(OBJ_DIR)/flatzinc/sat_constraint.$O \
$(OBJ_DIR)/flatzinc/sat_fz_solver.$O \
$(OBJ_DIR)/flatzinc/solver.$O \
$(OBJ_DIR)/flatzinc/solver_data.$O \
$(OBJ_DIR)/flatzinc/solver_util.$O
@@ -234,6 +236,9 @@ $(OBJ_DIR)/flatzinc/reporting.$O: $(SRC_DIR)/flatzinc/reporting.cc $(FLATZINC_DE
$(OBJ_DIR)/flatzinc/sat_constraint.$O: $(SRC_DIR)/flatzinc/sat_constraint.cc $(FLATZINC_DEPS)
$(CCC) $(CFLAGS) -c $(SRC_DIR)$Sflatzinc$Ssat_constraint.cc $(OBJ_OUT)$(OBJ_DIR)$Sflatzinc$Ssat_constraint.$O
$(OBJ_DIR)/flatzinc/sat_fz_solver.$O: $(SRC_DIR)/flatzinc/sat_fz_solver.cc $(FLATZINC_DEPS)
$(CCC) $(CFLAGS) -c $(SRC_DIR)$Sflatzinc$Ssat_fz_solver.cc $(OBJ_OUT)$(OBJ_DIR)$Sflatzinc$Ssat_fz_solver.$O
$(OBJ_DIR)/flatzinc/solver.$O: $(SRC_DIR)/flatzinc/solver.cc $(FLATZINC_DEPS)
$(CCC) $(CFLAGS) -c $(SRC_DIR)$Sflatzinc$Ssolver.cc $(OBJ_OUT)$(OBJ_DIR)$Sflatzinc$Ssolver.$O

View File

@@ -128,6 +128,9 @@ $(SRC_DIR)/base/status.h: \
$(SRC_DIR)/base/statusor.h: \
$(SRC_DIR)/base/status.h
$(SRC_DIR)/base/stringpiece_utils.h: \
$(SRC_DIR)/base/stringpiece.h
$(SRC_DIR)/base/stringprintf.h: \
$(SRC_DIR)/base/stringpiece.h
@@ -150,9 +153,6 @@ $(SRC_DIR)/base/sysinfo.h: \
$(SRC_DIR)/base/thorough_hash.h: \
$(SRC_DIR)/base/integral_types.h
$(SRC_DIR)/base/threadpool.h: \
$(SRC_DIR)/base/callback.h
$(SRC_DIR)/base/timer.h: \
$(SRC_DIR)/base/basictypes.h \
$(SRC_DIR)/base/logging.h \
@@ -350,6 +350,10 @@ $(SRC_DIR)/util/range_query_function.h: \
$(SRC_DIR)/util/rational_approximation.h: \
$(SRC_DIR)/base/integral_types.h
$(SRC_DIR)/util/rev.h: \
$(SRC_DIR)/base/logging.h \
$(SRC_DIR)/base/map_util.h
$(SRC_DIR)/util/running_stat.h: \
$(SRC_DIR)/base/logging.h \
$(SRC_DIR)/base/macros.h
@@ -1154,8 +1158,8 @@ $(SRC_DIR)/graph/shortestpaths.h: \
$(SRC_DIR)/graph/util.h: \
$(SRC_DIR)/graph/graph.h \
$(SRC_DIR)/base/hash.h \
$(SRC_DIR)/base/join.h \
$(SRC_DIR)/base/map_util.h \
$(SRC_DIR)/base/murmur.h \
$(SRC_DIR)/base/numbers.h \
$(SRC_DIR)/base/split.h \
@@ -1444,6 +1448,7 @@ SAT_LIB_OBJS = \
$(OBJ_DIR)/sat/sat_solver.$O \
$(OBJ_DIR)/sat/simplification.$O \
$(OBJ_DIR)/sat/symmetry.$O \
$(OBJ_DIR)/sat/table.$O \
$(OBJ_DIR)/sat/util.$O \
$(OBJ_DIR)/sat/boolean_problem.pb.$O \
$(OBJ_DIR)/sat/sat_parameters.pb.$O
@@ -1498,9 +1503,11 @@ $(SRC_DIR)/sat/integer.h: \
$(SRC_DIR)/sat/sat_solver.h \
$(SRC_DIR)/base/int_type.h \
$(SRC_DIR)/base/join.h \
$(SRC_DIR)/base/map_util.h \
$(SRC_DIR)/base/port.h \
$(SRC_DIR)/util/bitset.h \
$(SRC_DIR)/util/iterators.h \
$(SRC_DIR)/util/rev.h \
$(SRC_DIR)/util/saturated_arithmetic.h
$(SRC_DIR)/sat/intervals.h: \
@@ -1584,6 +1591,10 @@ $(SRC_DIR)/sat/symmetry.h: \
$(SRC_DIR)/util/stats.h \
$(SRC_DIR)/algorithms/sparse_permutation.h
$(SRC_DIR)/sat/table.h: \
$(SRC_DIR)/sat/integer.h \
$(SRC_DIR)/sat/model.h
$(SRC_DIR)/sat/util.h: \
$(GEN_DIR)/sat/sat_parameters.pb.h \
$(SRC_DIR)/base/random.h
@@ -1708,6 +1719,13 @@ $(OBJ_DIR)/sat/symmetry.$O: \
$(SRC_DIR)/sat/symmetry.h
$(CCC) $(CFLAGS) -c $(SRC_DIR)/sat/symmetry.cc $(OBJ_OUT)$(OBJ_DIR)$Ssat$Ssymmetry.$O
$(OBJ_DIR)/sat/table.$O: \
$(SRC_DIR)/sat/table.cc \
$(SRC_DIR)/sat/table.h \
$(SRC_DIR)/base/map_util.h \
$(SRC_DIR)/base/stl_util.h
$(CCC) $(CFLAGS) -c $(SRC_DIR)/sat/table.cc $(OBJ_OUT)$(OBJ_DIR)$Ssat$Stable.$O
$(OBJ_DIR)/sat/util.$O: \
$(SRC_DIR)/sat/util.cc \
$(SRC_DIR)/sat/util.h
@@ -3038,3 +3056,4 @@ $(GEN_DIR)/constraint_solver/solver_parameters.pb.h: $(GEN_DIR)/constraint_solve
$(OBJ_DIR)/constraint_solver/solver_parameters.pb.$O: $(GEN_DIR)/constraint_solver/solver_parameters.pb.cc
$(CCC) $(CFLAGS) -c $(GEN_DIR)/constraint_solver/solver_parameters.pb.cc $(OBJ_OUT)$(OBJ_DIR)$Sconstraint_solver$Ssolver_parameters.pb.$O

View File

@@ -1596,6 +1596,18 @@ void ParseShortIntLin(fz::SolverData* data, fz::Constraint* ct, IntExpr** left,
*left = nullptr;
*right = nullptr;
if (fzvars.empty() && size != 0) {
// We have a constant array.
CHECK_EQ(ct->arguments[1].values.size(), size);
int64 result = 0;
for (int i = 0; i < size; ++i) {
result += coefficients[i] * ct->arguments[1].values[i];
}
*left = solver->MakeIntConst(result);
*right = solver->MakeIntConst(rhs);
return;
}
switch (size) {
case 0: {
*left = solver->MakeIntConst(0);
@@ -1727,7 +1739,8 @@ bool AreAllVariablesBoolean(fz::SolverData* data, fz::Constraint* ct) {
bool ExtractLinAsShort(fz::SolverData* data, fz::Constraint* ct) {
const int size = ct->arguments[0].values.size();
if (ct->arguments[1].variables.empty()) {
return false;
// Constant linear scalprods will be treated correctly by ParseShortLin.
return true;
}
switch (size) {
case 0:
@@ -2086,7 +2099,6 @@ void ExtractIntLinLeReif(fz::SolverData* data, fz::Constraint* ct) {
void ExtractIntLinNe(fz::SolverData* data, fz::Constraint* ct) {
Solver* const solver = data->solver();
const int size = ct->arguments[0].values.size();
if (ExtractLinAsShort(data, ct)) {
IntExpr* left = nullptr;
IntExpr* right = nullptr;
@@ -2097,12 +2109,8 @@ void ExtractIntLinNe(fz::SolverData* data, fz::Constraint* ct) {
std::vector<int64> coeffs;
int64 rhs = 0;
ParseLongIntLin(data, ct, &vars, &coeffs, &rhs);
if (AreAllBooleans(vars) && AreAllOnes(coeffs)) {
PostBooleanSumInRange(data->Sat(), solver, vars, rhs, size);
} else {
AddConstraint(solver, ct, solver->MakeNonEquality(
solver->MakeScalProd(vars, coeffs), rhs));
}
AddConstraint(solver, ct, solver->MakeNonEquality(
solver->MakeScalProd(vars, coeffs), rhs));
}
}

View File

@@ -33,6 +33,7 @@
#include "flatzinc/parser.h"
#include "flatzinc/presolve.h"
#include "flatzinc/reporting.h"
#include "flatzinc/sat_fz_solver.h"
#include "flatzinc/solver.h"
#include "flatzinc/solver_util.h"
@@ -62,6 +63,8 @@ DEFINE_bool(
"Increase verbosity of the impact based search when used in free search.");
DEFINE_bool(verbose_mt, false, "Verbose Multi-Thread.");
DEFINE_bool(use_fz_sat, false, "Use the SAT/CP solver.");
DECLARE_bool(log_prefix);
DECLARE_bool(use_sat);
@@ -300,6 +303,12 @@ int main(int argc, char** argv) {
operations_research::fz::Model model =
operations_research::fz::ParseFlatzincModel(input,
!FLAGS_read_from_stdin);
operations_research::fz::Solve(model);
if (FLAGS_use_fz_sat) {
operations_research::sat::SolveWithSat(
model, operations_research::fz::SingleThreadParameters());
} else {
operations_research::fz::Solve(model);
}
return EXIT_SUCCESS;
}

View File

@@ -1649,10 +1649,10 @@ bool Presolver::SimplifyIntNeReif(Constraint* ct, std::string* log) {
ContainsKey(int_eq_reif_map_[ct->arguments[0].Var()],
ct->arguments[1].Var())) {
log->append("merge constraint with opposite constraint");
IntegerVariable* const opposite =
IntegerVariable* const opposite_boolvar =
int_eq_reif_map_[ct->arguments[0].Var()][ct->arguments[1].Var()];
ct->arguments[0] = Argument::IntVarRef(opposite);
ct->arguments[1] = Argument::IntVarRef(ct->arguments[1].Var());
ct->arguments[0] = Argument::IntVarRef(opposite_boolvar);
ct->arguments[1] = Argument::IntVarRef(ct->arguments[2].Var());
ct->RemoveArg(2);
ct->type = "bool_not";
return true;

View File

@@ -0,0 +1,228 @@
// Copyright 2010-2014 Google
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "flatzinc/sat_fz_solver.h"
#include "base/map_util.h"
#include "flatzinc/logging.h"
#include "sat/disjunctive.h"
#include "sat/integer.h"
#include "sat/integer_expr.h"
#include "sat/intervals.h"
#include "sat/model.h"
#include "sat/sat_solver.h"
#include "sat/table.h"
namespace operations_research {
namespace sat {
// TODO(user): deal with constant variables.
// In a first version we could create constant IntegerVariable in the solver.
IntegerVariable LookupVar(
const hash_map<fz::IntegerVariable*, IntegerVariable>& var_map,
const fz::Argument& argument) {
CHECK_EQ(argument.type, fz::Argument::INT_VAR_REF);
return FindOrDie(var_map, argument.variables[0]);
}
// TODO(user): deal with constant variables.
// In a first version we could create constant IntegerVariable in the solver.
std::vector<IntegerVariable> LookupVars(
const hash_map<fz::IntegerVariable*, IntegerVariable>& var_map,
const fz::Argument& argument) {
CHECK_EQ(argument.type, fz::Argument::INT_VAR_REF_ARRAY);
std::vector<IntegerVariable> result;
for (fz::IntegerVariable* var : argument.variables) {
result.push_back(FindOrDie(var_map, var));
}
return result;
}
void ExtractIntMin(
const fz::Constraint& ct,
const hash_map<fz::IntegerVariable*, IntegerVariable>& var_map,
Model* sat_model) {
const IntegerVariable a = LookupVar(var_map, ct.arguments[0]);
const IntegerVariable b = LookupVar(var_map, ct.arguments[1]);
const IntegerVariable c = LookupVar(var_map, ct.arguments[2]);
sat_model->Add(IsEqualToMinOf(c, {a, b}));
}
void ExtractIntAbs(
const fz::Constraint& ct,
const hash_map<fz::IntegerVariable*, IntegerVariable>& var_map,
Model* sat_model) {
const IntegerVariable v = LookupVar(var_map, ct.arguments[0]);
const IntegerVariable abs = LookupVar(var_map, ct.arguments[1]);
sat_model->Add(IsEqualToMaxOf(abs, {v, NegationOf(v)}));
}
void ExtractRegular(
const fz::Constraint& ct,
const hash_map<fz::IntegerVariable*, IntegerVariable>& var_map,
Model* sat_model) {
const std::vector<IntegerVariable> vars = LookupVars(var_map, ct.arguments[0]);
const int64 num_states = ct.arguments[1].Value();
const int64 num_values = ct.arguments[2].Value();
const std::vector<int64>& next = ct.arguments[3].values;
std::vector<std::vector<int64>> transitions;
int count = 0;
for (int i = 1; i <= num_states; ++i) {
for (int j = 1; j <= num_values; ++j) {
transitions.push_back({i, j, next[count++]});
}
}
const int64 initial_state = ct.arguments[4].Value();
std::vector<int64> final_states;
switch (ct.arguments[5].type) {
case fz::Argument::INT_VALUE: {
final_states.push_back(ct.arguments[5].values[0]);
break;
}
case fz::Argument::INT_INTERVAL: {
for (int v = ct.arguments[5].values[0]; v <= ct.arguments[5].values[1];
++v) {
final_states.push_back(v);
}
break;
}
case fz::Argument::INT_LIST: {
final_states = ct.arguments[5].values;
break;
}
default: { LOG(FATAL) << "Wrong constraint " << ct.DebugString(); }
}
sat_model->Add(
TransitionConstraint(vars, transitions, initial_state, final_states));
}
// The format is fixed in the flatzinc specification.
std::string SolutionString(
const Model& sat_model,
const hash_map<fz::IntegerVariable*, IntegerVariable>& var_map,
const fz::SolutionOutputSpecs& output) {
if (output.variable != nullptr) {
const int64 value =
sat_model.Get(Value(FindOrDie(var_map, output.variable)));
if (output.display_as_boolean) {
return StringPrintf("%s = %s;", output.name.c_str(),
value == 1 ? "true" : "false");
} else {
return StringPrintf("%s = %" GG_LL_FORMAT "d;", output.name.c_str(),
value);
}
} else {
const int bound_size = output.bounds.size();
std::string result =
StringPrintf("%s = array%dd(", output.name.c_str(), bound_size);
for (int i = 0; i < bound_size; ++i) {
if (output.bounds[i].max_value != 0) {
result.append(StringPrintf("%" GG_LL_FORMAT "d..%" GG_LL_FORMAT "d, ",
output.bounds[i].min_value,
output.bounds[i].max_value));
} else {
result.append("{},");
}
}
result.append("[");
for (int i = 0; i < output.flat_variables.size(); ++i) {
const int64 value =
sat_model.Get(Value(FindOrDie(var_map, output.flat_variables[i])));
if (output.display_as_boolean) {
result.append(StringPrintf(value ? "true" : "false"));
} else {
result.append(StringPrintf("%" GG_LL_FORMAT "d", value));
}
if (i != output.flat_variables.size() - 1) {
result.append(", ");
}
}
result.append("]);");
return result;
}
return "";
}
void SolveWithSat(const fz::Model& model, const fz::FlatzincParameters& p) {
Model sat_model;
// Correspondance between a fz::IntegerVariable and a sat::IntegerVariable.
hash_map<fz::IntegerVariable*, IntegerVariable> var_map;
// Extract all the variables.
FZLOG << "Extracting " << model.variables().size() << " variables. "
<< FZENDL;
for (fz::IntegerVariable* var : model.variables()) {
if (!var->active) continue;
var_map[var] =
sat_model.Add(NewIntegerVariable(var->domain.Min(), var->domain.Max()));
// TODO(user): encode if it is a list of value? This way constraint using
// it will get the intersection with the proper interval.
}
// Extract all the constraints.
FZLOG << "Extracting " << model.constraints().size() << " constraints. "
<< FZENDL;
std::set<std::string> unsupported_types;
for (fz::Constraint* ct : model.constraints()) {
if (ct != nullptr && ct->active) {
if (ct->type == "int_min") {
ExtractIntMin(*ct, var_map, &sat_model);
} else if (ct->type == "int_abs") {
ExtractIntAbs(*ct, var_map, &sat_model);
} else if (ct->type == "regular") {
ExtractRegular(*ct, var_map, &sat_model);
} else {
unsupported_types.insert(ct->type);
}
}
}
if (!unsupported_types.empty()) {
FZLOG << "There is unsuported constraints types in this model: " << FZENDL;
for (const std::string& type : unsupported_types) {
FZLOG << " - " << type;
}
return;
}
// For now assume a decision problem!
//
// TODO(user): deal with other kind of search (optim, all solutions, ...).
//
// TODO(user): Encode IntegerVariable that are not fixed at the end of the
// search.
CHECK(model.objective() == nullptr);
CHECK_EQ(1, p.num_solutions);
FZLOG << "Solving..." << FZENDL;
const SatSolver::Status status = sat_model.GetOrCreate<SatSolver>()->Solve();
FZLOG << "Status: " << status << FZENDL;
// Output!
std::string solution_string;
for (const fz::SolutionOutputSpecs& output : model.output()) {
solution_string.append(SolutionString(sat_model, var_map, output));
solution_string.append("\n");
}
// Print the solution.
// The "----------" is needed by minizinc.
std::cout << solution_string << "----------" << std::endl;
}
} // namespace sat
} // namespace operations_research

View File

@@ -0,0 +1,28 @@
// Copyright 2010-2014 Google
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef OR_TOOLS_FLATZINC_SAT_FZ_SOLVER_H_
#define OR_TOOLS_FLATZINC_SAT_FZ_SOLVER_H_
#include "flatzinc/model.h"
#include "flatzinc/solver.h"
namespace operations_research {
namespace sat {
void SolveWithSat(const fz::Model& model, const fz::FlatzincParameters& p);
} // namespace sat
} // namespace operations_research
#endif // OR_TOOLS_FLATZINC_SAT_FZ_SOLVER_H_

View File

@@ -21,10 +21,12 @@
#include <numeric>
#include <string>
#include "base/numbers.h"
#include "base/hash.h"
#include "base/join.h"
#include "base/numbers.h"
#include "base/split.h"
#include "base/join.h"
#include "base/map_util.h"
#include "base/murmur.h"
#include "graph/graph.h"
#include "util/filelineiter.h"
@@ -41,22 +43,68 @@ bool GraphIsSymmetric(const Graph& graph);
// Creates a remapped copy of graph "graph", where node i becomes node
// new_node_index[i]. The caller takes ownership of the returned graph.
// "new_node_index" must be a valid permutation of [0..num_nodes-1], or
// else the return StatusOr will be in an error state.
// "new_node_index" must be a valid permutation of [0..num_nodes-1] or the
// behavior is undefined (it may die).
// Note that you can call IsValidPermutation() to check it yourself.
template <class Graph>
util::StatusOr<Graph*> RemapGraph(const Graph& graph,
std::unique_ptr<Graph> RemapGraph(const Graph& graph,
const std::vector<int>& new_node_index);
// Returns a std::string representation of a graph: one arc per line. Eg.:
// "1->2\n3->3" for a graph with 4 nodes and 2 arcs (1->2) and (3->3).
// Arcs are sorted by their tail, then by the order of OutgoingArcs().
// Returns true iff the given vector is a permutation of [0..size()-1].
bool IsValidPermutation(const std::vector<int>& v);
// Returns a std::string representation of a graph.
enum GraphToStringFormat {
// One arc per line, eg. "3->1".
PRINT_GRAPH_ARCS,
// One space-separated adjacency list per line, eg. "5 1 3 1".
// Nodes with no outgoing arc get an empty line.
PRINT_GRAPH_ADJACENCY_LISTS,
// Ditto, but the adjacency lists are sorted.
PRINT_GRAPH_ADJACENCY_LISTS_SORTED,
};
template <class Graph>
std::string GraphToString(const Graph& graph);
std::string GraphToString(const Graph& graph, GraphToStringFormat format);
// Returns a copy of "graph", without self-arcs and duplicate arcs.
template <class Graph>
std::unique_ptr<Graph> RemoveSelfArcsAndDuplicateArcs(const Graph& graph);
// Given an arc path, changes it to a sub-path with the same source and
// destination but without any cycle. Nothing happen if the path was already
// without cycle.
//
// The graph class should support Tail(arc) and Head(arc). They should both
// return an integer representing the corresponding tail/head of the passed arc.
//
// TODO(user): In some cases, there is more than one possible solution. We could
// take some arc costs and return the cheapest path instead. Or return the
// shortest path in term of number of arcs.
template <class Graph>
void RemoveCyclesFromPath(const Graph& graph, std::vector<int>* arc_path);
// Returns true iff the given path contains a cycle.
template <class Graph>
bool PathHasCycle(const Graph& graph, const std::vector<int>& arc_path);
// Returns a vector representing a mapping from arcs to arcs such that each arc
// is mapped to another arc with its (tail, head) flipped, if such an arc
// exists (otherwise it is mapped to -1).
// If the graph is symmetric, the returned mapping is bijective and reflexive,
// i.e. out[out[arc]] = arc for all "arc", where "out" is the returned vector.
// If "die_if_not_symmetric" is true, this function CHECKs() that the graph
// is symmetric.
//
// Self-arcs are always mapped to themselves.
//
// Note that since graphs may have multi-arcs, the mapping isn't necessarily
// unique, hence the function name.
template <class Graph>
std::vector<int> ComputeOnePossibleReverseArcMapping(const Graph& graph,
bool die_if_not_symmetric);
// Read a graph file in the simple ".g" format: the file should be a text file
// containing only space-separated integers, whose first line is:
// <num nodes> <num edges> [<num_colors> <index of first node with color #1>
@@ -140,29 +188,11 @@ bool GraphIsSymmetric(const Graph& graph) {
}
template <class Graph>
util::StatusOr<Graph*> RemapGraph(const Graph& old_graph,
std::unique_ptr<Graph> RemapGraph(const Graph& old_graph,
const std::vector<int>& new_node_index) {
DCHECK(IsValidPermutation(new_node_index)) << "Invalid permutation";
const int num_nodes = old_graph.num_nodes();
// Quickly verify that "new_node_index" is a valid permutation.
bool ok = new_node_index.size() == num_nodes;
if (ok) {
std::vector<bool> tmp_node_mask(old_graph.num_nodes(), false);
for (const int i : new_node_index) {
if (i < 0 || i >= num_nodes || tmp_node_mask[i]) {
ok = false;
break;
}
tmp_node_mask[i] = true;
}
}
if (!ok) {
return util::Status(
util::error::INVALID_ARGUMENT,
StrCat(
"new_node_index is not a valid permutation of [0..num_nodes-1],"
" with num_nodes = ",
num_nodes));
}
CHECK_EQ(new_node_index.size(), num_nodes);
std::unique_ptr<Graph> new_graph(new Graph(num_nodes, old_graph.num_arcs()));
typedef typename Graph::NodeIndex NodeIndex;
typedef typename Graph::ArcIndex ArcIndex;
@@ -173,16 +203,29 @@ util::StatusOr<Graph*> RemapGraph(const Graph& old_graph,
}
}
new_graph->Build();
return new_graph.release();
return new_graph;
}
template <class Graph>
std::string GraphToString(const Graph& graph) {
std::string GraphToString(const Graph& graph, GraphToStringFormat format) {
std::string out;
std::vector<typename Graph::NodeIndex> adj;
for (const typename Graph::NodeIndex node : graph.AllNodes()) {
for (const typename Graph::ArcIndex arc : graph.OutgoingArcs(node)) {
if (!out.empty()) out += '\n';
StrAppend(&out, node, "->", graph.Head(arc));
if (format == PRINT_GRAPH_ARCS) {
for (const typename Graph::ArcIndex arc : graph.OutgoingArcs(node)) {
if (!out.empty()) out += '\n';
StrAppend(&out, node, "->", graph.Head(arc));
}
} else { // PRINT_GRAPH_ADJACENCY_LISTS[_SORTED]
adj.clear();
for (const typename Graph::ArcIndex arc : graph.OutgoingArcs(node)) {
adj.push_back(graph.Head(arc));
}
if (format == PRINT_GRAPH_ADJACENCY_LISTS_SORTED) {
std::sort(adj.begin(), adj.end());
}
if (node != 0) out += '\n';
StrAppend(&out, node, ": ", strings::Join(adj, " "));
}
}
return out;
@@ -206,6 +249,79 @@ std::unique_ptr<Graph> RemoveSelfArcsAndDuplicateArcs(const Graph& graph) {
return g;
}
template <class Graph>
void RemoveCyclesFromPath(const Graph& graph, std::vector<int>* arc_path) {
if (arc_path->empty()) return;
// This maps each node to the latest arc in the given path that leaves it.
std::map<int, int> last_arc_leaving_node;
for (const int arc : *arc_path) last_arc_leaving_node[graph.Tail(arc)] = arc;
// Special case for the destination.
// Note that this requires that -1 is not a valid arc of Graph.
last_arc_leaving_node[graph.Head(arc_path->back())] = -1;
// Reconstruct the path by starting at the source and then following the
// "next" arcs. We override the given arc_path at the same time.
int node = graph.Tail(arc_path->front());
int new_size = 0;
while (new_size < arc_path->size()) { // To prevent cycle on bad input.
const int arc = FindOrDie(last_arc_leaving_node, node);
if (arc == -1) break;
(*arc_path)[new_size++] = arc;
node = graph.Head(arc);
}
arc_path->resize(new_size);
}
template <class Graph>
bool PathHasCycle(const Graph& graph, const std::vector<int>& arc_path) {
if (arc_path.empty()) return false;
std::set<int> seen;
seen.insert(graph.Tail(arc_path.front()));
for (const int arc : arc_path) {
if (!InsertIfNotPresent(&seen, graph.Head(arc))) return true;
}
return false;
}
template <class Graph>
std::vector<int> ComputeOnePossibleReverseArcMapping(const Graph& graph,
bool die_if_not_symmetric) {
std::vector<int> reverse_arc(graph.num_arcs(), -1);
hash_multimap<std::pair</*tail*/ int, /*head*/ int>, /*arc index*/ int> arc_map;
for (int arc = 0; arc < graph.num_arcs(); ++arc) {
const int tail = graph.Tail(arc);
const int head = graph.Head(arc);
if (tail == head) {
// Special case: directly map any self-arc to itself.
reverse_arc[arc] = arc;
continue;
}
// Lookup for the reverse arc of the current one...
auto it = arc_map.find({head, tail});
if (it != arc_map.end()) {
// Found a reverse arc! Store the mapping and remove the
// reverse arc from the map.
reverse_arc[arc] = it->second;
reverse_arc[it->second] = arc;
arc_map.erase(it);
} else {
// Reverse arc not in the map. Add the current arc to the map.
arc_map.insert({{tail, head}, arc});
}
}
// Algorithm check, for debugging.
DCHECK_EQ(std::count(reverse_arc.begin(), reverse_arc.end(), -1),
arc_map.size());
if (die_if_not_symmetric) {
CHECK_EQ(arc_map.size(), 0)
<< "The graph is not symmetric: " << arc_map.size() << " of "
<< graph.num_arcs() << " arcs did not have a reverse.";
}
return reverse_arc;
}
template <class Graph>
util::StatusOr<Graph*> ReadGraphFile(
const std::string& filename, bool directed,

View File

@@ -612,8 +612,7 @@ void FindLinearBooleanProblemSymmetries(
for (int node = 0; node < graph->num_nodes(); ++node) {
new_node_index[node] = next_index_by_class[equivalence_classes[node]]++;
}
std::unique_ptr<Graph> remapped_graph(
RemapGraph(*graph, new_node_index).ValueOrDie());
std::unique_ptr<Graph> remapped_graph = RemapGraph(*graph, new_node_index);
const util::Status status = WriteGraphToFile(
*remapped_graph, FLAGS_debug_dump_symmetry_graph_to_file,
/*directed=*/false, class_size);

View File

@@ -213,6 +213,9 @@ bool DisjunctiveConstraint::Propagate(Trail* trail) {
// Loop until we reach the fixed-point. It should be unique (see Petr Villim
// PhD).
//
// TODO(user): Some of these passes are idempotent, so there is no need to
// call them if their input didn't change! Improve that.
while (true) {
const int64 old_timestamp = integer_trail_->num_enqueues();
@@ -292,44 +295,57 @@ void DisjunctiveConstraint::AddMaxEndReason(int t, IntegerValue upper_bound) {
IntegerLiteral::LowerOrEqual(end_vars_[t], upper_bound));
}
bool DisjunctiveConstraint::CheckIntervalForConflict(int t, Trail* trail) {
// TODO(user): instead of this, we could propagate MinEnd()/MaxStart() and
// let the code in integer.h detect empty domains.
if (CapAdd(MinStart(t), MinDuration(t)) > MaxEnd(t)) {
integer_reason_.clear();
integer_reason_.push_back(
integer_trail_->LowerBoundAsLiteral(start_vars_[t]));
integer_reason_.push_back(
integer_trail_->UpperBoundAsLiteral(end_vars_[t]));
if (duration_vars_[t] != kNoIntegerVariable) {
integer_reason_.push_back(
integer_trail_->LowerBoundAsLiteral(duration_vars_[t]));
}
void DisjunctiveConstraint::AddMaxStartReason(int t, IntegerValue upper_bound) {
integer_reason_.push_back(
IntegerLiteral::LowerOrEqual(start_vars_[t], upper_bound));
}
if (!task_is_currently_present_[t]) {
// We can propagate reason_for_presence_[t] to false.
// Note that it could have been already propagated in which case we do
// nothing.
if (trail->Assignment().LiteralIsFalse(
Literal(reason_for_presence_[t]))) {
return true;
}
DCHECK_NE(reason_for_presence_[t], kNoLiteralIndex);
literal_reason_.clear();
integer_trail_->EnqueueLiteral(Literal(reason_for_presence_[t]).Negated(),
literal_reason_, integer_reason_, trail);
return true;
} else {
// Conflict.
std::vector<Literal>* conflict = trail->MutableConflict();
conflict->clear();
if (reason_for_presence_[t] != kNoLiteralIndex) {
conflict->push_back(Literal(reason_for_presence_[t]).Negated());
}
integer_trail_->MergeReasonInto(integer_reason_, conflict);
bool DisjunctiveConstraint::IncreaseMinStart(int t,
IntegerValue new_min_start) {
if (!integer_trail_->Enqueue(
IntegerLiteral::GreaterOrEqual(start_vars_[t], new_min_start),
literal_reason_, integer_reason_)) {
return false;
}
// We propagate right away the new min-end lower-bound we have.
const IntegerValue min_end_lb = new_min_start + MinDuration(t);
if (MinEnd(t) < min_end_lb) {
integer_reason_.clear();
literal_reason_.clear();
AddMinStartReason(t, new_min_start);
AddMinDurationReason(t);
if (!integer_trail_->Enqueue(
IntegerLiteral::GreaterOrEqual(end_vars_[t], min_end_lb),
literal_reason_, integer_reason_)) {
return false;
}
}
return true;
}
bool DisjunctiveConstraint::DecreaseMaxEnd(int t, IntegerValue new_max_end) {
if (!integer_trail_->Enqueue(
IntegerLiteral::LowerOrEqual(end_vars_[t], new_max_end),
literal_reason_, integer_reason_)) {
return false;
}
// We propagate right away the new max-start upper-bound we have.
const IntegerValue max_start_ub = new_max_end - MinDuration(t);
if (MaxStart(t) > max_start_ub) {
integer_reason_.clear();
literal_reason_.clear();
AddMaxEndReason(t, new_max_end);
AddMinDurationReason(t);
if (!integer_trail_->Enqueue(
IntegerLiteral::LowerOrEqual(start_vars_[t], max_start_ub),
literal_reason_, integer_reason_)) {
return false;
}
}
return true;
}
@@ -441,7 +457,7 @@ bool DisjunctiveConstraint::OverloadCheckingPass(IntegerTrail* integer_trail,
DCHECK_NE(reason_for_presence_[t], kNoLiteralIndex);
integer_trail->EnqueueLiteral(
Literal(reason_for_presence_[t]).Negated(), literal_reason_,
integer_reason_, trail);
integer_reason_);
}
}
@@ -509,7 +525,7 @@ bool DisjunctiveConstraint::DetectablePrecedencePass(
integer_reason_.clear();
// We need:
// - MaxStart(ct) < MinEnd(t) for the detectable precedence
// - MaxStart(ct) < MinEnd(t) for the detectable precedence.
// - MinStart(ct) > window_start for the min_end_of_critical_tasks reason.
const IntegerValue window_start = sorted_tasks[critical_index].min_start;
for (int i = critical_index; i < sorted_tasks.size(); ++i) {
@@ -520,10 +536,7 @@ bool DisjunctiveConstraint::DetectablePrecedencePass(
DCHECK_LT(MaxStart(ct), min_end);
AddPresenceAndDurationReason(ct);
AddMinStartReason(ct, window_start);
// TODO(user): Add the reason on MaxStart() instead, it will be more
// correct for task with non-fixed duration.
AddMaxEndReason(ct, CapAdd(min_end, MinDuration(ct) - 1));
AddMaxStartReason(ct, min_end - 1);
}
// Add the reason for t (we don't need the max-end or presence reason).
@@ -532,13 +545,7 @@ bool DisjunctiveConstraint::DetectablePrecedencePass(
// This augment the min-start of t and subsequently it can augment the
// next min_end_of_critical_tasks, but our deduction is still valid.
if (!integer_trail->Enqueue(
IntegerLiteral::GreaterOrEqual(start_vars_[t],
min_end_of_critical_tasks),
literal_reason_, integer_reason_)) {
return false;
}
if (!CheckIntervalForConflict(t, trail)) return false;
if (!IncreaseMinStart(t, min_end_of_critical_tasks)) return false;
// We need to reorder t inside task_set_.
task_set_.NotifyEntryIsNowLastIfPresent({t, MinStart(t), MinDuration(t)});
@@ -587,16 +594,13 @@ bool DisjunctiveConstraint::PrecedencePass(IntegerTrail* integer_trail,
Literal(reason_for_beeing_before_[ct]).Negated());
}
}
// TODO(user): If var is actually a min-start of an interval, we
// could push the min-end and check the interval consistency right away.
if (!integer_trail->Enqueue(IntegerLiteral::GreaterOrEqual(var, min_end),
literal_reason_, integer_reason_)) {
return false;
}
// TODO(user): for non-optional intervals, we could check right away that
// the domain of var is non-empty. Or for a var that correspond to one
// of the interval bounds, we could detect infeasibility early with
// CheckIntervalForConflict(). Not doing should be fine, it just postpone
// a bit the conflict detection.
}
}
return true;
@@ -632,8 +636,8 @@ bool DisjunctiveConstraint::NotLastPass(IntegerTrail* integer_trail,
// [(critical tasks)
// (duration_t)]
//
// So we can deduce that the max-end of t is smaller that the largest
// max-start of the critical tasks.
// So we can deduce that the max-end of t is smaller than or equal to the
// largest max-start of the critical tasks.
//
// Note that this works as well when task_is_currently_present_[t] is false.
int critical_index = 0;
@@ -641,49 +645,37 @@ bool DisjunctiveConstraint::NotLastPass(IntegerTrail* integer_trail,
task_set_.ComputeMinEnd(/*task_to_ignore=*/t, &critical_index);
if (min_end_of_critical_tasks <= MaxStart(t)) continue;
// Find the largest max-start of the critical tasks (excluding t). This
// will be a valid new max-end for t.
IntegerValue new_max_end = kMinIntegerValue;
int task_responsible_for_new_max_end = -1;
// Find the largest max-start of the critical tasks (excluding t). The
// max-end for t need to be smaller than or equal to this.
IntegerValue largest_ct_max_start = kMinIntegerValue;
const std::vector<TaskSet::Entry>& sorted_tasks = task_set_.SortedTasks();
for (int i = critical_index; i < sorted_tasks.size(); ++i) {
const int ct = sorted_tasks[i].task;
if (t == ct) continue;
const IntegerValue max_start = MaxStart(ct);
if (max_start > new_max_end) {
new_max_end = max_start;
task_responsible_for_new_max_end = ct;
if (max_start > largest_ct_max_start) {
largest_ct_max_start = max_start;
}
}
if (max_end > new_max_end) {
if (max_end > largest_ct_max_start) {
literal_reason_.clear();
integer_reason_.clear();
// We don't need the max-end reason of the critical tasks except the
// one for the task responsible for new_max_end.
const IntegerValue window_start = sorted_tasks[critical_index].min_start;
for (int i = critical_index; i < sorted_tasks.size(); ++i) {
const int ct = sorted_tasks[i].task;
if (ct == t) continue;
AddPresenceAndDurationReason(ct);
AddMinStartReason(ct, window_start);
if (ct == task_responsible_for_new_max_end) {
AddMaxEndReason(ct, MaxEnd(ct));
}
AddMaxStartReason(ct, largest_ct_max_start);
}
// Add the reason for t (we don't need the min-start or presence reason).
AddMinDurationReason(t);
AddMaxEndReason(t, CapAdd(min_end_of_critical_tasks, MinDuration(t) - 1));
// Add the reason for t, we only need the max-start.
AddMaxStartReason(t, min_end_of_critical_tasks - 1);
// Enqueue the new max-end for t.
// Note that changing it will not influence the rest of the loop.
if (!integer_trail->Enqueue(
IntegerLiteral::LowerOrEqual(end_vars_[t], new_max_end),
literal_reason_, integer_reason_)) {
return false;
}
if (!CheckIntervalForConflict(t, trail)) return false;
if (!DecreaseMaxEnd(t, largest_ct_max_start)) return false;
}
}
return true;
@@ -859,13 +851,9 @@ bool DisjunctiveConstraint::EdgeFindingPass(IntegerTrail* integer_trail,
// TODO(user): propagate the precedence Boolean here too? I think it
// will be more powerful. Even if eventually all these precedence will
// become detectable (see Petr Villim PhD).
if (!integer_trail->Enqueue(
IntegerLiteral::GreaterOrEqual(start_vars_[gray_task],
min_end_of_critical_tasks),
literal_reason_, integer_reason_)) {
if (!IncreaseMinStart(gray_task, min_end_of_critical_tasks)) {
return false;
}
if (!CheckIntervalForConflict(gray_task, trail)) return false;
}
// Remove the gray_task from sorted_tasks_.

View File

@@ -24,19 +24,19 @@
namespace operations_research {
namespace sat {
// Enforces a disjunctive (or no overlap) constraints on the given interval
// Enforces a disjunctive (or no overlap) constraint on the given interval
// variables.
std::function<void(Model*)> Disjunctive(const std::vector<IntervalVariable>& vars);
// Same as Disjunctive() but also creates a Boolean variables for all the
// Same as Disjunctive() but also creates a Boolean variable for all the
// possible precedences of the form (task i is before task j).
std::function<void(Model*)> DisjunctiveWithBooleanPrecedences(
const std::vector<IntervalVariable>& vars);
// Helper class to compute the min-end of a set of tasks given their min-start
// and min-duration. In Petr Vilim PhD "Global Constraints in Scheduling", this
// corresponds to his Theta-tree except that we use a O(n) implementation for
// most of the function here, not a O(log(n)) one.
// and min-duration. In Petr Vilim's PhD "Global Constraints in Scheduling",
// this corresponds to his Theta-tree except that we use a O(n) implementation
// for most of the function here, not a O(log(n)) one.
class TaskSet {
public:
TaskSet() : optimized_restart_(0) {}
@@ -75,9 +75,10 @@ class TaskSet {
//
// [Bunch of tasks] ... [Bunch of tasks] ... [critical tasks].
//
// We call "critical tasks" the last group. These tasks will be the sole
// responsible for the min-end of the whole set. The returned critical_index
// will be the index of the first critical task in SortedTasks().
// We call "critical tasks" the last group. These tasks will be solely
// responsible for for the min-end of the whole set. The returned
// critical_index will be the index of the first critical task in
// SortedTasks().
//
// A reason for the min end is:
// - The min-duration of all the critical tasks.
@@ -132,6 +133,9 @@ class DisjunctiveConstraint : public PropagatorInterface {
// [(min-duration) ... (min-duration)]
// ^ ^ ^ ^
// min-start min-end max-start max-end
//
// Note that for tasks with variable durations, we don't necessarily have
// min-duration between the the min-XXX and max-XXX value.
IntegerValue MinDuration(int t) const {
return duration_vars_[t] == kNoIntegerVariable
? fixed_durations_[t]
@@ -140,11 +144,15 @@ class DisjunctiveConstraint : public PropagatorInterface {
IntegerValue MinStart(int t) const {
return integer_trail_->LowerBound(start_vars_[t]);
}
IntegerValue MaxStart(int t) const {
return integer_trail_->UpperBound(start_vars_[t]);
}
IntegerValue MinEnd(int t) const {
return integer_trail_->LowerBound(end_vars_[t]);
}
IntegerValue MaxEnd(int t) const {
return integer_trail_->UpperBound(end_vars_[t]);
}
IntegerValue MaxStart(int t) const { return MaxEnd(t) - MinDuration(t); }
IntegerValue MinEnd(int t) const { return MinStart(t) + MinDuration(t); }
// Helper functions to compute the reason of a propagation.
// Append to literal_reason_ and integer_reason_ the corresponding reason.
@@ -152,10 +160,15 @@ class DisjunctiveConstraint : public PropagatorInterface {
void AddMinDurationReason(int t);
void AddMinStartReason(int t, IntegerValue lower_bound);
void AddMaxEndReason(int t, IntegerValue upper_bound);
void AddMaxStartReason(int t, IntegerValue upper_bound);
// Checks that the interval [min_start_t, max_end_t] is larger than
// min_duration_t. Returns false and report an conflict otherwise.
bool CheckIntervalForConflict(int t, Trail* trail);
// Enqueues new bounds of an interval. The reasons (literal_reason_ and
// integer_reason_) must already be filled. Note that we automatically push
// min-end and max-start accordingly, so we maintain the invariants:
// - min-end >= min-start + min-duration
// - max-start <= max-end + min-duration
bool IncreaseMinStart(int t, IntegerValue new_min_start);
bool DecreaseMaxEnd(int t, IntegerValue new_max_end);
// All these passes use the algorithms described in Petr Vilim PhD "Global
// Constraints in Scheduling". Except that we don't use the O(log(n)) balanced

View File

@@ -18,6 +18,65 @@
namespace operations_research {
namespace sat {
void IntegerEncoder::FullyEncodeVariable(IntegerVariable i_var,
std::vector<IntegerValue> values) {
CHECK_EQ(0, sat_solver_->CurrentDecisionLevel());
CHECK(!values.empty()); // UNSAT problem. We don't deal with that here.
STLSortAndRemoveDuplicates(&values);
// TODO(user): This case is annoying, not sure yet how to best fix the
// variable. There is certainly no need to create a Boolean variable, but
// one needs to talk to IntegerTrail to fix the variable and we don't want
// the encoder to depend on this. So for now we fail here and it is up to
// the caller to deal with this case.
CHECK_NE(values.size(), 1);
// If the variable has already been fully encoded, for now we check that
// the sets of value is the same.
//
// TODO(user): Take the intersection, and handle that case in the constraints
// creation functions.
if (ContainsKey(full_encoding_index_, i_var)) {
const std::vector<ValueLiteralPair>& encoding = FullDomainEncoding(i_var);
CHECK_EQ(values.size(), encoding.size());
for (int i = 0; i < values.size(); ++i) {
CHECK_EQ(values[i], encoding[i].value);
}
return;
}
std::vector<ValueLiteralPair> encoding;
if (values.size() == 2) {
const BooleanVariable var = sat_solver_->NewBooleanVariable();
encoding.push_back({values[0], Literal(var, true)});
encoding.push_back({values[1], Literal(var, false)});
} else {
std::vector<sat::LiteralWithCoeff> cst;
for (const IntegerValue value : values) {
const BooleanVariable var = sat_solver_->NewBooleanVariable();
encoding.push_back({value, Literal(var, true)});
cst.push_back(LiteralWithCoeff(Literal(var, true), Coefficient(1)));
}
CHECK(sat_solver_->AddLinearConstraint(true, sat::Coefficient(1), true,
sat::Coefficient(1), &cst));
}
full_encoding_index_[i_var] = full_encoding_.size();
full_encoding_.push_back(encoding); // copy because we need it below.
// Deal with NegationOf(i_var).
//
// TODO(user): This seems a bit wasted, but it does simplify the code at a
// somehow small cost.
std::reverse(encoding.begin(), encoding.end());
for (auto& entry : encoding) {
entry.value = -entry.value; // Reverse the value.
}
full_encoding_index_[NegationOf(i_var)] = full_encoding_.size();
full_encoding_.push_back(std::move(encoding));
}
void IntegerEncoder::AddImplications(IntegerLiteral i_lit, Literal literal) {
if (i_lit.var >= encoding_by_var_.size()) {
encoding_by_var_.resize(i_lit.var + 1);
@@ -89,21 +148,94 @@ bool IntegerTrail::Propagate(Trail* trail) {
CHECK_EQ(trail->CurrentDecisionLevel(), integer_decision_levels_.size());
}
// Value encoder.
//
// TODO(user): There is no need to maintain the bounds of such variable if
// they are never used in any constraint!
//
// Algorithm:
// 1/ See if new variables are fully encoded and initialize them.
// 2/ In the loop below, each time a "min" variable was assigned to false,
// update the associated variable bounds, and change the watched "min".
// This step is is O(num variables at false between the old and new min).
//
// The data structure are reversible.
watched_min_.SetLevel(trail->CurrentDecisionLevel());
current_min_.SetLevel(trail->CurrentDecisionLevel());
if (encoder_->GetFullyEncodedVariables().size() != num_encoded_variables_) {
num_encoded_variables_ = encoder_->GetFullyEncodedVariables().size();
// for now this is only supported at level zero. Otherwise we need to
// inspect the trail to properly compute all the min.
//
// TODO(user): Don't rescan all the variables from scratch, we could only
// scan the new ones. But then we need a mecanism to detect the new ones.
CHECK_EQ(trail->CurrentDecisionLevel(), 0);
for (const auto& entry : encoder_->GetFullyEncodedVariables()) {
IntegerVariable var = entry.first;
const auto& encoding = encoder_->FullDomainEncoding(var);
for (int i = 0; i < encoding.size(); ++i) {
if (!trail_->Assignment().LiteralIsFalse(encoding[i].literal)) {
watched_min_.Set(encoding[i].literal.NegatedIndex(), {var, i});
current_min_.Set(var, i);
// No reason because we are at level zero.
if (!Enqueue(IntegerLiteral::GreaterOrEqual(var, encoding[i].value),
{}, {})) {
return false;
}
break;
}
}
}
}
// Process all the "associated" literals and Enqueue() the corresponding
// bounds.
while (propagation_trail_index_ < trail->Index()) {
const Literal literal = (*trail)[propagation_trail_index_++];
const IntegerLiteral i_lit = encoder_->GetIntegerLiteral(literal);
if (i_lit.var < 0) continue;
// The reason is simply the associated literal.
if (!Enqueue(i_lit, {literal.Negated()}, {})) return false;
// Bound encoder.
const IntegerLiteral i_lit = encoder_->GetIntegerLiteral(literal);
if (i_lit.var >= 0) {
// The reason is simply the associated literal.
if (!Enqueue(i_lit, {literal.Negated()}, {})) return false;
}
// Value encoder.
if (watched_min_.ContainsKey(literal.Index())) {
// A watched min value just became false.
const auto pair = watched_min_.FindOrDie(literal.Index());
const IntegerVariable var = pair.first;
const int min = pair.second;
const auto& encoding = encoder_->FullDomainEncoding(var);
std::vector<Literal> literal_reason = {literal.Negated()};
for (int i = min + 1; i < encoding.size(); ++i) {
if (!trail_->Assignment().LiteralIsFalse(encoding[i].literal)) {
watched_min_.EraseOrDie(literal.Index());
watched_min_.Set(encoding[i].literal.NegatedIndex(), {var, i});
current_min_.Set(var, i);
// Note that we also need the fact that all smaller value are false
// for the propagation. We use the current lower bound for that.
if (!Enqueue(IntegerLiteral::GreaterOrEqual(var, encoding[i].value),
literal_reason, {LowerBoundAsLiteral(var)})) {
return false;
}
break;
} else {
literal_reason.push_back(encoding[i].literal);
}
}
}
}
return true;
}
void IntegerTrail::Untrail(const Trail& trail, int literal_trail_index) {
watched_min_.SetLevel(trail.CurrentDecisionLevel());
current_min_.SetLevel(trail.CurrentDecisionLevel());
propagation_trail_index_ =
std::min(propagation_trail_index_, literal_trail_index);
@@ -171,6 +303,28 @@ int IntegerTrail::FindLowestTrailIndexThatExplainBound(
return prev_trail_index;
}
bool IntegerTrail::EnqueueAssociatedLiteral(
Literal literal, IntegerLiteral i_lit,
const std::vector<Literal>& literals_reason,
const std::vector<IntegerLiteral>& bounds_reason) {
if (!trail_->Assignment().VariableIsAssigned(literal.Variable())) {
// The reason is simply i_lit and will be expanded lazily when needed.
std::vector<Literal>* unused;
std::vector<IntegerLiteral>* integer_reason;
EnqueueLiteral(literal, &unused, &integer_reason);
integer_reason->push_back(i_lit);
return true;
}
if (trail_->Assignment().LiteralIsFalse(literal)) {
std::vector<Literal>* conflict = trail_->MutableConflict();
*conflict = literals_reason;
conflict->push_back(literal);
MergeReasonInto(bounds_reason, conflict);
return false;
}
return true;
}
bool IntegerTrail::Enqueue(IntegerLiteral i_lit,
const std::vector<Literal>& literals_reason,
const std::vector<IntegerLiteral>& bounds_reason) {
@@ -178,8 +332,58 @@ bool IntegerTrail::Enqueue(IntegerLiteral i_lit,
if (i_lit.bound <= vars_[i_lit.var].current_bound) return true;
++num_enqueues_;
// Check if the integer variable has an empty domain.
// Deal with fully encoded variable. We want to do that first because this may
// make the IntegerLiteral bound stronger.
const IntegerVariable var(i_lit.var);
if (current_min_.ContainsKey(var)) {
// Recover the current min, and propagate to false all the values that
// are in [min, i_lit.value). All these literals have the same reason, so
// we use the "same reason as" mecanism.
const int min_index = current_min_.FindOrDie(var);
const auto& encoding = encoder_->FullDomainEncoding(var);
if (i_lit.bound > encoding[min_index].value) {
const Literal negated_min = encoding[min_index].literal.Negated();
if (!EnqueueAssociatedLiteral(negated_min, i_lit, literals_reason,
bounds_reason)) {
return false;
}
int i = min_index + 1;
for (; i < encoding.size(); ++i) {
if (i_lit.bound <= encoding[i].value) break;
const Literal literal = encoding[i].literal.Negated();
if (!trail_->Assignment().VariableIsAssigned(literal.Variable())) {
trail_->EnqueueWithSameReasonAs(literal, negated_min.Variable());
} else if (trail_->Assignment().LiteralIsFalse(literal)) {
// Conflict.
std::vector<Literal>* conflict = trail_->MutableConflict();
*conflict = literals_reason;
conflict->push_back(literal);
MergeReasonInto(bounds_reason, conflict);
return false;
}
}
if (i == encoding.size()) {
// Conflict: no possible values left.
std::vector<Literal>* conflict = trail_->MutableConflict();
*conflict = literals_reason;
MergeReasonInto(bounds_reason, conflict);
return false;
} else {
// We have a new min.
watched_min_.EraseOrDie(encoding[min_index].literal.NegatedIndex());
watched_min_.Set(encoding[i].literal.NegatedIndex(), {var, i});
current_min_.Set(var, i);
// Adjust the bound of i_lit !
CHECK_GE(encoding[i].value, i_lit.bound);
i_lit.bound = encoding[i].value.value();
}
}
}
// Check if the integer variable has an empty domain.
if (i_lit.bound > UpperBound(var)) {
if (!IsOptional(var) ||
trail_->Assignment().LiteralIsFalse(Literal(is_empty_literals_[var]))) {
@@ -199,8 +403,7 @@ bool IntegerTrail::Enqueue(IntegerLiteral i_lit,
if (!trail_->Assignment().LiteralIsTrue(is_empty)) {
std::vector<Literal>* literal_reason_ptr;
std::vector<IntegerLiteral>* integer_reason_ptr;
EnqueueLiteral(is_empty, &literal_reason_ptr, &integer_reason_ptr,
trail_);
EnqueueLiteral(is_empty, &literal_reason_ptr, &integer_reason_ptr);
*literal_reason_ptr = literals_reason;
*integer_reason_ptr = bounds_reason;
integer_reason_ptr->push_back(UpperBoundAsLiteral(var));
@@ -217,18 +420,8 @@ bool IntegerTrail::Enqueue(IntegerLiteral i_lit,
const LiteralIndex literal_index =
encoder_->SearchForLiteralAtOrBefore(i_lit);
if (literal_index != kNoLiteralIndex) {
const Literal literal(literal_index);
if (!trail_->Assignment().VariableIsAssigned(literal.Variable())) {
std::vector<Literal>* literal_reason;
std::vector<IntegerLiteral>* integer_reason;
EnqueueLiteral(literal, &literal_reason, &integer_reason, trail_);
integer_reason->push_back(i_lit);
} else if (trail_->Assignment().LiteralIsFalse(literal)) {
// Conflict.
std::vector<Literal>* conflict = trail_->MutableConflict();
*conflict = literals_reason;
conflict->push_back(literal);
MergeReasonInto(bounds_reason, conflict);
if (!EnqueueAssociatedLiteral(Literal(literal_index), i_lit,
literals_reason, bounds_reason)) {
return false;
}
}
@@ -385,9 +578,8 @@ ClauseRef IntegerTrail::Reason(const Trail& trail, int trail_index) const {
void IntegerTrail::EnqueueLiteral(Literal literal,
std::vector<Literal>** literal_reason,
std::vector<IntegerLiteral>** integer_reason,
Trail* trail) {
const int trail_index = trail->Index();
std::vector<IntegerLiteral>** integer_reason) {
const int trail_index = trail_->Index();
if (trail_index >= literal_reasons_.size()) {
literal_reasons_.resize(trail_index + 1);
integer_reasons_.resize(trail_index + 1);
@@ -400,16 +592,15 @@ void IntegerTrail::EnqueueLiteral(Literal literal,
if (integer_reason != nullptr) {
*integer_reason = &integer_reasons_[trail_index];
}
trail->Enqueue(literal, propagator_id_);
trail_->Enqueue(literal, propagator_id_);
}
void IntegerTrail::EnqueueLiteral(Literal literal,
const std::vector<Literal>& literal_reason,
const std::vector<IntegerLiteral>& integer_reason,
Trail* trail) {
void IntegerTrail::EnqueueLiteral(
Literal literal, const std::vector<Literal>& literal_reason,
const std::vector<IntegerLiteral>& integer_reason) {
std::vector<Literal>* literal_reason_ptr;
std::vector<IntegerLiteral>* integer_reason_ptr;
EnqueueLiteral(literal, &literal_reason_ptr, &integer_reason_ptr, trail);
EnqueueLiteral(literal, &literal_reason_ptr, &integer_reason_ptr);
*literal_reason_ptr = literal_reason;
*integer_reason_ptr = integer_reason;
}

View File

@@ -19,11 +19,13 @@
#include "base/port.h"
#include "base/join.h"
#include "base/int_type.h"
#include "base/map_util.h"
#include "sat/model.h"
#include "sat/sat_base.h"
#include "sat/sat_solver.h"
#include "util/bitset.h"
#include "util/iterators.h"
#include "util/rev.h"
#include "util/saturated_arithmetic.h"
namespace operations_research {
@@ -146,8 +148,10 @@ inline std::ostream& operator<<(std::ostream& os, IntegerLiteral i_lit) {
// these variables activity and so on. These variables can also be propagated
// directly by the learned clauses.
//
// TODO(user): Add support for creating literals encoding x == v? for now this
// is not used though.
// This class also support a non-lazy full domain encoding which will create one
// literal per possible value in the domain. See FullyEncodeVariable(). This is
// meant to be called by constraints that directly work on the variable values
// like a table constraint or an all-diff constraint.
//
// TODO(user): We could also lazily create precedences Booleans between two
// arbitrary IntegerVariable. This is better done in the PrecedencesPropagator
@@ -168,6 +172,58 @@ class IntegerEncoder {
return encoder;
}
// This has 3 effects:
// 1/ It restricts the given variable to only take values amongst the given
// ones.
// 2/ It creates one Boolean variable per value that convey the fact that the
// var is equal to this value iff the Boolean is true. If there is only
// 2 values, then just one Boolean variable is created. For more than two
// values, a constraint is also added to enforce that exactly one Boolean
// variable is true.
// 3/ The encoding for NegationOf(var) is automatically created too. It reuses
// the same Boolean variable as the encoding of var.
//
// Calling this more than once is an error (Checked).
// TODO(user): we could instead only keep the intersection and fix the now
// impossible values to zero.
//
// Note(user): There is currently no relation here between
// FullyEncodeVariable() and CreateAssociatedLiteral(). However the
// IntegerTrail class will automatically link the two representations and do
// the right thing.
//
// Note(user): Calling this with just one value will cause a CHECK fail. One
// need to fix the IntegerVariable inside the IntegerTrail instead of calling
// this.
//
// TODO(user): It is currently only possible to call that at the decision
// level zero. This is Checked.
void FullyEncodeVariable(IntegerVariable var, std::vector<IntegerValue> values);
// Gets the full encoding of a variable on which FullyEncodeVariable() has
// been called. The returned elements are always sorted by increasing
// IntegerValue. Once created, the encoding never changes, but some Boolean
// variable may become fixed.
struct ValueLiteralPair {
ValueLiteralPair(IntegerValue v, Literal l) : value(v), literal(l) {}
bool operator==(const ValueLiteralPair& o) const {
return value == o.value && literal == o.literal;
}
IntegerValue value;
Literal literal;
};
const std::vector<ValueLiteralPair>& FullDomainEncoding(
IntegerVariable var) const {
return full_encoding_[FindOrDie(full_encoding_index_, var)];
}
// Returns the set of variable encoded as the keys in a map. The map values
// only have an internal meaning. The set of encoded variables is returned
// with this "weird" api for efficiency.
const hash_map<IntegerVariable, int>& GetFullyEncodedVariables() const {
return full_encoding_index_;
}
// Creates a new Boolean variable 'var' such that
// - if true, then the IntegerLiteral is true.
// - if false, then the negated IntegerLiteral is true.
@@ -218,14 +274,18 @@ class IntegerEncoder {
// doesn't correspond to an IntegerLiteral.
ITIVector<LiteralIndex, IntegerLiteral> reverse_encoding_;
// Full domain encoding. The map contains the index in full_encoding_ of
// the fully encoded variable. Each entry in full_encoding_ is sorted by
// IntegerValue and contains the encoding of one IntegerVariable.
hash_map<IntegerVariable, int> full_encoding_index_;
std::vector<std::vector<ValueLiteralPair>> full_encoding_;
DISALLOW_COPY_AND_ASSIGN(IntegerEncoder);
};
// This class maintains a set of integer variables with their current bounds.
// Bounds can be propagated from an external "source" and this class helps
// to maintain the reason for each propagation.
//
// TODO(user): Add support for a lazy encoding of the integer variable in SAT.
class IntegerTrail : public Propagator {
public:
IntegerTrail(IntegerEncoder* encoder, Trail* trail)
@@ -323,10 +383,9 @@ class IntegerTrail : public Propagator {
// assignment. They are only valid just after this is called. The full literal
// reason will be computed lazily when it becomes needed.
void EnqueueLiteral(Literal literal, std::vector<Literal>** literal_reason,
std::vector<IntegerLiteral>** integer_reason, Trail* trail);
std::vector<IntegerLiteral>** integer_reason);
void EnqueueLiteral(Literal literal, const std::vector<Literal>& literal_reason,
const std::vector<IntegerLiteral>& integer_reason,
Trail* trail);
const std::vector<IntegerLiteral>& integer_reason);
// Returns the reason (as set of Literal currently false) for a given integer
// literal. Note that the bound must be less restrictive than the current
@@ -355,6 +414,12 @@ class IntegerTrail : public Propagator {
}
private:
// Helper used by Enqueue() to propagate one of the literal associated to
// the given i_lit and maintained by encoder_.
bool EnqueueAssociatedLiteral(Literal literal, IntegerLiteral i_lit,
const std::vector<Literal>& literals_reason,
const std::vector<IntegerLiteral>& bounds_reason);
// Returns a lower bound on the given var that will always be valid.
IntegerValue LevelZeroBound(int var) const {
// The level zero bounds are stored at the begining of the trail and they
@@ -410,6 +475,13 @@ class IntegerTrail : public Propagator {
// The "is_empty" literal of the optional variables or kNoLiteralIndex.
ITIVector<IntegerVariable, LiteralIndex> is_empty_literals_;
// Data used to support the propagation of fully encoded variable. We keep
// for each variable the index in encoder_.GetDomainEncoding() of the first
// literal that is not assigned to false, and call this the "min".
int64 num_encoded_variables_ = 0;
RevMap<hash_map<LiteralIndex, std::pair<IntegerVariable, int>>> watched_min_;
RevMap<hash_map<IntegerVariable, int>> current_min_;
// Temporary data used by MergeReasonInto().
mutable std::vector<int> tmp_queue_;
mutable std::vector<int> tmp_trail_indices_;
@@ -593,6 +665,15 @@ inline std::function<int64(const Model&)> UpperBound(IntegerVariable v) {
};
}
// This checks that the variable is fixed.
inline std::function<int64(const Model&)> Value(IntegerVariable v) {
return [=](const Model& model) {
const IntegerTrail* trail = model.Get<IntegerTrail>();
CHECK_EQ(trail->LowerBound(v), trail->UpperBound(v));
return trail->LowerBound(v).value();
};
}
inline std::function<void(Model*)> GreaterOrEqual(IntegerVariable v, int64 lb) {
return [=](Model* model) {
if (!model->GetOrCreate<IntegerTrail>()->Enqueue(

View File

@@ -206,12 +206,12 @@ bool IsOneOfPropagator::Propagate(Trail* trail) {
std::vector<Literal>* literal_reason;
std::vector<IntegerLiteral>* integer_reason;
if (current_min > values_[i]) {
integer_trail_->EnqueueLiteral(
selectors_[i].Negated(), &literal_reason, &integer_reason, trail);
integer_trail_->EnqueueLiteral(selectors_[i].Negated(),
&literal_reason, &integer_reason);
integer_reason->push_back(integer_trail_->LowerBoundAsLiteral(var_));
} else if (current_max < values_[i]) {
integer_trail_->EnqueueLiteral(
selectors_[i].Negated(), &literal_reason, &integer_reason, trail);
integer_trail_->EnqueueLiteral(selectors_[i].Negated(),
&literal_reason, &integer_reason);
integer_reason->push_back(integer_trail_->UpperBoundAsLiteral(var_));
}
}

View File

@@ -201,9 +201,6 @@ inline std::function<void(Model*)> EndAtEnd(IntervalVariable i1,
};
}
// TODO(user): Add a propagator on the interval duration depending
// on the set of alternatives that are currently not executable.
//
// This requires that all the alternatives are optional tasks.
inline std::function<void(Model*)> IntervalWithAlternatives(
IntervalVariable master, const std::vector<IntervalVariable>& members) {

View File

@@ -305,7 +305,7 @@ void PrecedencesPropagator::PropagateOptionalArcs(Trail* trail) {
std::vector<Literal>* literal_reason;
std::vector<IntegerLiteral>* integer_reason;
integer_trail_->EnqueueLiteral(is_present.Negated(), &literal_reason,
&integer_reason, trail);
&integer_reason);
integer_reason->push_back(
integer_trail_->LowerBoundAsLiteral(arc.tail_var));
integer_reason->push_back(

View File

@@ -269,7 +269,7 @@ class Trail {
}
// Specific Enqueue() version for the search decision.
void EnqueueSeachDecision(Literal true_literal) {
void EnqueueSearchDecision(Literal true_literal) {
Enqueue(true_literal, AssignmentType::kSearchDecision);
}

View File

@@ -1617,7 +1617,7 @@ void SatSolver::EnqueueNewDecision(Literal literal) {
decisions_[current_decision_level_] = Decision(trail_->Index(), literal);
++current_decision_level_;
trail_->SetDecisionLevel(current_decision_level_);
trail_->EnqueueSeachDecision(literal);
trail_->EnqueueSearchDecision(literal);
}
Literal SatSolver::NextBranch() {

View File

@@ -322,15 +322,17 @@ class SatSolver {
const std::vector<BinaryClause>& NewlyAddedBinaryClauses();
void ClearNewlyAddedBinaryClauses();
// Various getters of the current solver state.
struct Decision {
Decision() : trail_index(-1) {}
Decision(int i, Literal l) : trail_index(i), literal(l) {}
int trail_index;
Literal literal;
};
int CurrentDecisionLevel() const { return current_decision_level_; }
// Note that the Decisions() vector is always of size NumVariables(), and that
// only the first CurrentDecisionLevel() entries have a meaning.
const std::vector<Decision>& Decisions() const { return decisions_; }
int CurrentDecisionLevel() const { return current_decision_level_; }
const Trail& LiteralTrail() const { return *trail_; }
const VariablesAssignment& Assignment() const { return trail_->Assignment(); }
@@ -372,6 +374,10 @@ class SatSolver {
binary_implication_graph_.AddBinaryClause(a, b);
}
// Performs propagation of the recently enqueued elements.
// Mainly visible for testing.
bool Propagate();
private:
// Calls Propagate() and returns true if no conflict occured. Otherwise,
// learns the conflict, backtracks, enqueues the consequence of the learned
@@ -486,8 +492,7 @@ class SatSolver {
// True and Enqueue() this change.
void EnqueueNewDecision(Literal literal);
// Performs propagation of the recently enqueued elements.
bool Propagate();
// Returns true if everything has been propagated.
bool PropagationIsDone() const;
// Update the propagators_ list with the relevant propagators.
@@ -943,6 +948,32 @@ inline std::function<void(Model*)> ClauseConstraint(
};
}
// The a => b constraint.
inline std::function<void(Model*)> Implication(Literal a, Literal b) {
return [=](Model* model) {
model->GetOrCreate<SatSolver>()->AddBinaryClause(a.Negated(), b);
};
}
// This can be used to enumerate all the solutions. After each SAT call to
// Solve(), calling this will reset the solver and exclude the current solution
// so that the next call to Solve() will give a new solution or UNSAT is there
// is no more new solutions.
inline std::function<void(Model*)> ExcludeCurrentSolutionAndBacktrack() {
return [=](Model* model) {
SatSolver* sat_solver = model->GetOrCreate<SatSolver>();
// Note that we only exclude the current decisions, which is an efficient
// way to not get the same SAT assignment.
std::vector<Literal> exlude_solution;
for (int i = 0; i < sat_solver->CurrentDecisionLevel(); ++i) {
exlude_solution.push_back(sat_solver->Decisions()[i].literal.Negated());
}
sat_solver->Backtrack(0);
model->Add(ClauseConstraint(exlude_solution));
};
}
inline std::function<SatParameters(Model*)> NewSatParameters(std::string params) {
return [=](Model* model) {
sat::SatParameters parameters;

242
src/sat/table.cc Normal file
View File

@@ -0,0 +1,242 @@
// Copyright 2010-2014 Google
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "sat/table.h"
#include "base/map_util.h"
#include "base/stl_util.h"
namespace operations_research {
namespace sat {
namespace {
// Transpose the given "matrix" and transform the value to IntegerValue.
std::vector<std::vector<IntegerValue>> Transpose(const std::vector<std::vector<int64>> tuples) {
CHECK(!tuples.empty());
const int n = tuples.size();
const int m = tuples[0].size();
std::vector<std::vector<IntegerValue>> transpose(m, std::vector<IntegerValue>(n));
for (int i = 0; i < n; ++i) {
CHECK_EQ(m, tuples[i].size());
for (int j = 0; j < m; ++j) {
transpose[j][i] = tuples[i][j];
}
}
return transpose;
}
// Converts the vector representation returned by FullDomainEncoding() to a map.
hash_map<IntegerValue, Literal> GetEncoding(IntegerVariable var, Model* model) {
hash_map<IntegerValue, Literal> encoding;
for (const auto& entry :
model->GetOrCreate<IntegerEncoder>()->FullDomainEncoding(var)) {
encoding[entry.value] = entry.literal;
}
return encoding;
}
// Add the implications and clauses to link one column of a table to the Literal
// controling if the lines are possible or not. The column has the given values,
// and the Literal of the column variable can be retreived using the encoding
// map.
void ProcessOneColumn(const std::vector<Literal>& line_literals,
const std::vector<IntegerValue>& values,
const hash_map<IntegerValue, Literal>& encoding,
Model* model) {
CHECK_EQ(line_literals.size(), values.size());
hash_map<IntegerValue, std::vector<Literal>> value_to_list_of_line_literals;
// If a value is false (i.e not possible), then the tuple with this value
// is false too (i.e not possible).
for (int i = 0; i < values.size(); ++i) {
const IntegerValue v = values[i];
value_to_list_of_line_literals[v].push_back(line_literals[i]);
model->Add(Implication(FindOrDie(encoding, v).Negated(),
line_literals[i].Negated()));
}
// If all the tuples containing a value are false, then this value must be
// false too.
for (const auto& entry : value_to_list_of_line_literals) {
std::vector<Literal> clause = entry.second;
clause.push_back(FindOrDie(encoding, entry.first).Negated());
model->Add(ClauseConstraint(clause));
}
}
} // namespace
std::function<void(Model*)> TableConstraint(
const std::vector<IntegerVariable>& vars, const std::vector<std::vector<int64>>& tuples) {
return [=](Model* model) {
// Create one Boolean variable per tuple to indicate if it can still be
// selected or not. Note that we don't enforce exactly one tuple to be
// selected because these variables are just used by this constraint, so
// only the information "can't be selected" is important.
//
// TODO(user): If a value in one column is unique, we don't need to create a
// new BooleanVariable corresponding to this line since we can use the one
// corresponding to this value in that column.
std::vector<Literal> tuple_literals;
for (int i = 0; i < tuples.size(); ++i) {
tuple_literals.push_back(Literal(model->Add(NewBooleanVariable()), true));
}
// Fully encode the variables using all the values appearing in the tuples.
IntegerEncoder* encoder = model->GetOrCreate<IntegerEncoder>();
hash_map<IntegerValue, Literal> encoding;
const std::vector<std::vector<IntegerValue>>& tr_tuples = Transpose(tuples);
for (int i = 0; i < vars.size(); ++i) {
encoder->FullyEncodeVariable(vars[i], tr_tuples[i]);
encoding = GetEncoding(vars[i], model);
ProcessOneColumn(tuple_literals, tr_tuples[i], encoding, model);
}
};
}
std::function<void(Model*)> TransitionConstraint(
const std::vector<IntegerVariable>& vars, const std::vector<std::vector<int64>>& automata,
int64 initial_state, const std::vector<int64>& final_states) {
return [=](Model* model) {
IntegerEncoder* encoder = model->GetOrCreate<IntegerEncoder>();
const int n = vars.size();
CHECK_GT(n, 0) << "No variables in TransitionConstraint().";
// Test precondition.
{
std::set<std::pair<int64, int64>> unique_transition_checker;
for (const std::vector<int64>& transition : automata) {
CHECK_EQ(transition.size(), 3);
const std::pair<int64, int64> p{transition[0], transition[1]};
CHECK(!ContainsKey(unique_transition_checker, p))
<< "Duplicate outgoing transitions with value " << transition[1]
<< " from state " << transition[0] << ".";
unique_transition_checker.insert(p);
}
}
// Compute the set of reachable state at each time point.
std::vector<std::set<int64>> reachable_states(n + 1);
reachable_states[0].insert(initial_state);
reachable_states[n] = {final_states.begin(), final_states.end()};
// Forward.
for (int time = 0; time + 1 < n; ++time) {
for (const std::vector<int64>& transition : automata) {
if (!ContainsKey(reachable_states[time], transition[0])) continue;
reachable_states[time + 1].insert(transition[2]);
}
}
// Backward.
for (int time = n - 1; time > 0; --time) {
std::set<int64> new_set;
for (const std::vector<int64>& transition : automata) {
if (!ContainsKey(reachable_states[time], transition[0])) continue;
if (!ContainsKey(reachable_states[time + 1], transition[2])) continue;
new_set.insert(transition[0]);
}
reachable_states[time].swap(new_set);
}
// We will model at each time step the current automata state using Boolean
// variables. We will have n+1 time step. At time zero, we start in the
// initial state, and at time n we should be in one of the final states. We
// don't need to create Booleans at at time when there is just one possible
// state (like at time zero).
hash_map<IntegerValue, Literal> encoding;
hash_map<IntegerValue, Literal> in_encoding;
hash_map<IntegerValue, Literal> out_encoding;
for (int time = 0; time < n; ++time) {
// All these vector have the same size. We will use them to enforce a
// local table constraint representing one step of the automata at the
// given time.
std::vector<Literal> tuple_literals;
std::vector<IntegerValue> in_states;
std::vector<IntegerValue> transition_values;
std::vector<IntegerValue> out_states;
for (const std::vector<int64>& transition : automata) {
if (!ContainsKey(reachable_states[time], transition[0])) continue;
if (!ContainsKey(reachable_states[time + 1], transition[2])) continue;
// TODO(user): if this transition correspond to just one in-state or
// one-out state or one variable value, we could reuse the corresponding
// Boolean variable instead of creating a new one!
tuple_literals.push_back(
Literal(model->Add(NewBooleanVariable()), true));
in_states.push_back(IntegerValue(transition[0]));
transition_values.push_back(IntegerValue(transition[1]));
out_states.push_back(IntegerValue(transition[2]));
}
// Fully instantiate vars[time].
{
std::vector<IntegerValue> s = transition_values;
STLSortAndRemoveDuplicates(&s);
encoding.clear();
if (s.size() > 1) {
std::vector<IntegerValue> values(s.begin(), s.end());
encoder->FullyEncodeVariable(vars[time], values);
encoding = GetEncoding(vars[time], model);
} else {
// Fix vars[time] to its unique possible value.
CHECK_EQ(s.size(), 1);
const int64 unique_value = s.begin()->value();
model->Add(LowerOrEqual(vars[time], unique_value));
model->Add(GreaterOrEqual(vars[time], unique_value));
}
}
// For each possible out states, create one Boolean variable.
//
// TODO(user): enforce an at most one constraint? it is not really needed
// though, so I am not sure it will improve or hurt the performance. To
// investigate on real problems.
{
std::vector<IntegerValue> s = out_states;
STLSortAndRemoveDuplicates(&s);
out_encoding.clear();
if (s.size() == 2) {
const BooleanVariable var = model->Add(NewBooleanVariable());
out_encoding[s.front()] = Literal(var, true);
out_encoding[s.back()] = Literal(var, false);
} else if (s.size() > 1) {
// Enforce at most one constraint?
for (const IntegerValue state : s) {
out_encoding[state] =
Literal(model->Add(NewBooleanVariable()), true);
}
}
}
// Now we link everything together.
if (in_encoding.size() > 1) {
ProcessOneColumn(tuple_literals, in_states, in_encoding, model);
}
if (encoding.size() > 1) {
ProcessOneColumn(tuple_literals, transition_values, encoding, model);
}
if (out_encoding.size() > 1) {
ProcessOneColumn(tuple_literals, out_states, out_encoding, model);
}
in_encoding = out_encoding;
}
};
}
} // namespace sat
} // namespace operations_research

48
src/sat/table.h Normal file
View File

@@ -0,0 +1,48 @@
// Copyright 2010-2014 Google
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef OR_TOOLS_SAT_TABLE_H_
#define OR_TOOLS_SAT_TABLE_H_
#include "sat/integer.h"
#include "sat/model.h"
namespace operations_research {
namespace sat {
// Enforces that the given tuple of variables is equal to one of the given
// tuples. All the tuples must have the same size as var.size(), this is
// Checked.
std::function<void(Model*)> TableConstraint(
const std::vector<IntegerVariable>& vars, const std::vector<std::vector<int64>>& tuples);
// Given an automata defined by a set of 3-tuples:
// (state, transition_with_value_as_label, next_state)
// this accepts the sequences of vars.size() variables that are recognized by
// this automata. That is:
// - We start from the initial state.
// - For each variable, we move along the transition labeled by this variable
// value. Moreover, the variable must take a value that correspond to a
// feasible transition.
// - We only accept sequences that ends in one of the final states.
//
// We CHECK that there is only one possible transition for a state/value pair.
// See the test for some examples.
std::function<void(Model*)> TransitionConstraint(
const std::vector<IntegerVariable>& vars, const std::vector<std::vector<int64>>& automata,
int64 initial_state, const std::vector<int64>& final_states);
} // namespace sat
} // namespace operations_research
#endif // OR_TOOLS_SAT_TABLE_H_

129
src/util/rev.h Normal file
View File

@@ -0,0 +1,129 @@
// Copyright 2010-2014 Google
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Reversible (i.e Backtrackable) classes, used to simplify coding propagators.
#ifndef OR_TOOLS_UTIL_REV_H_
#define OR_TOOLS_UTIL_REV_H_
#include <vector>
#include "base/logging.h"
#include "base/map_util.h"
namespace operations_research {
// Like a normal map but support backtrackable operations.
//
// This works on any class "Map" that supports: begin(), end(), find(), erase(),
// insert(), key_type, value_type, mapped_type and const_iterator.
template <class Map>
class RevMap {
public:
typedef typename Map::key_type key_type;
typedef typename Map::mapped_type mapped_type;
typedef typename Map::value_type value_type;
typedef typename Map::const_iterator const_iterator;
// Backtracking support: changes the current "level" (always non-negative).
//
// Initially the class starts at level zero. Increasing the level works in
// O(level diff) and saves the state of the current old level. Decreasing the
// level restores the state to what it was at this level and all higher levels
// are forgotten. Everything done at level zero cannot be backtracked over.
void SetLevel(int level);
int Level() const { return first_op_index_of_next_level_.size(); }
bool ContainsKey(key_type key) const { return operations_research::ContainsKey(map_, key); }
const mapped_type& FindOrDie(key_type key) const {
return operations_research::FindOrDie(map_, key);
}
void EraseOrDie(key_type key);
void Set(key_type key, mapped_type value); // Adds or overwrites.
// Wrapper to the underlying const map functions.
int size() const { return map_.size(); }
bool empty() const { return map_.empty(); }
const_iterator find(const key_type& k) const { return map_.find(k); }
const_iterator begin() const { return map_.begin(); }
const_iterator end() const { return map_.end(); }
private:
Map map_;
// The operation that needs to be performed to reverse one modification:
// - If is_deletion is true, then we need to delete the entry with given key.
// - Otherwise we need to add back (or overwrite) the saved entry.
struct UndoOperation {
bool is_deletion;
key_type key;
mapped_type value;
};
// TODO(user): We could merge the operations with the same key from the same
// level. Investigate and implement if this is worth the effort for our use
// case.
std::vector<UndoOperation> operations_;
std::vector<int> first_op_index_of_next_level_;
};
template <class Map>
void RevMap<Map>::SetLevel(int level) {
DCHECK_GE(level, 0);
if (level < Level()) {
const int backtrack_level = first_op_index_of_next_level_[level];
first_op_index_of_next_level_.resize(level); // Shrinks.
while (operations_.size() > backtrack_level) {
const UndoOperation& to_undo = operations_.back();
if (to_undo.is_deletion) {
map_.erase(to_undo.key);
} else {
map_.insert({to_undo.key, to_undo.value}).first->second = to_undo.value;
}
operations_.pop_back();
}
return;
}
// This is ok even if level == Level().
first_op_index_of_next_level_.resize(level, operations_.size()); // Grows.
}
template <class Map>
void RevMap<Map>::EraseOrDie(key_type key) {
const auto iter = map_.find(key);
if (iter == map_.end()) LOG(FATAL) << "key not present: '" << key << "'.";
if (Level() > 0) {
operations_.push_back({false, key, iter->second});
}
map_.erase(iter);
}
template <class Map>
void RevMap<Map>::Set(key_type key, mapped_type value) {
auto insertion_result = map_.insert({key, value});
if (Level() > 0) {
if (insertion_result.second) {
// It is an insertion. Undo = delete.
operations_.push_back({true, key});
} else {
// It is a modification. Undo = change back to old value.
operations_.push_back({false, key, insertion_result.first->second});
}
}
insertion_result.first->second = value;
}
} // namespace operations_research
#endif // OR_TOOLS_UTIL_REV_H_