From 7b802c0015457db1694983a685ad7779f24b2702 Mon Sep 17 00:00:00 2001 From: "lperron@google.com" Date: Mon, 27 Jan 2014 15:05:30 +0000 Subject: [PATCH] add core computation to sat solver; add vehicle dependent dimensions in routing --- examples/cpp/sat_runner.cc | 81 +- makefiles/Makefile.cpp.mk | 34 +- src/constraint_solver/assignment.cc | 265 ++-- src/constraint_solver/constraint_solver.h | 51 +- src/constraint_solver/routing.cc | 239 ++-- src/constraint_solver/routing.h | 52 +- src/constraint_solver/routing_search.cc | 55 +- src/sat/boolean_problem.cc | 12 + src/sat/boolean_problem.h | 5 + src/sat/clause.cc | 512 ++++++++ src/sat/clause.h | 357 ++++++ src/sat/pb_constraint.cc | 47 +- src/sat/pb_constraint.h | 26 +- src/sat/sat_base.h | 34 +- src/sat/sat_conflict.cc | 528 -------- src/sat/sat_parameters.proto | 6 + src/sat/sat_solver.cc | 1374 ++++++++++++--------- src/sat/sat_solver.h | 500 ++------ src/sat/unsat_proof.cc | 169 +++ src/sat/unsat_proof.h | 118 ++ src/util/bitset.h | 7 + src/util/saturated_arithmetic.h | 14 + 22 files changed, 2677 insertions(+), 1809 deletions(-) create mode 100644 src/sat/clause.cc create mode 100644 src/sat/clause.h delete mode 100644 src/sat/sat_conflict.cc create mode 100644 src/sat/unsat_proof.cc create mode 100644 src/sat/unsat_proof.h diff --git a/examples/cpp/sat_runner.cc b/examples/cpp/sat_runner.cc index d9e60221b1..bef9bdaea3 100644 --- a/examples/cpp/sat_runner.cc +++ b/examples/cpp/sat_runner.cc @@ -28,6 +28,7 @@ #include "cpp/sat_cnf_reader.h" #include "sat/boolean_problem.h" #include "sat/sat_solver.h" +#include "util/time_limit.h" DEFINE_string( input, "", @@ -63,11 +64,16 @@ DEFINE_bool(search_optimal, false, "If true, search for the optimal solution. " "The algorithm is currently really basic."); -// TODO(user): Adds minisat to the mix. + +DEFINE_bool(refine_core, false, + "If true, turn on the unsat_proof parameters and if the problem is " + "UNSAT, refine as much as possible its UNSAT core in order to get " + "a small one."); namespace operations_research { namespace sat { namespace { + // To benefit from the operations_research namespace, we put all the main() code // here. int Run() { @@ -80,6 +86,13 @@ int Run() { CHECK(google::protobuf::TextFormat::ParseFromString(FLAGS_params, ¶meters)) << FLAGS_params; } + parameters.set_log_search_progress(true); + + // Enforce some parameters if we are looking for UNSAT core. + if (FLAGS_refine_core) { + parameters.set_unsat_proof(true); + parameters.set_treat_binary_clauses_separately(false); + } // Initialize the solver. SatSolver solver; @@ -88,7 +101,7 @@ int Run() { // Read the problem. LinearBooleanProblem problem; if (HasSuffixString(FLAGS_input, ".opb") || - HasSuffixString(FLAGS_input, ".opb.bz2")) { + HasSuffixString(FLAGS_input, ".opb.bz2")) { OpbReader reader; if (!reader.Load(FLAGS_input, &problem)) { LOG(FATAL) << "Cannot load file '" << FLAGS_input << "'."; @@ -99,10 +112,10 @@ int Run() { if (!reader.Load(FLAGS_input, &problem)) { LOG(FATAL) << "Cannot load file '" << FLAGS_input << "'."; } - } else { file::ReadFileToProtoOrDie(FLAGS_input, &problem); } + // Load the problem into the solver. if (!LoadBooleanProblem(problem, &solver)) { LOG(FATAL) << "Couldn't load problem '" << FLAGS_input << "'."; @@ -113,11 +126,33 @@ int Run() { LOG(FATAL) << "Issue when setting the objective bounds."; } + // Heuristics to drive the SAT search. + UseObjectiveForSatAssignmentPreference(problem, &solver); + // Basic search for the optimal value by calling multiple times the solver. if (FLAGS_search_optimal && problem.type() == LinearBooleanProblem::MINIMIZATION) { + TimeLimit time_limit(parameters.max_time_in_seconds()); Coefficient objective = std::numeric_limits::max(); - while (solver.Solve() == SatSolver::MODEL_SAT) { + int old_num_fixed_variables = 0; + while (true) { + const SatSolver::Status result = solver.Solve(); + if (result == SatSolver::MODEL_UNSAT) { + if (objective == std::numeric_limits::max()) { + LOG(INFO) << "The problem is UNSAT"; + break; + } + LOG(INFO) << "Optimal found!"; + LOG(INFO) << "Objective = " << objective; + LOG(INFO) << "Time = " << time_limit.GetElapsedTime(); + break; + } + if (result != SatSolver::MODEL_SAT) { + LOG(INFO) << "Search aborted."; + LOG(INFO) << "Objective = " << objective; + LOG(INFO) << "Time = " << time_limit.GetElapsedTime(); + break; + } CHECK(IsAssignmentValid(problem, solver.Assignment())); const Coefficient old_objective = objective; objective = ComputeObjectiveValue(problem, solver.Assignment()); @@ -126,10 +161,15 @@ int Run() { if (!AddObjectiveConstraint(problem, false, 0, true, objective - 1, &solver)) { LOG(INFO) << "UNSAT (when tightenning the objective constraint)."; + LOG(INFO) << "Optimal found!"; + LOG(INFO) << "Objective = " << objective; + LOG(INFO) << "Time = " << time_limit.GetElapsedTime(); break; } + parameters.set_max_time_in_seconds(time_limit.GetTimeLeft()); + solver.SetParameters(parameters); + } - LOG(INFO) << "Optimal found! " << objective; return EXIT_SUCCESS; } @@ -139,6 +179,37 @@ int Run() { CHECK(IsAssignmentValid(problem, solver.Assignment())); } + // Unsat with verification. + // Note(user): For now we just compute an UNSAT core and check it. + if (result == SatSolver::MODEL_UNSAT && parameters.unsat_proof()) { + std::vector core; + solver.ComputeUnsatCore(&core); + LOG(INFO) << "UNSAT. Identified a core of " << core.size() + << " constraints."; + + // The following block is mainly for testing the UNSAT core feature. + if (FLAGS_refine_core) { + int old_core_size = core.size(); + LinearBooleanProblem old_problem; + LinearBooleanProblem core_unsat_problem; + old_problem.CopyFrom(problem); + int i = 1; + do { + ExtractSubproblem(old_problem, core, &core_unsat_problem); + core_unsat_problem.set_name(StringPrintf("Subproblem #%d", i)); + old_core_size = core.size(); + old_problem.CopyFrom(core_unsat_problem); + SatSolver new_solver; + new_solver.SetParameters(parameters); + CHECK(LoadBooleanProblem(core_unsat_problem, &new_solver)); + CHECK_EQ(new_solver.Solve(), SatSolver::MODEL_UNSAT) << "Wrong core!"; + new_solver.ComputeUnsatCore(&core); + LOG(INFO) << "Core #" << i << " checked, next size is " << core.size(); + ++i; + } while (core.size() != old_core_size); + } + } + if (!FLAGS_output.empty()) { if (result == SatSolver::MODEL_SAT) { StoreAssignment(solver.Assignment(), problem.mutable_assignment()); diff --git a/makefiles/Makefile.cpp.mk b/makefiles/Makefile.cpp.mk index fe582619b8..920a9728a0 100644 --- a/makefiles/Makefile.cpp.mk +++ b/makefiles/Makefile.cpp.mk @@ -1103,44 +1103,48 @@ $(BIN_DIR)/integer_programming$E: $(DYNAMIC_LP_DEPS) $(OBJ_DIR)/integer_programm # Sat solver +sat: bin/sat_runner$E + SAT_LIB_OBJS = \ $(OBJ_DIR)/boolean_problem.pb.$O\ $(OBJ_DIR)/boolean_problem.$O\ $(OBJ_DIR)/pb_constraint.$O\ - $(OBJ_DIR)/sat_conflict.$O\ + $(OBJ_DIR)/clause.$O\ $(OBJ_DIR)/sat_parameters.pb.$O\ $(OBJ_DIR)/sat_solver.$O\ - -sat: bin/sat_runner$E + $(OBJ_DIR)/unsat_proof.$O\ satlibs: $(DYNAMIC_SAT_DEPS) $(STATIC_SAT_DEPS) -$(OBJ_DIR)/sat_solver.$O: $(SRC_DIR)/sat/sat_solver.cc $(SRC_DIR)/sat/sat_solver.h $(SRC_DIR)/sat/sat_base.h $(GEN_DIR)/sat/sat_parameters.pb.h +$(OBJ_DIR)/sat_solver.$O:$(SRC_DIR)/sat/sat_solver.cc $(SRC_DIR)/sat/sat_solver.h $(SRC_DIR)/sat/sat_base.h $(SRC_DIR)/sat/clause.h $(SRC_DIR)/sat/unsat_proof.h $(GEN_DIR)/sat/sat_parameters.pb.h $(CCC) $(CFLAGS) -c $(SRC_DIR)/sat/sat_solver.cc $(OBJ_OUT)$(OBJ_DIR)$Ssat_solver.$O -$(OBJ_DIR)/sat_conflict.$O: $(SRC_DIR)/sat/sat_conflict.cc $(SRC_DIR)/sat/sat_solver.h $(SRC_DIR)/sat/sat_base.h $(GEN_DIR)/sat/sat_parameters.pb.h - $(CCC) $(CFLAGS) -c $(SRC_DIR)/sat/sat_conflict.cc $(OBJ_OUT)$(OBJ_DIR)$Ssat_conflict.$O - -$(OBJ_DIR)/boolean_problem.$O: $(SRC_DIR)/sat/boolean_problem.cc $(SRC_DIR)/sat/boolean_problem.h $(GEN_DIR)/sat/boolean_problem.pb.h $(SRC_DIR)/sat/sat_solver.h $(SRC_DIR)/sat/sat_base.h $(GEN_DIR)/sat/sat_parameters.pb.h +$(OBJ_DIR)/boolean_problem.$O:$(SRC_DIR)/sat/boolean_problem.cc $(SRC_DIR)/sat/boolean_problem.h $(GEN_DIR)/sat/boolean_problem.pb.h $(SRC_DIR)/sat/sat_solver.h $(SRC_DIR)/sat/sat_base.h $(CCC) $(CFLAGS) -c $(SRC_DIR)/sat/boolean_problem.cc $(OBJ_OUT)$(OBJ_DIR)$Sboolean_problem.$O -$(GEN_DIR)/sat/boolean_problem.pb.cc: $(SRC_DIR)/sat/boolean_problem.proto +$(GEN_DIR)/sat/boolean_problem.pb.cc:$(SRC_DIR)/sat/boolean_problem.proto $(PROTOBUF_DIR)/bin/protoc --proto_path=$(INC_DIR) --cpp_out=$(GEN_DIR) $(SRC_DIR)/sat/boolean_problem.proto -$(GEN_DIR)/sat/boolean_problem.pb.h: $(GEN_DIR)/sat/boolean_problem.pb.cc +$(GEN_DIR)/sat/boolean_problem.pb.h:$(GEN_DIR)/sat/boolean_problem.pb.cc -$(OBJ_DIR)/boolean_problem.pb.$O: $(GEN_DIR)/sat/boolean_problem.pb.cc $(GEN_DIR)/sat/boolean_problem.pb.h +$(OBJ_DIR)/boolean_problem.pb.$O:$(GEN_DIR)/sat/boolean_problem.pb.cc $(GEN_DIR)/sat/boolean_problem.pb.h $(CCC) $(CFLAGS) -c $(GEN_DIR)/sat/boolean_problem.pb.cc $(OBJ_OUT)$(OBJ_DIR)$Sboolean_problem.pb.$O -$(OBJ_DIR)/pb_constraint.$O: $(SRC_DIR)/sat/pb_constraint.cc $(SRC_DIR)/sat/sat_base.h +$(OBJ_DIR)/pb_constraint.$O:$(SRC_DIR)/sat/pb_constraint.cc $(SRC_DIR)/sat/sat_base.h $(SRC_DIR)/sat/pb_constraint.h $(CCC) $(CFLAGS) -c $(SRC_DIR)/sat/pb_constraint.cc $(OBJ_OUT)$(OBJ_DIR)$Spb_constraint.$O -$(GEN_DIR)/sat/sat_parameters.pb.cc: $(SRC_DIR)/sat/sat_parameters.proto +$(OBJ_DIR)/clause.$O:$(SRC_DIR)/sat/clause.cc $(SRC_DIR)/sat/sat_base.h $(SRC_DIR)/sat/clause.h + $(CCC) $(CFLAGS) -c $(SRC_DIR)/sat/clause.cc $(OBJ_OUT)$(OBJ_DIR)$Sclause.$O + +$(OBJ_DIR)/unsat_proof.$O:$(SRC_DIR)/sat/unsat_proof.cc $(SRC_DIR)/sat/sat_base.h $(SRC_DIR)/sat/unsat_proof.h + $(CCC) $(CFLAGS) -c $(SRC_DIR)/sat/unsat_proof.cc $(OBJ_OUT)$(OBJ_DIR)$Sunsat_proof.$O + +$(GEN_DIR)/sat/sat_parameters.pb.cc:$(SRC_DIR)/sat/sat_parameters.proto $(PROTOBUF_DIR)/bin/protoc --proto_path=$(INC_DIR) --cpp_out=$(GEN_DIR) $(SRC_DIR)/sat/sat_parameters.proto -$(GEN_DIR)/sat/sat_parameters.pb.h: $(GEN_DIR)/sat/sat_parameters.pb.cc +$(GEN_DIR)/sat/sat_parameters.pb.h:$(GEN_DIR)/sat/sat_parameters.pb.cc -$(OBJ_DIR)/sat_parameters.pb.$O: $(GEN_DIR)/sat/sat_parameters.pb.cc $(GEN_DIR)/sat/sat_parameters.pb.h +$(OBJ_DIR)/sat_parameters.pb.$O:$(GEN_DIR)/sat/sat_parameters.pb.cc $(GEN_DIR)/sat/sat_parameters.pb.h $(CCC) $(CFLAGS) -c $(GEN_DIR)/sat/sat_parameters.pb.cc $(OBJ_OUT)$(OBJ_DIR)$Ssat_parameters.pb.$O $(LIB_DIR)/$(LIBPREFIX)sat.$(DYNAMIC_LIB_SUFFIX): $(SAT_LIB_OBJS) diff --git a/src/constraint_solver/assignment.cc b/src/constraint_solver/assignment.cc index 3364f812e3..7087b8082e 100644 --- a/src/constraint_solver/assignment.cc +++ b/src/constraint_solver/assignment.cc @@ -300,18 +300,16 @@ void SequenceVarElement::Restore() { void SequenceVarElement::LoadFromProto( const SequenceVarAssignmentProto& sequence_var_assignment_proto) { - for (int i = 0; i < sequence_var_assignment_proto.forward_sequence_size(); - ++i) { - forward_sequence_.push_back( - sequence_var_assignment_proto.forward_sequence(i)); + for (const int32 forward_sequence : + sequence_var_assignment_proto.forward_sequence()) { + forward_sequence_.push_back(forward_sequence); } - for (int i = 0; i < sequence_var_assignment_proto.backward_sequence_size(); - ++i) { - backward_sequence_.push_back( - sequence_var_assignment_proto.backward_sequence(i)); + for (const int32 backward_sequence : + sequence_var_assignment_proto.backward_sequence()) { + backward_sequence_.push_back(backward_sequence); } - for (int i = 0; i < sequence_var_assignment_proto.unperformed_size(); ++i) { - unperformed_.push_back(sequence_var_assignment_proto.unperformed(i)); + for (const int32 unperformed : sequence_var_assignment_proto.unperformed()) { + unperformed_.push_back(unperformed); } if (sequence_var_assignment_proto.active()) { Activate(); @@ -325,14 +323,14 @@ void SequenceVarElement::WriteToProto( SequenceVarAssignmentProto* sequence_var_assignment_proto) const { sequence_var_assignment_proto->set_var_id(var_->name()); sequence_var_assignment_proto->set_active(Activated()); - for (int i = 0; i < forward_sequence_.size(); ++i) { - sequence_var_assignment_proto->add_forward_sequence(forward_sequence_[i]); + for (const int forward_sequence : forward_sequence_) { + sequence_var_assignment_proto->add_forward_sequence(forward_sequence); } - for (int i = 0; i < backward_sequence_.size(); ++i) { - sequence_var_assignment_proto->add_backward_sequence(backward_sequence_[i]); + for (const int backward_sequence : backward_sequence_) { + sequence_var_assignment_proto->add_backward_sequence(backward_sequence); } - for (int i = 0; i < unperformed_.size(); ++i) { - sequence_var_assignment_proto->add_unperformed(unperformed_[i]); + for (const int unperformed : unperformed_) { + sequence_var_assignment_proto->add_unperformed(unperformed); } } @@ -401,23 +399,23 @@ void SequenceVarElement::SetUnperformed(const std::vector& unperformed) { bool SequenceVarElement::CheckClassInvariants() { hash_set visited; - for (ConstIter > it(forward_sequence_); !it.at_end(); ++it) { - if (ContainsKey(visited, *it)) { + for (const int forward_sequence : forward_sequence_) { + if (ContainsKey(visited, forward_sequence)) { return false; } - visited.insert(*it); + visited.insert(forward_sequence); } - for (ConstIter > it(backward_sequence_); !it.at_end(); ++it) { - if (ContainsKey(visited, *it)) { + for (const int backward_sequence : backward_sequence_) { + if (ContainsKey(visited, backward_sequence)) { return false; } - visited.insert(*it); + visited.insert(backward_sequence); } - for (ConstIter > it(unperformed_); !it.at_end(); ++it) { - if (ContainsKey(visited, *it)) { + for (const int unperformed : unperformed_) { + if (ContainsKey(visited, unperformed)) { return false; } - visited.insert(*it); + visited.insert(unperformed); } return true; } @@ -595,8 +593,7 @@ bool Assignment::Save(File* file) const { template void RealSave(AssignmentProto* const assignment_proto, const Container& container, Proto* (AssignmentProto::*Add)()) { - for (int i = 0; i < container.Size(); ++i) { - const Element& element = container.Element(i); + for (const Element& element : container.elements()) { const Var* const var = element.Var(); const std::string& name = var->name(); if (!name.empty()) { @@ -636,8 +633,7 @@ void Assignment::Save(AssignmentProto* const assignment_proto) const { template void RealDebugString(const Container& container, std::string* const out) { - for (int i = 0; i < container.Size(); ++i) { - const Element& element = container.Element(i); + for (const Element& element : container.elements()) { if (element.Var() != nullptr) { StringAppendF(out, "%s %s | ", element.Var()->name().c_str(), element.DebugString().c_str()); @@ -659,235 +655,236 @@ std::string Assignment::DebugString() const { return out; } -IntVarElement* Assignment::Add(IntVar* const v) { - return int_var_container_.Add(v); +IntVarElement* Assignment::Add(IntVar* const var) { + return int_var_container_.Add(var); } -void Assignment::Add(const std::vector& v) { - for (ConstIter > it(v); !it.at_end(); ++it) { - Add(*it); +void Assignment::Add(const std::vector& vars) { + for (IntVar* const var : vars) { + Add(var); } } -IntVarElement* Assignment::FastAdd(IntVar* const v) { - return int_var_container_.FastAdd(v); +IntVarElement* Assignment::FastAdd(IntVar* const var) { + return int_var_container_.FastAdd(var); } -int64 Assignment::Min(const IntVar* const v) const { - return int_var_container_.Element(v).Min(); +int64 Assignment::Min(const IntVar* const var) const { + return int_var_container_.Element(var).Min(); } -int64 Assignment::Max(const IntVar* const v) const { - return int_var_container_.Element(v).Max(); +int64 Assignment::Max(const IntVar* const var) const { + return int_var_container_.Element(var).Max(); } -int64 Assignment::Value(const IntVar* const v) const { - return int_var_container_.Element(v).Value(); +int64 Assignment::Value(const IntVar* const var) const { + return int_var_container_.Element(var).Value(); } -bool Assignment::Bound(const IntVar* const v) const { - return int_var_container_.Element(v).Bound(); +bool Assignment::Bound(const IntVar* const var) const { + return int_var_container_.Element(var).Bound(); } -void Assignment::SetMin(const IntVar* const v, int64 m) { - int_var_container_.MutableElement(v)->SetMin(m); +void Assignment::SetMin(const IntVar* const var, int64 m) { + int_var_container_.MutableElement(var)->SetMin(m); } -void Assignment::SetMax(const IntVar* const v, int64 m) { - int_var_container_.MutableElement(v)->SetMax(m); +void Assignment::SetMax(const IntVar* const var, int64 m) { + int_var_container_.MutableElement(var)->SetMax(m); } -void Assignment::SetRange(const IntVar* const v, int64 l, int64 u) { - int_var_container_.MutableElement(v)->SetRange(l, u); +void Assignment::SetRange(const IntVar* const var, int64 l, int64 u) { + int_var_container_.MutableElement(var)->SetRange(l, u); } -void Assignment::SetValue(const IntVar* const v, int64 value) { - int_var_container_.MutableElement(v)->SetValue(value); +void Assignment::SetValue(const IntVar* const var, int64 value) { + int_var_container_.MutableElement(var)->SetValue(value); } // ----- Interval Var ----- -IntervalVarElement* Assignment::Add(IntervalVar* const v) { - return interval_var_container_.Add(v); +IntervalVarElement* Assignment::Add(IntervalVar* const var) { + return interval_var_container_.Add(var); } void Assignment::Add(const std::vector& vars) { - for (ConstIter > it(vars); !it.at_end(); ++it) { - Add(*it); + for (IntervalVar* const var : vars) { + Add(var); } } -IntervalVarElement* Assignment::FastAdd(IntervalVar* const v) { - return interval_var_container_.FastAdd(v); +IntervalVarElement* Assignment::FastAdd(IntervalVar* const var) { + return interval_var_container_.FastAdd(var); } -int64 Assignment::StartMin(const IntervalVar* const v) const { - return interval_var_container_.Element(v).StartMin(); +int64 Assignment::StartMin(const IntervalVar* const var) const { + return interval_var_container_.Element(var).StartMin(); } -int64 Assignment::StartMax(const IntervalVar* const v) const { - return interval_var_container_.Element(v).StartMax(); +int64 Assignment::StartMax(const IntervalVar* const var) const { + return interval_var_container_.Element(var).StartMax(); } -int64 Assignment::StartValue(const IntervalVar* const v) const { - return interval_var_container_.Element(v).StartValue(); +int64 Assignment::StartValue(const IntervalVar* const var) const { + return interval_var_container_.Element(var).StartValue(); } -int64 Assignment::DurationMin(const IntervalVar* const v) const { - return interval_var_container_.Element(v).DurationMin(); +int64 Assignment::DurationMin(const IntervalVar* const var) const { + return interval_var_container_.Element(var).DurationMin(); } -int64 Assignment::DurationMax(const IntervalVar* const v) const { - return interval_var_container_.Element(v).DurationMax(); +int64 Assignment::DurationMax(const IntervalVar* const var) const { + return interval_var_container_.Element(var).DurationMax(); } -int64 Assignment::DurationValue(const IntervalVar* const v) const { - return interval_var_container_.Element(v).DurationValue(); +int64 Assignment::DurationValue(const IntervalVar* const var) const { + return interval_var_container_.Element(var).DurationValue(); } -int64 Assignment::EndMin(const IntervalVar* const v) const { - return interval_var_container_.Element(v).EndMin(); +int64 Assignment::EndMin(const IntervalVar* const var) const { + return interval_var_container_.Element(var).EndMin(); } -int64 Assignment::EndMax(const IntervalVar* const v) const { - return interval_var_container_.Element(v).EndMax(); +int64 Assignment::EndMax(const IntervalVar* const var) const { + return interval_var_container_.Element(var).EndMax(); } -int64 Assignment::EndValue(const IntervalVar* const v) const { - return interval_var_container_.Element(v).EndValue(); +int64 Assignment::EndValue(const IntervalVar* const var) const { + return interval_var_container_.Element(var).EndValue(); } -int64 Assignment::PerformedMin(const IntervalVar* const v) const { - return interval_var_container_.Element(v).PerformedMin(); +int64 Assignment::PerformedMin(const IntervalVar* const var) const { + return interval_var_container_.Element(var).PerformedMin(); } -int64 Assignment::PerformedMax(const IntervalVar* const v) const { - return interval_var_container_.Element(v).PerformedMax(); +int64 Assignment::PerformedMax(const IntervalVar* const var) const { + return interval_var_container_.Element(var).PerformedMax(); } -int64 Assignment::PerformedValue(const IntervalVar* const v) const { - return interval_var_container_.Element(v).PerformedValue(); +int64 Assignment::PerformedValue(const IntervalVar* const var) const { + return interval_var_container_.Element(var).PerformedValue(); } -void Assignment::SetStartMin(const IntervalVar* const v, int64 m) { - interval_var_container_.MutableElement(v)->SetStartMin(m); +void Assignment::SetStartMin(const IntervalVar* const var, int64 m) { + interval_var_container_.MutableElement(var)->SetStartMin(m); } -void Assignment::SetStartMax(const IntervalVar* const v, int64 m) { - interval_var_container_.MutableElement(v)->SetStartMax(m); +void Assignment::SetStartMax(const IntervalVar* const var, int64 m) { + interval_var_container_.MutableElement(var)->SetStartMax(m); } -void Assignment::SetStartRange(const IntervalVar* const v, int64 mi, int64 ma) { - interval_var_container_.MutableElement(v)->SetStartRange(mi, ma); +void Assignment::SetStartRange(const IntervalVar* const var, int64 mi, + int64 ma) { + interval_var_container_.MutableElement(var)->SetStartRange(mi, ma); } -void Assignment::SetStartValue(const IntervalVar* const v, int64 value) { - interval_var_container_.MutableElement(v)->SetStartValue(value); +void Assignment::SetStartValue(const IntervalVar* const var, int64 value) { + interval_var_container_.MutableElement(var)->SetStartValue(value); } -void Assignment::SetDurationMin(const IntervalVar* const v, int64 m) { - interval_var_container_.MutableElement(v)->SetDurationMin(m); +void Assignment::SetDurationMin(const IntervalVar* const var, int64 m) { + interval_var_container_.MutableElement(var)->SetDurationMin(m); } -void Assignment::SetDurationMax(const IntervalVar* const v, int64 m) { - interval_var_container_.MutableElement(v)->SetDurationMax(m); +void Assignment::SetDurationMax(const IntervalVar* const var, int64 m) { + interval_var_container_.MutableElement(var)->SetDurationMax(m); } -void Assignment::SetDurationRange(const IntervalVar* const v, int64 mi, +void Assignment::SetDurationRange(const IntervalVar* const var, int64 mi, int64 ma) { - interval_var_container_.MutableElement(v)->SetDurationRange(mi, ma); + interval_var_container_.MutableElement(var)->SetDurationRange(mi, ma); } -void Assignment::SetDurationValue(const IntervalVar* const v, int64 value) { - interval_var_container_.MutableElement(v)->SetDurationValue(value); +void Assignment::SetDurationValue(const IntervalVar* const var, int64 value) { + interval_var_container_.MutableElement(var)->SetDurationValue(value); } -void Assignment::SetEndMin(const IntervalVar* const v, int64 m) { - interval_var_container_.MutableElement(v)->SetEndMin(m); +void Assignment::SetEndMin(const IntervalVar* const var, int64 m) { + interval_var_container_.MutableElement(var)->SetEndMin(m); } -void Assignment::SetEndMax(const IntervalVar* const v, int64 m) { - interval_var_container_.MutableElement(v)->SetEndMax(m); +void Assignment::SetEndMax(const IntervalVar* const var, int64 m) { + interval_var_container_.MutableElement(var)->SetEndMax(m); } -void Assignment::SetEndRange(const IntervalVar* const v, int64 mi, int64 ma) { - interval_var_container_.MutableElement(v)->SetEndRange(mi, ma); +void Assignment::SetEndRange(const IntervalVar* const var, int64 mi, int64 ma) { + interval_var_container_.MutableElement(var)->SetEndRange(mi, ma); } -void Assignment::SetEndValue(const IntervalVar* const v, int64 value) { - interval_var_container_.MutableElement(v)->SetEndValue(value); +void Assignment::SetEndValue(const IntervalVar* const var, int64 value) { + interval_var_container_.MutableElement(var)->SetEndValue(value); } -void Assignment::SetPerformedMin(const IntervalVar* const v, int64 m) { - interval_var_container_.MutableElement(v)->SetPerformedMin(m); +void Assignment::SetPerformedMin(const IntervalVar* const var, int64 m) { + interval_var_container_.MutableElement(var)->SetPerformedMin(m); } -void Assignment::SetPerformedMax(const IntervalVar* const v, int64 m) { - interval_var_container_.MutableElement(v)->SetPerformedMax(m); +void Assignment::SetPerformedMax(const IntervalVar* const var, int64 m) { + interval_var_container_.MutableElement(var)->SetPerformedMax(m); } -void Assignment::SetPerformedRange(const IntervalVar* const v, int64 mi, +void Assignment::SetPerformedRange(const IntervalVar* const var, int64 mi, int64 ma) { - interval_var_container_.MutableElement(v)->SetPerformedRange(mi, ma); + interval_var_container_.MutableElement(var)->SetPerformedRange(mi, ma); } -void Assignment::SetPerformedValue(const IntervalVar* const v, int64 value) { - interval_var_container_.MutableElement(v)->SetPerformedValue(value); +void Assignment::SetPerformedValue(const IntervalVar* const var, int64 value) { + interval_var_container_.MutableElement(var)->SetPerformedValue(value); } // ----- Sequence Var ----- -SequenceVarElement* Assignment::Add(SequenceVar* const v) { - return sequence_var_container_.Add(v); +SequenceVarElement* Assignment::Add(SequenceVar* const var) { + return sequence_var_container_.Add(var); } void Assignment::Add(const std::vector& vars) { - for (ConstIter > it(vars); !it.at_end(); ++it) { - Add(*it); + for (SequenceVar* const var : vars) { + Add(var); } } -SequenceVarElement* Assignment::FastAdd(SequenceVar* const v) { - return sequence_var_container_.FastAdd(v); +SequenceVarElement* Assignment::FastAdd(SequenceVar* const var) { + return sequence_var_container_.FastAdd(var); } -const std::vector& Assignment::ForwardSequence(const SequenceVar* const v) +const std::vector& Assignment::ForwardSequence(const SequenceVar* const var) const { - return sequence_var_container_.Element(v).ForwardSequence(); + return sequence_var_container_.Element(var).ForwardSequence(); } -const std::vector& Assignment::BackwardSequence(const SequenceVar* const v) +const std::vector& Assignment::BackwardSequence(const SequenceVar* const var) const { - return sequence_var_container_.Element(v).BackwardSequence(); + return sequence_var_container_.Element(var).BackwardSequence(); } -const std::vector& Assignment::Unperformed(const SequenceVar* const v) const { - return sequence_var_container_.Element(v).Unperformed(); +const std::vector& Assignment::Unperformed(const SequenceVar* const var) const { + return sequence_var_container_.Element(var).Unperformed(); } -void Assignment::SetSequence(const SequenceVar* const v, +void Assignment::SetSequence(const SequenceVar* const var, const std::vector& forward_sequence, const std::vector& backward_sequence, const std::vector& unperformed) { - sequence_var_container_.MutableElement(v) + sequence_var_container_.MutableElement(var) ->SetSequence(forward_sequence, backward_sequence, unperformed); } -void Assignment::SetForwardSequence(const SequenceVar* const v, +void Assignment::SetForwardSequence(const SequenceVar* const var, const std::vector& forward_sequence) { - sequence_var_container_.MutableElement(v) + sequence_var_container_.MutableElement(var) ->SetForwardSequence(forward_sequence); } -void Assignment::SetBackwardSequence(const SequenceVar* const v, +void Assignment::SetBackwardSequence(const SequenceVar* const var, const std::vector& backward_sequence) { - sequence_var_container_.MutableElement(v) + sequence_var_container_.MutableElement(var) ->SetBackwardSequence(backward_sequence); } -void Assignment::SetUnperformed(const SequenceVar* const v, +void Assignment::SetUnperformed(const SequenceVar* const var, const std::vector& unperformed) { - sequence_var_container_.MutableElement(v)->SetUnperformed(unperformed); + sequence_var_container_.MutableElement(var)->SetUnperformed(unperformed); } // ----- Objective ----- diff --git a/src/constraint_solver/constraint_solver.h b/src/constraint_solver/constraint_solver.h index 41d42402f7..479a41c2dd 100644 --- a/src/constraint_solver/constraint_solver.h +++ b/src/constraint_solver/constraint_solver.h @@ -172,15 +172,9 @@ struct SolverParameters { NO_COMPRESSION, COMPRESS_WITH_ZLIB }; - enum ProfileLevel { - NO_PROFILING, - NORMAL_PROFILING - }; + enum ProfileLevel { NO_PROFILING, NORMAL_PROFILING }; - enum TraceLevel { - NO_TRACE, - NORMAL_TRACE - }; + enum TraceLevel { NO_TRACE, NORMAL_TRACE }; static const TrailCompression kDefaultTrailCompression; static const int kDefaultTrailBlockSize; @@ -232,16 +226,9 @@ struct DefaultPhaseParameters { CHOOSE_MAX_VALUE_IMPACT = 2, }; - enum ValueSelection { - SELECT_MIN_IMPACT = 0, - SELECT_MAX_IMPACT = 1, - }; + enum ValueSelection { SELECT_MIN_IMPACT = 0, SELECT_MAX_IMPACT = 1, }; - enum DisplayLevel { - NONE = 0, - NORMAL = 1, - VERBOSE = 2 - }; + enum DisplayLevel { NONE = 0, NORMAL = 1, VERBOSE = 2 }; static const int kDefaultNumberOfSplits; static const int kDefaultHeuristicPeriod; @@ -843,12 +830,7 @@ class Solver { // This enum is used internally in private methods Solver::PushState and // Solver::PopState to tag states in the search tree. - enum MarkerType { - SENTINEL, - SIMPLE_MARKER, - CHOICE_POINT, - REVERSIBLE_ACTION - }; + enum MarkerType { SENTINEL, SIMPLE_MARKER, CHOICE_POINT, REVERSIBLE_ACTION }; // This enum represents the state of the solver w.r.t. the search. enum SolverState { @@ -3008,10 +2990,7 @@ class Solver { DemonProfiler* const demon_profiler_; // interval of constants cached, inclusive: - enum { - MIN_CACHED_INT_CONST = -8, - MAX_CACHED_INT_CONST = 8 - }; + enum { MIN_CACHED_INT_CONST = -8, MAX_CACHED_INT_CONST = 8 }; IntVar* cached_constants_[MAX_CACHED_INT_CONST + 1 - MIN_CACHED_INT_CONST]; // Cached constraints. @@ -4648,15 +4627,14 @@ class AssignmentContainer { const E& Element(int index) const { return elements_[index]; } int Size() const { return elements_.size(); } void Store() { - for (int i = 0; i < elements_.size(); ++i) { - elements_[i].Store(); + for (E& element : elements_) { + element.Store(); } } void Restore() { - for (int i = 0; i < elements_.size(); ++i) { - E* element = &elements_[i]; - if (element->Activated()) { - element->Restore(); + for (E& element : elements_) { + if (element.Activated()) { + element.Restore(); } } } @@ -4674,10 +4652,9 @@ class AssignmentContainer { // Do not use the hash_map::== operator! It does not just compare content, // but also how the map is hashed (e.g., number of buckets). This is not // what we want. - typedef ConstIter > Iterator; - for (Iterator it(container.elements_); !it.at_end(); ++it) { - const int position = FindWithDefault(elements_map_, it->Var(), -1); - if (position < 0 || elements_[position] != *it) { + for (const E& element : container.elements_) { + const int position = FindWithDefault(elements_map_, element.Var(), -1); + if (position < 0 || elements_[position] != element) { return false; } } diff --git a/src/constraint_solver/routing.cc b/src/constraint_solver/routing.cc index dadade4963..cedf5fbf3d 100644 --- a/src/constraint_solver/routing.cc +++ b/src/constraint_solver/routing.cc @@ -202,7 +202,7 @@ class LightFunctionElementConstraint : public Constraint { IntVar* const var_; IntVar* const index_; - std::unique_ptr > values_; + std::unique_ptr> values_; }; Constraint* MakeLightElement(Solver* const solver, IntVar* const var, @@ -259,7 +259,7 @@ class LightFunctionElement2Constraint : public Constraint { IntVar* const var_; IntVar* const index1_; IntVar* const index2_; - std::unique_ptr > values_; + std::unique_ptr> values_; }; Constraint* MakeLightElement2( @@ -625,9 +625,9 @@ class RoutingCache { } private: - ITIVector > + ITIVector> cached_; - ITIVector > + ITIVector> cache_; std::unique_ptr callback_; }; @@ -750,9 +750,8 @@ RoutingModel::RoutingModel(int nodes, int vehicles) Initialize(); } -RoutingModel::RoutingModel( - int nodes, int vehicles, - const std::vector >& start_ends) +RoutingModel::RoutingModel(int nodes, int vehicles, + const std::vector>& start_ends) : nodes_(nodes), vehicles_(vehicles), no_cycle_constraint_(nullptr), @@ -830,7 +829,7 @@ RoutingModel::RoutingModel(int nodes, int vehicles, CHECK_EQ(vehicles, starts.size()); CHECK_EQ(vehicles, ends.size()); hash_set depot_set; - std::vector > start_ends(starts.size()); + std::vector> start_ends(starts.size()); for (int i = 0; i < starts.size(); ++i) { depot_set.insert(starts[i]); depot_set.insert(ends[i]); @@ -886,7 +885,16 @@ void RoutingModel::AddNoCycleConstraintInternal() { bool RoutingModel::AddDimension(NodeEvaluator2* evaluator, int64 slack_max, int64 capacity, bool fix_start_cumul_to_zero, const std::string& dimension_name) { - return AddDimensionWithCapacityInternal(evaluator, slack_max, capacity, + const std::vector evaluators(vehicles_, evaluator); + return AddDimensionWithCapacityInternal(evaluators, slack_max, capacity, + nullptr, fix_start_cumul_to_zero, + dimension_name); +} + +bool RoutingModel::AddDimensionWithVehicleTransits( + const std::vector& evaluators, int64 slack_max, int64 capacity, + bool fix_start_cumul_to_zero, const std::string& dimension_name) { + return AddDimensionWithCapacityInternal(evaluators, slack_max, capacity, nullptr, fix_start_cumul_to_zero, dimension_name); } @@ -895,13 +903,23 @@ bool RoutingModel::AddDimensionWithVehicleCapacity( NodeEvaluator2* evaluator, int64 slack_max, VehicleEvaluator* vehicle_capacity, bool fix_start_cumul_to_zero, const std::string& dimension_name) { + const std::vector evaluators(vehicles_, evaluator); return AddDimensionWithCapacityInternal( - evaluator, slack_max, kint64max, vehicle_capacity, + evaluators, slack_max, kint64max, vehicle_capacity, + fix_start_cumul_to_zero, dimension_name); +} + +bool RoutingModel::AddDimensionWithVehicleTransitAndCapacity( + const std::vector& evaluators, int64 slack_max, + VehicleEvaluator* vehicle_capacity, bool fix_start_cumul_to_zero, + const std::string& dimension_name) { + return AddDimensionWithCapacityInternal( + evaluators, slack_max, kint64max, vehicle_capacity, fix_start_cumul_to_zero, dimension_name); } bool RoutingModel::AddDimensionWithCapacityInternal( - NodeEvaluator2* evaluator, int64 slack_max, int64 capacity, + const std::vector& evaluators, int64 slack_max, int64 capacity, VehicleEvaluator* vehicle_capacity, bool fix_start_cumul_to_zero, const std::string& dimension_name) { CheckDepot(); @@ -910,8 +928,20 @@ bool RoutingModel::AddDimensionWithCapacityInternal( dimension_name_to_index_[dimension_name] = dimension_index; dimensions_.push_back(new RoutingDimension(this, dimension_name)); RoutingDimension* const dimension = dimensions_[dimension_index]; - dimension->Initialize(vehicle_capacity, capacity, - NewCachedCallback(evaluator), slack_max); + std::vector cached_evaluators; + hash_map evaluator_to_cached; + for (NodeEvaluator2* const evaluator : evaluators) { + CHECK(evaluator != nullptr); + NodeEvaluator2* cached_evaluator = + FindPtrOrNull(evaluator_to_cached, evaluator); + if (cached_evaluator == nullptr) { + cached_evaluator = NewCachedCallback(evaluator); + evaluator_to_cached[evaluator] = cached_evaluator; + } + cached_evaluators.push_back(cached_evaluator); + } + dimension->Initialize(vehicle_capacity, capacity, cached_evaluators, + slack_max); solver_->AddConstraint(solver_->MakePathCumul( nexts_, active_, dimension->cumuls(), dimension->transits())); if (fix_start_cumul_to_zero) { @@ -923,7 +953,11 @@ bool RoutingModel::AddDimensionWithCapacityInternal( } return true; } else { - delete evaluator; + hash_set evaluator_set(evaluators.begin(), + evaluators.end()); + for (NodeEvaluator2* const evaluator : evaluator_set) { + delete evaluator; + } delete vehicle_capacity; return false; } @@ -1095,7 +1129,7 @@ void RoutingModel::ComputeCostClasses() { const int64 coeff = dimension->vehicle_span_cost_coefficients()[vehicle]; if (coeff == 0) continue; cost_class.dimension_transit_evaluator_and_cost_coefficient.push_back( - std::make_pair(dimension->transit_evaluator(), coeff)); + std::make_pair(dimension->transit_evaluator(vehicle), coeff)); } std::sort(cost_class.dimension_transit_evaluator_and_cost_coefficient.begin(), cost_class.dimension_transit_evaluator_and_cost_coefficient.end()); @@ -1183,13 +1217,13 @@ void RoutingModel::AddLocalSearchOperator(LocalSearchOperator* ls_operator) { int64 RoutingModel::GetDepot() const { return vehicles() > 0 ? Start(0) : -1; } void RoutingModel::SetDepot(NodeIndex depot) { - std::vector > start_end(vehicles_, - std::make_pair(depot, depot)); + std::vector> start_end(vehicles_, + std::make_pair(depot, depot)); SetStartEnd(start_end); } void RoutingModel::SetStartEnd( - const std::vector >& start_ends) { + const std::vector>& start_ends) { if (is_depot_set_) { LOG(WARNING) << "A depot has already been specified, ignoring new ones"; return; @@ -1472,6 +1506,7 @@ struct VehicleClass { // constraints by iterating on a list of arcs appearing in descending order // of priority. // TODO(user): Use the dimension class in this class. +// TODO(user): Add support for vehicle-dependent dimension transits. class RouteConstructor { public: RouteConstructor(Assignment* const assignment, RoutingModel* const model, @@ -1530,17 +1565,19 @@ class RouteConstructor { if (node_to_vehicle_class_index_[node1] < 0) { for (int dimension_index = 0; dimension_index < dimensions_.size(); ++dimension_index) { - cumuls_[dimension_index][node1] = std::max( - dimensions_[dimension_index]->GetTransitValue(start_depot, node1), - dimensions_[dimension_index]->CumulVar(node1)->Min()); + cumuls_[dimension_index][node1] = + std::max(dimensions_[dimension_index]->GetTransitValue(start_depot, + node1, 0), + dimensions_[dimension_index]->CumulVar(node1)->Min()); } } if (node_to_vehicle_class_index_[node2] < 0) { for (int dimension_index = 0; dimension_index < dimensions_.size(); ++dimension_index) { - cumuls_[dimension_index][node2] = std::max( - dimensions_[dimension_index]->GetTransitValue(start_depot, node2), - dimensions_[dimension_index]->CumulVar(node2)->Min()); + cumuls_[dimension_index][node2] = + std::max(dimensions_[dimension_index]->GetTransitValue(start_depot, + node2, 0), + dimensions_[dimension_index]->CumulVar(node2)->Min()); } } @@ -1627,14 +1664,10 @@ class RouteConstructor { } } - const std::vector >& final_routes() const { return final_routes_; } + const std::vector>& final_routes() const { return final_routes_; } private: - enum MergeStatus { - FIRST_SECOND, - SECOND_FIRST, - NO_MERGE - }; + enum MergeStatus { FIRST_SECOND, SECOND_FIRST, NO_MERGE }; struct RouteSort { bool operator()(const std::vector& route1, const std::vector& route2) { @@ -1665,7 +1698,7 @@ class RouteConstructor { bool FeasibleRoute(const std::vector& route, int64 route_cumul, int dimension_index) { const RoutingDimension& dimension = *dimensions_[dimension_index]; - ConstIter > it(route); + ConstIter> it(route); int64 cumul = route_cumul; while (!it.at_end()) { const int previous = *it; @@ -1678,7 +1711,7 @@ class RouteConstructor { } const int next = *it; int64 available_from_previous = - cumul_previous + dimension.GetTransitValue(previous, next); + cumul_previous + dimension.GetTransitValue(previous, next, 0); int64 available_cumul_next = std::max(cumuls_[dimension_index][next], available_from_previous); @@ -1719,7 +1752,7 @@ class RouteConstructor { dimension.CumulVar(non_depot_node)->Max()); int64 available_from_tail1 = cumuls_[dimension_index][tail1] + - dimension.GetTransitValue(tail1, head2); + dimension.GetTransitValue(tail1, head2, 0); int64 new_available_cumul_head2 = std::max(cumuls_[dimension_index][head2], available_from_tail1); @@ -1745,7 +1778,7 @@ class RouteConstructor { : cumuls_[dimension_index][tail2]; if (!feasible_route || (new_possible_cumul_tail2 + - dimension.GetTransitValue(tail2, end_depot) > + dimension.GetTransitValue(tail2, end_depot, 0) > depot_threashold)) { return false; } @@ -1920,12 +1953,12 @@ class RouteConstructor { const std::vector vehicle_classes_; std::vector nexts_; std::vector dimensions_; // Not owned. - std::vector > cumuls_; - std::vector > new_possible_cumuls_; - std::vector > routes_; + std::vector> cumuls_; + std::vector> new_possible_cumuls_; + std::vector> routes_; std::vector in_route_; hash_set deleted_routes_; - std::vector > final_routes_; + std::vector> final_routes_; std::vector chains_; hash_set deleted_chains_; std::vector final_chains_; @@ -2060,8 +2093,8 @@ class SavingsBuilder : public DecisionBuilder { std::vector dimensions_; int64 nodes_number_; int depot_; - std::vector > costs_; - std::vector > neighbors_; + std::vector> costs_; + std::vector> neighbors_; std::vector savings_list_; double route_shape_parameter_; std::vector vehicle_classes_; @@ -2092,7 +2125,7 @@ struct SweepNodeSortDistance { } SweepNodeDistanceComparator; SweepArranger::SweepArranger( - const ITIVector >& points) + const ITIVector>& points) : coordinates_(2 * points.size(), 0), sectors_(1) { for (RoutingModel::NodeIndex i(0); i < points.size(); ++i) { coordinates_[2 * i] = points[i].first; @@ -2317,7 +2350,7 @@ class FastOnePathBuilder : public DecisionBuilder { } RoutingModel* const model_; - std::unique_ptr > evaluator_; + std::unique_ptr> evaluator_; }; // Decision builder to build a solution with all nodes inactive. It does no @@ -2684,7 +2717,7 @@ IntVar* RoutingModel::ApplyLocks(const std::vector& locks) { } bool RoutingModel::ApplyLocksToAllVehicles( - const std::vector >& locks, bool close_routes) { + const std::vector>& locks, bool close_routes) { preassignment_->Clear(); return RoutesToAssignment(locks, true, close_routes, preassignment_); } @@ -2824,7 +2857,7 @@ Assignment* RoutingModel::DoRestoreAssignment() { return nullptr; } -bool RoutingModel::RoutesToAssignment(const std::vector >& routes, +bool RoutingModel::RoutesToAssignment(const std::vector>& routes, bool ignore_inactive_nodes, bool close_routes, Assignment* const assignment) const { @@ -2945,7 +2978,7 @@ bool RoutingModel::RoutesToAssignment(const std::vector > } Assignment* RoutingModel::ReadAssignmentFromRoutes( - const std::vector >& routes, bool ignore_inactive_nodes) { + const std::vector>& routes, bool ignore_inactive_nodes) { QuietCloseModel(); if (!RoutesToAssignment(routes, ignore_inactive_nodes, true, assignment_)) { return nullptr; @@ -2957,7 +2990,7 @@ Assignment* RoutingModel::ReadAssignmentFromRoutes( } void RoutingModel::AssignmentToRoutes(const Assignment& assignment, - std::vector >* const routes) + std::vector>* const routes) const { CHECK(closed_); CHECK(routes != nullptr); @@ -3917,10 +3950,12 @@ int64 RoutingModel::GetDimensionSpanCost(const std::string& name) const { : 0; } int64 RoutingModel::GetTransitValue(const std::string& dimension_name, - int64 from_index, int64 to_index) const { + int64 from_index, int64 to_index, + int64 vehicle) const { DimensionIndex dimension_index(-1); if (FindCopy(dimension_name_to_index_, dimension_name, &dimension_index)) { - return dimensions_[dimension_index]->GetTransitValue(from_index, to_index); + return dimensions_[dimension_index]->GetTransitValue(from_index, to_index, + vehicle); } else { return 0; } @@ -4011,9 +4046,10 @@ RoutingDimension::RoutingDimension(RoutingModel* model, const std::string& name) void RoutingDimension::Initialize( RoutingModel::VehicleEvaluator* vehicle_capacity, int64 capacity, - RoutingModel::NodeEvaluator2* transit_evaluator, int64 slack_max) { + const std::vector& transit_evaluators, + int64 slack_max) { InitializeCumuls(vehicle_capacity, capacity); - InitializeTransits(transit_evaluator, slack_max); + InitializeTransits(transit_evaluators, slack_max); } namespace { @@ -4145,30 +4181,88 @@ int64 WrappedEvaluator(RoutingModel* model, DCHECK(evaluator != nullptr); return evaluator->Run(model->IndexToNode(from), model->IndexToNode(to)); } + +template +int64 IthElementOrValue(const std::vector& v, int64 index) { + return index >= 0 ? v[index] : value; +} + +template +int64 IthEvaluatorValueOrValue(const std::vector* evaluators, int64 from, + int64 to, int64 eval_index) { + return eval_index >= 0 ? (*evaluators)[eval_index]->Run(from, to) : value; +} } // namespace void RoutingDimension::InitializeTransits( - RoutingModel::NodeEvaluator2* transit_evaluator, int64 slack_max) { - CHECK(transit_evaluator != nullptr); - transit_evaluator->CheckIsRepeatable(); + const std::vector& transit_evaluators, + int64 slack_max) { + CHECK_EQ(model_->vehicles(), transit_evaluators.size()); + for (const RoutingModel::NodeEvaluator2* const evaluator : + transit_evaluators) { + CHECK(evaluator != nullptr); + evaluator->CheckIsRepeatable(); + } Solver* const solver = model_->solver(); const int size = model_->Size(); transits_.resize(size); slacks_.resize(size); + // Compute transit classes + class_evaluators_.clear(); + transit_evaluators_.clear(); + hash_map evaluator_to_class; + std::vector vehicle_to_class(transit_evaluators.size(), -1); + for (int i = 0; i < transit_evaluators.size(); ++i) { + RoutingModel::NodeEvaluator2* const evaluator = transit_evaluators[i]; + int evaluator_class = -1; + if (!FindCopy(evaluator_to_class, evaluator, &evaluator_class)) { + evaluator_class = class_evaluators_.size(); + evaluator_to_class[evaluator] = evaluator_class; + class_evaluators_.emplace_back( + NewPermanentCallback(&WrappedEvaluator, model_, evaluator)); + } + vehicle_to_class[i] = evaluator_class; + transit_evaluators_.push_back(class_evaluators_[evaluator_class].get()); + } + CHECK(!class_evaluators_.empty()); for (int i = 0; i < size; ++i) { IntVar* fixed_transit = nullptr; if (FLAGS_routing_use_light_propagation) { - fixed_transit = solver->MakeIntVar(kint64min, kint64max); - solver->AddConstraint(MakeLightElement( - solver, fixed_transit, model_->NextVar(i), - NewPermanentCallback(&WrappedEvaluator, model_, transit_evaluator, - static_cast(i)))); + if (class_evaluators_.size() == 1) { + fixed_transit = solver->MakeIntVar(kint64min, kint64max); + solver->AddConstraint( + MakeLightElement(solver, fixed_transit, model_->NextVar(i), + NewPermanentCallback(&WrappedEvaluator, model_, + transit_evaluators[0], + static_cast(i)))); + } else { + fixed_transit = solver->MakeIntVar(kint64min, kint64max); + solver->AddConstraint(MakeLightElement2( + solver, fixed_transit, model_->NextVar(i), model_->VehicleVar(i), + NewPermanentCallback( + &IthEvaluatorValueOrValue, + &transit_evaluators_, static_cast(i)))); + } } else { - fixed_transit = - solver->MakeElement(NewPermanentCallback(&WrappedEvaluator, model_, - transit_evaluator, - static_cast(i)), - model_->NextVar(i))->Var(); + if (class_evaluators_.size() == 1) { + fixed_transit = + solver->MakeElement(NewPermanentCallback(&WrappedEvaluator, model_, + transit_evaluators[0], + static_cast(i)), + model_->NextVar(i))->Var(); + } else { + IntVar* const vehicle_class_var = + solver->MakeElement(NewPermanentCallback(&IthElementOrValue<-1>, + vehicle_to_class), + model_->VehicleVar(i))->Var(); + fixed_transit = + solver->MakeElement( + NewPermanentCallback( + &IthEvaluatorValueOrValue< + std::unique_ptr, 0LL>, + &class_evaluators_, static_cast(i)), + model_->NextVar(i), vehicle_class_var)->Var(); + } } if (slack_max == 0) { transits_[i] = fixed_transit; @@ -4179,14 +4273,12 @@ void RoutingDimension::InitializeTransits( slacks_[i] = slack_var; } } - transit_evaluator_.reset( - NewPermanentCallback(&WrappedEvaluator, model_, transit_evaluator)); } -int64 RoutingDimension::GetTransitValue(int64 from_index, - int64 to_index) const { - DCHECK(transit_evaluator_ != nullptr); - return transit_evaluator_->Run(from_index, to_index); +int64 RoutingDimension::GetTransitValue(int64 from_index, int64 to_index, + int64 vehicle) const { + DCHECK(transit_evaluators_[vehicle] != nullptr); + return transit_evaluators_[vehicle]->Run(from_index, to_index); } void RoutingDimension::SetSpanCostCoefficientForVehicle(int64 coefficient, @@ -4371,13 +4463,6 @@ void RoutingDimension::SetupGlobalSpanCost(std::vector* cost_elements) } } -namespace { -template -int64 GetIthElementOrZero(const T& v, int64 i) { - return i < 0 ? 0 : v[i]; -} -} // namespace - void RoutingDimension::SetupSlackCosts(std::vector* cost_elements) const { if (model_->vehicles() == 0) return; // Figure out whether all vehicles have the same span cost coefficient or not. @@ -4419,7 +4504,7 @@ void RoutingDimension::SetupSlackCosts(std::vector* cost_elements) cons } else { IntVar* cost_coefficient_var = solver->MakeElement( - NewPermanentCallback(&GetIthElementOrZero >, + NewPermanentCallback(&IthElementOrValue<0LL>, vehicle_span_cost_coefficients_), model_->VehicleVar(var_index))->Var(); cost_elements->push_back( diff --git a/src/constraint_solver/routing.h b/src/constraint_solver/routing.h index 72931e6684..f835c0db19 100644 --- a/src/constraint_solver/routing.h +++ b/src/constraint_solver/routing.h @@ -480,12 +480,19 @@ class RoutingModel { // Takes ownership of the callback 'evaluator'. bool AddDimension(NodeEvaluator2* evaluator, int64 slack_max, int64 capacity, bool fix_start_cumul_to_zero, const std::string& name); + bool AddDimensionWithVehicleTransits( + const std::vector& evaluators, int64 slack_max, + int64 capacity, bool fix_start_cumul_to_zero, const std::string& name); // Takes ownership of both 'evaluator' and 'vehicle_capacity' callbacks. bool AddDimensionWithVehicleCapacity(NodeEvaluator2* evaluator, int64 slack_max, VehicleEvaluator* vehicle_capacity, bool fix_start_cumul_to_zero, const std::string& name); + bool AddDimensionWithVehicleTransitAndCapacity( + const std::vector& evaluators, int64 slack_max, + VehicleEvaluator* vehicle_capacity, bool fix_start_cumul_to_zero, + const std::string& name); // Creates a dimension where the transit variable is constrained to be // equal to 'value'; 'capacity' is the upper bound of the cumul variables. // 'name' is the name used to reference the dimension; this name is used to @@ -963,7 +970,8 @@ class RoutingModel { int64 GetDimensionTransitCost(const std::string& d) const; void SetDimensionSpanCost(const std::string& d, int64 c); int64 GetDimensionSpanCost(const std::string& d) const; - int64 GetTransitValue(const std::string& d, int64 from, int64 to) const; + int64 GetTransitValue(const std::string& d, int64 from, int64 to, + int64 vehicle) const; #ifndef SWIG const std::vector& CumulVars(const std::string& dimension_name) const; #endif @@ -1035,11 +1043,10 @@ class RoutingModel { void SetStartEnd(const std::vector >& start_end); void AddDisjunctionInternal(const std::vector& nodes, int64 penalty); void AddNoCycleConstraintInternal(); - bool AddDimensionWithCapacityInternal(NodeEvaluator2* evaluator, - int64 slack_max, int64 capacity, - VehicleEvaluator* vehicle_capacity, - bool fix_start_cumul_to_zero, - const std::string& dimension_name); + bool AddDimensionWithCapacityInternal( + const std::vector& evaluators, int64 slack_max, + int64 capacity, VehicleEvaluator* vehicle_capacity, + bool fix_start_cumul_to_zero, const std::string& dimension_name); DimensionIndex GetDimensionIndex(const std::string& dimension_name) const; void ComputeCostClasses(); int64 GetArcCostForClassInternal(int64 from_index, int64 to_index, @@ -1206,7 +1213,7 @@ class RoutingDimension { // Returns the transition value for a given pair of nodes (as var index); // this value is the one taken by the corresponding transit variable when // the 'next' variable for 'from_index' is bound to 'to_index'. - int64 GetTransitValue(int64 from_index, int64 to_index) const; + int64 GetTransitValue(int64 from_index, int64 to_index, int64 vehicle) const; // Get the cumul, transit and slack variables for the given node (given as // int64 var index). IntVar* CumulVar(int64 index) const { return cumuls_[index]; } @@ -1222,9 +1229,10 @@ class RoutingDimension { RoutingModel::VehicleEvaluator* capacity_evaluator() const { return capacity_evaluator_.get(); } - // Returns the callback evaluating the transit value between to node indices. - Solver::IndexEvaluator2* transit_evaluator() const { - return transit_evaluator_.get(); + // Returns the callback evaluating the transit value between two node indices + // for a given vehicle. + Solver::IndexEvaluator2* transit_evaluator(int vehicle) const { + return transit_evaluators_[vehicle]; } #endif // Sets a cost proportional to the dimension span on a given vehicle, @@ -1328,14 +1336,15 @@ class RoutingDimension { }; RoutingDimension(RoutingModel* model, const std::string& name); - void Initialize(RoutingModel::VehicleEvaluator* vehicle_capacity, - int64 capacity, - RoutingModel::NodeEvaluator2* transit_evaluator, - int64 slack_max); + void Initialize( + RoutingModel::VehicleEvaluator* vehicle_capacity, int64 capacity, + const std::vector& transit_evaluators, + int64 slack_max); void InitializeCumuls(RoutingModel::VehicleEvaluator* vehicle_capacity, int64 capacity); - void InitializeTransits(RoutingModel::NodeEvaluator2* transit_evaluator, - int64 slack_max); + void InitializeTransits( + const std::vector& transit_evaluators, + int64 slack_max); // Sets up the cost variables related to cumul soft upper bounds. void SetupCumulVarSoftUpperBoundCosts(std::vector* cost_elements) const; // Sets up the cost variables related to the global span and per-vehicle span @@ -1346,7 +1355,10 @@ class RoutingDimension { std::vector cumuls_; std::unique_ptr capacity_evaluator_; std::vector transits_; - std::unique_ptr transit_evaluator_; + // "transit_evaluators_" does the indexing by vehicle, while + // "class_evaluators_" does the de-duplicated ownership. + std::vector transit_evaluators_; + std::vector > class_evaluators_; std::vector slacks_; int64 global_span_cost_coefficient_; std::vector vehicle_span_cost_coefficients_; @@ -1496,9 +1508,9 @@ class CheapestInsertionFilteredDecisionBuilder // possible insertion positions of node 'node_to_insert' in the partial route // starting at node 'start' and adds them to 'valued_position', a list of // unsorted pairs of (cost, position to insert the node). - void AppendEvaluatedPositionsAfter( - int64 node_to_insert, int64 start, int64 next_after_start, - std::vector* valued_positions); + void AppendEvaluatedPositionsAfter(int64 node_to_insert, int64 start, + int64 next_after_start, + std::vector* valued_positions); std::unique_ptr > evaluator_; }; diff --git a/src/constraint_solver/routing_search.cc b/src/constraint_solver/routing_search.cc index 2e57e2cf37..f078983333 100644 --- a/src/constraint_solver/routing_search.cc +++ b/src/constraint_solver/routing_search.cc @@ -194,7 +194,7 @@ BasePathFilter::BasePathFilter(const std::vector& nexts, int next_domain_size, Callback1* objective_callback) : RoutingLocalSearchFilter(nexts, objective_callback), - node_path_starts_(next_domain_size), + node_path_starts_(next_domain_size, kUnassigned), paths_(nexts.size(), -1) {} bool BasePathFilter::Accept(const Assignment* delta, @@ -385,7 +385,7 @@ class PathCumulFilter : public BasePathFilter { const std::vector cumuls_; const std::vector slacks_; std::vector start_to_vehicle_; - Solver::IndexEvaluator2* const evaluator_; + std::vector evaluators_; int64 total_current_cumul_cost_value_; // Map between paths and path soft cumul bound costs. The paths are indexed // by the index of the start node of the path. @@ -419,7 +419,7 @@ PathCumulFilter::PathCumulFilter(const RoutingModel& routing_model, objective_callback), cumuls_(dimension.cumuls()), slacks_(dimension.slacks()), - evaluator_(dimension.transit_evaluator()), + evaluators_(routing_model.vehicles(), nullptr), total_current_cumul_cost_value_(0), current_cumul_cost_values_(), cumul_cost_delta_(0), @@ -473,6 +473,7 @@ PathCumulFilter::PathCumulFilter(const RoutingModel& routing_model, start_to_vehicle_.resize(Size(), -1); for (int i = 0; i < routing_model.vehicles(); ++i) { start_to_vehicle_[routing_model.Start(i)] = i; + evaluators_[i] = dimension.transit_evaluator(i); } } @@ -500,6 +501,8 @@ void PathCumulFilter::OnSynchronize() { // For each path, compute the minimum end cumul and store the max of these. for (int r = 0; r < NumPaths(); ++r) { int64 node = Start(r); + const int vehicle = start_to_vehicle_[Start(r)]; + Solver::IndexEvaluator2* const evaluator = evaluators_[vehicle]; // First pass: evaluating route length to reserve memory to store route // information. int number_of_route_arcs = 0; @@ -515,7 +518,7 @@ void PathCumulFilter::OnSynchronize() { int64 total_transit = 0; while (node < Size()) { const int64 next = Value(node); - const int64 transit = evaluator_->Run(node, next); + const int64 transit = evaluator->Run(node, next); total_transit += transit; const int64 transit_slack = transit + slacks_[node]->Min(); current_path_transits_.PushTransit(r, node, next, transit_slack); @@ -527,7 +530,6 @@ void PathCumulFilter::OnSynchronize() { if (FilterSlackCost()) { const int64 start = ComputePathMaxStartFromEndCumul(current_path_transits_, r, cumul); - const int vehicle = start_to_vehicle_[Start(r)]; current_cumul_cost_value += vehicle_span_cost_coefficients_[vehicle] * (cumul - start - total_transit); } @@ -574,6 +576,7 @@ bool PathCumulFilter::AcceptPath(const Assignment::IntContainer& container, const int64 capacity = capacity_evaluator_ == nullptr ? kint64max : capacity_evaluator_->Run(vehicle); + Solver::IndexEvaluator2* const evaluator = evaluators_[vehicle]; // Check that the path is feasible with regards to cumul bounds, scanning // the paths from start to end (caching path node sequences and transits // for further span cost filtering). @@ -584,7 +587,7 @@ bool PathCumulFilter::AcceptPath(const Assignment::IntContainer& container, lns_detected_ = true; return true; } - const int64 transit = evaluator_->Run(node, next); + const int64 transit = evaluator->Run(node, next); total_transit += transit; const int64 transit_slack = transit + slacks_[node]->Min(); delta_path_transits_.PushTransit(path, node, next, transit_slack); @@ -796,6 +799,10 @@ IntVarFilteredDecisionBuilder::IntVarFilteredDecisionBuilder( } Decision* IntVarFilteredDecisionBuilder::Next(Solver* solver) { + // Wiping assignment when starting a new search. + assignment_->MutableIntVarContainer()->Clear(); + assignment_->MutableIntVarContainer()->Resize(vars_.size()); + SynchronizeFilters(); SetValuesFromDomains(); if (BuildSolution()) { assignment_->Restore(); @@ -945,8 +952,7 @@ void CheapestInsertionFilteredDecisionBuilder::InsertBetween(int64 node, MakeDisjunctionNodesUnperformed(node); } -void -CheapestInsertionFilteredDecisionBuilder::AppendEvaluatedPositionsAfter( +void CheapestInsertionFilteredDecisionBuilder::AppendEvaluatedPositionsAfter( int64 node_to_insert, int64 start, int64 next_after_start, std::vector* valued_positions) { CHECK(valued_positions != nullptr); @@ -956,8 +962,8 @@ CheapestInsertionFilteredDecisionBuilder::AppendEvaluatedPositionsAfter( (insert_after == start) ? next_after_start : Value(insert_after); valued_positions->push_back( std::make_pair(evaluator_->Run(insert_after, node_to_insert) + - evaluator_->Run(node_to_insert, insert_before) - - evaluator_->Run(insert_after, insert_before), + evaluator_->Run(node_to_insert, insert_before) - + evaluator_->Run(insert_after, insert_before), insert_after)); insert_after = insert_before; } @@ -996,7 +1002,7 @@ bool GlobalCheapestInsertionFilteredDecisionBuilder::BuildSolution() { found = false; ComputeEvaluatorSortedPositionPairs(&insertion_pairs); for (const std::pair, std::pair>& insertion_pair : - insertion_pairs) { + insertion_pairs) { const int64 pickup = insertion_pair.first.second; const int64 pickup_insertion = insertion_pair.first.first; const int64 pickup_insertion_next = Value(pickup_insertion); @@ -1004,9 +1010,9 @@ bool GlobalCheapestInsertionFilteredDecisionBuilder::BuildSolution() { const int64 delivery = insertion_pair.second.second; const int64 delivery_insertion = insertion_pair.second.first; DCHECK_NE(delivery_insertion, pickup_insertion); - const int64 delivery_insertion_next = - (delivery_insertion == pickup) ? pickup_insertion_next - : Value(delivery_insertion); + const int64 delivery_insertion_next = (delivery_insertion == pickup) + ? pickup_insertion_next + : Value(delivery_insertion); InsertBetween(delivery, delivery_insertion, delivery_insertion_next); if (Commit()) { found = true; @@ -1033,9 +1039,9 @@ bool GlobalCheapestInsertionFilteredDecisionBuilder::BuildSolution() { return Commit(); } -void GlobalCheapestInsertionFilteredDecisionBuilder:: - ComputeEvaluatorSortedPositions( - std::vector* sorted_positions) { +void +GlobalCheapestInsertionFilteredDecisionBuilder::ComputeEvaluatorSortedPositions( + std::vector* sorted_positions) { CHECK(sorted_positions != nullptr); sorted_positions->clear(); std::vector> valued_insertions; @@ -1050,9 +1056,8 @@ void GlobalCheapestInsertionFilteredDecisionBuilder:: &valued_positions); } for (const std::pair& valued_position : valued_positions) { - valued_insertions.push_back(std::make_pair(valued_position.first, - std::make_pair(valued_position.second, - node))); + valued_insertions.push_back(std::make_pair( + valued_position.first, std::make_pair(valued_position.second, node))); } } SortAndExtractPairSeconds(&valued_insertions, sorted_positions); @@ -1066,7 +1071,7 @@ void GlobalCheapestInsertionFilteredDecisionBuilder:: std::vector>> valued_positions; for (const RoutingModel::NodePair node_pair : - model()->GetPickupAndDeliveryPairs()) { + model()->GetPickupAndDeliveryPairs()) { const int64 pickup = node_pair.first; const int64 delivery = node_pair.second; if (Contains(pickup) || Contains(delivery)) { @@ -1078,15 +1083,14 @@ void GlobalCheapestInsertionFilteredDecisionBuilder:: AppendEvaluatedPositionsAfter(pickup, start, Value(start), &valued_pickup_positions); for (const ValuedPosition& valued_pickup_position : - valued_pickup_positions) { + valued_pickup_positions) { const int64 pickup_position = valued_pickup_position.second; CHECK(!model()->IsEnd(pickup_position)); std::vector valued_delivery_positions; - AppendEvaluatedPositionsAfter(delivery, pickup, - Value(pickup_position), + AppendEvaluatedPositionsAfter(delivery, pickup, Value(pickup_position), &valued_delivery_positions); for (const ValuedPosition& valued_delivery_position : - valued_delivery_positions) { + valued_delivery_positions) { valued_positions.push_back(std::make_pair( valued_pickup_position.first + valued_delivery_position.first, std::make_pair(std::make_pair(pickup_position, pickup), @@ -1098,7 +1102,6 @@ void GlobalCheapestInsertionFilteredDecisionBuilder:: SortAndExtractPairSeconds(&valued_positions, sorted_positions); } - // LocalCheapestInsertionFilteredDecisionBuilder LocalCheapestInsertionFilteredDecisionBuilder:: diff --git a/src/sat/boolean_problem.cc b/src/sat/boolean_problem.cc index a380984832..08ac1c9b27 100644 --- a/src/sat/boolean_problem.cc +++ b/src/sat/boolean_problem.cc @@ -231,5 +231,17 @@ void StoreAssignment(const VariablesAssignment& assignment, } } +void ExtractSubproblem(const LinearBooleanProblem& problem, + const std::vector& constraint_indices, + LinearBooleanProblem* subproblem) { + subproblem->CopyFrom(problem); + subproblem->set_name("Subproblem of " + problem.name()); + subproblem->clear_constraints(); + for (int index : constraint_indices) { + CHECK_LT(index, problem.constraints_size()); + subproblem->add_constraints()->MergeFrom(problem.constraints(index)); + } +} + } // namespace sat } // namespace operations_research diff --git a/src/sat/boolean_problem.h b/src/sat/boolean_problem.h index 12722a5efe..cbaccd5634 100644 --- a/src/sat/boolean_problem.h +++ b/src/sat/boolean_problem.h @@ -53,6 +53,11 @@ std::string LinearBooleanProblemToCnfString(const LinearBooleanProblem& problem) void StoreAssignment(const VariablesAssignment& assignment, BooleanAssignment* output); +// Constructs a sub-problem formed by the constraints with given indices. +void ExtractSubproblem(const LinearBooleanProblem& problem, + const std::vector& constraint_indices, + LinearBooleanProblem* subproblem); + } // namespace sat } // namespace operations_research diff --git a/src/sat/clause.cc b/src/sat/clause.cc new file mode 100644 index 0000000000..b91a12e8b2 --- /dev/null +++ b/src/sat/clause.cc @@ -0,0 +1,512 @@ +// Copyright 2010-2013 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "sat/clause.h" + +#include +#include "base/unique_ptr.h" +#include +#include + +#include "base/integral_types.h" +#include "base/logging.h" +#include "base/sysinfo.h" +#include "base/join.h" +#include "util/time_limit.h" +#include "base/stl_util.h" + +namespace operations_research { +namespace sat { + +namespace { + +// Returns true if the given watcher list contains the given clause. +template +bool WatcherListContains(const std::vector& list, + const SatClause& candidate) { + for (const Watcher& watcher : list) { + if (watcher.clause == &candidate) return true; + } + return false; +} + +// A simple wrapper to simplify the erase(std::remove_if()) pattern. +template +void RemoveIf(Container c, Predicate p) { + c->erase(std::remove_if(c->begin(), c->end(), p), c->end()); +} + +// Removes dettached clauses from a watcher list. +template +bool CleanUpPredicate(const Watcher& watcher) { + return !watcher.clause->IsAttached(); +} + +} // namespace + +// ----- LiteralWatchers ----- + +LiteralWatchers::LiteralWatchers() + : is_clean_(true), + num_inspected_clauses_(0), + num_watched_clauses_(0), + stats_("LiteralWatchers") {} + +LiteralWatchers::~LiteralWatchers() { + IF_STATS_ENABLED(LOG(INFO) << stats_.StatString()); +} + +void LiteralWatchers::Resize(int num_variables) { + DCHECK(is_clean_); + watchers_on_false_.resize(num_variables << 1); + needs_cleaning_.resize(num_variables << 1, false); + statistics_.resize(num_variables); +} + +// Note that this is the only place where we add Watcher so the DCHECK +// guarantees that there are no duplicates. +void LiteralWatchers::AttachOnFalse(Literal a, Literal b, SatClause* clause) { + SCOPED_TIME_STAT(&stats_); + DCHECK(is_clean_); + DCHECK(!WatcherListContains(watchers_on_false_[a.Index()], *clause)); + watchers_on_false_[a.Index()].push_back(Watcher(clause, b)); +} + +bool LiteralWatchers::PropagateOnFalse(Literal false_literal, Trail* trail) { + SCOPED_TIME_STAT(&stats_); + DCHECK(is_clean_); + std::vector& watchers = watchers_on_false_[false_literal.Index()]; + const VariablesAssignment& assignment = trail->Assignment(); + int new_index = 0; + + // Note(user): It sounds better to inspect the list in order, this is because + // small clauses like binary or ternary clauses will often propagate and thus + // stay at the beginning of the list. + const int initial_size = watchers.size(); + for (int i = 0; i < initial_size; ++i) { + ++num_inspected_clauses_; + + // Don't even look at the clause memory if the blocking literal is true. + if (assignment.IsLiteralTrue(watchers[i].blocking_literal)) { + watchers[new_index] = watchers[i]; + ++new_index; + continue; + } + + SatClause* clause = watchers[i].clause; + if (!clause->PropagateOnFalse(false_literal, trail)) { + // Conflict: All literals of this clause are false. + memmove(&watchers[new_index], &watchers[i], + (initial_size - i) * sizeof(Watcher)); + watchers.resize(new_index + initial_size - i); + return false; + } + + // Update the watched literal if clause->FirstLiteral() changed. + // See the contract of PropagateOnFalse(). + if (clause->FirstLiteral() != false_literal) { + AttachOnFalse(clause->FirstLiteral(), clause->SecondLiteral(), clause); + } else { + watchers[new_index] = Watcher(clause, clause->SecondLiteral()); + ++new_index; + } + } + watchers.resize(new_index); + return true; +} + +bool LiteralWatchers::AttachAndPropagate(SatClause* clause, Trail* trail) { + SCOPED_TIME_STAT(&stats_); + ++num_watched_clauses_; + UpdateStatistics(*clause, /*added=*/true); + clause->SortLiterals(statistics_, parameters_); + return clause->AttachAndEnqueuePotentialUnitPropagation(trail, this); +} + +void LiteralWatchers::LazyDetach(SatClause* clause) { + SCOPED_TIME_STAT(&stats_); + --num_watched_clauses_; + UpdateStatistics(*clause, /*added=*/false); + clause->LazyDetach(); + is_clean_ = false; + needs_cleaning_[clause->FirstLiteral().Index()] = true; + needs_cleaning_[clause->SecondLiteral().Index()] = true; +} + +void LiteralWatchers::CleanUpWatchers() { + SCOPED_TIME_STAT(&stats_); + for (int i = 0; i < needs_cleaning_.size(); ++i) { + if (needs_cleaning_[LiteralIndex(i)]) { + RemoveIf(&(watchers_on_false_[LiteralIndex(i)]), + CleanUpPredicate); + needs_cleaning_[LiteralIndex(i)] = false; + } + } + is_clean_ = true; +} + +void LiteralWatchers::UpdateStatistics(const SatClause& clause, bool added) { + SCOPED_TIME_STAT(&stats_); + for (const Literal literal : clause) { + const VariableIndex var = literal.Variable(); + const int direction = added ? 1 : -1; + statistics_[var].num_appearances += direction; + statistics_[var].weighted_num_appearances += + 1.0 / clause.Size() * direction; + if (literal.IsPositive()) { + statistics_[var].num_positive_clauses += direction; + } else { + statistics_[var].num_negative_clauses += direction; + } + } +} + +// ----- BinaryImplicationGraph ----- + +void BinaryImplicationGraph::Resize(int num_variables) { + SCOPED_TIME_STAT(&stats_); + implications_.resize(num_variables << 1); +} + +void BinaryImplicationGraph::AddBinaryClause(Literal a, Literal b) { + SCOPED_TIME_STAT(&stats_); + implications_[a.Negated().Index()].push_back(b); + implications_[b.Negated().Index()].push_back(a); +} + +void BinaryImplicationGraph::AddBinaryConflict(Literal a, Literal b, + Trail* trail) { + SCOPED_TIME_STAT(&stats_); + AddBinaryClause(a, b); + if (trail->Assignment().IsLiteralFalse(a)) { + trail->EnqueueWithBinaryReason(b, a); + } else if (trail->Assignment().IsLiteralFalse(b)) { + trail->EnqueueWithBinaryReason(a, b); + } +} + +bool BinaryImplicationGraph::PropagateOnTrue(Literal true_literal, + Trail* trail) { + SCOPED_TIME_STAT(&stats_); + const VariablesAssignment& assignment = trail->Assignment(); + for (Literal literal : implications_[true_literal.Index()]) { + if (assignment.IsLiteralTrue(literal)) { + // Note(user): I tried to update the reason here if the literal was + // enqueued after the true_literal on the trail. This property is + // important for ComputeFirstUIPConflict() to work since it needs the + // trail order to be a topological order for the deduction graph. + // But the performance where not too good... + continue; + } + + ++num_propagations_; + if (assignment.IsLiteralFalse(literal)) { + // Conflict. + temporary_clause_[0] = true_literal.Negated(); + temporary_clause_[1] = literal; + trail->SetFailingClause( + ClauseRef(&temporary_clause_[0], &temporary_clause_[0] + 2)); + return false; + } else { + // Propagation. + trail->EnqueueWithBinaryReason(literal, true_literal.Negated()); + } + } + return true; +} + +void BinaryImplicationGraph::MinimizeClause(const Trail& trail, + std::vector* conflict) { + SCOPED_TIME_STAT(&stats_); + is_marked_.ClearAndResize(LiteralIndex(implications_.size())); + is_removed_.ClearAndResize(LiteralIndex(implications_.size())); + for (Literal lit : *conflict) { + is_marked_.Set(lit.Index()); + } + + // Identify and remove the redundant literals from the given conflict. + // 1/ If a -> b then a can be removed from the conflict clause. + // This is because not b -> not a. + // 2/ a -> b can only happen if level(a) <= level(b). + // 3/ Because of 2/, cycles can appear only at the same level. + // The vector is_removed_ is used to avoid removing all elements of a + // cycle. Note that this is not optimal in the sense that we may not remove + // a literal that can be removed. + // + // TODO(user): no need to explore the unique literal of the current decision + // level since it can't be removed. + int index = 0; + for (int i = 0; i < conflict->size(); ++i) { + const Literal lit = (*conflict)[i]; + const int lit_level = trail.Info(lit.Variable()).level; + bool keep_literal = true; + for (Literal implied : implications_[lit.Index()]) { + if (is_marked_[implied.Index()]) { + DCHECK_LE(lit_level, trail.Info(implied.Variable()).level); + if (lit_level == trail.Info(implied.Variable()).level && + is_removed_[implied.Index()]) + continue; + keep_literal = false; + break; + } + } + if (keep_literal) { + (*conflict)[index] = lit; + ++index; + } else { + is_removed_.Set(lit.Index()); + } + } + if (index < conflict->size()) { + ++num_minimization_; + num_literals_removed_ += conflict->size() - index; + conflict->erase(conflict->begin() + index, conflict->end()); + } +} + +void BinaryImplicationGraph::RemoveFixedVariables( + const VariablesAssignment& assigment) { + SCOPED_TIME_STAT(&stats_); + is_marked_.ClearAndResize(LiteralIndex(implications_.size())); + for (LiteralIndex i(0); i < implications_.size(); ++i) { + if (assigment.IsLiteralTrue(Literal(i))) { + // If b is true and a -> b then because not b -> not a, all the + // implications list that contains b will be marked by this process. + for (Literal lit : implications_[Literal(i).NegatedIndex()]) { + is_marked_.Set(lit.NegatedIndex()); + } + STLClearObject(&(implications_[i])); + STLClearObject(&(implications_[Literal(i).NegatedIndex()])); + } + } + for (LiteralIndex i(0); i < implications_.size(); ++i) { + if (is_marked_[i]) { + RemoveIf(&implications_[i], + std::bind1st(std::mem_fun(&VariablesAssignment::IsLiteralTrue), + &assigment)); + } + } +} + +// ----- SatClause ----- + +// static +SatClause* SatClause::Create(const std::vector& literals, ClauseType type, + ResolutionNode* node) { + CHECK_GE(literals.size(), 2); + SatClause* clause = reinterpret_cast( + ::operator new(sizeof(SatClause) + literals.size() * sizeof(Literal))); + clause->size_ = literals.size(); + for (int i = 0; i < literals.size(); ++i) { + clause->literals_[i] = literals[i]; + } + clause->is_learned_ = (type == LEARNED_CLAUSE); + clause->is_attached_ = false; + clause->activity_ = 0.0; + clause->lbd_ = 0; + clause->resolution_node_ = node; + return clause; +} + +// Note that for an attached clause, removing fixed literal is okay because if +// any of them is assigned, then the clause is necessary true. +bool SatClause::RemoveFixedLiteralsAndTestIfTrue( + const VariablesAssignment& assignment, std::vector* removed_literals) { + removed_literals->clear(); + DCHECK(is_attached_); + if (assignment.IsVariableAssigned(literals_[0].Variable()) || + assignment.IsVariableAssigned(literals_[1].Variable())) { + DCHECK(IsSatisfied(assignment)); + return true; + } + int j = 2; + for (int i = 2; i < size_; ++i) { + if (assignment.IsVariableAssigned(literals_[i].Variable())) { + if (assignment.IsLiteralTrue(literals_[i])) return true; + removed_literals->push_back(literals_[i]); + } else { + literals_[j] = literals_[i]; + ++j; + } + } + size_ = j; + return false; +} + +namespace { + +// Support struct to sort literals for ordering. +struct WeightedLiteral { + WeightedLiteral(Literal l, int w) : literal(l), weight(w) {} + + Literal literal; + int weight; +}; + +// Lexical order, by smaller weight, then by smaller literal to break ties. +bool LiteralWithSmallerWeightFirst(const WeightedLiteral& wv1, + const WeightedLiteral& wv2) { + return (wv1.weight < wv2.weight) || + (wv1.weight == wv2.weight && + wv1.literal.SignedValue() < wv2.literal.SignedValue()); +} + +// Lexical order, by larger weight, then by smaller literal to break ties. +bool LiteralWithLargerWeightFirst(const WeightedLiteral& wv1, + const WeightedLiteral& wv2) { + return (wv1.weight > wv2.weight) || + (wv1.weight == wv2.weight && + wv1.literal.SignedValue() < wv2.literal.SignedValue()); +} + +} // namespace + +void SatClause::SortLiterals( + const ITIVector& statistics, + const SatParameters& parameters) { + CHECK(!IsAttached()); + const SatParameters::LiteralOrdering literal_order = + parameters.literal_ordering(); + if (literal_order != SatParameters::LITERAL_IN_ORDER) { + std::vector order; + for (Literal literal : *this) { + int weight = literal.IsPositive() + ? statistics[literal.Variable()].num_positive_clauses + : statistics[literal.Variable()].num_negative_clauses; + order.push_back(WeightedLiteral(literal, weight)); + } + switch (literal_order) { + case SatParameters::VAR_MIN_USAGE: { + std::sort(order.begin(), order.end(), LiteralWithSmallerWeightFirst); + break; + } + case SatParameters::VAR_MAX_USAGE: { + std::sort(order.begin(), order.end(), LiteralWithLargerWeightFirst); + break; + } + default: { break; } + } + for (int i = 0; i < order.size(); ++i) { + literals_[i] = order[i].literal; + } + } +} + +bool SatClause::AttachAndEnqueuePotentialUnitPropagation( + Trail* trail, LiteralWatchers* demons) { + CHECK(!IsAttached()); + // Select the first two literals that are not assigned to false and put them + // on position 0 and 1. + int num_literal_not_false = 0; + for (int i = 0; i < size_; ++i) { + if (!trail->Assignment().IsLiteralFalse(literals_[i])) { + std::swap(literals_[i], literals_[num_literal_not_false]); + ++num_literal_not_false; + if (num_literal_not_false == 2) { + break; + } + } + } + + // Returns false if all the literals were false. + // This should only happen on an UNSAT problem, and there is no need to attach + // the clause in this case. + if (num_literal_not_false == 0) return false; + + if (num_literal_not_false == 1) { + // To maintain the validity of the 2-watcher algorithm, we need to watch + // the false literal with the highest decision levels. + int max_level = trail->Info(literals_[1].Variable()).level; + for (int i = 2; i < size_; ++i) { + const int level = trail->Info(literals_[i].Variable()).level; + if (level > max_level) { + max_level = level; + std::swap(literals_[1], literals_[i]); + } + } + + // If there is a propagation, make literals_[1] the propagated literal and + // enqueue it. + if (!trail->Assignment().IsLiteralTrue(literals_[0])) { + std::swap(literals_[0], literals_[1]); + trail->EnqueueWithSatClauseReason(literals_[1], this); + } + } + + // Attach the watchers. + is_attached_ = true; + demons->AttachOnFalse(literals_[0], literals_[1], this); + demons->AttachOnFalse(literals_[1], literals_[0], this); + return true; +} + +// Propagates one watched literal becoming false. This method maintains the +// invariant that watched literals are always in position 0 and 1. +bool SatClause::PropagateOnFalse(Literal watched_literal, Trail* trail) { + const VariablesAssignment& assignment = trail->Assignment(); + DCHECK(IsAttached()); + DCHECK_GE(size_, 2); + DCHECK(assignment.IsLiteralFalse(watched_literal)); + + // The instantiated literal should be in position 0. + if (literals_[1] == watched_literal) { + literals_[1] = literals_[0]; + literals_[0] = watched_literal; + } + DCHECK_EQ(literals_[0], watched_literal); + + // If the other watched literal is true, do nothing. + if (assignment.IsLiteralTrue(literals_[1])) return true; + + for (int i = 2; i < size_; ++i) { + if (assignment.IsLiteralFalse(literals_[i])) continue; + + // Note(user): If the value of literals_[i] is true, it is possible to leave + // the watched literal unchanged. However this seems less efficient. Even if + // we swap it with the literal at position 2 to speed up future checks. + + // literal[i] is undefined or true, it's now the new literal to watch. + literals_[0] = literals_[i]; + literals_[i] = watched_literal; + return true; + } + + // Literals_[1] is either false or undefined, all other literals are false. + if (assignment.IsLiteralFalse(literals_[1])) { + trail->SetFailingSatClause(ToClauseRef(), this); + trail->SetFailingResolutionNode(resolution_node_); + return false; + } + + // Literals_[1] is undefined, set it to true. + trail->EnqueueWithSatClauseReason(literals_[1], this); + return true; +} + +bool SatClause::IsSatisfied(const VariablesAssignment& assignment) const { + for (const Literal literal : *this) { + if (assignment.IsLiteralTrue(literal)) return true; + } + return false; +} + +std::string SatClause::DebugString() const { + std::string result; + for (const Literal literal : *this) { + if (!result.empty()) result.append(" "); + result.append(literal.DebugString()); + } + return result; +} + +} // namespace sat +} // namespace operations_research diff --git a/src/sat/clause.h b/src/sat/clause.h new file mode 100644 index 0000000000..f190edbcaa --- /dev/null +++ b/src/sat/clause.h @@ -0,0 +1,357 @@ +// Copyright 2010-2013 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// This file contains the solver internal representation of the clauses and the +// classes used for their propagation. + +#ifndef OR_TOOLS_SAT_CLAUSE_H_ +#define OR_TOOLS_SAT_CLAUSE_H_ + +#include "base/unique_ptr.h" +#include +#include +#include + +#include "base/integral_types.h" +#include "base/logging.h" +#include "base/scoped_ptr.h" +#include "base/stringprintf.h" +#include "base/timer.h" +#include "base/int_type_indexed_vector.h" +#include "base/int_type.h" +#include "sat/sat_base.h" +#include "sat/sat_parameters.pb.h" +#include "util/bitset.h" +#include "util/stats.h" + +namespace operations_research { +namespace sat { + +// Forward declarations. +// TODO(user): This cyclic dependency can be relatively easily removed. +class LiteralWatchers; + +// Variable information. This is updated each time we attach/detach a clause. +struct VariableInfo { + VariableInfo() + : num_positive_clauses(0), + num_negative_clauses(0), + num_appearances(0), + weighted_num_appearances(0.0) {} + + int num_positive_clauses; + int num_negative_clauses; + int num_appearances; + double weighted_num_appearances; +}; + +// This is how the SatSolver store a clause. A clause is just a disjunction of +// literals. In many places, we just use std::vector to encode one. However, +// the solver needs to keep a few extra fields attached to each clause. +class SatClause { + public: + // Creates a sat clause. There must be at least 2 literals. + // Smaller clause are treated separatly and never constructed. + enum ClauseType { + PROBLEM_CLAUSE, + LEARNED_CLAUSE, + }; + static SatClause* Create(const std::vector& literals, ClauseType type, + ResolutionNode* node); + + // Number of literals in the clause. + int Size() const { return size_; } + + // Allows for range based iteration: for (Literal literal : clause) {}. + const Literal* const begin() const { return &(literals_[0]); } + const Literal* const end() const { return &(literals_[size_]); } + + // Returns a ClauseRef that point to this clause. + ClauseRef ToClauseRef() const { return ClauseRef(begin(), end()); } + + // Returns the first and second literals. These are always the watched + // literals if the clause is attached in the LiteralWatchers. + Literal FirstLiteral() const { return literals_[0]; } + Literal SecondLiteral() const { return literals_[1]; } + + // Removes literals that are fixed. This should only be called at level 0 + // where a literal is fixed iff it is assigned. Aborts and returns true if + // they are not all false. + bool RemoveFixedLiteralsAndTestIfTrue(const VariablesAssignment& assignment, + std::vector* removed_literals); + + // Propagates watched_literal which just became false in the clause. Returns + // false if an inconsistency was detected. + // + // IMPORTANT: If a new literal needs watching instead, then FirstLiteral() + // will be the new watched literal, otherwise it will be equal to the given + // watched_literal. + bool PropagateOnFalse(Literal watched_literal, Trail* trail); + + // True if the clause is learned. + bool IsLearned() const { return is_learned_; } + + // Returns true if the clause is satisfied for the given assignment. Note that + // the assignment may be partial, so false does not mean that the clause can't + // be satisfied by completing the assignment. + bool IsSatisfied(const VariablesAssignment& assignment) const; + + // Sorts the literals of the clause depending on the given parameters and + // statistics. Do not call this on an attached clause. + void SortLiterals(const ITIVector& statistics, + const SatParameters& parameters); + + // Sets up the 2-watchers data structure. It selects two non-false literals + // and attaches the clause to the event: one of the watched literals become + // false. It returns false if the clause only contains literals assigned to + // false. If only one literals is not false, it propagates it to true if it + // is not already assigned. + bool AttachAndEnqueuePotentialUnitPropagation(Trail* trail, + LiteralWatchers* demons); + + // Modify and get the clause activity. + void IncreaseActivity(double increase) { activity_ += increase; } + void MultiplyActivity(double factor) { activity_ *= factor; } + double Activity() const { return activity_; } + + // Set and get the clause LBD (Literal Blocks Distance). The LBD is not + // computed here. See ComputeClauseLbd() in SatSolver. + void SetLbd(int value) { lbd_ = value; } + int Lbd() const { return lbd_; } + + // Returns true if the clause is attached to a LiteralWatchers. + bool IsAttached() const { return is_attached_; } + + // Marks the clause so that the next call to CleanUpWatchers() can identify it + // and actually detach it. + void LazyDetach() { is_attached_ = false; } + + // Returns the node of the resolution DAG associated to this clause. + // This will always be nullptr if the parameter unsat_proof() is false. + ResolutionNode* ResolutionNodePointer() const { return resolution_node_; } + void ChangeResolutionNode(ResolutionNode* node) { resolution_node_ = node; } + + std::string DebugString() const; + + private: + // The data is packed so that only 16 bytes are used for these fields. + // Note that the max lbd is the maximum depth of the search tree (decision + // levels), so it should fit easily in 29 bits. Note that we can also upper + // bound it without hurting too much the clause cleaning heuristic. + bool is_learned_ : 1; + bool is_attached_ : 1; + int lbd_ : 30; + int size_ : 32; + double activity_; + + // This is only needed when the parameter unsat_proof() is true. + // TODO(user): It is possible to use less memory when this is not the case + // by some tweaks in Create() and in the way we access it. + ResolutionNode* resolution_node_; + + // This class store the literals inline, and literals_ mark the starts of the + // variable length portion. + Literal literals_[0]; + + DISALLOW_COPY_AND_ASSIGN(SatClause); +}; + +// Stores the 2-watched literals data structure. See +// http://www.cs.berkeley.edu/~necula/autded/lecture24-sat.pdf for +// detail. +class LiteralWatchers { + public: + LiteralWatchers(); + ~LiteralWatchers(); + + // Resizes the data structure. + void Resize(int num_variables); + + // Attaches the given clause. This eventually propagates a literal which is + // enqueued on the trail. Returns false if a contradiction was encountered. + bool AttachAndPropagate(SatClause* clause, Trail* trail); + + // Attaches the given clause to the event: the given literal becomes false. + // The blocking_literal can be any literal from the clause, it is used to + // speed up PropagateOnFalse() by skipping the clause if it is true. + void AttachOnFalse(Literal literal, Literal blocking_literal, + SatClause* clause); + + // Lazily detach the given clause. The deletion will actually occur when + // CleanUpWatchers() is called. The later needs to be called before any other + // function in this class can be called. This is DCHECKed. + void LazyDetach(SatClause* clause); + void CleanUpWatchers(); + + // Launches all propagation when the given literal becomes false. + // Returns false if a contradiction was encountered. + bool PropagateOnFalse(Literal false_literal, Trail* trail); + + // Total number of clauses inspected during calls to PropagateOnFalse(). + int64 num_inspected_clauses() const { return num_inspected_clauses_; } + + // Number of clauses currently watched. + int64 num_watched_clauses() const { return num_watched_clauses_; } + + // Returns some statistics on the number of appearance of this variable in + // all the attached clauses. + const VariableInfo& VariableStatistic(VariableIndex var) const { + return statistics_[var]; + } + + // Parameters management. + void SetParameters(const SatParameters& parameters) { + parameters_ = parameters; + } + + private: + // Updates statistics_ for the literals in the given clause. added indicates + // if we are adding the clause or deleting it. + void UpdateStatistics(const SatClause& clause, bool added); + + // Contains, for each literal, the list of clauses that need to be inspected + // when the corresponding literal becomes false. + struct Watcher { + Watcher() {} + Watcher(SatClause* c, Literal b) : clause(c), blocking_literal(b) {} + SatClause* clause; + Literal blocking_literal; + }; + ITIVector > watchers_on_false_; + + // Indicates if the corresponding watchers_on_false_ list need to be + // cleaned. The boolean is_clean_ is just used in DCHECKs. + ITIVector needs_cleaning_; + bool is_clean_; + + ITIVector statistics_; + SatParameters parameters_; + int64 num_inspected_clauses_; + int64 num_watched_clauses_; + mutable StatsGroup stats_; + DISALLOW_COPY_AND_ASSIGN(LiteralWatchers); +}; + +// Special class to store and propagate clauses of size 2 (i.e. implication). +// Such clauses are never deleted. +// +// TODO(user): All the variables in a strongly connected component are +// equivalent and can be thus merged as one. This is relatively cheap to compute +// from time to time (linear complexity). We will also get contradiction (a <=> +// not a) this way. +// +// TODO(user): An implication (a => not a) implies that a is false. I am not +// sure it is worth detecting that because if the solver assign a to true, it +// will learn that right away. I don't think we can do it faster. +// +// TODO(user): The implication graph can be pruned. This is called the +// transitive reduction of a graph. For instance If a => {b,c} and b => {c}, +// then there is no need to store a => {c}. The transitive reduction is unique +// on an acyclic graph. Computing it will allow for a faster propagation and +// memory reduction. It is however not cheap. Maybe simple lazy heuristics to +// remove redundant arcs are better. Note that all the learned clauses we add +// will never be redundant (but they could introduce cycles). +// +// TODO(user): Add a preprocessor to remove duplicates in the implication lists. +// Note that all the learned clauses we had will never create duplicates. +// +// References for most of the above TODO and more: +// - Brafman RI, "A simplifier for propositional formulas with many binary +// clauses", IEEE Trans Syst Man Cybern B Cybern. 2004 Feb;34(1):52-9. +// http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.28.4911 +// - Marijn J. H. Heule, Matti Järvisalo, Armin Biere, "Efficient CNF +// Simplification Based on Binary Implication Graphs", Theory and Applications +// of Satisfiability Testing - SAT 2011, Lecture Notes in Computer Science +// Volume 6695, 2011, pp 201-215 +// http://www.cs.helsinki.fi/u/mjarvisa/papers/heule-jarvisalo-biere.sat11.pdf +class BinaryImplicationGraph { + public: + BinaryImplicationGraph() + : num_propagations_(0), + num_minimization_(0), + num_literals_removed_(0), + stats_("BinaryImplicationGraph") {} + ~BinaryImplicationGraph() { + IF_STATS_ENABLED(LOG(INFO) << stats_.StatString()); + } + + // Resizes the data structure. + void Resize(int num_variables); + + // Adds the binary clause (a OR b), which is the same as (not a => b). + // Note that it is also equivalent to (not b => a). + void AddBinaryClause(Literal a, Literal b); + + // Same as AddBinaryClause() but enqueues a possible unit propagation. + void AddBinaryConflict(Literal a, Literal b, Trail* trail); + + // Propagates all the direct implications of the given literal becoming true. + // Returns false if a conflict was encountered, in which case + // trail->SetFailingClause() will be called with the correct size 2 clause. + // This calls trail->Enqueue() on the newly assigned literals. + bool PropagateOnTrue(Literal true_literal, Trail* trail); + + // Uses the binary implication graph to minimize the given clause by removing + // literals that implies others. + // + // TODO(user): The current algorithm is minimalist, and just look at direct + // implication. Investigate recursive version. + void MinimizeClause(const Trail& trail, std::vector* clause); + + // This must only be called at decision level 0 after all the possible + // propagations. It: + // - Removes the variable at true from the implications lists. + // - Frees the propagation list of the assigned literals. + void RemoveFixedVariables(const VariablesAssignment& assigment); + + // Number of literal propagated by this class (including conflicts). + int64 num_propagations() const { return num_propagations_; } + + // MinimizeClause() stats. + int64 num_minimization() const { return num_minimization_; } + int64 num_literals_removed() const { return num_literals_removed_; } + + // Returns the number of current implications. + int64 NumberOfImplications() const { + int num = 0; + for (const std::vector& v : implications_) num += v.size(); + return num / 2; + } + + private: + // This is indexed by the Index() of a literal. Each list stores the + // literals that are implied if the index literal becomes true. + ITIVector > implications_; + + // Holds the last conflicting binary clause. + Literal temporary_clause_[2]; + + // Some stats. + int64 num_propagations_; + int64 num_minimization_; + int64 num_literals_removed_; + + // Bitset used by MinimizeClause(). + // TODO(user): use the same one as the one used in the classic minimization + // because they are already initialized. Moreover they contains more + // information. + SparseBitset is_marked_; + SparseBitset is_removed_; + + mutable StatsGroup stats_; + DISALLOW_COPY_AND_ASSIGN(BinaryImplicationGraph); +}; + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_CLAUSE_H_ diff --git a/src/sat/pb_constraint.cc b/src/sat/pb_constraint.cc index caa8728cff..1ce12e64e4 100644 --- a/src/sat/pb_constraint.cc +++ b/src/sat/pb_constraint.cc @@ -12,23 +12,13 @@ // limitations under the License. #include "sat/pb_constraint.h" +#include "util/saturated_arithmetic.h" + namespace operations_research { namespace sat { namespace { -// Returns false if the addition overflow/underflow. Otherwise returns true -// and performs the addition *b += a; -bool SafeAdd(Coefficient a, Coefficient* b) { - if (a > 0) { - if (*b > std::numeric_limits::max() - a) return false; - } else { - if (*b < std::numeric_limits::min() - a) return false; - } - *b += a; - return true; -} - bool LiteralComparator(const LiteralWithCoeff& a, const LiteralWithCoeff& b) { return a.literal.Index() < b.literal.Index(); } @@ -58,13 +48,13 @@ bool PbCannonicalForm(std::vector* cst, Coefficient* bound_shi if (representative != nullptr && current.literal.Variable() == representative->literal.Variable()) { if (current.literal == representative->literal) { - if (!SafeAdd(current.coefficient, &(representative->coefficient))) + if (!SafeAddInto(current.coefficient, &(representative->coefficient))) return false; } else { // Here current_literal is equal to (1 - representative). - if (!SafeAdd(-current.coefficient, &(representative->coefficient))) + if (!SafeAddInto(-current.coefficient, &(representative->coefficient))) return false; - if (!SafeAdd(-current.coefficient, bound_shift)) return false; + if (!SafeAddInto(-current.coefficient, bound_shift)) return false; } } else { if (representative != nullptr && representative->coefficient == 0) { @@ -85,11 +75,11 @@ bool PbCannonicalForm(std::vector* cst, Coefficient* bound_shi for (int i = 0; i < cst->size(); ++i) { const LiteralWithCoeff current = (*cst)[i]; if (current.coefficient < 0) { - if (!SafeAdd(-current.coefficient, bound_shift)) return false; + if (!SafeAddInto(-current.coefficient, bound_shift)) return false; (*cst)[i].coefficient = -current.coefficient; (*cst)[i].literal = current.literal.Negated(); } - if (!SafeAdd((*cst)[i].coefficient, max_value)) return false; + if (!SafeAddInto((*cst)[i].coefficient, max_value)) return false; } // Finally sort by increasing coefficients. @@ -108,7 +98,8 @@ bool LinearConstraintIsCannonical(const std::vector& cst) { } UpperBoundedLinearConstraint::UpperBoundedLinearConstraint( - const std::vector& cst) { + const std::vector& cst, ResolutionNode* node) + : node_(node) { DCHECK(!cst.empty()); DCHECK(std::is_sorted(cst.begin(), cst.end(), CoeffComparator)); literals_.reserve(cst.size()); @@ -208,6 +199,9 @@ void UpperBoundedLinearConstraint::FillReason(const Trail& trail, return; } + // This is needed for unsat proof. + const bool include_level_zero = trail.NeedFixedLiteralsInReason(); + // Compute the initial reason which is formed by all the literals of the // constraint that were assigned to true at the time of the propagation. // We remove literals with a level of 0 since they are not needed. @@ -219,7 +213,7 @@ void UpperBoundedLinearConstraint::FillReason(const Trail& trail, const Literal literal = literals_[i]; if (trail.Assignment().IsLiteralTrue(literal) && trail.Info(literal.Variable()).trail_index <= source_trail_index) { - if (trail.Info(literal.Variable()).level != 0) { + if (include_level_zero || trail.Info(literal.Variable()).level != 0) { reason->push_back(literal.Negated()); } current_rhs -= coeffs_[coeff_index]; @@ -282,7 +276,7 @@ void UpperBoundedLinearConstraint::Untrail(Coefficient* slack) { // TODO(user): This is relatively slow. Take the "transpose" all at once, and // maybe put small constraints first on the to_update_ lists. bool PbConstraints::AddConstraint(const std::vector& cst, - Coefficient rhs) { + Coefficient rhs, ResolutionNode* node) { SCOPED_TIME_STAT(&stats_); DCHECK(!cst.empty()); DCHECK(std::is_sorted(cst.begin(), cst.end(), CoeffComparator)); @@ -291,6 +285,9 @@ bool PbConstraints::AddConstraint(const std::vector& cst, // added constraint. if (!constraints_.empty() && constraints_.back().HasIdenticalTerms(cst)) { if (rhs < constraints_.back().Rhs()) { + // The new constraint is tighther, so we also replace the ResolutionNode. + // TODO(user): The old one could be unlocked at this point. + constraints_.back().ChangeResolutionNode(node); return constraints_.back().InitializeRhs(rhs, propagation_trail_index_, &slacks_.back(), trail_, &reason_scratchpad_); @@ -301,7 +298,7 @@ bool PbConstraints::AddConstraint(const std::vector& cst, } const ConstraintIndex cst_index(constraints_.size()); - constraints_.emplace_back(UpperBoundedLinearConstraint(cst)); + constraints_.emplace_back(UpperBoundedLinearConstraint(cst, node)); slacks_.push_back(0); if (!constraints_.back().InitializeRhs(rhs, propagation_trail_index_, &slacks_.back(), trail_, @@ -335,9 +332,13 @@ bool PbConstraints::PropagateNext() { if (slack < 0 && !conflict) { update.need_untrail_inspection = true; ++num_constraint_lookups_; + // Important: we must use the conflict_scratchpad_ here not the + // reason_scratchpad_. if (!constraints_[update.index.value()].Propagate( - order, &slacks_[update.index], trail_, &reason_scratchpad_)) { - trail_->SetFailingClause(ClauseRef(reason_scratchpad_)); + order, &slacks_[update.index], trail_, &conflict_scratchpad_)) { + trail_->SetFailingClause(ClauseRef(conflict_scratchpad_)); + trail_->SetFailingResolutionNode( + constraints_[update.index.value()].ResolutionNodePointer()); conflict = true; } } diff --git a/src/sat/pb_constraint.h b/src/sat/pb_constraint.h index a97d20eb6e..4867d7bdbd 100644 --- a/src/sat/pb_constraint.h +++ b/src/sat/pb_constraint.h @@ -79,7 +79,8 @@ bool LinearConstraintIsCannonical(const std::vector& cst); class UpperBoundedLinearConstraint { public: // Takes a pseudo-Boolean formula in canonical form. - explicit UpperBoundedLinearConstraint(const std::vector& cst); + UpperBoundedLinearConstraint(const std::vector& cst, + ResolutionNode* node); // Returns true if the given terms are the same as the one in this constraint. bool HasIdenticalTerms(const std::vector& cst); @@ -134,6 +135,13 @@ class UpperBoundedLinearConstraint { void FillReason(const Trail& trail, int source_trail_index, std::vector* reason); + // Returns the resolution node associated to this constraint. Note that it can + // be nullptr if the solver is not configured to compute the reason for an + // unsatisfiable problem or if this constraint is not relevant for the current + // core computation. + ResolutionNode* ResolutionNodePointer() const { return node_; } + void ChangeResolutionNode(ResolutionNode* node) { node_ = node; } + private: Coefficient GetCurrentRhsFromSlack(Coefficient slack) { return (index_ < 0) ? slack : coeffs_[index_] + slack; @@ -157,6 +165,8 @@ class UpperBoundedLinearConstraint { std::vector coeffs_; std::vector starts_; std::vector literals_; + + ResolutionNode* node_; }; // Class responsible for managing a set of pseudo-Boolean constraints and their @@ -185,7 +195,8 @@ class PbConstraints { // // Note(user): There is an optimization if the last constraint added is the // same as the one we are trying to add. - bool AddConstraint(const std::vector& cst, Coefficient rhs); + bool AddConstraint(const std::vector& cst, Coefficient rhs, + ResolutionNode* node); int NumberOfConstraints() const { return constraints_.size(); } // If some literals enqueued on the trail haven't been processed by this class @@ -232,7 +243,11 @@ class PbConstraints { int propagation_trail_index_; // Temporary vector to hold the reason of a pseudo-Boolean propagation. + // Important: the conflict must use another vector since these scratchpads + // must remain valid as long as they are needed by the sat solver and we do + // need to compute reasons and not overwrite the conflict. mutable std::vector reason_scratchpad_; + mutable std::vector conflict_scratchpad_; // We use a dequeue to store the pseudo-Boolean constraint because we want // pointers to its elements to be still valid after more push_back(). @@ -283,13 +298,8 @@ class PbReasonCache { const AssignmentInfo& info = trail_.Info(var); return std::make_pair(info.pb_constraint, info.source_trail_index); } -#if defined(_MSC_VER) - hash_map, - VariableIndex, - PairPointerIntHasher > map_; -#else + hash_map, VariableIndex> map_; -#endif DISALLOW_COPY_AND_ASSIGN(PbReasonCache); }; diff --git a/src/sat/sat_base.h b/src/sat/sat_base.h index 47296e11f2..4c91a16f87 100644 --- a/src/sat/sat_base.h +++ b/src/sat/sat_base.h @@ -94,8 +94,8 @@ class VariablesAssignment { public: VariablesAssignment() {} void Resize(int num_variables) { - assignment_.ClearAndResize(LiteralIndex(num_variables << 1)); - last_assignment_.ClearAndResize(LiteralIndex(num_variables << 1)); + assignment_.Resize(LiteralIndex(num_variables << 1)); + last_assignment_.Resize(LiteralIndex(num_variables << 1)); } // Makes the given literal true by assigning its underlying variable to either @@ -195,6 +195,7 @@ class ClauseRef { // Forward declaration of the classes needed to compute the reason of an // assignment. +class ResolutionNode; class SatClause; class UpperBoundedLinearConstraint; @@ -210,9 +211,8 @@ struct AssignmentInfo { // function. This AssignmentInfo can then hold a pointer to an HasReason // class. Currently, this is not done this way for efficiency. enum Type { - PREPROCESSING, - SEARCH_DECISION, UNIT_REASON, + SEARCH_DECISION, CLAUSE_PROPAGATION, BINARY_PROPAGATION, PB_PROPAGATION, @@ -239,6 +239,7 @@ struct AssignmentInfo { }; union { SatClause* sat_clause; + ResolutionNode* resolution_node; UpperBoundedLinearConstraint* pb_constraint; }; }; @@ -248,7 +249,9 @@ struct AssignmentInfo { // and the information of each assignment. class Trail { public: - Trail() : num_enqueues_(0), trail_index_(0) { current_info_.level = 0; } + Trail() : num_enqueues_(0), trail_index_(0), need_level_zero_(false) { + current_info_.level = 0; + } void Resize(int num_variables) { assignment_.Resize(num_variables); @@ -271,6 +274,10 @@ class Trail { } // Specific Enqueue() version for our different constraint types. + void EnqueueWithUnitReason(Literal true_literal, ResolutionNode* node) { + current_info_.resolution_node = node; + Enqueue(true_literal, AssignmentInfo::UNIT_REASON); + } void EnqueueWithBinaryReason(Literal true_literal, Literal reason) { current_info_.literal = reason; Enqueue(true_literal, AssignmentInfo::BINARY_PROPAGATION); @@ -313,8 +320,16 @@ class Trail { failing_clause_ = ref; failing_sat_clause_ = nullptr; } + void SetFailingResolutionNode(ResolutionNode* node) { failing_node_ = node; } ClauseRef FailingClause() const { return failing_clause_; } SatClause* FailingSatClause() const { return failing_sat_clause_; } + ResolutionNode* FailingResolutionNode() const { return failing_node_; } + + // This is required for producing correct unsat proof. Recall that a fixed + // literals is one assigned at level zero. The option is here so every code + // that needs it can easily access it. + bool NeedFixedLiteralsInReason() const { return need_level_zero_; } + void SetNeedFixedLiteralsInReason(bool value) { need_level_zero_ = value; } // Getters. int64 NumberOfEnqueues() const { return num_enqueues_; } @@ -323,6 +338,13 @@ class Trail { const VariablesAssignment& Assignment() const { return assignment_; } const AssignmentInfo& Info(VariableIndex var) const { return info_[var]; } + // Sets the new resolution node for a variable that is fixed. + void SetFixedVariableInfo(VariableIndex var, ResolutionNode* node) { + CHECK_EQ(info_[var].level, 0); + info_[var].type = AssignmentInfo::UNIT_REASON; + info_[var].resolution_node = node; + } + private: int64 num_enqueues_; int trail_index_; @@ -332,6 +354,8 @@ class Trail { ITIVector info_; ClauseRef failing_clause_; SatClause* failing_sat_clause_; + ResolutionNode* failing_node_; + bool need_level_zero_; DISALLOW_COPY_AND_ASSIGN(Trail); }; diff --git a/src/sat/sat_conflict.cc b/src/sat/sat_conflict.cc deleted file mode 100644 index 36667395b1..0000000000 --- a/src/sat/sat_conflict.cc +++ /dev/null @@ -1,528 +0,0 @@ -// Copyright 2010-2013 Google -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include -#include -#include - -#include "base/integral_types.h" -#include "base/logging.h" -#include "base/stringprintf.h" -#include "base/concise_iterator.h" -#include "base/stl_util.h" -#include "sat/sat_solver.h" - -namespace operations_research { -namespace sat { - -// This method will compute a first UIP conflict -// http://www.cs.tau.ac.il/~msagiv/courses/ATP/iccad2001_final.pdf -// http://gauss.ececs.uc.edu/SAT/articles/FAIA185-0131.pdf -void SatSolver::ComputeFirstUIPConflict( - ClauseRef failing_clause, std::vector* conflict, - std::vector* discarded_last_level_literals) { - SCOPED_TIME_STAT(&stats_); - - // This will be used to mark all the literals inspected while we process the - // conflict and the reasons behind each of its variable assignments. - is_marked_.ClearAndResize(num_variables_); - - conflict->clear(); - discarded_last_level_literals->clear(); - const int current_level = CurrentDecisionLevel(); - int num_literal_at_current_level_that_needs_to_be_processed = 0; - DCHECK_GT(current_level, 0); - - // To find the 1-UIP conflict clause, we start by the failing_clause, and - // expand each of its literal using the reason for this literal assignement to - // false. The is_marked_ set allow us to never expand the same literal twice. - // - // The expansion is not done (i.e. stop) for literal that where assigned at a - // decision level below the current one. If the level of such literal is not - // zero, it is added to the conflict clause. - // - // Now, the trick is that we use the trail to expand the literal of the - // current level in a very specific order. Namely the reverse order of the one - // in which they where infered. We stop as soon as - // num_literal_at_current_level_that_needs_to_be_processed is exactly one. - // - // This last literal will be the first UIP because by definition all the - // propagation done at the current level will pass though it at some point. - ClauseRef clause_to_expand = failing_clause; - DCHECK(!clause_to_expand.IsEmpty()); - int trail_index = trail_.Index() - 1; - while (true) { - for (const Literal literal : clause_to_expand) { - const VariableIndex var = literal.Variable(); - if (!is_marked_[var]) { - is_marked_.Set(var); - const int level = DecisionLevel(var); - if (level == current_level) { - ++num_literal_at_current_level_that_needs_to_be_processed; - } else if (level > 0) { - // Note that all these literals are currently false since the clause - // to expand was used to infer the value of a literal at this level. - DCHECK(trail_.Assignment().IsLiteralFalse(literal)); - conflict->push_back(literal); - } - } - } - - // Find next marked literal to expand from the trail. - DCHECK_GT(num_literal_at_current_level_that_needs_to_be_processed, 0); - while (!is_marked_[trail_[trail_index].Variable()]) { - --trail_index; - DCHECK_GE(trail_index, 0); - DCHECK_EQ(DecisionLevel(trail_[trail_index].Variable()), current_level); - } - - if (num_literal_at_current_level_that_needs_to_be_processed == 1) { - // We have the first UIP. Add its negation to the conflict clause. - // This way, after backtracking to the proper level, the conflict clause - // will be unit, and infer the negation of the UIP that caused the fail. - conflict->push_back(trail_[trail_index].Negated()); - break; - } - - const Literal literal = trail_[trail_index]; - discarded_last_level_literals->push_back(literal); - - // If we already encountered the same reason, we can just skip this literal - // which is what setting clause_to_expand to the empty clause do. - if (reason_cache_.FirstVariableWithSameReason(literal.Variable()) != - literal.Variable()) { - clause_to_expand = ClauseRef(); - } else { - clause_to_expand = Reason(literal.Variable()); - DCHECK(!clause_to_expand.IsEmpty()); - } - - --num_literal_at_current_level_that_needs_to_be_processed; - --trail_index; - } -} - -void SatSolver::MinimizeConflict(std::vector* conflict) { - SCOPED_TIME_STAT(&stats_); - const int old_size = conflict->size(); - switch (parameters_.minimization_algorithm()) { - case SatParameters::NONE: - return; - case SatParameters::SIMPLE: { - MinimizeConflictSimple(conflict); - break; - } - case SatParameters::RECURSIVE: { - MinimizeConflictRecursively(conflict); - break; - } - case SatParameters::EXPERIMENTAL: { - MinimizeConflictExperimental(conflict); - break; - } - } - if (conflict->size() < old_size) { - ++counters_.num_minimizations; - counters_.num_literals_removed += old_size - conflict->size(); - } -} - -// This simple version just looks for any literal that is directly infered by -// other literals of the conflict. It is directly infered if the literals of its -// reason clause are either from level 0 or from the conflict itself. -// -// Note that because of the assignement struture, there is no need to process -// the literals of the conflict in order. While exploring the reason for a -// literal assignement, there will be no cycles. -void SatSolver::MinimizeConflictSimple(std::vector* conflict) { - SCOPED_TIME_STAT(&stats_); - is_marked_.ClearAndResize(num_variables_); - for (Literal literal : *conflict) { - is_marked_.Set(literal.Variable()); - } - int index = 0; - const int current_level = CurrentDecisionLevel(); - for (int i = 0; i < conflict->size(); ++i) { - const VariableIndex var = (*conflict)[i].Variable(); - bool can_be_removed = false; - if (DecisionLevel(var) != current_level) { - // It is important not to call Reason(var) when it can be avoided. - const ClauseRef reason = Reason(var); - if (!reason.IsEmpty()) { - can_be_removed = true; - for (Literal literal : reason) { - if (DecisionLevel(literal.Variable()) == 0) continue; - if (!is_marked_[literal.Variable()]) { - can_be_removed = false; - break; - } - } - } - } - if (!can_be_removed) { - (*conflict)[index] = (*conflict)[i]; - ++index; - } - } - conflict->erase(conflict->begin() + index, conflict->end()); -} - -// This is similar to MinimizeConflictSimple() except that for each literal of -// the conflict, the literals of its reason are recursively expended using their -// reason and so on. The recusion stop until we show that the initial literal -// can be infered from the conflict variables alone, or if we show that this is -// not the case. The result of any variable expension will be cached in order -// not to be expended again. -void SatSolver::MinimizeConflictRecursively(std::vector* conflict) { - SCOPED_TIME_STAT(&stats_); - - // is_marked_ will contains all the conflict literals plus the literals that - // have been shown to depends only on the conflict literals. is_independent_ - // will contains the literals that have been shown NOT to depends only on the - // conflict literals. The too set are exclusive for non-conflict literals, but - // a conflict literal (which is always marked) can be independent if we showed - // that it can't be removed from the clause. - // - // Optimization: There is no need to call is_marked_.ClearAndResize() or to - // mark the conflict literals since this was already done by - // ComputeFirstUIPConflict(). - is_independent_.ClearAndResize(num_variables_); - - // min_trail_index_per_level_ will always be reset to all - // std::numeric_limits::max() at the end. This is used to prune the - // search because any literal at a given level with an index smaller or equal - // to min_trail_index_per_level_[level] can't be redundant. - if (CurrentDecisionLevel() >= min_trail_index_per_level_.size()) { - min_trail_index_per_level_.resize(CurrentDecisionLevel() + 1, - std::numeric_limits::max()); - } - - // Compute the number of variable at each decision levels. This will be used - // to pruned the DFS because we know that the minimized conflict will have at - // least one variable of each decision levels. Because such variable can't be - // eliminated using lower decision levels variable otherwise it will have been - // propagated. - for (Literal literal : *conflict) { - const VariableIndex var = literal.Variable(); - const int level = DecisionLevel(var); - min_trail_index_per_level_[level] = - std::min(min_trail_index_per_level_[level], trail_.Info(var).trail_index); - } - - // Remove the redundant variable from the conflict. That is the ones that can - // be infered by some other variables in the conflict. - int index = 0; - for (int i = 0; i < conflict->size(); ++i) { - const VariableIndex var = (*conflict)[i].Variable(); - if (trail_.Info(var).trail_index <= - min_trail_index_per_level_[DecisionLevel(var)] || - !CanBeInferedFromConflictVariables(var)) { - // Mark the conflict variable as independent. Note that is_marked_[var] - // will still be true. - is_independent_.Set(var); - (*conflict)[index] = (*conflict)[i]; - ++index; - } - } - - // Reset min_trail_index_per_level_. This works since we can never eliminate - // all the literals from the same level. - conflict->resize(index); - for (Literal literal : *conflict) { - min_trail_index_per_level_[DecisionLevel(literal.Variable())] = - std::numeric_limits::max(); - } -} - -bool SatSolver::CanBeInferedFromConflictVariables(VariableIndex variable) { - // Test for an already processed variable with the same reason. - { - DCHECK(is_marked_[variable]); - const VariableIndex v = reason_cache_.FirstVariableWithSameReason(variable); - if (v != variable) return !is_independent_[v]; - } - - // This function implement an iterative DFS from the given variable. It uses - // the reason clause as adjacency lists. dfs_stack_ can be seens as the - // recursive call stack of the variable we are currently processing. All its - // adjacent variable will be pushed into variable_to_process_, and we will - // then dequeue them one by one and process them. - dfs_stack_.assign(1, variable); - variable_to_process_.assign(1, variable); - - // First we expand the reason for the given variable. - DCHECK(!Reason(variable).IsEmpty()); - for (Literal literal : Reason(variable)) { - const VariableIndex var = literal.Variable(); - if (var == variable) continue; - const int level = DecisionLevel(var); - if (level == 0 || is_marked_[var]) continue; - if (trail_.Info(var).trail_index <= min_trail_index_per_level_[level] || - is_independent_[var]) - return false; - variable_to_process_.push_back(var); - } - - // Then we start the DFS. - while (!variable_to_process_.empty()) { - const VariableIndex current_var = variable_to_process_.back(); - if (current_var == dfs_stack_.back()) { - // We finished the DFS of the variable dfs_stack_.back(), this can be seen - // as a recursive call terminating. - if (dfs_stack_.size() > 1) { - DCHECK(!is_marked_[current_var]); - is_marked_.Set(current_var); - } - variable_to_process_.pop_back(); - dfs_stack_.pop_back(); - continue; - } - - // If this variable became marked since the we pushed it, we can skip it. - if (is_marked_[current_var]) { - variable_to_process_.pop_back(); - continue; - } - - // This case will never be encountered since we abort right away as soon - // as an independent variable is found. - DCHECK(!is_independent_[current_var]); - - // Test for an already processed variable with the same reason. - { - const VariableIndex v = - reason_cache_.FirstVariableWithSameReason(current_var); - if (v != current_var) { - if (is_independent_[v]) break; - DCHECK(is_marked_[v]); - variable_to_process_.pop_back(); - continue; - } - } - - // Expand the variable. This can be seen as making a recursive call. - dfs_stack_.push_back(current_var); - bool abort_early = false; - DCHECK(!Reason(current_var).IsEmpty()); - for (Literal literal : Reason(current_var)) { - const VariableIndex var = literal.Variable(); - if (var == current_var) continue; - const int level = DecisionLevel(var); - if (level == 0 || is_marked_[var]) continue; - if (trail_.Info(var).trail_index <= min_trail_index_per_level_[level] || - is_independent_[var]) { - abort_early = true; - break; - } - variable_to_process_.push_back(var); - } - if (abort_early) break; - } - - // All the variable left on the dfs_stack_ are independent. - for (const VariableIndex var : dfs_stack_) { - is_independent_.Set(var); - } - return dfs_stack_.empty(); -} - -namespace { - -struct WeightedVariable { - WeightedVariable(VariableIndex v, int w) : var(v), weight(w) {} - - VariableIndex var; - int weight; -}; - -// Lexical order, by larger weight, then by smaller variable number -// to break ties -struct VariableWithLargerWeightFirst { - bool operator()(const WeightedVariable& wv1, - const WeightedVariable& wv2) const { - return (wv1.weight > wv2.weight || - (wv1.weight == wv2.weight && wv1.var < wv2.var)); - } -}; -} // namespace. - -// This function allows a conflict variable to be replaced by another variable -// not originally in the conflict. Greater reduction and backtracking can be -// achieved this way, but the effect of this is not clear. -// -// TODO(user): More investigation needed. This seems to help on the Hanoi -// problems, but degrades performance on others. -// -// TODO(user): Find a reference for this? neither minisat nor glucose do that, -// they just do MinimizeConflictRecursively() with a different implementation. -// Note that their behavior also make more sense with the way they (and we) bump -// the variable activities. -void SatSolver::MinimizeConflictExperimental(std::vector* conflict) { - SCOPED_TIME_STAT(&stats_); - - // First, sort the variables in the conflict by decreasing decision levels. - // Also initialize is_marked_ to true for all conflict variables. - is_marked_.ClearAndResize(num_variables_); - const int current_level = CurrentDecisionLevel(); - std::vector variables_sorted_by_level; - for (Literal literal : *conflict) { - const VariableIndex var = literal.Variable(); - is_marked_.Set(var); - const int level = DecisionLevel(var); - if (level < current_level) { - variables_sorted_by_level.push_back(WeightedVariable(var, level)); - } - } - std::sort(variables_sorted_by_level.begin(), variables_sorted_by_level.end(), - VariableWithLargerWeightFirst()); - - // Then process the reason of the variable with highest level first. - std::vector to_remove; - for (WeightedVariable weighted_var : variables_sorted_by_level) { - const VariableIndex var = weighted_var.var; - - // A nullptr reason means that this was a decision variable from the - // previous levels. - const ClauseRef reason = Reason(var); - if (reason.IsEmpty()) continue; - - // Compute how many and which literals from the current reason do not appear - // in the current conflict. Level 0 literals are ignored. - std::vector not_contained_literals; - for (const Literal reason_literal : reason) { - const VariableIndex reason_var = reason_literal.Variable(); - - // We ignore level 0 variables. - if (DecisionLevel(reason_var) == 0) continue; - - // We have a reason literal whose variable is not yet seen. - // If there is more than one, break right away, we will not minimize the - // current conflict with this variable. - if (!is_marked_[reason_var]) { - not_contained_literals.push_back(reason_literal); - if (not_contained_literals.size() > 1) break; - } - } - if (not_contained_literals.empty()) { - // This variable will be deleted from the conflict. Note that we don't - // unmark it. This is because this variable can be infered from the other - // variables in the conflict, so it is okay to skip it when processing the - // reasons of other variables. - to_remove.push_back(var); - } else if (not_contained_literals.size() == 1) { - // Replace the literal from variable var with the only - // not_contained_literals from the current reason. - to_remove.push_back(var); - is_marked_.Set(not_contained_literals.front().Variable()); - conflict->push_back(not_contained_literals.front()); - } - } - - // Unmark the variable that should be removed from the conflict. - for (VariableIndex var : to_remove) { - is_marked_.Clear(var); - } - - // Remove the now unmarked literals from the conflict. - int index = 0; - for (int i = 0; i < conflict->size(); ++i) { - const Literal literal = (*conflict)[i]; - if (is_marked_[literal.Variable()]) { - (*conflict)[index] = literal; - ++index; - } - } - conflict->erase(conflict->begin() + index, conflict->end()); -} - -namespace { - -// Order the clause by increasing LBD (Literal Blocks Distance) first. For the -// same LBD they are ordered by decreasing activity. -bool ClauseOrdering(SatClause* a, SatClause* b) { - if (a->Lbd() == b->Lbd()) return a->Activity() > b->Activity(); - return a->Lbd() < b->Lbd(); -} - -} // namespace - -void SatSolver::InitLearnedClauseLimit() { - const int num_learned_clauses = learned_clauses_.size(); - target_number_of_learned_clauses_ = - num_learned_clauses + parameters_.clause_cleanup_increment(); - num_learned_clause_before_cleanup_ = - target_number_of_learned_clauses_ / parameters_.clause_cleanup_ratio() - - num_learned_clauses; - VLOG(1) << "reduced learned database to " << num_learned_clauses - << " clauses. Next cleanup in " << num_learned_clause_before_cleanup_ - << " conflicts."; -} - -void SatSolver::CompressLearnedClausesIfNeeded() { - if (num_learned_clause_before_cleanup_ > 0) return; - SCOPED_TIME_STAT(&stats_); - - // First time? - if (learned_clauses_.size() == 0) { - InitLearnedClauseLimit(); - return; - } - - // Move the clause that should be kept at the beginning and sort the other - // using the ClauseOrdering order. - std::vector::iterator clause_to_keep_end = std::partition( - learned_clauses_.begin(), learned_clauses_.end(), - std::bind1st(std::mem_fun(&SatSolver::ClauseShouldBeKept), this)); - std::sort(clause_to_keep_end, learned_clauses_.end(), ClauseOrdering); - - // Compute the index of the first clause to delete. - const int num_learned_clauses = learned_clauses_.size(); - const int first_clause_to_delete = - std::max(static_cast(clause_to_keep_end - learned_clauses_.begin()), - std::min(num_learned_clauses, target_number_of_learned_clauses_)); - - // Delete all the learned clause after 'first_clause_to_delete'. - for (int i = first_clause_to_delete; i < num_learned_clauses; ++i) { - SatClause* clause = learned_clauses_[i]; - watched_clauses_.LazyDetach(clause); - } - watched_clauses_.CleanUpWatchers(); - for (int i = first_clause_to_delete; i < num_learned_clauses; ++i) { - counters_.num_literals_forgotten += learned_clauses_[i]->Size(); - delete learned_clauses_[i]; - } - learned_clauses_.resize(first_clause_to_delete); - InitLearnedClauseLimit(); -} - -bool SatSolver::ShouldRestart() { - SCOPED_TIME_STAT(&stats_); - if (conflicts_until_next_restart_ != 0) return false; - restart_count_++; - conflicts_until_next_restart_ = - parameters_.restart_period() * SUniv(restart_count_ + 1); - return true; -} - -void SatSolver::InitRestart() { - SCOPED_TIME_STAT(&stats_); - restart_count_ = 0; - if (parameters_.restart_period() > 0) { - DCHECK_EQ(SUniv(1), 1); - conflicts_until_next_restart_ = parameters_.restart_period(); - } else { - conflicts_until_next_restart_ = -1; - } -} - -} // namespace sat -} // namespace operations_research diff --git a/src/sat/sat_parameters.proto b/src/sat/sat_parameters.proto index 27edd0c969..72705122c5 100644 --- a/src/sat/sat_parameters.proto +++ b/src/sat/sat_parameters.proto @@ -173,4 +173,10 @@ message SatParameters { // Whether the solver should log the search progress to LOG(INFO). optional bool log_search_progress = 41 [default = false]; + + // Indicates if the solver maintain in memory the information needed to + // generate an UNSAT core if the problem is unsat or to generate a full + // resolution proof. This can potentially use a lot of memory and may slow + // down the solver a bit. + optional bool unsat_proof = 42 [default = false]; } diff --git a/src/sat/sat_solver.cc b/src/sat/sat_solver.cc index 4c654f34c0..d71a1385bd 100644 --- a/src/sat/sat_solver.cc +++ b/src/sat/sat_solver.cc @@ -22,507 +22,20 @@ #include "base/sysinfo.h" #include "base/join.h" #include "util/time_limit.h" +#include "util/saturated_arithmetic.h" #include "base/stl_util.h" namespace operations_research { namespace sat { -namespace { - -// Returns true if the given watcher list contains the given clause. -template -bool WatcherListContains(const std::vector& list, - const SatClause& candidate) { - for (const Watcher& watcher : list) { - if (watcher.clause == &candidate) return true; - } - return false; -} - -// A simple wrapper to simplify the erase(std::remove_if()) pattern. -template -void RemoveIf(Container c, Predicate p) { - c->erase(std::remove_if(c->begin(), c->end(), p), c->end()); -} - -// Removes dettached clauses from a watcher list. -template -bool CleanUpPredicate(const Watcher& watcher) { - return !watcher.clause->IsAttached(); -} - -// Compares literals by variable first, then sign. -bool CompareLiteral(Literal l1, Literal l2) { return l1.Index() < l2.Index(); } -} // namespace - -// ----- LiteralWatchers ----- - -LiteralWatchers::LiteralWatchers() - : is_clean_(true), - num_inspected_clauses_(0), - num_watched_clauses_(0), - stats_("LiteralWatchers") {} - -LiteralWatchers::~LiteralWatchers() { - IF_STATS_ENABLED(LOG(INFO) << stats_.StatString()); -} - -void LiteralWatchers::Resize(int num_variables) { - DCHECK(is_clean_); - watchers_on_false_.resize(num_variables << 1); - needs_cleaning_.resize(num_variables << 1, false); - statistics_.resize(num_variables); -} - -// Note that this is the only place where we add Watcher so the DCHECK -// guarantees that there are no duplicates. -void LiteralWatchers::AttachOnFalse(Literal a, Literal b, SatClause* clause) { - SCOPED_TIME_STAT(&stats_); - DCHECK(is_clean_); - DCHECK(!WatcherListContains(watchers_on_false_[a.Index()], *clause)); - watchers_on_false_[a.Index()].push_back(Watcher(clause, b)); -} - -bool LiteralWatchers::PropagateOnFalse(Literal false_literal, Trail* trail) { - SCOPED_TIME_STAT(&stats_); - DCHECK(is_clean_); - std::vector& watchers = watchers_on_false_[false_literal.Index()]; - const VariablesAssignment& assignment = trail->Assignment(); - int new_index = 0; - - // Note(user): It sounds better to inspect the list in order, this is because - // small clauses like binary or ternary clauses will often propagate and thus - // stay at the beginning of the list. - const int initial_size = watchers.size(); - for (int i = 0; i < initial_size; ++i) { - ++num_inspected_clauses_; - - // Don't even look at the clause memory if the blocking literal is true. - if (assignment.IsLiteralTrue(watchers[i].blocking_literal)) { - watchers[new_index] = watchers[i]; - ++new_index; - continue; - } - - SatClause* clause = watchers[i].clause; - if (!clause->PropagateOnFalse(false_literal, trail)) { - // Conflict: All literals of this clause are false. - memmove(&watchers[new_index], &watchers[i], - (initial_size - i) * sizeof(Watcher)); - watchers.resize(new_index + initial_size - i); - return false; - } - - // Update the watched literal if clause->FirstLiteral() changed. - // See the contract of PropagateOnFalse(). - if (clause->FirstLiteral() != false_literal) { - AttachOnFalse(clause->FirstLiteral(), clause->SecondLiteral(), clause); - } else { - watchers[new_index] = Watcher(clause, clause->SecondLiteral()); - ++new_index; - } - } - watchers.resize(new_index); - return true; -} - -bool LiteralWatchers::AttachAndPropagate(SatClause* clause, Trail* trail) { - SCOPED_TIME_STAT(&stats_); - ++num_watched_clauses_; - UpdateStatistics(*clause, /*added=*/true); - clause->SortLiterals(statistics_, parameters_); - return clause->AttachAndEnqueuePotentialUnitPropagation(trail, this); -} - -void LiteralWatchers::LazyDetach(SatClause* clause) { - SCOPED_TIME_STAT(&stats_); - --num_watched_clauses_; - UpdateStatistics(*clause, /*added=*/false); - clause->LazyDetach(); - is_clean_ = false; - needs_cleaning_[clause->FirstLiteral().Index()] = true; - needs_cleaning_[clause->SecondLiteral().Index()] = true; -} - -void LiteralWatchers::CleanUpWatchers() { - SCOPED_TIME_STAT(&stats_); - for (int i = 0; i < needs_cleaning_.size(); ++i) { - if (needs_cleaning_[LiteralIndex(i)]) { - RemoveIf(&(watchers_on_false_[LiteralIndex(i)]), - CleanUpPredicate); - needs_cleaning_[LiteralIndex(i)] = false; - } - } - is_clean_ = true; -} - -void LiteralWatchers::UpdateStatistics(const SatClause& clause, bool added) { - SCOPED_TIME_STAT(&stats_); - for (const Literal literal : clause) { - const VariableIndex var = literal.Variable(); - const int direction = added ? 1 : -1; - statistics_[var].num_appearances += direction; - statistics_[var].weighted_num_appearances += - 1.0 / clause.Size() * direction; - if (literal.IsPositive()) { - statistics_[var].num_positive_clauses += direction; - } else { - statistics_[var].num_negative_clauses += direction; - } - } -} - -// ----- BinaryImplicationGraph ----- - -void BinaryImplicationGraph::Resize(int num_variables) { - SCOPED_TIME_STAT(&stats_); - implications_.resize(num_variables << 1); -} - -void BinaryImplicationGraph::AddBinaryClause(Literal a, Literal b) { - SCOPED_TIME_STAT(&stats_); - implications_[a.Negated().Index()].push_back(b); - implications_[b.Negated().Index()].push_back(a); -} - -void BinaryImplicationGraph::AddBinaryConflict(Literal a, Literal b, - Trail* trail) { - SCOPED_TIME_STAT(&stats_); - AddBinaryClause(a, b); - if (trail->Assignment().IsLiteralFalse(a)) { - trail->EnqueueWithBinaryReason(b, a); - } else if (trail->Assignment().IsLiteralFalse(b)) { - trail->EnqueueWithBinaryReason(a, b); - } -} - -bool BinaryImplicationGraph::PropagateOnTrue(Literal true_literal, - Trail* trail) { - SCOPED_TIME_STAT(&stats_); - const VariablesAssignment& assignment = trail->Assignment(); - for (Literal literal : implications_[true_literal.Index()]) { - if (assignment.IsLiteralTrue(literal)) { - // Note(user): I tried to update the reason here if the literal was - // enqueued after the true_literal on the trail. This property is - // important for ComputeFirstUIPConflict() to work since it needs the - // trail order to be a topological order for the deduction graph. - // But the performance where not too good... - continue; - } - - ++num_propagations_; - if (assignment.IsLiteralFalse(literal)) { - // Conflict. - temporary_clause_[0] = true_literal.Negated(); - temporary_clause_[1] = literal; - trail->SetFailingClause( - ClauseRef(&temporary_clause_[0], &temporary_clause_[0] + 2)); - return false; - } else { - // Propagation. - trail->EnqueueWithBinaryReason(literal, true_literal.Negated()); - } - } - return true; -} - -void BinaryImplicationGraph::MinimizeClause(const Trail& trail, - std::vector* conflict) { - SCOPED_TIME_STAT(&stats_); - is_marked_.ClearAndResize(LiteralIndex(implications_.size())); - is_removed_.ClearAndResize(LiteralIndex(implications_.size())); - for (Literal lit : *conflict) { - is_marked_.Set(lit.Index()); - } - - // Identify and remove the redundant literals from the given conflict. - // 1/ If a -> b then a can be removed from the conflict clause. - // This is because not b -> not a. - // 2/ a -> b can only happen if level(a) <= level(b). - // 3/ Because of 2/, cycles can appear only at the same level. - // The vector is_removed_ is used to avoid removing all elements of a - // cycle. Note that this is not optimal in the sense that we may not remove - // a literal that can be removed. - // - // TODO(user): no need to explore the unique literal of the current decision - // level since it can't be removed. - int index = 0; - for (int i = 0; i < conflict->size(); ++i) { - const Literal lit = (*conflict)[i]; - const int lit_level = trail.Info(lit.Variable()).level; - bool keep_literal = true; - for (Literal implied : implications_[lit.Index()]) { - if (is_marked_[implied.Index()]) { - DCHECK_LE(lit_level, trail.Info(implied.Variable()).level); - if (lit_level == trail.Info(implied.Variable()).level && - is_removed_[implied.Index()]) - continue; - keep_literal = false; - break; - } - } - if (keep_literal) { - (*conflict)[index] = lit; - ++index; - } else { - is_removed_.Set(lit.Index()); - } - } - if (index < conflict->size()) { - ++num_minimization_; - num_literals_removed_ += conflict->size() - index; - conflict->erase(conflict->begin() + index, conflict->end()); - } -} - -void BinaryImplicationGraph::RemoveFixedVariables( - const VariablesAssignment& assigment) { - SCOPED_TIME_STAT(&stats_); - is_marked_.ClearAndResize(LiteralIndex(implications_.size())); - for (LiteralIndex i(0); i < implications_.size(); ++i) { - if (assigment.IsLiteralTrue(Literal(i))) { - // If b is true and a -> b then because not b -> not a, all the - // implications list that contains b will be marked by this process. - for (Literal lit : implications_[Literal(i).NegatedIndex()]) { - is_marked_.Set(lit.NegatedIndex()); - } - STLClearObject(&(implications_[i])); - STLClearObject(&(implications_[Literal(i).NegatedIndex()])); - } - } - for (LiteralIndex i(0); i < implications_.size(); ++i) { - if (is_marked_[i]) { - RemoveIf(&implications_[i], - std::bind1st(std::mem_fun(&VariablesAssignment::IsLiteralTrue), - &assigment)); - } - } -} - -// ----- SatClause ----- - -// static -SatClause* SatClause::Create(const std::vector& literals, ClauseType type) { - CHECK_GE(literals.size(), 2); - SatClause* clause = reinterpret_cast( - ::operator new(sizeof(SatClause) + literals.size() * sizeof(Literal))); - clause->size_ = literals.size(); - for (int i = 0; i < literals.size(); ++i) { - clause->literals_[i] = literals[i]; - } - clause->is_learned_ = (type == LEARNED_CLAUSE); - clause->is_attached_ = false; - clause->activity_ = 0.0; - clause->lbd_ = 0; - return clause; -} - -// This currently only checks for trivially true clause. -SatClause::SimplifyStatus SatClause::Simplify() { - std::vector copy(begin(), end()); - std::sort(copy.begin(), copy.end(), CompareLiteral); - for (int i = 0; i < copy.size() - 1; ++i) { - if (copy[i] == copy[i + 1].Negated()) return CLAUSE_ALWAYS_TRUE; - } - return CLAUSE_ACTIVE; -} - -// Note that for an attached clause, removing fixed literal is okay because if -// any of them is assigned, then the clause is necessary true. -bool SatClause::RemoveFixedLiteralsAndTestIfTrue( - const VariablesAssignment& assignment) { - DCHECK(is_attached_); - if (assignment.IsVariableAssigned(literals_[0].Variable()) || - assignment.IsVariableAssigned(literals_[1].Variable())) { - DCHECK(IsSatisfied(assignment)); - return true; - } - int j = 2; - for (int i = 2; i < size_; ++i) { - if (assignment.IsVariableAssigned(literals_[i].Variable())) { - if (assignment.IsLiteralTrue(literals_[i])) return true; - } else { - literals_[j] = literals_[i]; - ++j; - } - } - size_ = j; - return false; -} - -namespace { - -// Support struct to sort literals for ordering. -struct WeightedLiteral { - WeightedLiteral(Literal l, int w) : literal(l), weight(w) {} - - Literal literal; - int weight; -}; - -// Lexical order, by smaller weight, then by smaller literal to break ties. -bool LiteralWithSmallerWeightFirst(const WeightedLiteral& wv1, - const WeightedLiteral& wv2) { - return (wv1.weight < wv2.weight) || - (wv1.weight == wv2.weight && - wv1.literal.SignedValue() < wv2.literal.SignedValue()); -} - -// Lexical order, by larger weight, then by smaller literal to break ties. -bool LiteralWithLargerWeightFirst(const WeightedLiteral& wv1, - const WeightedLiteral& wv2) { - return (wv1.weight > wv2.weight) || - (wv1.weight == wv2.weight && - wv1.literal.SignedValue() < wv2.literal.SignedValue()); -} - -} // namespace - -void SatClause::SortLiterals( - const ITIVector& statistics, - const SatParameters& parameters) { - CHECK(!IsAttached()); - const SatParameters::LiteralOrdering literal_order = - parameters.literal_ordering(); - if (literal_order != SatParameters::LITERAL_IN_ORDER) { - std::vector order; - for (Literal literal : *this) { - int weight = literal.IsPositive() - ? statistics[literal.Variable()].num_positive_clauses - : statistics[literal.Variable()].num_negative_clauses; - order.push_back(WeightedLiteral(literal, weight)); - } - switch (literal_order) { - case SatParameters::VAR_MIN_USAGE: { - std::sort(order.begin(), order.end(), LiteralWithSmallerWeightFirst); - break; - } - case SatParameters::VAR_MAX_USAGE: { - std::sort(order.begin(), order.end(), LiteralWithLargerWeightFirst); - break; - } - default: { break; } - } - for (int i = 0; i < order.size(); ++i) { - literals_[i] = order[i].literal; - } - } -} - -bool SatClause::AttachAndEnqueuePotentialUnitPropagation( - Trail* trail, LiteralWatchers* demons) { - CHECK(!IsAttached()); - // Select the first two literals that are not assigned to false and put them - // on position 0 and 1. - int num_literal_not_false = 0; - for (int i = 0; i < size_; ++i) { - if (!trail->Assignment().IsLiteralFalse(literals_[i])) { - std::swap(literals_[i], literals_[num_literal_not_false]); - ++num_literal_not_false; - if (num_literal_not_false == 2) { - break; - } - } - } - - // Returns false if all the literals were false. - // This should only happen on an UNSAT problem, and there is no need to attach - // the clause in this case. - if (num_literal_not_false == 0) return false; - - if (num_literal_not_false == 1) { - // To maintain the validity of the 2-watcher algorithm, we need to watch - // the false literal with the highest decision levels. - int max_level = trail->Info(literals_[1].Variable()).level; - for (int i = 2; i < size_; ++i) { - const int level = trail->Info(literals_[i].Variable()).level; - if (level > max_level) { - max_level = level; - std::swap(literals_[1], literals_[i]); - } - } - - // If there is a propagation, make literals_[1] the propagated literal and - // enqueue it. - if (!trail->Assignment().IsLiteralTrue(literals_[0])) { - std::swap(literals_[0], literals_[1]); - trail->EnqueueWithSatClauseReason(literals_[1], this); - } - } - - // Attach the watchers. - is_attached_ = true; - demons->AttachOnFalse(literals_[0], literals_[1], this); - demons->AttachOnFalse(literals_[1], literals_[0], this); - return true; -} - -// Propagates one watched literal becoming false. This method maintains the -// invariant that watched literals are always in position 0 and 1. -bool SatClause::PropagateOnFalse(Literal watched_literal, Trail* trail) { - const VariablesAssignment& assignment = trail->Assignment(); - DCHECK(IsAttached()); - DCHECK_GE(size_, 2); - DCHECK(assignment.IsLiteralFalse(watched_literal)); - - // The instantiated literal should be in position 0. - if (literals_[1] == watched_literal) { - literals_[1] = literals_[0]; - literals_[0] = watched_literal; - } - DCHECK_EQ(literals_[0], watched_literal); - - // If the other watched literal is true, do nothing. - if (assignment.IsLiteralTrue(literals_[1])) return true; - - for (int i = 2; i < size_; ++i) { - if (assignment.IsLiteralFalse(literals_[i])) continue; - - // Note(user): If the value of literals_[i] is true, it is possible to leave - // the watched literal unchanged. However this seems less efficient. Even if - // we swap it with the literal at position 2 to speed up future checks. - - // literal[i] is undefined or true, it's now the new literal to watch. - literals_[0] = literals_[i]; - literals_[i] = watched_literal; - return true; - } - - // Literals_[1] is either false or undefined, all other literals are false. - if (assignment.IsLiteralFalse(literals_[1])) { - trail->SetFailingSatClause(ToClauseRef(), this); - return false; - } - - // Literals_[1] is undefined, set it to true. - trail->EnqueueWithSatClauseReason(literals_[1], this); - return true; -} - -bool SatClause::IsSatisfied(const VariablesAssignment& assignment) const { - for (const Literal literal : *this) { - if (assignment.IsLiteralTrue(literal)) return true; - } - return false; -} - -std::string SatClause::DebugString() const { - std::string result; - for (const Literal literal : *this) { - if (!result.empty()) result.append(" "); - result.append(literal.DebugString()); - } - return result; -} - -// ----- SatSolver ----- - SatSolver::SatSolver() : num_variables_(0), + num_constraints_(0), pb_constraints_(&trail_), current_decision_level_(0), propagation_trail_index_(0), binary_propagation_trail_index_(0), + num_processed_fixed_variables_(0), counters_(), is_model_unsat_(false), variable_activity_increment_(1.0), @@ -532,10 +45,33 @@ SatSolver::SatSolver() conflicts_until_next_restart_(0), restart_count_(0), reason_cache_(trail_), + is_relevant_for_core_computation_(true), stats_("SatSolver") {} SatSolver::~SatSolver() { IF_STATS_ENABLED(LOG(INFO) << stats_.StatString()); + if (parameters_.unsat_proof()) { + // We need to free the memory used by the ResolutionNode of the clauses + for (SatClause* clause : learned_clauses_) { + unsat_proof_.UnlockNode(clause->ResolutionNodePointer()); + } + for (SatClause* clause : problem_clauses_) { + unsat_proof_.UnlockNode(clause->ResolutionNodePointer()); + } + // We also have to free the ResolutionNode of the variable assigned at + // level 0. + for (int i = 0; i < trail_.Index(); ++i) { + const AssignmentInfo& info = trail_.Info(trail_[i].Variable()); + if (info.type == AssignmentInfo::UNIT_REASON) { + ResolutionNode* node = info.resolution_node; + unsat_proof_.UnlockNode(node); + } + } + // And the one from the pseudo-Boolean constraints. + for (ResolutionNode* node : to_unlock_) { + unsat_proof_.UnlockNode(node); + } + } STLDeleteElements(&problem_clauses_); STLDeleteElements(&learned_clauses_); } @@ -571,6 +107,7 @@ void SatSolver::SetParameters(const SatParameters& parameters) { SCOPED_TIME_STAT(&stats_); parameters_ = parameters; watched_clauses_.SetParameters(parameters); + trail_.SetNeedFixedLiteralsInReason(parameters.unsat_proof()); } std::string SatSolver::Indent() const { @@ -594,39 +131,67 @@ bool SatSolver::ModelUnsat() { return false; } +bool SatSolver::AddUnitClause(Literal true_literal) { + SCOPED_TIME_STAT(&stats_); + CHECK_EQ(CurrentDecisionLevel(), 0); + if (trail_.Assignment().IsLiteralFalse(true_literal)) return false; + if (trail_.Assignment().IsLiteralTrue(true_literal)) return true; + trail_.EnqueueWithUnitReason(true_literal, CreateRootResolutionNode()); + ++num_constraints_; + return true; +} + bool SatSolver::AddProblemClause(const std::vector& literals) { SCOPED_TIME_STAT(&stats_); + + // TODO(user): To avoid duplication, we currently just call + // AddLinearConstraint(). Make a faster specific version if that becomes a + // performance issue. + tmp_pb_constraint_.clear(); + for (Literal lit : literals) { + tmp_pb_constraint_.push_back(LiteralWithCoeff(lit, 1)); + } + return AddLinearConstraint( + /*has_lower_bound=*/true, /*lower_bound=*/1, + /*has_lower_bound=*/false, /*upper_bound=*/0, &tmp_pb_constraint_); +} + +bool SatSolver::AddProblemClauseInternal(const std::vector& literals, + ResolutionNode* node) { + SCOPED_TIME_STAT(&stats_); CHECK_EQ(CurrentDecisionLevel(), 0); // Deals with clause of size 0 (always false) and 1 (set a literal) right away // so we guarantee that a SatClause is always of size greater than one. This // simplifies the code. - if (literals.size() == 0) return ModelUnsat(); - if (literals.size() == 1) return TestValidityAndEnqueueIfNeeded(literals[0]); - - std::unique_ptr clause( - SatClause::Create(literals, SatClause::PROBLEM_CLAUSE)); - switch (clause->Simplify()) { - case SatClause::CLAUSE_ALWAYS_FALSE: - return ModelUnsat(); - case SatClause::CLAUSE_ALWAYS_TRUE: - FALLTHROUGH_INTENDED; - case SatClause::CLAUSE_SUBSUMED: - return true; - case SatClause::CLAUSE_ACTIVE: { - if (parameters_.treat_binary_clauses_separately() && - clause->Size() == 2) { - binary_implication_graph_.AddBinaryClause(clause->FirstLiteral(), - clause->SecondLiteral()); - } else { - if (!watched_clauses_.AttachAndPropagate(clause.get(), &trail_)) { - return ModelUnsat(); - } - problem_clauses_.push_back(clause.release()); - } + CHECK_GT(literals.size(), 0); + if (literals.size() == 1) { + if (trail_.Assignment().IsLiteralFalse(literals[0])) { + if (node != nullptr) unsat_proof_.UnlockNode(node); + return false; + } + if (trail_.Assignment().IsLiteralTrue(literals[0])) { + if (node != nullptr) unsat_proof_.UnlockNode(node); return true; } + trail_.EnqueueWithUnitReason(literals[0], node); // Not assigned. + return true; } + + // Create a new clause. + std::unique_ptr clause( + SatClause::Create(literals, SatClause::PROBLEM_CLAUSE, node)); + + if (parameters_.treat_binary_clauses_separately() && clause->Size() == 2) { + binary_implication_graph_.AddBinaryClause(clause->FirstLiteral(), + clause->SecondLiteral()); + } else { + if (!watched_clauses_.AttachAndPropagate(clause.get(), &trail_)) { + return ModelUnsat(); + } + problem_clauses_.push_back(clause.release()); + } + return true; } bool SatSolver::AddLinearConstraintInternal(const std::vector& cst, @@ -637,6 +202,9 @@ bool SatSolver::AddLinearConstraintInternal(const std::vector& if (rhs < 0) return ModelUnsat(); // Unsatisfiable constraint. if (rhs >= max_value) return true; // Always satisfied constraint. + // Create the associated resolution node. + ResolutionNode* node = CreateRootResolutionNode(); + // A linear upper bounded constraint is a clause if the only problematic // assignment is the one where all the literals are true. Since they are // ordered by coefficient, this is easy to check. @@ -646,12 +214,17 @@ bool SatSolver::AddLinearConstraintInternal(const std::vector& for (const LiteralWithCoeff& term : cst) { literals_scratchpad_.push_back(term.literal.Negated()); } - return AddProblemClause(literals_scratchpad_); + return AddProblemClauseInternal(literals_scratchpad_, node); } + // Remember that we need to unlock the node passed to pb constraints. + // TODO(user): Find a cleaner way. Also, if the pb_constraints_ do not need + // this node in the end, we delay its memory release because of this. + if (node != nullptr) to_unlock_.push_back(node); + // TODO(user): If this constraint forces all its literal to false (when rhs is // zero for instance), we still add it. Optimize this? - return pb_constraints_.AddConstraint(cst, rhs); + return pb_constraints_.AddConstraint(cst, rhs, node); } bool SatSolver::AddLinearConstraint(bool use_lower_bound, @@ -661,42 +234,69 @@ bool SatSolver::AddLinearConstraint(bool use_lower_bound, std::vector* cst) { SCOPED_TIME_STAT(&stats_); CHECK_EQ(CurrentDecisionLevel(), 0); + + // This block removes assigned literals from the constraint. + // + // Note(user): We could make this work with unsat_proof() on by adding the + // removed literals (with a coeff of the good sign) as dependencies to the + // ResolutionNode associated with this constraint. However, for pseudo-Boolean + // constraints, we would loose the minimization of the reason which seems + // important in order to get smaller core. + Coefficient fixed_variable_shift = 0; + if (!parameters_.unsat_proof()) { + int index = 0; + for (const LiteralWithCoeff& term : *cst) { + if (trail_.Assignment().IsLiteralFalse(term.literal)) continue; + if (trail_.Assignment().IsLiteralTrue(term.literal)) { + CHECK(SafeAddInto(-term.coefficient, &fixed_variable_shift)); + continue; + } + (*cst)[index] = term; + ++index; + } + cst->resize(index); + } + + // Cannonicalize the constraint. Coefficient bound_shift; Coefficient max_value; CHECK(PbCannonicalForm(cst, &bound_shift, &max_value)); + CHECK(SafeAddInto(fixed_variable_shift, &bound_shift)); + if (use_upper_bound) { - if (!AddLinearConstraintInternal(*cst, upper_bound + bound_shift, - max_value)) { - return ModelUnsat(); - } + Coefficient ub = upper_bound; + CHECK(SafeAddInto(bound_shift, &ub)); + if (!AddLinearConstraintInternal(*cst, ub, max_value)) return ModelUnsat(); } if (use_lower_bound) { + // We transform the constraint into an upper-bounded one. for (int i = 0; i < cst->size(); ++i) { (*cst)[i].literal = (*cst)[i].literal.Negated(); } - if (!AddLinearConstraintInternal( - *cst, max_value - (lower_bound + bound_shift), max_value)) { - return ModelUnsat(); - } + Coefficient ub = max_value; + CHECK(SafeAddInto(-lower_bound, &ub)); + CHECK(SafeAddInto(-bound_shift, &ub)); + if (!AddLinearConstraintInternal(*cst, ub, max_value)) return ModelUnsat(); } + ++num_constraints_; return true; } void SatSolver::AddLearnedClauseAndEnqueueUnitPropagation( - const std::vector& literals) { + const std::vector& literals, ResolutionNode* node) { SCOPED_TIME_STAT(&stats_); if (literals.size() == 1) { // A length 1 clause fix a literal for all the search. // ComputeBacktrackLevel() should have returned 0. CHECK_EQ(CurrentDecisionLevel(), 0); - trail_.Enqueue(literals[0], AssignmentInfo::UNIT_REASON); + trail_.EnqueueWithUnitReason(literals[0], node); } else { if (parameters_.treat_binary_clauses_separately() && literals.size() == 2) { binary_implication_graph_.AddBinaryConflict(literals[0], literals[1], &trail_); } else { SatClause* clause = - SatClause::Create(literals, SatClause::LEARNED_CLAUSE); + SatClause::Create(literals, SatClause::LEARNED_CLAUSE, node); CompressLearnedClausesIfNeeded(); --num_learned_clause_before_cleanup_; learned_clauses_.emplace_back(clause); @@ -716,6 +316,7 @@ bool SatSolver::InitialPropagation() { if (!Propagate()) { return ModelUnsat(); } + ProcessNewlyFixedVariableResolutionNodes(); ProcessNewlyFixedVariables(); return true; } @@ -733,6 +334,7 @@ int SatSolver::EnqueueDecisionAndBackjumpOnConflict(Literal true_literal) { // TODO(user): Do more advanced preprocessing? if (CurrentDecisionLevel() == 0) { if (num_processed_fixed_variables_ < trail_.Index()) { + ProcessNewlyFixedVariableResolutionNodes(); ProcessNewlyFixedVariables(); } } @@ -748,23 +350,21 @@ int SatSolver::EnqueueDecisionAndBackjumpOnConflict(Literal true_literal) { reason_cache_.Clear(); // A conflict occured, compute a nice reason for this failure. - std::vector reason; - std::vector discarded_last_level_literals; - ComputeFirstUIPConflict(trail_.FailingClause(), &reason, - &discarded_last_level_literals); - DCHECK(IsConflictValid(reason)); + ComputeFirstUIPConflict(trail_.FailingClause(), &learned_conflict_, + &reason_used_to_infer_the_conflict_); + DCHECK(IsConflictValid(learned_conflict_)); // Update the activity of all the variables in the first UIP clause. // Also update the activity of the last level variables expanded (and // thus discarded) during the first UIP computation. Note that both // sets are disjoint. - const int initial_lbd = ComputeLbd(reason); + const int initial_lbd = ComputeLbd(learned_conflict_); const int lbd_limit = parameters_.use_lbd() && parameters_.use_glucose_bump_again_strategy() ? initial_lbd : 0; - BumpVariableActivities(reason, lbd_limit); - BumpVariableActivities(discarded_last_level_literals, lbd_limit); + BumpVariableActivities(learned_conflict_, lbd_limit); + BumpVariableActivities(reason_used_to_infer_the_conflict_, lbd_limit); // Bump the clause activities. // Note that the activity of the learned clause will be bumped too @@ -772,26 +372,34 @@ int SatSolver::EnqueueDecisionAndBackjumpOnConflict(Literal true_literal) { if (trail_.FailingSatClause() != nullptr) { BumpClauseActivity(trail_.FailingSatClause()); } - BumpReasonActivities(discarded_last_level_literals); + BumpReasonActivities(reason_used_to_infer_the_conflict_); - // Minimize the reason. - MinimizeConflict(&reason); - DCHECK(IsConflictValid(reason)); - DCHECK_EQ(initial_lbd, ComputeLbd(reason)); + // Minimize the learned conflict. + MinimizeConflict(&learned_conflict_, &reason_used_to_infer_the_conflict_); + DCHECK(IsConflictValid(learned_conflict_)); + DCHECK_EQ(initial_lbd, ComputeLbd(learned_conflict_)); if (parameters_.treat_binary_clauses_separately() && parameters_.use_binary_clauses_minimization()) { // Note that on the contrary to the MinimizeConflict() above that // just uses the reason graph, this minimization can change the // clause LBD and even the backtracking level. - binary_implication_graph_.MinimizeClause(trail_, &reason); - DCHECK(IsConflictValid(reason)); + binary_implication_graph_.MinimizeClause(trail_, &learned_conflict_); + DCHECK(IsConflictValid(learned_conflict_)); } + // Compute the resolution node if needed. + ResolutionNode* node = + parameters_.unsat_proof() + ? CreateResolutionNode( + trail_.FailingResolutionNode(), + ClauseRef(reason_used_to_infer_the_conflict_)) + : nullptr; + // Backtrack and add the reason to the set of learned clause. - counters_.num_literals_learned += reason.size(); - Backtrack(ComputeBacktrackLevel(reason)); + counters_.num_literals_learned += learned_conflict_.size(); + Backtrack(ComputeBacktrackLevel(learned_conflict_)); first_propagation_index = trail_.Index(); - AddLearnedClauseAndEnqueueUnitPropagation(reason); + AddLearnedClauseAndEnqueueUnitPropagation(learned_conflict_, node); // Decay the activities. UpdateVariableActivityIncrement(); @@ -1017,7 +625,9 @@ void SatSolver::BumpVariableActivities(const std::vector& literals, const double max_activity_value = parameters_.max_variable_activity_value(); for (const Literal literal : literals) { const VariableIndex var = literal.Variable(); - if (DecisionLevel(var) == CurrentDecisionLevel() && + const int level = DecisionLevel(var); + if (level == 0) continue; + if (level == CurrentDecisionLevel() && trail_.Info(var).type == AssignmentInfo::CLAUSE_PROPAGATION && trail_.Info(var).sat_clause->IsLearned() && trail_.Info(var).sat_clause->Lbd() < bump_again_lbd_limit) { @@ -1034,7 +644,8 @@ void SatSolver::BumpReasonActivities(const std::vector& literals) { SCOPED_TIME_STAT(&stats_); for (const Literal literal : literals) { const VariableIndex var = literal.Variable(); - if (trail_.Info(var).type == AssignmentInfo::CLAUSE_PROPAGATION) { + if (DecisionLevel(var) > 0 && + trail_.Info(var).type == AssignmentInfo::CLAUSE_PROPAGATION) { BumpClauseActivity(trail_.Info(var).sat_clause); } } @@ -1214,9 +825,39 @@ double SatSolver::ComputeInitialVariableWeight(VariableIndex var) const { } } +void SatSolver::ProcessNewlyFixedVariableResolutionNodes() { + if (!parameters_.unsat_proof()) return; + CHECK_GE(num_processed_fixed_variables_, 0); + for (int i = num_processed_fixed_variables_; i < trail_.Index(); ++i) { + const AssignmentInfo& info = trail_.Info(trail_[i].Variable()); + if (info.type == AssignmentInfo::UNIT_REASON) continue; + CHECK_NE(info.type, AssignmentInfo::SEARCH_DECISION); + CHECK_NE(info.type, AssignmentInfo::BINARY_PROPAGATION); + + // We need this loop to remove the propagated literal from the reason. + // TODO(user): The reason should probably not contains the propagated + // literal in the first place. Clean that up. + literals_scratchpad_.clear(); + for (Literal literal : Reason(trail_[i].Variable())) { + if (literal != trail_[i]) literals_scratchpad_.push_back(literal); + } + + // Note that this works because level 0 literals are part of the reason + // at this point. + ResolutionNode* new_node = + CreateResolutionNode(info.type == AssignmentInfo::CLAUSE_PROPAGATION + ? info.sat_clause->ResolutionNodePointer() + : info.pb_constraint->ResolutionNodePointer(), + ClauseRef(literals_scratchpad_)); + trail_.SetFixedVariableInfo(trail_[i].Variable(), new_node); + } +} + void SatSolver::ProcessNewlyFixedVariables() { SCOPED_TIME_STAT(&stats_); DCHECK_EQ(CurrentDecisionLevel(), 0); + std::vector removed_literals; + std::vector resolution_nodes; int num_detached_clauses = 0; int num_binary = 0; @@ -1229,17 +870,29 @@ void SatSolver::ProcessNewlyFixedVariables() { for (int i = 0; i < 2; ++i) { for (SatClause* clause : (i == 0) ? problem_clauses_ : learned_clauses_) { if (clause->IsAttached()) { - if (clause->RemoveFixedLiteralsAndTestIfTrue(trail_.Assignment())) { + if (clause->RemoveFixedLiteralsAndTestIfTrue(trail_.Assignment(), + &removed_literals)) { // The clause is always true, detach it. + // TODO(user): Unlock its associated resolution node right away since + // the solver will not be able to reach it again. watched_clauses_.LazyDetach(clause); ++num_detached_clauses; - } else if (clause->Size() == 2 && - parameters_.treat_binary_clauses_separately()) { - // The clause is now a binary clause, treat it separately. - binary_implication_graph_.AddBinaryClause(clause->FirstLiteral(), - clause->SecondLiteral()); - watched_clauses_.LazyDetach(clause); - ++num_binary; + } else if (!removed_literals.empty()) { + if (clause->Size() == 2 && + parameters_.treat_binary_clauses_separately()) { + // The clause is now a binary clause, treat it separately. + binary_implication_graph_.AddBinaryClause(clause->FirstLiteral(), + clause->SecondLiteral()); + watched_clauses_.LazyDetach(clause); + ++num_binary; + } else if (parameters_.unsat_proof()) { + // The "new" clause is derived from the old one plus the level 0 + // literals. + ResolutionNode* new_node = CreateResolutionNode( + clause->ResolutionNodePointer(), ClauseRef(removed_literals)); + unsat_proof_.UnlockNode(clause->ResolutionNodePointer()); + clause->ChangeResolutionNode(new_node); + } } } } @@ -1256,6 +909,12 @@ void SatSolver::ProcessNewlyFixedVariables() { learned_clauses_.begin(), learned_clauses_.end(), std::bind1st(std::mem_fun(&SatSolver::IsClauseAttachedOrUsedAsReason), this)); + if (parameters_.unsat_proof()) { + for (std::vector::iterator it = iter; it != learned_clauses_.end(); + ++it) { + unsat_proof_.UnlockNode((*it)->ResolutionNodePointer()); + } + } STLDeleteContainerPointers(iter, learned_clauses_.end()); learned_clauses_.erase(iter, learned_clauses_.end()); } @@ -1308,18 +967,9 @@ bool SatSolver::Propagate() { return true; } -bool SatSolver::TestValidityAndEnqueueIfNeeded(Literal literal) { - SCOPED_TIME_STAT(&stats_); - if (trail_.Assignment().IsLiteralFalse(literal)) return false; - if (trail_.Assignment().IsLiteralTrue(literal)) return true; - trail_.Enqueue(literal, AssignmentInfo::PREPROCESSING); // Not assigned. - return true; -} - ClauseRef SatSolver::Reason(VariableIndex var) const { DCHECK(trail_.Assignment().IsVariableAssigned(var)); switch (trail_.Info(var).type) { - case AssignmentInfo::PREPROCESSING: case AssignmentInfo::SEARCH_DECISION: case AssignmentInfo::UNIT_REASON: return ClauseRef(); @@ -1428,6 +1078,23 @@ void SatSolver::Untrail(int target_trail_index) { binary_propagation_trail_index_ = target_trail_index; } +void SatSolver::ComputeUnsatCore(std::vector* core) { + SCOPED_TIME_STAT(&stats_); + CHECK(parameters_.unsat_proof()); + CHECK_EQ(is_model_unsat_, true); + + ProcessNewlyFixedVariableResolutionNodes(); + + // Generate the resolution node corresponding to the last conflict. + ResolutionNode* final_node = CreateResolutionNode( + trail_.FailingResolutionNode(), trail_.FailingClause()); + CHECK(final_node != nullptr); + + // Compute the core and free up the final_node. + unsat_proof_.ComputeUnsatCore(final_node, core); + unsat_proof_.UnlockNode(final_node); +} + bool SatSolver::IsAssignmentValid(const VariablesAssignment& assignment) const { SCOPED_TIME_STAT(&stats_); VLOG(2) << "Checking solution"; @@ -1471,5 +1138,586 @@ std::string SatSolver::DebugString(const SatClause& clause) const { return result; } +ResolutionNode* SatSolver::CreateRootResolutionNode() { + SCOPED_TIME_STAT(&stats_); + return parameters_.unsat_proof() && is_relevant_for_core_computation_ + ? unsat_proof_.CreateNewRootNode(num_constraints_) + : nullptr; +} + +ResolutionNode* SatSolver::CreateResolutionNode( + ResolutionNode* failing_clause_resolution_node, + ClauseRef reason_used_to_infer_the_conflict) { + SCOPED_TIME_STAT(&stats_); + tmp_parents_.clear(); + + // Note that nullptr is a valid resolution node. It means that the associated + // deduction doesn't depend on the set of constraint we care about. + if (failing_clause_resolution_node != nullptr) { + tmp_parents_.push_back(failing_clause_resolution_node); + } + for (Literal literal : reason_used_to_infer_the_conflict) { + // We currently support only two reason types. + const AssignmentInfo& info = trail_.Info(literal.Variable()); + ResolutionNode* node = nullptr; + switch (info.type) { + case AssignmentInfo::CLAUSE_PROPAGATION: + CHECK(info.sat_clause != nullptr); + node = info.sat_clause->ResolutionNodePointer(); + break; + case AssignmentInfo::UNIT_REASON: + node = info.resolution_node; + break; + case AssignmentInfo::PB_PROPAGATION: + CHECK(info.pb_constraint != nullptr); + node = info.pb_constraint->ResolutionNodePointer(); + break; + case AssignmentInfo::SEARCH_DECISION: + case AssignmentInfo::BINARY_PROPAGATION: + LOG(FATAL) << "This shouldn't happen"; + break; + } + if (node != nullptr) tmp_parents_.push_back(node); + } + return tmp_parents_.empty() + ? nullptr + : unsat_proof_.CreateNewResolutionNode(&tmp_parents_); +} + +// This method will compute a first UIP conflict +// http://www.cs.tau.ac.il/~msagiv/courses/ATP/iccad2001_final.pdf +// http://gauss.ececs.uc.edu/SAT/articles/FAIA185-0131.pdf +void SatSolver::ComputeFirstUIPConflict( + ClauseRef failing_clause, std::vector* conflict, + std::vector* reason_used_to_infer_the_conflict) { + SCOPED_TIME_STAT(&stats_); + + // This will be used to mark all the literals inspected while we process the + // conflict and the reasons behind each of its variable assignments. + is_marked_.ClearAndResize(num_variables_); + + conflict->clear(); + reason_used_to_infer_the_conflict->clear(); + const int current_level = CurrentDecisionLevel(); + int num_literal_at_current_level_that_needs_to_be_processed = 0; + DCHECK_GT(current_level, 0); + + // To find the 1-UIP conflict clause, we start by the failing_clause, and + // expand each of its literal using the reason for this literal assignement to + // false. The is_marked_ set allow us to never expand the same literal twice. + // + // The expansion is not done (i.e. stop) for literal that where assigned at a + // decision level below the current one. If the level of such literal is not + // zero, it is added to the conflict clause. + // + // Now, the trick is that we use the trail to expand the literal of the + // current level in a very specific order. Namely the reverse order of the one + // in which they where infered. We stop as soon as + // num_literal_at_current_level_that_needs_to_be_processed is exactly one. + // + // This last literal will be the first UIP because by definition all the + // propagation done at the current level will pass though it at some point. + ClauseRef clause_to_expand = failing_clause; + DCHECK(!clause_to_expand.IsEmpty()); + int trail_index = trail_.Index() - 1; + while (true) { + for (const Literal literal : clause_to_expand) { + const VariableIndex var = literal.Variable(); + if (!is_marked_[var]) { + is_marked_.Set(var); + const int level = DecisionLevel(var); + if (level == current_level) { + ++num_literal_at_current_level_that_needs_to_be_processed; + } else if (level > 0) { + // Note that all these literals are currently false since the clause + // to expand was used to infer the value of a literal at this level. + DCHECK(trail_.Assignment().IsLiteralFalse(literal)); + conflict->push_back(literal); + } else { + reason_used_to_infer_the_conflict->push_back(literal); + } + } + } + + // Find next marked literal to expand from the trail. + DCHECK_GT(num_literal_at_current_level_that_needs_to_be_processed, 0); + while (!is_marked_[trail_[trail_index].Variable()]) { + --trail_index; + DCHECK_GE(trail_index, 0); + DCHECK_EQ(DecisionLevel(trail_[trail_index].Variable()), current_level); + } + + if (num_literal_at_current_level_that_needs_to_be_processed == 1) { + // We have the first UIP. Add its negation to the conflict clause. + // This way, after backtracking to the proper level, the conflict clause + // will be unit, and infer the negation of the UIP that caused the fail. + conflict->push_back(trail_[trail_index].Negated()); + break; + } + + const Literal literal = trail_[trail_index]; + reason_used_to_infer_the_conflict->push_back(literal); + + // If we already encountered the same reason, we can just skip this literal + // which is what setting clause_to_expand to the empty clause do. + if (reason_cache_.FirstVariableWithSameReason(literal.Variable()) != + literal.Variable()) { + clause_to_expand = ClauseRef(); + } else { + clause_to_expand = Reason(literal.Variable()); + DCHECK(!clause_to_expand.IsEmpty()); + } + + --num_literal_at_current_level_that_needs_to_be_processed; + --trail_index; + } +} + +void SatSolver::MinimizeConflict( + std::vector* conflict, + std::vector* reason_used_to_infer_the_conflict) { + SCOPED_TIME_STAT(&stats_); + + const int old_size = conflict->size(); + switch (parameters_.minimization_algorithm()) { + case SatParameters::NONE: + return; + case SatParameters::SIMPLE: { + MinimizeConflictSimple(conflict); + break; + } + case SatParameters::RECURSIVE: { + MinimizeConflictRecursively(conflict); + break; + } + case SatParameters::EXPERIMENTAL: { + MinimizeConflictExperimental(conflict); + break; + } + } + if (conflict->size() < old_size) { + ++counters_.num_minimizations; + counters_.num_literals_removed += old_size - conflict->size(); + } + + // TODO(user): This has been only checked with the RECURSIVE algorithm at + // this point. + if (parameters_.unsat_proof()) { + CHECK_EQ(parameters_.minimization_algorithm(), SatParameters::RECURSIVE); + + // Loop over all the marked variable. The reason of the one that are not of + // the last level (already added) and are not independent where used to + // minimize the clause. + const int current_level = CurrentDecisionLevel(); + const std::vector& marked = is_marked_.PositionsSetAtLeastOnce(); + for (int i = 0; i < marked.size(); ++i) { + if (DecisionLevel(marked[i]) == current_level) continue; + if (!is_independent_[marked[i]]) { + reason_used_to_infer_the_conflict->push_back(Literal(marked[i], true)); + } + } + } +} + +// This simple version just looks for any literal that is directly infered by +// other literals of the conflict. It is directly infered if the literals of its +// reason clause are either from level 0 or from the conflict itself. +// +// Note that because of the assignement struture, there is no need to process +// the literals of the conflict in order. While exploring the reason for a +// literal assignement, there will be no cycles. +void SatSolver::MinimizeConflictSimple(std::vector* conflict) { + SCOPED_TIME_STAT(&stats_); + is_marked_.ClearAndResize(num_variables_); + for (Literal literal : *conflict) { + is_marked_.Set(literal.Variable()); + } + int index = 0; + const int current_level = CurrentDecisionLevel(); + for (int i = 0; i < conflict->size(); ++i) { + const VariableIndex var = (*conflict)[i].Variable(); + bool can_be_removed = false; + if (DecisionLevel(var) != current_level) { + // It is important not to call Reason(var) when it can be avoided. + const ClauseRef reason = Reason(var); + if (!reason.IsEmpty()) { + can_be_removed = true; + for (Literal literal : reason) { + if (DecisionLevel(literal.Variable()) == 0) continue; + if (!is_marked_[literal.Variable()]) { + can_be_removed = false; + break; + } + } + } + } + if (!can_be_removed) { + (*conflict)[index] = (*conflict)[i]; + ++index; + } + } + conflict->erase(conflict->begin() + index, conflict->end()); +} + +// This is similar to MinimizeConflictSimple() except that for each literal of +// the conflict, the literals of its reason are recursively expended using their +// reason and so on. The recusion stop until we show that the initial literal +// can be infered from the conflict variables alone, or if we show that this is +// not the case. The result of any variable expension will be cached in order +// not to be expended again. +void SatSolver::MinimizeConflictRecursively(std::vector* conflict) { + SCOPED_TIME_STAT(&stats_); + + // is_marked_ will contains all the conflict literals plus the literals that + // have been shown to depends only on the conflict literals. is_independent_ + // will contains the literals that have been shown NOT to depends only on the + // conflict literals. The too set are exclusive for non-conflict literals, but + // a conflict literal (which is always marked) can be independent if we showed + // that it can't be removed from the clause. + // + // Optimization: There is no need to call is_marked_.ClearAndResize() or to + // mark the conflict literals since this was already done by + // ComputeFirstUIPConflict(). + is_independent_.ClearAndResize(num_variables_); + + // min_trail_index_per_level_ will always be reset to all + // std::numeric_limits::max() at the end. This is used to prune the + // search because any literal at a given level with an index smaller or equal + // to min_trail_index_per_level_[level] can't be redundant. + if (CurrentDecisionLevel() >= min_trail_index_per_level_.size()) { + min_trail_index_per_level_.resize(CurrentDecisionLevel() + 1, + std::numeric_limits::max()); + } + + // Compute the number of variable at each decision levels. This will be used + // to pruned the DFS because we know that the minimized conflict will have at + // least one variable of each decision levels. Because such variable can't be + // eliminated using lower decision levels variable otherwise it will have been + // propagated. + for (Literal literal : *conflict) { + const VariableIndex var = literal.Variable(); + const int level = DecisionLevel(var); + min_trail_index_per_level_[level] = + std::min(min_trail_index_per_level_[level], trail_.Info(var).trail_index); + } + + // Remove the redundant variable from the conflict. That is the ones that can + // be infered by some other variables in the conflict. + int index = 0; + for (int i = 0; i < conflict->size(); ++i) { + const VariableIndex var = (*conflict)[i].Variable(); + if (trail_.Info(var).trail_index <= + min_trail_index_per_level_[DecisionLevel(var)] || + !CanBeInferedFromConflictVariables(var)) { + // Mark the conflict variable as independent. Note that is_marked_[var] + // will still be true. + is_independent_.Set(var); + (*conflict)[index] = (*conflict)[i]; + ++index; + } + } + + // Reset min_trail_index_per_level_. This works since we can never eliminate + // all the literals from the same level. + conflict->resize(index); + for (Literal literal : *conflict) { + min_trail_index_per_level_[DecisionLevel(literal.Variable())] = + std::numeric_limits::max(); + } +} + +bool SatSolver::CanBeInferedFromConflictVariables(VariableIndex variable) { + // Test for an already processed variable with the same reason. + { + DCHECK(is_marked_[variable]); + const VariableIndex v = reason_cache_.FirstVariableWithSameReason(variable); + if (v != variable) return !is_independent_[v]; + } + + // This function implement an iterative DFS from the given variable. It uses + // the reason clause as adjacency lists. dfs_stack_ can be seens as the + // recursive call stack of the variable we are currently processing. All its + // adjacent variable will be pushed into variable_to_process_, and we will + // then dequeue them one by one and process them. + dfs_stack_.assign(1, variable); + variable_to_process_.assign(1, variable); + + // First we expand the reason for the given variable. + DCHECK(!Reason(variable).IsEmpty()); + for (Literal literal : Reason(variable)) { + const VariableIndex var = literal.Variable(); + if (var == variable) continue; + const int level = DecisionLevel(var); + if (is_marked_[var]) continue; + if (level == 0) { + // Note that this is not needed if the solver is not configured to produce + // an unsat proof. However, the (level == 0) test shoud always be false in + // this case because there will never be literals of level zero in any + // reason when we don't want a proof. + is_marked_.Set(var); + continue; + } + if (trail_.Info(var).trail_index <= min_trail_index_per_level_[level] || + is_independent_[var]) { + return false; + } + variable_to_process_.push_back(var); + } + + // Then we start the DFS. + while (!variable_to_process_.empty()) { + const VariableIndex current_var = variable_to_process_.back(); + if (current_var == dfs_stack_.back()) { + // We finished the DFS of the variable dfs_stack_.back(), this can be seen + // as a recursive call terminating. + if (dfs_stack_.size() > 1) { + DCHECK(!is_marked_[current_var]); + is_marked_.Set(current_var); + } + variable_to_process_.pop_back(); + dfs_stack_.pop_back(); + continue; + } + + // If this variable became marked since the we pushed it, we can skip it. + if (is_marked_[current_var]) { + variable_to_process_.pop_back(); + continue; + } + + // This case will never be encountered since we abort right away as soon + // as an independent variable is found. + DCHECK(!is_independent_[current_var]); + + // Test for an already processed variable with the same reason. + { + const VariableIndex v = + reason_cache_.FirstVariableWithSameReason(current_var); + if (v != current_var) { + if (is_independent_[v]) break; + DCHECK(is_marked_[v]); + variable_to_process_.pop_back(); + continue; + } + } + + // Expand the variable. This can be seen as making a recursive call. + dfs_stack_.push_back(current_var); + bool abort_early = false; + DCHECK(!Reason(current_var).IsEmpty()); + for (Literal literal : Reason(current_var)) { + const VariableIndex var = literal.Variable(); + if (var == current_var) continue; + const int level = DecisionLevel(var); + if (level == 0 || is_marked_[var]) continue; + if (trail_.Info(var).trail_index <= min_trail_index_per_level_[level] || + is_independent_[var]) { + abort_early = true; + break; + } + variable_to_process_.push_back(var); + } + if (abort_early) break; + } + + // All the variable left on the dfs_stack_ are independent. + for (const VariableIndex var : dfs_stack_) { + is_independent_.Set(var); + } + return dfs_stack_.empty(); +} + +namespace { + +struct WeightedVariable { + WeightedVariable(VariableIndex v, int w) : var(v), weight(w) {} + + VariableIndex var; + int weight; +}; + +// Lexical order, by larger weight, then by smaller variable number +// to break ties +struct VariableWithLargerWeightFirst { + bool operator()(const WeightedVariable& wv1, + const WeightedVariable& wv2) const { + return (wv1.weight > wv2.weight || + (wv1.weight == wv2.weight && wv1.var < wv2.var)); + } +}; +} // namespace. + +// This function allows a conflict variable to be replaced by another variable +// not originally in the conflict. Greater reduction and backtracking can be +// achieved this way, but the effect of this is not clear. +// +// TODO(user): More investigation needed. This seems to help on the Hanoi +// problems, but degrades performance on others. +// +// TODO(user): Find a reference for this? neither minisat nor glucose do that, +// they just do MinimizeConflictRecursively() with a different implementation. +// Note that their behavior also make more sense with the way they (and we) bump +// the variable activities. +void SatSolver::MinimizeConflictExperimental(std::vector* conflict) { + SCOPED_TIME_STAT(&stats_); + + // First, sort the variables in the conflict by decreasing decision levels. + // Also initialize is_marked_ to true for all conflict variables. + is_marked_.ClearAndResize(num_variables_); + const int current_level = CurrentDecisionLevel(); + std::vector variables_sorted_by_level; + for (Literal literal : *conflict) { + const VariableIndex var = literal.Variable(); + is_marked_.Set(var); + const int level = DecisionLevel(var); + if (level < current_level) { + variables_sorted_by_level.push_back(WeightedVariable(var, level)); + } + } + std::sort(variables_sorted_by_level.begin(), variables_sorted_by_level.end(), + VariableWithLargerWeightFirst()); + + // Then process the reason of the variable with highest level first. + std::vector to_remove; + for (WeightedVariable weighted_var : variables_sorted_by_level) { + const VariableIndex var = weighted_var.var; + + // A nullptr reason means that this was a decision variable from the + // previous levels. + const ClauseRef reason = Reason(var); + if (reason.IsEmpty()) continue; + + // Compute how many and which literals from the current reason do not appear + // in the current conflict. Level 0 literals are ignored. + std::vector not_contained_literals; + for (const Literal reason_literal : reason) { + const VariableIndex reason_var = reason_literal.Variable(); + + // We ignore level 0 variables. + if (DecisionLevel(reason_var) == 0) continue; + + // We have a reason literal whose variable is not yet seen. + // If there is more than one, break right away, we will not minimize the + // current conflict with this variable. + if (!is_marked_[reason_var]) { + not_contained_literals.push_back(reason_literal); + if (not_contained_literals.size() > 1) break; + } + } + if (not_contained_literals.empty()) { + // This variable will be deleted from the conflict. Note that we don't + // unmark it. This is because this variable can be infered from the other + // variables in the conflict, so it is okay to skip it when processing the + // reasons of other variables. + to_remove.push_back(var); + } else if (not_contained_literals.size() == 1) { + // Replace the literal from variable var with the only + // not_contained_literals from the current reason. + to_remove.push_back(var); + is_marked_.Set(not_contained_literals.front().Variable()); + conflict->push_back(not_contained_literals.front()); + } + } + + // Unmark the variable that should be removed from the conflict. + for (VariableIndex var : to_remove) { + is_marked_.Clear(var); + } + + // Remove the now unmarked literals from the conflict. + int index = 0; + for (int i = 0; i < conflict->size(); ++i) { + const Literal literal = (*conflict)[i]; + if (is_marked_[literal.Variable()]) { + (*conflict)[index] = literal; + ++index; + } + } + conflict->erase(conflict->begin() + index, conflict->end()); +} + +namespace { + +// Order the clause by increasing LBD (Literal Blocks Distance) first. For the +// same LBD they are ordered by decreasing activity. +bool ClauseOrdering(SatClause* a, SatClause* b) { + if (a->Lbd() == b->Lbd()) return a->Activity() > b->Activity(); + return a->Lbd() < b->Lbd(); +} + +} // namespace + +void SatSolver::InitLearnedClauseLimit() { + const int num_learned_clauses = learned_clauses_.size(); + target_number_of_learned_clauses_ = + num_learned_clauses + parameters_.clause_cleanup_increment(); + num_learned_clause_before_cleanup_ = + target_number_of_learned_clauses_ / parameters_.clause_cleanup_ratio() - + num_learned_clauses; + VLOG(1) << "reduced learned database to " << num_learned_clauses + << " clauses. Next cleanup in " << num_learned_clause_before_cleanup_ + << " conflicts."; +} + +void SatSolver::CompressLearnedClausesIfNeeded() { + if (num_learned_clause_before_cleanup_ > 0) return; + SCOPED_TIME_STAT(&stats_); + + // First time? + if (learned_clauses_.size() == 0) { + InitLearnedClauseLimit(); + return; + } + + // Move the clause that should be kept at the beginning and sort the other + // using the ClauseOrdering order. + std::vector::iterator clause_to_keep_end = std::partition( + learned_clauses_.begin(), learned_clauses_.end(), + std::bind1st(std::mem_fun(&SatSolver::ClauseShouldBeKept), this)); + std::sort(clause_to_keep_end, learned_clauses_.end(), ClauseOrdering); + + // Compute the index of the first clause to delete. + const int num_learned_clauses = learned_clauses_.size(); + const int first_clause_to_delete = + std::max(static_cast(clause_to_keep_end - learned_clauses_.begin()), + std::min(num_learned_clauses, target_number_of_learned_clauses_)); + + // Delete all the learned clause after 'first_clause_to_delete'. + for (int i = first_clause_to_delete; i < num_learned_clauses; ++i) { + SatClause* clause = learned_clauses_[i]; + watched_clauses_.LazyDetach(clause); + if (clause->ResolutionNodePointer() != nullptr) { + unsat_proof_.UnlockNode(clause->ResolutionNodePointer()); + } + } + watched_clauses_.CleanUpWatchers(); + for (int i = first_clause_to_delete; i < num_learned_clauses; ++i) { + counters_.num_literals_forgotten += learned_clauses_[i]->Size(); + delete learned_clauses_[i]; + } + learned_clauses_.resize(first_clause_to_delete); + InitLearnedClauseLimit(); +} + +bool SatSolver::ShouldRestart() { + SCOPED_TIME_STAT(&stats_); + if (conflicts_until_next_restart_ != 0) return false; + restart_count_++; + conflicts_until_next_restart_ = + parameters_.restart_period() * SUniv(restart_count_ + 1); + return true; +} + +void SatSolver::InitRestart() { + SCOPED_TIME_STAT(&stats_); + restart_count_ = 0; + if (parameters_.restart_period() > 0) { + DCHECK_EQ(SUniv(1), 1); + conflicts_until_next_restart_ = parameters_.restart_period(); + } else { + conflicts_until_next_restart_ = -1; + } +} + } // namespace sat } // namespace operations_research diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 3352c777fa..5764a3dc59 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -32,8 +32,10 @@ #include "base/int_type.h" #include "base/random.h" #include "sat/pb_constraint.h" +#include "sat/clause.h" #include "sat/sat_base.h" #include "sat/sat_parameters.pb.h" +#include "sat/unsat_proof.h" #include "util/bitset.h" #include "util/stats.h" #include "base/adjustable_priority_queue.h" @@ -41,363 +43,6 @@ namespace operations_research { namespace sat { -// Forward declarations. -// TODO(user): This cyclic dependency can be relatively easily removed. -class LiteralWatchers; - -// Returns the ith element of the strategy S^univ proposed by M. Luby et al. in -// Optimal Speedup of Las Vegas Algorithms, Information Processing Letters 1993. -// This is used to decide the number of conflicts allowed before the next -// restart. This method, used by most SAT solvers, is usually referenced as -// Luby. -// Returns 2^{k-1} when i == 2^k - 1 -// and SUniv(i - 2^{k-1} + 1) when 2^{k-1} <= i < 2^k - 1. -// The sequence is defined for i > 0 and starts with: -// {1, 1, 2, 1, 1, 2, 4, 1, 1, 2, 1, 1, 2, 4, 8, ...} -inline int SUniv(int i) { - DCHECK_GT(i, 0); - while (i > 2) { - const int most_significant_bit_position = - MostSignificantBitPosition64(i + 1); - if ((1 << most_significant_bit_position) == i + 1) { - return 1 << (most_significant_bit_position - 1); - } - i -= (1 << most_significant_bit_position) - 1; - } - return 1; -} - -// Variable information. This is updated each time we attach/detach a clause. -struct VariableInfo { - VariableInfo() - : num_positive_clauses(0), - num_negative_clauses(0), - num_appearances(0), - weighted_num_appearances(0.0) {} - - int num_positive_clauses; - int num_negative_clauses; - int num_appearances; - double weighted_num_appearances; -}; - -// Priority queue element to support variable ordering by larger weight first. -struct WeightedVarQueueElement { - WeightedVarQueueElement() - : heap_index(-1), weight(0.0), tie_breaker(0.0), variable(-1) {} - - // Interface for the AdjustablePriorityQueue. - void SetHeapIndex(int h) { heap_index = h; } - int GetHeapIndex() const { return heap_index; } - - // Priority order. - bool operator<(const WeightedVarQueueElement& other) const { - return weight < other.weight || - (weight == other.weight && - (tie_breaker < other.tie_breaker || - (tie_breaker == other.tie_breaker && variable < other.variable))); - } - - int heap_index; - double weight; - double tie_breaker; - VariableIndex variable; -}; - -// This is how the SatSolver store a clause. A clause is just a disjunction of -// literals. In many places, we just use std::vector to encode one. However, -// the solver needs to keep a few extra fields attached to each clause. -class SatClause { - public: - // Creates a sat clause. There must be at least 2 literals. - // Smaller clause are treated separatly and never constructed. - enum ClauseType { - PROBLEM_CLAUSE, - LEARNED_CLAUSE, - }; - static SatClause* Create(const std::vector& literals, ClauseType type); - - // Number of literals in the clause. - int Size() const { return size_; } - - // Allows for range based iteration: for (Literal literal : clause) {}. - const Literal* const begin() const { return &(literals_[0]); } - const Literal* const end() const { return &(literals_[size_]); } - - // Returns a ClauseRef that point to this clause. - ClauseRef ToClauseRef() const { return ClauseRef(begin(), end()); } - - // Returns the first and second literals. These are always the watched - // literals if the clause is attached in the LiteralWatchers. - Literal FirstLiteral() const { return literals_[0]; } - Literal SecondLiteral() const { return literals_[1]; } - - // Tries to simplify the clause. - enum SimplifyStatus { - CLAUSE_ALWAYS_TRUE, - CLAUSE_ALWAYS_FALSE, - CLAUSE_SUBSUMED, - CLAUSE_ACTIVE, - }; - SimplifyStatus Simplify(); - - // Removes literals that are fixed. This should only be called at level 0 - // where a literal is fixed iff it is assigned. Aborts and returns true if - // they are not all false. - bool RemoveFixedLiteralsAndTestIfTrue(const VariablesAssignment& assignment); - - // Propagates watched_literal which just became false in the clause. Returns - // false if an inconsistency was detected. - // - // IMPORTANT: If a new literal needs watching instead, then FirstLiteral() - // will be the new watched literal, otherwise it will be equal to the given - // watched_literal. - bool PropagateOnFalse(Literal watched_literal, Trail* trail); - - // True if the clause is learned. - bool IsLearned() const { return is_learned_; } - - // Returns true if the clause is satisfied for the given assignment. - bool IsSatisfied(const VariablesAssignment& assignment) const; - - // Sorts the literals of the clause depending on the given parameters and - // statistics. Do not call this on an attached clause. - void SortLiterals(const ITIVector& statistics, - const SatParameters& parameters); - - // Sets up the 2-watchers data structure. It selects two non-false literals - // and attaches the clause to the event: one of the watched literals become - // false. It returns false if the clause only contains literals assigned to - // false. If only one literals is not false, it propagates it to true if it - // is not already assigned. - bool AttachAndEnqueuePotentialUnitPropagation(Trail* trail, - LiteralWatchers* demons); - - // Modify and get the clause activity. - void IncreaseActivity(double increase) { activity_ += increase; } - void MultiplyActivity(double factor) { activity_ *= factor; } - double Activity() const { return activity_; } - - // Set and get the clause LBD (Literal Blocks Distance). The LBD is not - // computed here. See ComputeClauseLbd() in SatSolver. - void SetLbd(int value) { lbd_ = value; } - int Lbd() const { return lbd_; } - - // Returns true if the clause is attached to a LiteralWatchers. - bool IsAttached() const { return is_attached_; } - - // Marks the clause so that the next call to CleanUpWatchers() can identify it - // and actually detach it. - void LazyDetach() { is_attached_ = false; } - - std::string DebugString() const; - - private: - // The data is packed so that only 16 bytes are used for these fields. - // Note that the max lbd is the maximum depth of the search tree (decision - // levels), so it should fit easily in 29 bits. Note that we can also upper - // bound it without hurting too much the clause cleaning heuristic. - bool is_learned_ : 1; - bool is_attached_ : 1; - int lbd_ : 30; - int size_ : 32; - double activity_; - - // This class store the literals inline, and literals_ mark the starts of the - // variable length portion. - Literal literals_[0]; - - DISALLOW_COPY_AND_ASSIGN(SatClause); -}; - -// ----- LiteralWatchers ----- - -// Stores the 2-watched literals data structure. See -// http://www.cs.berkeley.edu/~necula/autded/lecture24-sat.pdf for -// detail. -class LiteralWatchers { - public: - LiteralWatchers(); - ~LiteralWatchers(); - - // Resizes the data structure. - void Resize(int num_variables); - - // Attaches the given clause. This eventually propagates a literal which is - // enqueued on the trail. Returns false if a contradiction was encountered. - bool AttachAndPropagate(SatClause* clause, Trail* trail); - - // Attaches the given clause to the event: the given literal becomes false. - // The blocking_literal can be any literal from the clause, it is used to - // speed up PropagateOnFalse() by skipping the clause if it is true. - void AttachOnFalse(Literal literal, Literal blocking_literal, - SatClause* clause); - - // Lazily detach the given clause. The deletion will actually occur when - // CleanUpWatchers() is called. The later needs to be called before any other - // function in this class can be called. This is DCHECKed. - void LazyDetach(SatClause* clause); - void CleanUpWatchers(); - - // Launches all propagation when the given literal becomes false. - // Returns false if a contradiction was encountered. - bool PropagateOnFalse(Literal false_literal, Trail* trail); - - // Total number of clauses inspected during calls to PropagateOnFalse(). - int64 num_inspected_clauses() const { return num_inspected_clauses_; } - - // Number of clauses currently watched. - int64 num_watched_clauses() const { return num_watched_clauses_; } - - // Returns some statistics on the number of appearance of this variable in - // all the attached clauses. - const VariableInfo& VariableStatistic(VariableIndex var) const { - return statistics_[var]; - } - - // Parameters management. - void SetParameters(const SatParameters& parameters) { - parameters_ = parameters; - } - - private: - // Updates statistics_ for the literals in the given clause. added indicates - // if we are adding the clause or deleting it. - void UpdateStatistics(const SatClause& clause, bool added); - - // Contains, for each literal, the list of clauses that need to be inspected - // when the corresponding literal becomes false. - struct Watcher { - Watcher() {} - Watcher(SatClause* c, Literal b) : clause(c), blocking_literal(b) {} - SatClause* clause; - Literal blocking_literal; - }; - ITIVector > watchers_on_false_; - - // Indicates if the corresponding watchers_on_false_ list need to be - // cleaned. The boolean is_clean_ is just used in DCHECKs. - ITIVector needs_cleaning_; - bool is_clean_; - - ITIVector statistics_; - SatParameters parameters_; - int64 num_inspected_clauses_; - int64 num_watched_clauses_; - mutable StatsGroup stats_; - DISALLOW_COPY_AND_ASSIGN(LiteralWatchers); -}; - -// Special class to store and propagate clauses of size 2 (i.e. implication). -// Such clauses are never deleted. -// -// TODO(user): All the variables in a strongly connected component are -// equivalent and can be thus merged as one. This is relatively cheap to compute -// from time to time (linear complexity). We will also get contradiction (a <=> -// not a) this way. -// -// TODO(user): An implication (a => not a) implies that a is false. I am not -// sure it is worth detecting that because if the solver assign a to true, it -// will learn that right away. I don't think we can do it faster. -// -// TODO(user): The implication graph can be pruned. This is called the -// transitive reduction of a graph. For instance If a => {b,c} and b => {c}, -// then there is no need to store a => {c}. The transitive reduction is unique -// on an acyclic graph. Computing it will allow for a faster propagation and -// memory reduction. It is however not cheap. Maybe simple lazy heuristics to -// remove redundant arcs are better. Note that all the learned clauses we add -// will never be redundant (but they could introduce cycles). -// -// TODO(user): Add a preprocessor to remove duplicates in the implication lists. -// Note that all the learned clauses we had will never create duplicates. -// -// References for most of the above TODO and more: -// - Brafman RI, "A simplifier for propositional formulas with many binary -// clauses", IEEE Trans Syst Man Cybern B Cybern. 2004 Feb;34(1):52-9. -// http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.28.4911 -// - Marijn J. H. Heule, Matti Järvisalo, Armin Biere, "Efficient CNF -// Simplification Based on Binary Implication Graphs", Theory and Applications -// of Satisfiability Testing - SAT 2011, Lecture Notes in Computer Science -// Volume 6695, 2011, pp 201-215 -// http://www.cs.helsinki.fi/u/mjarvisa/papers/heule-jarvisalo-biere.sat11.pdf -class BinaryImplicationGraph { - public: - BinaryImplicationGraph() - : num_propagations_(0), - num_minimization_(0), - num_literals_removed_(0), - stats_("BinaryImplicationGraph") {} - ~BinaryImplicationGraph() { - IF_STATS_ENABLED(LOG(INFO) << stats_.StatString()); - } - - // Resizes the data structure. - void Resize(int num_variables); - - // Adds the binary clause (a OR b), which is the same as (not a => b). - // Note that it is also equivalent to (not b => a). - void AddBinaryClause(Literal a, Literal b); - - // Same as AddBinaryClause() but enqueues a possible unit propagation. - void AddBinaryConflict(Literal a, Literal b, Trail* trail); - - // Propagates all the direct implications of the given literal becoming true. - // Returns false if a conflict was encountered, in which case - // trail->SetFailingClause() will be called with the correct size 2 clause. - // This calls trail->Enqueue() on the newly assigned literals. - bool PropagateOnTrue(Literal true_literal, Trail* trail); - - // Uses the binary implication graph to minimize the given clause by removing - // literals that implies others. - // - // TODO(user): The current algorithm is minimalist, and just look at direct - // implication. Investigate recursive version. - void MinimizeClause(const Trail& trail, std::vector* clause); - - // This must only be called at decision level 0 after all the possible - // propagations. It: - // - Removes the variable at true from the implications lists. - // - Frees the propagation list of the assigned literals. - void RemoveFixedVariables(const VariablesAssignment& assigment); - - // Number of literal propagated by this class (including conflicts). - int64 num_propagations() const { return num_propagations_; } - - // MinimizeClause() stats. - int64 num_minimization() const { return num_minimization_; } - int64 num_literals_removed() const { return num_literals_removed_; } - - // Returns the number of current implications. - int64 NumberOfImplications() const { - int num = 0; - for (const std::vector& v : implications_) num += v.size(); - return num / 2; - } - - private: - // This is indexed by the Index() of a literal. Each list stores the - // literals that are implied if the index literal becomes true. - ITIVector > implications_; - - // Holds the last conflicting binary clause. - Literal temporary_clause_[2]; - - // Some stats. - int64 num_propagations_; - int64 num_minimization_; - int64 num_literals_removed_; - - // Bitset used by MinimizeClause(). - // TODO(user): use the same one as the one used in the classic minimization - // because they are already initialized. Moreover they contains more - // information. - SparseBitset is_marked_; - SparseBitset is_removed_; - - mutable StatsGroup stats_; - DISALLOW_COPY_AND_ASSIGN(BinaryImplicationGraph); -}; - // A constant used by the EnqueueDecision*() API. const int kUnsatTrailIndex = -1; @@ -416,10 +61,13 @@ class SatSolver { // Increases the number of variables of the current problem. void SetNumVariables(int num_variables); + // Fixes a variable so that the given literal is true. This can be used to + // solve a subproblem where some variables are fixed. Note that it is more + // efficient to add such unit clause before all the others. + bool AddUnitClause(Literal true_literal); + // Adds a clause to the problem. Returns false if the clause is always false // and thus make the problem unsatisfiable. - // - // TODO(user): Remove this from the API and only use AddLinearConstraint()? bool AddProblemClause(const std::vector& literals); // Adds a pseudo-Boolean constraint to the problem. Returns false if the @@ -432,11 +80,28 @@ class SatSolver { // of the problem more and more. Just re-adding such constraint is relatively // efficient. // - // TODO(user): Add error handling for overflow/underflow. + // OVERFLOW: The sum of the absolute value of all the coefficients + // in the constraint must not overflow. This is currently CHECKed(). + // TODO(user): Instead of failing, implement an error handling code. bool AddLinearConstraint(bool use_lower_bound, Coefficient lower_bound, bool use_upper_bound, Coefficient upper_bound, std::vector* cst); + // Advanced usage. This is only relevant when trying to compute an unsat core. + // All the constraints added by one of the Add*() function above when this was + // set to true will be considered for the core. All the others will just be + // ignored (and thus save memory during the solve). This starts with a value + // of true. + void SetNextConstraintsRelevanceForUnsatCore(bool value) { + is_relevant_for_core_computation_ = value; + } + + // Returns the number of time AddProblemClause() or AddLinearConstraint() was + // called. This will also be the unique index associated to the next + // constraint that will be added. This unique index is used by UnsatCore() to + // indicates what constraints are part of the core. + int NumConstraints() { return num_constraints_; } + // Gives a hint so the solver tries to find a solution with the given literal // sets to true. The weight is a positive number reflecting the relative // importance between multiple calls to SetAssignmentPreference(). @@ -473,6 +138,16 @@ class SatSolver { }; Status Solve(); + // Returns an UNSAT core. That is a subset of the problem clauses that are + // still UNSAT. A problem constraint of index #i is the one that was added + // with the i-th call to AddProblemClause() or AddLinearConstraint(), see + // NumConstraints(). + // + // Preconditions: + // - Solve() must be called with the parameters unsat_proof() set to true. + // - It must have returned MODEL_UNSAT. + void ComputeUnsatCore(std::vector* core); + // Returns true if a given assignment is a solution of the current problem. // TODO(user): This currently only check normal clauses. Fix it to include // binary clauses and linear constraints. @@ -588,9 +263,10 @@ class SatSolver { IsClauseUsedAsReason(clause); } - // Returns false if the literal is already assigned to false. - // Otherwise, returns true and Enqueue it if it is unassigned. - bool TestValidityAndEnqueueIfNeeded(Literal literal); + // Add a problem clause. Not that the clause is assumed to be "cleaned", that + // is no duplicate variables (not strictly required) and not empty. + bool AddProblemClauseInternal(const std::vector& literals, + ResolutionNode* node); // This is used by all the Add*LinearConstraint() functions. It detects // infeasible/trivial constraints or clause constraints and takes the proper @@ -603,7 +279,7 @@ class SatSolver { // literals of the learned close except one will be false. Thus the last one // will be implied True. This function also Enqueue() the implied literal. void AddLearnedClauseAndEnqueueUnitPropagation( - const std::vector& literals); + const std::vector& literals, ResolutionNode* node); // Creates a new decision which corresponds to setting the given literal to // True and Enqueue() this change. @@ -620,6 +296,11 @@ class SatSolver { // and add them to the priority queue with the correct weight. void Untrail(int trail_index); + // Update the resolution node associated to all the newly fixed variables so + // each node expresses the reason why this variable was assigned. This is + // needed because level zero variables are treated differently by the solver. + void ProcessNewlyFixedVariableResolutionNodes(); + // Simplifies the problem when new variables are assigned at level 0. void ProcessNewlyFixedVariables(); @@ -635,15 +316,29 @@ class SatSolver { // learning in a boolean satisfiability solver" Proceedings of the 2001 // IEEE/ACM international conference on Computer-aided design, Pages 279-285. // http://www.cs.tau.ac.il/~msagiv/courses/ATP/iccad2001_final.pdf - void ComputeFirstUIPConflict(ClauseRef failing_clause, - std::vector* conflict, - std::vector* discarded_last_level_literals); + void ComputeFirstUIPConflict( + ClauseRef failing_clause, std::vector* conflict, + std::vector* reason_used_to_infer_the_conflict); + + // Creates the root resolution node associated with the current constraint. + // This will returns nullptr if the solver is not configured to compute unsat + // core or if the current constraint is not relevant for the core computation. + ResolutionNode* CreateRootResolutionNode(); + + // Creates a ResolutionNode associated to a learned conflict. Basically, the + // node will hold the information that the learned clause can be derived from + // the conflict clause and all the reason that where used during the + // computation of the first uip conflict. + ResolutionNode* CreateResolutionNode( + ResolutionNode* failing_clause_resolution_node, + ClauseRef reason_used_to_infer_the_conflict); // Applies some heuristics to a conflict in order to minimize its size and/or // replace literals by other literals from lower decision levels. The first // function choose which one of the other functions to call depending on the // parameters. - void MinimizeConflict(std::vector* conflict); + void MinimizeConflict(std::vector* conflict, + std::vector* reason_used_to_infer_the_conflict); void MinimizeConflictExperimental(std::vector* conflict); void MinimizeConflictSimple(std::vector* conflict); void MinimizeConflictRecursively(std::vector* conflict); @@ -722,6 +417,9 @@ class SatSolver { VariableIndex num_variables_; + // The number of constraints of the initial problem that where added. + int num_constraints_; + // Original clauses of the problem and clauses learned during search. // These vector have the ownership of the pointers. We currently do not use // std::unique_ptr because it can't be used with STL algorithm @@ -793,6 +491,28 @@ class SatSolver { // Variable ordering (priority will be adjusted dynamically). The variable in // the queue are said to be active. queue_elements_ holds the elements used by // var_ordering_ (it uses pointers). + struct WeightedVarQueueElement { + WeightedVarQueueElement() + : heap_index(-1), weight(0.0), tie_breaker(0.0), variable(-1) {} + + // Interface for the AdjustablePriorityQueue. + void SetHeapIndex(int h) { heap_index = h; } + int GetHeapIndex() const { return heap_index; } + + // Priority order. The AdjustablePriorityQueue returns the largest element + // first. + bool operator<(const WeightedVarQueueElement& other) const { + return weight < other.weight || + (weight == other.weight && + (tie_breaker < other.tie_breaker || + (tie_breaker == other.tie_breaker && variable < other.variable))); + } + + int heap_index; + double weight; + double tie_breaker; + VariableIndex variable; + }; AdjustablePriorityQueue var_ordering_; ITIVector queue_elements_; @@ -840,17 +560,61 @@ class SatSolver { DEFINE_INT_TYPE(SatDecisionLevel, int); SparseBitset is_level_marked_; + // Temporary vectors used by EnqueueDecisionAndBackjumpOnConflict(). + std::vector learned_conflict_; + std::vector reason_used_to_infer_the_conflict_; + // "cache" to avoid inspecting many times the same reason during conflict // analysis. PbReasonCache reason_cache_; + // Stores the resolution DAG. + // This is only used is parameters_.unsat_proof() is true. + UnsatProof unsat_proof_; + // A random number generator. mutable MTRandom random_; + // Temporary vector used by AddProblemClause(). + std::vector tmp_pb_constraint_; + + // List of nodes that will need to be unlocked when this class is destructed. + // TODO(user): This is currently used for the pseudo-Boolean constraint + // resolution nodes, and is not really clean. + std::vector to_unlock_; + + // Temporary vector used by CreateResolutionNode(). + std::vector tmp_parents_; + + // Boolean used to include/exclude constraints from the core computation. + bool is_relevant_for_core_computation_; + mutable StatsGroup stats_; DISALLOW_COPY_AND_ASSIGN(SatSolver); }; +// Returns the ith element of the strategy S^univ proposed by M. Luby et al. in +// Optimal Speedup of Las Vegas Algorithms, Information Processing Letters 1993. +// This is used to decide the number of conflicts allowed before the next +// restart. This method, used by most SAT solvers, is usually referenced as +// Luby. +// Returns 2^{k-1} when i == 2^k - 1 +// and SUniv(i - 2^{k-1} + 1) when 2^{k-1} <= i < 2^k - 1. +// The sequence is defined for i > 0 and starts with: +// {1, 1, 2, 1, 1, 2, 4, 1, 1, 2, 1, 1, 2, 4, 8, ...} +inline int SUniv(int i) { + DCHECK_GT(i, 0); + while (i > 2) { + const int most_significant_bit_position = + MostSignificantBitPosition64(i + 1); + if ((1 << most_significant_bit_position) == i + 1) { + return 1 << (most_significant_bit_position - 1); + } + i -= (1 << most_significant_bit_position) - 1; + } + return 1; +} + } // namespace sat } // namespace operations_research diff --git a/src/sat/unsat_proof.cc b/src/sat/unsat_proof.cc new file mode 100644 index 0000000000..09d9eee0e6 --- /dev/null +++ b/src/sat/unsat_proof.cc @@ -0,0 +1,169 @@ +// Copyright 2010-2013 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "sat/unsat_proof.h" + +namespace operations_research { +namespace sat { + +// A node of the resolution DAG. +struct ResolutionNode { + public: + // Constructor for the root nodes. + ResolutionNode() + : is_locked_(true), + is_problem_node_(true), + is_marked_(false), + ref_count_(1), + parents_() {} + + // Constructor for the inner nodes. + // We use a swap based constructor to avoid a copy. + explicit ResolutionNode(std::vector* to_swap) + : is_locked_(true), + is_problem_node_(false), + is_marked_(false), + ref_count_(1) { + CHECK(!to_swap->empty()); + to_swap->swap(parents_); + for (ResolutionNode* node : parents_) { + CHECK(node->IsLocked()); + ++(node->ref_count_); + } + } + + // We just check that the object was properly cleaned. + ~ResolutionNode() { DCHECK(parents_.empty()); } + + // Returns the list of parents ResolutionNode pointer. + const std::vector& parents() const { return parents_; } + + // Decrements the reference counter of this node and returns true if it + // reaches zero. In the later case, the object ownership is transfered to the + // caller who is free to delete it. Nodes on which DecrementReferenceCounter() + // must be called are appended to to_decrement. + bool DecrementReferenceCounter(std::vector* to_decrement) { + CHECK_GT(ref_count_, 0); + // Nothing to do if the reference count is still positive. + --ref_count_; + if (ref_count_ > 0) return false; + for (ResolutionNode* node : parents_) { + to_decrement->push_back(node); + } + parents_.clear(); + return true; + } + + // Setter/Getter for a boolean marker. + bool MarkAndTestIfFirstTime() { + if (is_marked_) return false; + is_marked_ = true; + return true; + } + void ClearMark() { is_marked_ = false; } + + bool IsProblemNode() const { return is_problem_node_; } + bool IsLocked() const { return is_locked_; } + void Unlock() { is_locked_ = false; } + + private: + // Indicates if this node is "locked". That means it is referenced from + // outside the UnsatProof classes and as such it can be deleted. + bool is_locked_ : 1; + + // Indicates if this node correspond to a problem node or not. + bool is_problem_node_ : 1; + + // Marker used by algorithms traversing the DAG of ResolutionNode. + bool is_marked_ : 1; + + // Number of ResolutionNodePointer pointing to this ResolutionNode. + // This is used to implement a reference counting and delete this object when + // the count reach 0. We do not use shared_ptr for two reasons: + // - Its size is the one of 2 pointers which is too much. + // - Since our nodes form a DAG which is potentially very deep, it may cause + // too much recursive call between the destructors. + int32 ref_count_; + + // The clause corresponding to this Resolution node can be derived from the + // clauses corresponding to the parents by the "resolution rule" (or + // subsumption): (A v x) and (B v not(x)) => A v B. + // + // The parents are stored in order so that we start by the first parent clause + // and then resolve it by each of the following clause in order. + std::vector parents_; + + DISALLOW_COPY_AND_ASSIGN(ResolutionNode); +}; + +UnsatProof::~UnsatProof() { + CHECK_EQ(num_nodes_, 0); +} + +ResolutionNode* UnsatProof::CreateNewRootNode(int clause_index) { + ++num_nodes_; + ResolutionNode* node = new ResolutionNode(); + root_node_to_clause_index_[node] = clause_index; + return node; +} + +ResolutionNode* UnsatProof::CreateNewResolutionNode( + std::vector* to_swap) { + ++num_nodes_; + ResolutionNode* node = new ResolutionNode(to_swap); + CHECK(!node->parents().empty()); + return node; +} + +void UnsatProof::UnlockNode(ResolutionNode* node) { + if (node == nullptr) return; + CHECK(node->IsLocked()) << "Node already released!"; + node->Unlock(); + node_stack_.clear(); + node_stack_.push_back(node); + while (!node_stack_.empty()) { + ResolutionNode* current_node = node_stack_.back(); + node_stack_.pop_back(); + if (current_node->DecrementReferenceCounter(&node_stack_)) { + --num_nodes_; + delete current_node; + } + } +} + +void UnsatProof::ComputeUnsatCore( + ResolutionNode* final_node, std::vector* core) const { + core->clear(); + to_unmark_.clear(); + node_stack_.assign(1, final_node); + while (!node_stack_.empty()) { + const ResolutionNode* current_node = node_stack_.back(); + node_stack_.pop_back(); + for (ResolutionNode* node : current_node->parents()) { + if (node->MarkAndTestIfFirstTime()) { + to_unmark_.push_back(node); + if (node->IsProblemNode()) { + core->push_back(root_node_to_clause_index_.find(node)->second); + } + if (!node->parents().empty()) { + node_stack_.push_back(node); + } + } + } + } + + // Clean after us. + for (ResolutionNode* node : to_unmark_) node->ClearMark(); +} + +} // namespace sat +} // namespace operations_research diff --git a/src/sat/unsat_proof.h b/src/sat/unsat_proof.h new file mode 100644 index 0000000000..edfbfe1d26 --- /dev/null +++ b/src/sat/unsat_proof.h @@ -0,0 +1,118 @@ +// Copyright 2010-2013 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// This file contains the in-memory data structure used by the SAT solver to +// generate unsatisfiability proof and UNSAT cores. +// +// Good references for the algorithm used are: +// - Roberto A., Robert N., Albert O., Enric R.-C. "Efficient Generation of +// Unsatisfiability Proofs and Cores in SAT", +// http://www.lsi.upc.edu/~oliveras/espai/papers/lpar08.pdf +// - Paul Beame, Henry Kautz, Ashish Sabharwal, "Understanding the Power of +// Clause Learning", https://www.cs.rochester.edu/~kautz/papers/learnIjcai.pdf +// - TraceCheck: http://fmv.jku.at/tracecheck/index.html + +#ifndef OR_TOOLS_SAT_UNSAT_PROOF_H_ +#define OR_TOOLS_SAT_UNSAT_PROOF_H_ + +#include "base/hash.h" +#include + +#include "base/hash.h" +#include "sat/sat_base.h" + +namespace operations_research { +namespace sat { + +// Forward declaration. A client only need to manipulates pointer to this class +// which is defined in the .cc +class ResolutionNode; + +// An UNSAT resolution proof will be given as a Directed Acyclic Graph (DAG) of +// clauses. Each clause corresponds to a ResolutionNode. Nodes without parent +// correspond to initial problems clauses. The other nodes correspond to new +// clauses that can be infered from its parents using the basic "resolution +// rule" or subsumption: (A v x) and (B v not(x)) => A v B. +// +// The order of the parents of each node will be such that we can reconstruct +// the clause associated to it by starting by the first parent clause and then +// resolving it by each of the following clause in order. There will be only +// one way to perform each resolution. +class UnsatProof { + public: + UnsatProof() : num_nodes_(0) {} + ~UnsatProof(); + + // Creates a new root node corresponding to an original problem clause with + // given index. UnlockNode() will need to be called before this class is + // deleted, see the comment there. + ResolutionNode* CreateNewRootNode(int clause_index); + + // Creates a new ResolutionNode with given parents. The vector parents must + // not be empty and will be swapped with an empty vector. UnlockNode() will + // need to be called before this class is deleted, see the comment there. Note + // that we check that all the given parents are locked. + // + // For CheckUnsatProof() to work, the parents must be provided as decribed in + // the top level comment of this class. It is possible to remove this + // restriction, but it is a small price to pay for the SAT solver and it + // simplifies the code of CheckUnsatProof(). + ResolutionNode* CreateNewResolutionNode(std::vector* parents); + + // Unlocks the given node so it can be deleted if it is not used as a parents + // to any other node. This can only be called on locked node (there is a + // check). + // + // The idea is that the SAT solver can call UnlockNode() as soon as it known + // that the node can't be used directly to infer another clause. This way, + // this class may be able to free up some memory. + void UnlockNode(ResolutionNode* node); + + // Returns the number of ResolutionNode currently stored by this class. + // Nodes that where deleted are not counted. + int NumNodes() const { return num_nodes_; } + + // Returns the set of original clause indices (the one provided to + // CreateNewRootNode()) from which we can deduce the clause corresponding to + // the given final_node. If final_node is associated with the the empty + // conflict, this will return an UNSAT core. + void ComputeUnsatCore(ResolutionNode* final_node, std::vector* core) const; + + // TODO(user): to implement. This will need to know the clause associated to + // each root node to be able to reconstruct the clause associated to each + // node. There is also some complications for more general constraints like + // pseudo-boolean ones because they can produce many different reason clauses + // and so we need more than one root for each of these clauses. + bool CheckUnsatProof(ResolutionNode* final_node) const; + + private: + // See NumNodes(). + int num_nodes_; + + // Temporary vector used by UnlockNode() and GetCore(). + mutable std::vector node_stack_; + + // Temporary vector used by GetCore(). + mutable std::vector to_unmark_; + + // Index to identify in the original problem the constraint corresponding to + // this root node. Note that duplicate indices are allowed which make sense + // when an original constraint was expanded into multiple clauses internally. + hash_map root_node_to_clause_index_; + + DISALLOW_COPY_AND_ASSIGN(UnsatProof); +}; + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_UNSAT_PROOF_H_ diff --git a/src/util/bitset.h b/src/util/bitset.h index 245a816a0b..326c456db0 100644 --- a/src/util/bitset.h +++ b/src/util/bitset.h @@ -405,6 +405,13 @@ class Bitset64 { Set(size_ - 1, value); } + // Resize the Bitset64 to the given number of bits. New bits are sets to 0. + void Resize(IndexType size) { + DCHECK_GE(size.value(), 0); + size_ = size > 0 ? size : IndexType(0); + data_.resize(BitLength64(size_.value()), 0); + } + // Changes the number of bits the Bitset64 can hold and set all of them to 0. void ClearAndResize(IndexType size) { DCHECK_GE(size.value(), 0); diff --git a/src/util/saturated_arithmetic.h b/src/util/saturated_arithmetic.h index 51d6934799..473077f839 100644 --- a/src/util/saturated_arithmetic.h +++ b/src/util/saturated_arithmetic.h @@ -13,11 +13,25 @@ #ifndef OR_TOOLS_UTIL_SATURATED_ARITHMETIC_H_ #define OR_TOOLS_UTIL_SATURATED_ARITHMETIC_H_ +#include + #include "base/integral_types.h" namespace operations_research { // ---------- Overflow utility functions ---------- +// Performs *b += a and returns false iff the addition overflow or underflow. +template +bool SafeAddInto(IntegerType a, IntegerType* b) { + if (a > 0) { + if (*b > std::numeric_limits::max() - a) return false; + } else { + if (*b < std::numeric_limits::min() - a) return false; + } + *b += a; + return true; +} + // A note on overflow treatment. // kint64min and kint64max are treated as infinity. // Thus if the computation overflows, the result is always kint64m(ax/in).