From 6609fac882d62671262856878cf1c69fbc7ef3f2 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Thu, 8 Jun 2017 12:33:16 +0200 Subject: [PATCH] AC all different in the SAT solver; change the way integer variables are encoded on top of boolean variables in the SAT solver; change protobuf utilities --- cmake/external/gflags.cmake | 2 +- examples/cpp/mps_driver.cc | 3 +- examples/cpp/solve.cc | 6 +- makefiles/Makefile.gen.mk | 23 +- ortools/base/inlined_vector.h | 74 +++-- ortools/base/span.h | 38 +-- ortools/glop/basis_representation.cc | 8 +- ortools/glop/entering_variable.cc | 16 +- ortools/glop/lp_solver.cc | 2 +- ortools/glop/lu_factorization.cc | 2 +- ortools/glop/markowitz.cc | 6 +- ortools/glop/rank_one_update.h | 1 - ortools/glop/revised_simplex.cc | 40 +-- ortools/glop/status.cc | 2 - ortools/glop/status.h | 2 +- ortools/sat/all_different.cc | 366 ++++++++++++++++++++++++ ortools/sat/all_different.h | 113 ++++++++ ortools/sat/clause.cc | 5 +- ortools/sat/clause.h | 5 +- ortools/sat/cp_constraints.cc | 16 +- ortools/sat/cp_model_checker.cc | 8 +- ortools/sat/cp_model_checker.h | 1 + ortools/sat/cp_model_presolve.cc | 42 ++- ortools/sat/cp_model_solver.cc | 32 ++- ortools/sat/cp_model_utils.h | 1 - ortools/sat/cumulative.cc | 1 + ortools/sat/integer.cc | 403 ++++++++++++--------------- ortools/sat/integer.h | 89 +++--- ortools/sat/integer_expr.cc | 47 ++++ ortools/sat/integer_expr.h | 32 +-- ortools/sat/optimization.cc | 2 +- ortools/sat/sat_solver.cc | 26 +- ortools/sat/sat_solver.h | 4 +- ortools/sat/table.cc | 94 +++---- ortools/util/BUILD | 6 + ortools/util/file_util.cc | 75 +++++ ortools/util/file_util.h | 123 ++++++++ ortools/util/proto_tools.cc | 49 +--- ortools/util/proto_tools.h | 16 +- ortools/util/random_engine.h | 27 ++ 40 files changed, 1227 insertions(+), 581 deletions(-) create mode 100644 ortools/sat/all_different.cc create mode 100644 ortools/sat/all_different.h create mode 100644 ortools/util/file_util.cc create mode 100644 ortools/util/file_util.h create mode 100644 ortools/util/random_engine.h diff --git a/cmake/external/gflags.cmake b/cmake/external/gflags.cmake index 625a374d5c..529a993900 100644 --- a/cmake/external/gflags.cmake +++ b/cmake/external/gflags.cmake @@ -20,7 +20,7 @@ ExternalProject_Add(Gflags_project -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON) -ADD_LIBRARY(Gflags STATIC IMPORTED) +ADD_LIBRARY(Gflags STATIC IMPORTED) SET_PROPERTY(TARGET Gflags PROPERTY IMPORTED_LOCATION ${CMAKE_CURRENT_BINARY_DIR}/gflags_project/src/gflags/lib/libgflags.a) SET(Gflags_LIBRARIES "") LIST(APPEND Gflags_LIBRARIES Gflags) diff --git a/examples/cpp/mps_driver.cc b/examples/cpp/mps_driver.cc index d523221ac3..4c6ae78e39 100644 --- a/examples/cpp/mps_driver.cc +++ b/examples/cpp/mps_driver.cc @@ -51,7 +51,6 @@ DEFINE_string(params, "", "them (i.e. in case of conflicts, --params wins)"); using operations_research::FullProtocolMessageAsString; -using operations_research::ReadFileToProto; using operations_research::glop::GetProblemStatusString; using operations_research::glop::GlopParameters; using operations_research::glop::LinearProgram; @@ -103,7 +102,7 @@ int main(int argc, char* argv[]) { continue; } } else { - ReadFileToProto(file_name, &model_proto); + file::ReadFileToProto(file_name, &model_proto); MPModelProtoToLinearProgram(model_proto, &linear_program); } if (FLAGS_mps_dump_problem) { diff --git a/examples/cpp/solve.cc b/examples/cpp/solve.cc index 3bb4684b4e..91fa3864ff 100644 --- a/examples/cpp/solve.cc +++ b/examples/cpp/solve.cc @@ -33,7 +33,7 @@ #include "ortools/lp_data/lp_data.h" #include "ortools/lp_data/mps_reader.h" #include "ortools/lp_data/proto_utils.h" -#include "ortools/util/proto_tools.h" +#include "ortools/util/file_util.h" DEFINE_string(input, "", "REQUIRED: Input file name."); DEFINE_string(solver, "glop", @@ -164,8 +164,8 @@ void Run() { LinearProgramToMPModelProto(linear_program_fixed, &model_proto); } } else { - ReadFileToProto(FLAGS_input, &model_proto); - ReadFileToProto(FLAGS_input, &request_proto); + file::ReadFileToProto(FLAGS_input, &model_proto); + file::ReadFileToProto(FLAGS_input, &request_proto); // If the input proto is in binary format, both ReadFileToProto could return // true. Instead use the actual number of variables found to test the // correct format of the input. diff --git a/makefiles/Makefile.gen.mk b/makefiles/Makefile.gen.mk index 6ae2bf2d9a..c31fae6c4d 100644 --- a/makefiles/Makefile.gen.mk +++ b/makefiles/Makefile.gen.mk @@ -271,6 +271,7 @@ UTIL_DEPS = \ UTIL_LIB_OBJS = \ $(OBJ_DIR)/util/bitset.$O \ $(OBJ_DIR)/util/cached_log.$O \ + $(OBJ_DIR)/util/file_util.$O \ $(OBJ_DIR)/util/fp_utils.$O \ $(OBJ_DIR)/util/graph_export.$O \ $(OBJ_DIR)/util/piecewise_linear_function.$O \ @@ -410,6 +411,11 @@ $(OBJ_DIR)/util/cached_log.$O: \ $(SRC_DIR)/ortools/base/logging.h $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Sutil$Scached_log.cc $(OBJ_OUT)$(OBJ_DIR)$Sutil$Scached_log.$O +$(OBJ_DIR)/util/file_util.$O: \ + $(SRC_DIR)/ortools/util/file_util.cc + $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Sutil$Sfile_util.cc $(OBJ_OUT)$(OBJ_DIR)$Sutil$Sfile_util.$O + + $(OBJ_DIR)/util/fp_utils.$O: \ $(SRC_DIR)/ortools/util/fp_utils.cc \ $(SRC_DIR)/ortools/util/bitset.h \ @@ -1385,6 +1391,7 @@ $(OBJ_DIR)/algorithms/sparse_permutation.$O: \ $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Salgorithms$Ssparse_permutation.cc $(OBJ_OUT)$(OBJ_DIR)$Salgorithms$Ssparse_permutation.$O SAT_DEPS = \ + $(SRC_DIR)/ortools/sat/all_different.h \ $(SRC_DIR)/ortools/sat/boolean_problem.h \ $(GEN_DIR)/ortools/sat/boolean_problem.pb.h \ $(SRC_DIR)/ortools/sat/clause.h \ @@ -1455,6 +1462,7 @@ SAT_DEPS = \ $(GEN_DIR)/ortools/linear_solver/linear_solver.pb.h SAT_LIB_OBJS = \ + $(OBJ_DIR)/sat/all_different.$O \ $(OBJ_DIR)/sat/boolean_problem.$O \ $(OBJ_DIR)/sat/clause.$O \ $(OBJ_DIR)/sat/cp_constraints.$O \ @@ -1713,6 +1721,20 @@ $(OBJ_DIR)/sat/boolean_problem.$O: \ $(SRC_DIR)/ortools/graph/util.h $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Ssat$Sboolean_problem.cc $(OBJ_OUT)$(OBJ_DIR)$Ssat$Sboolean_problem.$O +$(OBJ_DIR)/sat/all_different.$O: \ + $(SRC_DIR)/ortools/sat/all_different.cc \ + $(SRC_DIR)/ortools/sat/all_different.h \ + $(SRC_DIR)/ortools/base/commandlineflags.h \ + $(SRC_DIR)/ortools/base/hash.h \ + $(SRC_DIR)/ortools/base/join.h \ + $(SRC_DIR)/ortools/base/map_util.h \ + $(SRC_DIR)/ortools/base/stringprintf.h \ + $(SRC_DIR)/ortools/algorithms/find_graph_symmetries.h \ + $(SRC_DIR)/ortools/graph/graph.h \ + $(SRC_DIR)/ortools/graph/io.h \ + $(SRC_DIR)/ortools/graph/util.h + $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Ssat$Sall_different.cc $(OBJ_OUT)$(OBJ_DIR)$Ssat$Sall_different.$O + $(OBJ_DIR)/sat/clause.$O: \ $(SRC_DIR)/ortools/sat/clause.cc \ $(SRC_DIR)/ortools/sat/clause.h \ @@ -3301,4 +3323,3 @@ $(GEN_DIR)/ortools/constraint_solver/solver_parameters.pb.h: $(GEN_DIR)/ortools/ $(OBJ_DIR)/constraint_solver/solver_parameters.pb.$O: $(GEN_DIR)/ortools/constraint_solver/solver_parameters.pb.cc $(CCC) $(CFLAGS) -c $(GEN_DIR)/ortools/constraint_solver/solver_parameters.pb.cc $(OBJ_OUT)$(OBJ_DIR)$Sconstraint_solver$Ssolver_parameters.pb.$O - diff --git a/ortools/base/inlined_vector.h b/ortools/base/inlined_vector.h index 55c873f3b4..5d01cef7fa 100644 --- a/ortools/base/inlined_vector.h +++ b/ortools/base/inlined_vector.h @@ -11,34 +11,30 @@ // See the License for the specific language governing permissions and // limitations under the License. + #ifndef OR_TOOLS_BASE_INLINED_VECTOR_H_ #define OR_TOOLS_BASE_INLINED_VECTOR_H_ -// An InlinedVector is like a std::vector, except that storage +// An gtl::InlinedVector is like a std::vector, except that storage // for sequences of length <= N are provided inline without requiring // any heap allocation. Typically N is very small (e.g., 4) so that // sequences that are expected to be short do not require allocations. // // Only some of the std::vector<> operations are currently implemented. // Other operations may be added as needed to facilitate migrating -// code that uses std::vector<> to InlinedVector<>. -// -// NOTE: If you want an inlined version to replace use of a -// std::vector, consider using util::bitmap::InlinedBitVector -// in ortools/base/inlined_bitvector.h -// +// code that uses std::vector<> to gtl::InlinedVector<>. #include +#include #include #include #include -#include +#include // NOLINT(build/include_order) #include #include #include #include -#include // NOLINT(build/include_order) #include "ortools/base/logging.h" @@ -129,7 +125,7 @@ class InlinedVector { } } const_pointer data() const { - return const_cast*>(this)->data(); + return const_cast*>(this)->data(); } // Remove all elements @@ -386,7 +382,7 @@ class InlinedVector { }; // 2) Construct a T with args at not-yet-initialized memory pointed by dst. struct Construct { - template + template void operator()(T* dst, Args&&... args) const { new (dst) T(std::forward(args)...); } @@ -458,54 +454,54 @@ class InlinedVector { // Provide linkage for constants. template -const size_t InlinedVector::kSizeUnaligned; +const size_t gtl::InlinedVector::kSizeUnaligned; template -const size_t InlinedVector::kSize; +const size_t gtl::InlinedVector::kSize; template -const unsigned int InlinedVector::kSentinel; +const unsigned int gtl::InlinedVector::kSentinel; template -const size_t InlinedVector::kFit1; +const size_t gtl::InlinedVector::kFit1; template -const size_t InlinedVector::kFit; +const size_t gtl::InlinedVector::kFit; template -inline void swap(InlinedVector& a, InlinedVector& b) { +inline void swap(gtl::InlinedVector& a, gtl::InlinedVector& b) { a.swap(b); } template -inline bool operator==(const InlinedVector& a, - const InlinedVector& b) { +inline bool operator==(const gtl::InlinedVector& a, + const gtl::InlinedVector& b) { return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin()); } template -inline bool operator!=(const InlinedVector& a, - const InlinedVector& b) { +inline bool operator!=(const gtl::InlinedVector& a, + const gtl::InlinedVector& b) { return !(a == b); } template -inline bool operator<(const InlinedVector& a, - const InlinedVector& b) { +inline bool operator<(const gtl::InlinedVector& a, + const gtl::InlinedVector& b) { return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end()); } template -inline bool operator>(const InlinedVector& a, - const InlinedVector& b) { +inline bool operator>(const gtl::InlinedVector& a, + const gtl::InlinedVector& b) { return b < a; } template -inline bool operator<=(const InlinedVector& a, - const InlinedVector& b) { +inline bool operator<=(const gtl::InlinedVector& a, + const gtl::InlinedVector& b) { return !(b < a); } template -inline bool operator>=(const InlinedVector& a, - const InlinedVector& b) { +inline bool operator>=(const gtl::InlinedVector& a, + const gtl::InlinedVector& b) { return !(a < b); } @@ -513,12 +509,12 @@ inline bool operator>=(const InlinedVector& a, // Implementation template -inline InlinedVector::InlinedVector() { +inline gtl::InlinedVector::InlinedVector() { InitRep(); } template -inline InlinedVector::InlinedVector(size_t n) { +inline gtl::InlinedVector::InlinedVector(size_t n) { InitRep(); if (n > capacity()) { Grow(n); // Must use Nop in case T is not copyable @@ -528,7 +524,7 @@ inline InlinedVector::InlinedVector(size_t n) { } template -inline InlinedVector::InlinedVector(size_t n, const value_type& elem) { +inline gtl::InlinedVector::InlinedVector(size_t n, const value_type& elem) { InitRep(); if (n > capacity()) { Grow(n); // Can use Nop since we know we have nothing to copy @@ -538,13 +534,13 @@ inline InlinedVector::InlinedVector(size_t n, const value_type& elem) { } template -inline InlinedVector::InlinedVector(const InlinedVector& v) { +inline gtl::InlinedVector::InlinedVector(const InlinedVector& v) { InitRep(); *this = v; } template -typename InlinedVector::iterator InlinedVector::insert( +typename gtl::InlinedVector::iterator gtl::InlinedVector::insert( iterator pos, const value_type& v) { DCHECK_GE(pos, begin()); DCHECK_LE(pos, end()); @@ -568,7 +564,7 @@ typename InlinedVector::iterator InlinedVector::insert( } template -typename InlinedVector::iterator InlinedVector::erase( +typename gtl::InlinedVector::iterator gtl::InlinedVector::erase( iterator first, iterator last) { DCHECK_LE(begin(), first); DCHECK_LE(first, last); @@ -583,7 +579,7 @@ typename InlinedVector::iterator InlinedVector::erase( } template -void InlinedVector::swap(InlinedVector& other) { +void gtl::InlinedVector::swap(InlinedVector& other) { using std::swap; // Augment ADL with std::swap. if (&other == this) { return; @@ -636,14 +632,14 @@ void InlinedVector::swap(InlinedVector& other) { template template -inline void InlinedVector::AppendRange(Iter first, Iter last, +inline void gtl::InlinedVector::AppendRange(Iter first, Iter last, std::input_iterator_tag) { std::copy(first, last, std::back_inserter(*this)); } template template -inline void InlinedVector::AppendRange(Iter first, Iter last, +inline void gtl::InlinedVector::AppendRange(Iter first, Iter last, std::forward_iterator_tag) { typedef typename std::iterator_traits::difference_type Length; Length length = std::distance(first, last); @@ -655,7 +651,7 @@ inline void InlinedVector::AppendRange(Iter first, Iter last, template template -inline void InlinedVector::AppendRange(Iter first, Iter last) { +inline void gtl::InlinedVector::AppendRange(Iter first, Iter last) { typedef typename std::iterator_traits::iterator_category IterTag; AppendRange(first, last, IterTag()); } diff --git a/ortools/base/span.h b/ortools/base/span.h index fea74883ea..81daec3854 100644 --- a/ortools/base/span.h +++ b/ortools/base/span.h @@ -11,6 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. + #ifndef OR_TOOLS_BASE_SPAN_H_ #define OR_TOOLS_BASE_SPAN_H_ @@ -93,15 +94,13 @@ // MyMutatingRoutine(my_proto.mutable_value()); #include -#include #include #include #include "ortools/base/inlined_vector.h" namespace gtl { - -namespace array_slice_internal { +namespace internal { // Template logic for generic constructors. @@ -285,9 +284,7 @@ class SpanImplBase { if (data() == other.data()) return true; return std::equal(data(), data() + size(), other.data()); } - bool operator!=(const SpanImplBase& other) const { - return !(*this == other); - } + bool operator!=(const SpanImplBase& other) const { return !(*this == other); } private: pointer ptr_; @@ -311,7 +308,7 @@ class SpanImpl : public SpanImplBase { template explicit SpanImpl(const C& v) : SpanImplBase(ContainerData::Get(std::addressof(v)), - ContainerSize::Get(std::addressof(v))) {} + ContainerSize::Get(std::addressof(v))) {} }; template @@ -327,16 +324,14 @@ class MutableSpanImpl : public SpanImplBase { template explicit MutableSpanImpl(C* v) : SpanImplBase(ContainerMutableData::Get(v), - ContainerSize::Get(v)) {} + ContainerSize::Get(v)) {} }; - -} // namespace array_slice_internal - +} // namespace internal template class Span { private: - typedef array_slice_internal::SpanImpl Impl; + typedef internal::SpanImpl Impl; public: typedef T value_type; @@ -365,14 +360,14 @@ class Span { : impl_(a, N) {} template - Span(const InlinedVector& v) // NOLINT(runtime/explicit) + Span(const gtl::InlinedVector& v) // NOLINT(runtime/explicit) : impl_(v.data(), v.size()) {} // The constructor for any class supplying 'data() const' that returns either // const T* or a less const-qualified version of it, and 'some_integral_type // size() const'. google::protobuf::RepeatedField, std::string and (since C++11) // std::vector and std::array are examples of this. See - // array_slice_internal.h for details. + // span_internal.h for details. template > Span(const V& v) // NOLINT(runtime/explicit) @@ -445,7 +440,7 @@ class Span { template class MutableSpan { private: - typedef array_slice_internal::MutableSpanImpl Impl; + typedef internal::MutableSpanImpl Impl; public: typedef T value_type; @@ -474,15 +469,14 @@ class MutableSpan { : impl_(a, N) {} template - MutableSpan( - InlinedVector* v) // NOLINT(runtime/explicit) + MutableSpan(gtl::InlinedVector* v) // NOLINT(runtime/explicit) : impl_(v->data(), v->size()) {} // The constructor for any class supplying 'T* data()' or 'T* mutable_data()' // (the former is called if both exist), and 'some_integral_type size() // const'. google::protobuf::RepeatedField is an example of this. Also supports std::string // arguments, when T==char. The appropriate ctor is selected using SFINAE. See - // array_slice_internal.h for details. + // span_internal.h for details. template > MutableSpan(V* v) // NOLINT(runtime/explicit) @@ -518,12 +512,8 @@ class MutableSpan { void pop_back() { remove_suffix(1); } void pop_front() { remove_prefix(1); } - bool operator==(Span other) const { - return Span(*this) == other; - } - bool operator!=(Span other) const { - return Span(*this) != other; - } + bool operator==(Span other) const { return Span(*this) == other; } + bool operator!=(Span other) const { return Span(*this) != other; } // DEPRECATED(jacobsa): Please use data() instead. pointer mutable_data() const { return impl_.data(); } diff --git a/ortools/glop/basis_representation.cc b/ortools/glop/basis_representation.cc index 1edd2cd5de..01f2a7f431 100644 --- a/ortools/glop/basis_representation.cc +++ b/ortools/glop/basis_representation.cc @@ -209,7 +209,7 @@ void BasisFactorization::Clear() { Status BasisFactorization::Initialize() { SCOPED_TIME_STAT(&stats_); Clear(); - if (IsIdentityBasis()) return Status::OK; + if (IsIdentityBasis()) return Status::OK(); MatrixView basis_matrix; basis_matrix.PopulateFromBasis(matrix_, basis_); return lu_factorization_.ComputeFactorization(basis_matrix); @@ -218,7 +218,7 @@ Status BasisFactorization::Initialize() { bool BasisFactorization::IsRefactorized() const { return num_updates_ == 0; } Status BasisFactorization::Refactorize() { - if (IsRefactorized()) return Status::OK; + if (IsRefactorized()) return Status::OK(); return ForceRefactorization(); } @@ -283,7 +283,7 @@ Status BasisFactorization::MiddleProductFormUpdate( GLOP_RETURN_AND_LOG_ERROR(Status::ERROR_LU, "Degenerate rank-one update."); } rank_one_factorization_.Update(elementary_update_matrix); - return Status::OK; + return Status::OK(); } Status BasisFactorization::Update(ColIndex entering_col, @@ -301,7 +301,7 @@ Status BasisFactorization::Update(ColIndex entering_col, } ++num_updates_; tau_computation_can_be_optimized_ = false; - return Status::OK; + return Status::OK(); } return ForceRefactorization(); } diff --git a/ortools/glop/entering_variable.cc b/ortools/glop/entering_variable.cc index 0dda0e9192..1e211ab1a3 100644 --- a/ortools/glop/entering_variable.cc +++ b/ortools/glop/entering_variable.cc @@ -65,7 +65,7 @@ Status EnteringVariable::PrimalChooseEnteringColumn(ColIndex* entering_col) { } if (*entering_col != kInvalidCol) { unused_columns_.Clear(*entering_col); - return Status::OK; + return Status::OK(); } ResetUnusedColumns(); if (parameters_.normalize_using_column_norm()) { @@ -80,19 +80,19 @@ Status EnteringVariable::PrimalChooseEnteringColumn(ColIndex* entering_col) { DantzigChooseEnteringColumn(entering_col); } } - return Status::OK; + return Status::OK(); case GlopParameters::STEEPEST_EDGE: NormalizedChooseEnteringColumn(entering_col); - return Status::OK; + return Status::OK(); case GlopParameters::DEVEX: NormalizedChooseEnteringColumn(entering_col); - return Status::OK; + return Status::OK(); } LOG(DFATAL) << "Unknown pricing rule: " << GlopParameters_PricingRule_Name(rule_) << ". Using steepest edge."; NormalizedChooseEnteringColumn(entering_col); - return Status::OK; + return Status::OK(); } namespace { @@ -273,7 +273,7 @@ Status EnteringVariable::DualChooseEnteringColumn( stats_.num_perfect_ties.Add(equivalent_entering_choices_.size())); } - if (*entering_col == kInvalidCol) return Status::OK; + if (*entering_col == kInvalidCol) return Status::OK(); *pivot = update_coefficient[*entering_col]; // If the step is 0.0, we make sure the reduced cost is 0.0 so @@ -294,7 +294,7 @@ Status EnteringVariable::DualChooseEnteringColumn( // the pertubed problem is solved to the optimal. reduced_costs_->ShiftCost(*entering_col); } - return Status::OK; + return Status::OK(); } Status EnteringVariable::DualPhaseIChooseEnteringColumn( @@ -400,7 +400,7 @@ Status EnteringVariable::DualPhaseIChooseEnteringColumn( } *pivot = (*entering_col == kInvalidCol) ? 0.0 : update_coefficient[*entering_col]; - return Status::OK; + return Status::OK(); } void EnteringVariable::SetParameters(const GlopParameters& parameters) { diff --git a/ortools/glop/lp_solver.cc b/ortools/glop/lp_solver.cc index 363b14429d..542c2727ff 100644 --- a/ortools/glop/lp_solver.cc +++ b/ortools/glop/lp_solver.cc @@ -33,7 +33,7 @@ #include "ortools/util/fp_utils.h" #ifndef ANDROID_JNI -#include "ortools/util/proto_tools.h" +#include "ortools/util/file_util.h" #endif DEFINE_bool(lp_solver_enable_fp_exceptions, false, diff --git a/ortools/glop/lu_factorization.cc b/ortools/glop/lu_factorization.cc index 77586fd64c..29a9158423 100644 --- a/ortools/glop/lu_factorization.cc +++ b/ortools/glop/lu_factorization.cc @@ -58,7 +58,7 @@ Status LuFactorization::ComputeFactorization(const MatrixView& matrix) { stats_.basis_num_entries.Add(matrix.num_entries().value()); }); DCHECK(CheckFactorization(matrix, Fractional(1e-6))); - return Status::OK; + return Status::OK(); } void LuFactorization::RightSolve(DenseColumn* x) const { diff --git a/ortools/glop/markowitz.cc b/ortools/glop/markowitz.cc index f0581eee48..a1d8913a57 100644 --- a/ortools/glop/markowitz.cc +++ b/ortools/glop/markowitz.cc @@ -31,7 +31,7 @@ Status Markowitz::ComputeRowAndColumnPermutation(const MatrixView& basis_matrix, row_perm->assign(num_rows, kInvalidRow); // Get the empty matrix corner case out of the way. - if (basis_matrix.IsEmpty()) return Status::OK; + if (basis_matrix.IsEmpty()) return Status::OK(); basis_matrix_ = &basis_matrix; // Initialize all the matrices. @@ -132,7 +132,7 @@ Status Markowitz::ComputeRowAndColumnPermutation(const MatrixView& basis_matrix, 1.0 * stats_num_pivots_without_fill_in / end_index); stats_.degree_two_pivot_columns.Add(1.0 * stats_degree_two_pivot_columns / end_index); - return Status::OK; + return Status::OK(); } Status Markowitz::ComputeLU(const MatrixView& basis_matrix, @@ -152,7 +152,7 @@ Status Markowitz::ComputeLU(const MatrixView& basis_matrix, upper_.Swap(upper); DCHECK(lower->IsLowerTriangular()); DCHECK(upper->IsUpperTriangular()); - return Status::OK; + return Status::OK(); } void Markowitz::Clear() { diff --git a/ortools/glop/rank_one_update.h b/ortools/glop/rank_one_update.h index f6e2adbd06..148ac2580a 100644 --- a/ortools/glop/rank_one_update.h +++ b/ortools/glop/rank_one_update.h @@ -15,7 +15,6 @@ #define OR_TOOLS_GLOP_RANK_ONE_UPDATE_H_ #include "ortools/base/logging.h" -#include "ortools/glop/status.h" #include "ortools/lp_data/lp_types.h" #include "ortools/lp_data/lp_utils.h" #include "ortools/lp_data/sparse.h" diff --git a/ortools/glop/revised_simplex.cc b/ortools/glop/revised_simplex.cc index 4989da3385..9f0ac97191 100644 --- a/ortools/glop/revised_simplex.cc +++ b/ortools/glop/revised_simplex.cc @@ -161,7 +161,7 @@ Status RevisedSimplex::Solve(const LinearProgram& lp, TimeLimit* time_limit) { } if (FLAGS_simplex_stop_after_first_basis) { DisplayAllStats(); - return Status::OK; + return Status::OK(); } const bool use_dual = parameters_.use_dual_simplex(); @@ -358,7 +358,7 @@ Status RevisedSimplex::Solve(const LinearProgram& lp, TimeLimit* time_limit) { num_optimization_iterations_ = num_iterations_ - num_feasibility_iterations_; DisplayAllStats(); - return Status::OK; + return Status::OK(); } ProblemStatus RevisedSimplex::GetProblemStatus() const { @@ -1037,7 +1037,7 @@ Status RevisedSimplex::InitializeFirstBasis(const RowToColMapping& basis) { variable_values_.RecomputeBasicVariableValues(); const Fractional tolerance = parameters_.primal_feasibility_tolerance(); DCHECK_LE(variable_values_.ComputeMaximumPrimalResidual(), tolerance); - return Status::OK; + return Status::OK(); } Status RevisedSimplex::Initialize(const LinearProgram& lp) { @@ -1188,7 +1188,7 @@ Status RevisedSimplex::Initialize(const LinearProgram& lp) { VLOG(1) << "Incremental solve."; } DCHECK(BasisIsConsistent()); - return Status::OK; + return Status::OK(); } void RevisedSimplex::DisplayBasicVariableStatistics() { @@ -1581,7 +1581,7 @@ Status RevisedSimplex::ChooseLeavingVariableRow( // helps a lot on the Netlib problems. if (!basis_factorization_.IsRefactorized()) { *refactorize = true; - return Status::OK; + return Status::OK(); } // Note(user): This reduces quite a bit the number of iterations. @@ -1632,7 +1632,7 @@ Status RevisedSimplex::ChooseLeavingVariableRow( ratio_test_stats_.abs_used_pivot.Add(std::abs(direction_[*leaving_row])); } }); - return Status::OK; + return Status::OK(); } template @@ -1842,7 +1842,7 @@ Status RevisedSimplex::DualChooseLeavingVariableRow(RowIndex* leaving_row, // Return right away if there is no leaving variable. // Fill cost_variation and target_bound otherwise. - if (*leaving_row == kInvalidRow) return Status::OK; + if (*leaving_row == kInvalidRow) return Status::OK(); const ColIndex leaving_col = basis_[*leaving_row]; const Fractional value = variable_values_.Get(leaving_col); if (value < lower_bound_[leaving_col]) { @@ -1854,7 +1854,7 @@ Status RevisedSimplex::DualChooseLeavingVariableRow(RowIndex* leaving_row, *target_bound = upper_bound_[leaving_col]; DCHECK_LT(*cost_variation, 0.0); } - return Status::OK; + return Status::OK(); } namespace { @@ -1994,7 +1994,7 @@ Status RevisedSimplex::DualPhaseIChooseLeavingVariableRow( // If there is no dual-infeasible position, we are done. *leaving_row = kInvalidRow; - if (num_dual_infeasible_positions_ == 0) return Status::OK; + if (num_dual_infeasible_positions_ == 0) return Status::OK(); // TODO(user): Reuse parameters_.optimization_rule() to decide if we use // steepest edge or the normal Dantzig pricing. @@ -2028,7 +2028,7 @@ Status RevisedSimplex::DualPhaseIChooseLeavingVariableRow( // Returns right away if there is no leaving variable or fill the other // return values otherwise. - if (*leaving_row == kInvalidRow) return Status::OK; + if (*leaving_row == kInvalidRow) return Status::OK(); *cost_variation = dual_pricing_vector_[*leaving_row]; const ColIndex leaving_col = basis_[*leaving_row]; if (*cost_variation < 0.0) { @@ -2037,7 +2037,7 @@ Status RevisedSimplex::DualPhaseIChooseLeavingVariableRow( *target_bound = lower_bound_[leaving_col]; } DCHECK(IsFinite(*target_bound)); - return Status::OK; + return Status::OK(); } template @@ -2169,7 +2169,7 @@ Status RevisedSimplex::UpdateAndPivot(ColIndex entering_col, if (basis_factorization_.IsRefactorized()) { PermuteBasis(); } - return Status::OK; + return Status::OK(); } bool RevisedSimplex::NeedsBasisRefactorization(bool refactorize) { @@ -2199,7 +2199,7 @@ Status RevisedSimplex::RefactorizeBasisIfNeeded(bool* refactorize) { PermuteBasis(); } *refactorize = false; - return Status::OK; + return Status::OK(); } // Minimizes c.x subject to A.x = 0 where A is an mxn-matrix, c an n-vector, and @@ -2256,7 +2256,7 @@ Status RevisedSimplex::Minimize(TimeLimit* time_limit) { << " has been reached."; problem_status_ = ProblemStatus::PRIMAL_FEASIBLE; objective_limit_reached_ = true; - return Status::OK; + return Status::OK(); } } else if (feasibility_phase_) { // Note that direction_non_zero_ contains the positions of the basic @@ -2466,7 +2466,7 @@ Status RevisedSimplex::Minimize(TimeLimit* time_limit) { iteration_stats_.degenerate_run_size.Add( num_consecutive_degenerate_iterations_); } - return Status::OK; + return Status::OK(); } // TODO(user): Two other approaches for the phase I described in Koberstein's @@ -2555,7 +2555,7 @@ Status RevisedSimplex::DualMinimize(TimeLimit* time_limit) { << " has been reached."; problem_status_ = ProblemStatus::DUAL_FEASIBLE; objective_limit_reached_ = true; - return Status::OK; + return Status::OK(); } } @@ -2604,7 +2604,7 @@ Status RevisedSimplex::DualMinimize(TimeLimit* time_limit) { } else { problem_status_ = ProblemStatus::OPTIMAL; } - return Status::OK; + return Status::OK(); } update_row_.ComputeUpdateRow(leaving_row); @@ -2643,7 +2643,7 @@ Status RevisedSimplex::DualMinimize(TimeLimit* time_limit) { ChangeSign(&solution_dual_ray_row_combination_); } } - return Status::OK; + return Status::OK(); } // If the coefficient is too small, we recompute the reduced costs. @@ -2675,7 +2675,7 @@ Status RevisedSimplex::DualMinimize(TimeLimit* time_limit) { AdvanceDeterministicTime(time_limit); if (num_iterations_ == parameters_.max_number_of_iterations() || time_limit->LimitReached()) { - return Status::OK; + return Status::OK(); } IF_STATS_ENABLED({ @@ -2730,7 +2730,7 @@ Status RevisedSimplex::DualMinimize(TimeLimit* time_limit) { } ++num_iterations_; } - return Status::OK; + return Status::OK(); } ColIndex RevisedSimplex::SlackColIndex(RowIndex row) const { diff --git a/ortools/glop/status.cc b/ortools/glop/status.cc index eb94e2cc0e..f31845f1de 100644 --- a/ortools/glop/status.cc +++ b/ortools/glop/status.cc @@ -26,8 +26,6 @@ Status::Status(ErrorCode error_code, std::string error_message) : error_code_(error_code), error_message_(error_code == NO_ERROR ? "" : std::move(error_message)) {} -const Status Status::OK; - std::string GetErrorCodeString(Status::ErrorCode error_code) { switch (error_code) { case Status::NO_ERROR: diff --git a/ortools/glop/status.h b/ortools/glop/status.h index b4802bd1a8..dcbbb628ba 100644 --- a/ortools/glop/status.h +++ b/ortools/glop/status.h @@ -52,7 +52,7 @@ class Status { Status(ErrorCode error_code, std::string error_message); // Improves readability but identical to 0-arg constructor. - static const Status OK; + static const Status OK() { return Status(); } // Accessors. ErrorCode error_code() const { return error_code_; } diff --git a/ortools/sat/all_different.cc b/ortools/sat/all_different.cc new file mode 100644 index 0000000000..8421c096cb --- /dev/null +++ b/ortools/sat/all_different.cc @@ -0,0 +1,366 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/all_different.h" + +#include "ortools/base/strongly_connected_components.h" + +namespace operations_research { +namespace sat { + +std::function AllDifferentAC( + const std::vector& variables) { + return [=](Model* model) { + if (variables.size() < 3) return; + + AllDifferentConstraint* constraint = new AllDifferentConstraint( + variables, model->GetOrCreate(), + model->GetOrCreate(), model->GetOrCreate()); + constraint->RegisterWith(model->GetOrCreate()); + model->TakeOwnership(constraint); + }; +} + +AllDifferentConstraint::AllDifferentConstraint( + std::vector variables, IntegerEncoder* encoder, + Trail* trail, IntegerTrail* integer_trail) + : num_variables_(variables.size()), + variables_(std::move(variables)), + trail_(trail), + integer_trail_(integer_trail) { + // Initialize literals cache. + int64 min_value = kint64max; + int64 max_value = kint64min; + variable_min_value_.resize(num_variables_); + variable_max_value_.resize(num_variables_); + variable_literal_index_.resize(num_variables_); + int num_fixed_variables = 0; + for (int x = 0; x < num_variables_; x++) { + variable_min_value_[x] = integer_trail_->LowerBound(variables_[x]).value(); + variable_max_value_[x] = integer_trail_->UpperBound(variables_[x]).value(); + + // Compute value range of all variables. + min_value = std::min(min_value, variable_min_value_[x]); + max_value = std::max(max_value, variable_max_value_[x]); + + // FullyEncode does not like 1-value domains, handle this case first. + // TODO(user): Prune now, ignore these variables during solving. + if (variable_min_value_[x] == variable_max_value_[x]) { + num_fixed_variables++; + variable_literal_index_[x].push_back(kTrueLiteralIndex); + continue; + } + + // Force full encoding if not already done. + if (!encoder->VariableIsFullyEncoded(variables_[x])) { + encoder->FullyEncodeVariable( + variables_[x], integer_trail_->InitialVariableDomain(variables_[x])); + } + + // Fill cache with literals, default value is kFalseLiteralIndex. + int64 size = variable_max_value_[x] - variable_min_value_[x] + 1; + variable_literal_index_[x].resize(size, kFalseLiteralIndex); + for (const auto& entry : encoder->FullDomainEncoding(variables_[x])) { + int64 value = entry.value.value(); + // Can happen because of initial propagation! + if (value < variable_min_value_[x] || variable_max_value_[x] < value) { + continue; + } + variable_literal_index_[x][value - variable_min_value_[x]] = + entry.literal.Index(); + } + } + min_all_values_ = min_value; + num_all_values_ = max_value - min_value + 1; + + successor_.resize(num_variables_); + variable_to_value_.assign(num_variables_, -1); + visiting_.resize(num_variables_); + variable_visited_from_.resize(num_variables_); + residual_graph_successors_.resize(num_variables_ + num_all_values_ + 1); + component_number_.resize(num_variables_ + num_all_values_ + 1); +} + +void AllDifferentConstraint::RegisterWith(GenericLiteralWatcher* watcher) { + const int id = watcher->Register(this); + watcher->SetPropagatorPriority(id, 2); + for (const auto& literal_indices : variable_literal_index_) { + for (const LiteralIndex li : literal_indices) { + // Watch only unbound literals. + if (li >= 0 && + !trail_->Assignment().VariableIsAssigned(Literal(li).Variable())) { + watcher->WatchLiteral(Literal(li), id); + watcher->WatchLiteral(Literal(li).Negated(), id); + } + } + } +} + +LiteralIndex AllDifferentConstraint::VariableLiteralIndexOf(int x, + int64 value) { + return (value < variable_min_value_[x] || variable_max_value_[x] < value) + ? kFalseLiteralIndex + : variable_literal_index_[x][value - variable_min_value_[x]]; +} + +inline bool AllDifferentConstraint::VariableHasPossibleValue(int x, + int64 value) { + LiteralIndex li = VariableLiteralIndexOf(x, value); + if (li == kFalseLiteralIndex) return false; + if (li == kTrueLiteralIndex) return true; + DCHECK_GE(li, 0); + return !trail_->Assignment().LiteralIsFalse(Literal(li)); +} + +bool AllDifferentConstraint::MakeAugmentingPath(int start) { + // Do a BFS and use visiting_ as a queue, with num_visited pointing + // at its begin() and num_to_visit its end(). + // To switch to the augmenting path once a nonmatched value was found, + // we remember the BFS tree in variable_visited_from_. + int num_to_visit = 0; + int num_visited = 0; + // Enqueue start. + visiting_[num_to_visit++] = start; + variable_visited_[start] = true; + variable_visited_from_[start] = -1; + + while (num_visited < num_to_visit) { + // Dequeue node to visit. + const int node = visiting_[num_visited++]; + + for (const int value : successor_[node]) { + if (value_visited_[value]) continue; + value_visited_[value] = true; + if (value_to_variable_[value] == -1) { + // value is not matched: change path from node to start, and return. + int path_node = node; + int path_value = value; + while (path_node != -1) { + int old_value = variable_to_value_[path_node]; + variable_to_value_[path_node] = path_value; + value_to_variable_[path_value] = path_node; + path_node = variable_visited_from_[path_node]; + path_value = old_value; + } + return true; + } else { + // Enqueue node matched to value. + const int next_node = value_to_variable_[value]; + variable_visited_[next_node] = true; + visiting_[num_to_visit++] = next_node; + variable_visited_from_[next_node] = node; + } + } + } + return false; +} + +// The algorithm copies the solver state to successor_, which is used to compute +// a matching. If all variables can be matched, it generates the residual graph +// in separate vectors, computes its SCCs, and filters variable -> value if +// variable is not in the same SCC as value. +// Explanations for failure and filtering are fine-grained: +// failure is explained by a Hall set, i.e. dom(variables) \subseteq {values}, +// with |variables| < |values|; filtering is explained by the Hall set that +// would happen if the variable was assigned to the value. +// +// TODO(user): If needed, there are several ways performance could be +// improved. +// If copying the variable state is too costly, it could be maintained instead. +// If the propagator has too many fruitless calls (without failing/pruning), +// we can remember the O(n) arcs used in the matching and the SCC decomposition, +// and guard calls to Propagate() if these arcs are still valid. +bool AllDifferentConstraint::Propagate() { + // Copy variable state to graph state. + prev_matching_ = variable_to_value_; + value_to_variable_.assign(num_all_values_, -1); + variable_to_value_.assign(num_variables_, -1); + for (int x = 0; x < num_variables_; x++) { + successor_[x].clear(); + const int64 min_value = integer_trail_->LowerBound(variables_[x]).value(); + const int64 max_value = integer_trail_->UpperBound(variables_[x]).value(); + for (int64 value = min_value; value <= max_value; value++) { + if (VariableHasPossibleValue(x, value)) { + const int offset_value = value - min_all_values_; + // Forward-checking should propagate x != value. + successor_[x].push_back(offset_value); + } + } + if (successor_[x].size() == 1) { + const int offset_value = successor_[x][0]; + if (value_to_variable_[offset_value] == -1) { + value_to_variable_[offset_value] = x; + variable_to_value_[x] = offset_value; + } + } + } + + // If forward-checking should propagate something, wait for it to propagate. + for (int x = 0; x < num_variables_; x++) { + for (const int offset_value : successor_[x]) { + if (value_to_variable_[offset_value] != -1 && + value_to_variable_[offset_value] != x) { + return true; + } + } + } + + // Seed with previous matching. + for (int x = 0; x < num_variables_; x++) { + if (variable_to_value_[x] != -1) continue; + const int prev_value = prev_matching_[x]; + if (prev_value == -1 || value_to_variable_[prev_value] != -1) continue; + + if (VariableHasPossibleValue(x, prev_matching_[x] + min_all_values_)) { + variable_to_value_[x] = prev_matching_[x]; + value_to_variable_[prev_matching_[x]] = x; + } + } + + // Compute max matching. + int x = 0; + for (; x < num_variables_; x++) { + if (variable_to_value_[x] == -1) { + value_visited_.assign(num_all_values_, false); + variable_visited_.assign(num_variables_, false); + MakeAugmentingPath(x); + } + if (variable_to_value_[x] == -1) break; // No augmenting path exists. + } + + // Fail if covering variables impossible. + // Explain with the forbidden parts of the graph that prevent + // MakeAugmentingPath from increasing the matching size. + if (x < num_variables_) { + // For now explain all forbidden arcs. + std::vector* conflict = trail_->MutableConflict(); + conflict->clear(); + for (int y = 0; y < num_variables_; y++) { + if (!variable_visited_[y]) continue; + for (int value = variable_min_value_[y]; value <= variable_max_value_[y]; + value++) { + const LiteralIndex li = VariableLiteralIndexOf(y, value); + if (li >= 0 && !value_visited_[value - min_all_values_]) { + DCHECK(trail_->Assignment().LiteralIsFalse(Literal(li))); + conflict->push_back(Literal(li)); + } + } + } + return false; + } + + // The current matching is a valid solution, now try to filter values. + // Build residual graph, compute its SCCs. + for (int x = 0; x < num_variables_; x++) { + residual_graph_successors_[x].clear(); + for (const int succ : successor_[x]) { + if (succ != variable_to_value_[x]) { + residual_graph_successors_[x].push_back(num_variables_ + succ); + } + } + } + for (int offset_value = 0; offset_value < num_all_values_; offset_value++) { + residual_graph_successors_[num_variables_ + offset_value].clear(); + if (value_to_variable_[offset_value] != -1) { + residual_graph_successors_[num_variables_ + offset_value].push_back( + value_to_variable_[offset_value]); + } + } + const int dummy_node = num_variables_ + num_all_values_; + residual_graph_successors_[dummy_node].clear(); + if (num_variables_ < num_all_values_) { + for (int x = 0; x < num_variables_; x++) { + residual_graph_successors_[dummy_node].push_back(x); + } + for (int offset_value = 0; offset_value < num_all_values_; offset_value++) { + if (value_to_variable_[offset_value] == -1) { + residual_graph_successors_[num_variables_ + offset_value].push_back( + dummy_node); + } + } + } + + // Compute SCCs, make node -> component map. + std::vector> components; + FindStronglyConnectedComponents( + static_cast(residual_graph_successors_.size()), + residual_graph_successors_, &components); + const int num_components = components.size(); + for (int i = 0; i < num_components; i++) { + for (const int node : components[i]) { + component_number_[node] = i; + } + } + + // Remove arcs var -> val where SCC(var) -/->* SCC(val). + for (int x = 0; x < num_variables_; x++) { + if (successor_[x].size() == 1) continue; + for (const int offset_value : successor_[x]) { + const int value_node = offset_value + num_variables_; + if (variable_to_value_[x] != offset_value && + component_number_[x] != component_number_[value_node] && + VariableHasPossibleValue(x, offset_value + min_all_values_)) { + // We can deduce that x != value. To explain, force x == offset_value, + // then find another assignment for the variable matched to + // offset_value. It will fail: explaining why is the same as + // explaining failure as above, and it is an explanation of x != value. + value_visited_.assign(num_all_values_, false); + variable_visited_.assign(num_variables_, false); + // Undo x -> old_value and old_variable -> offset_value. + const int old_variable = value_to_variable_[offset_value]; + variable_to_value_[old_variable] = -1; + const int old_value = variable_to_value_[x]; + value_to_variable_[old_value] = -1; + variable_to_value_[x] = offset_value; + value_to_variable_[offset_value] = x; + + value_visited_[offset_value] = true; + MakeAugmentingPath(old_variable); + DCHECK_EQ(variable_to_value_[old_variable], -1); // No reassignment. + + // TODO(user): use a local temp vector, it is cleaner than reusing + // the one from MutableConflict(). + std::vector* reason = trail_->MutableConflict(); + reason->clear(); + for (int y = 0; y < num_variables_; y++) { + if (!variable_visited_[y]) continue; + for (int value = variable_min_value_[y]; + value <= variable_max_value_[y]; value++) { + const LiteralIndex li = VariableLiteralIndexOf(y, value); + if (li >= 0 && !value_visited_[value - min_all_values_]) { + DCHECK(!VariableHasPossibleValue(y, value)); + reason->push_back(Literal(li)); + } + } + } + + const int index = trail_->Index(); + LiteralIndex li = + VariableLiteralIndexOf(x, offset_value + min_all_values_); + DCHECK_NE(li, kTrueLiteralIndex); + DCHECK_NE(li, kFalseLiteralIndex); + + const Literal deduction = Literal(li).Negated(); + trail_->Enqueue(deduction, AssignmentType::kCachedReason); + *trail_->GetVectorToStoreReason(index) = *reason; + trail_->NotifyThatReasonIsCached(deduction.Variable()); + return true; + } + } + } + + return true; +} + +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/all_different.h b/ortools/sat/all_different.h new file mode 100644 index 0000000000..5ccc0fbfa0 --- /dev/null +++ b/ortools/sat/all_different.h @@ -0,0 +1,113 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_ALL_DIFFERENT_H_ +#define OR_TOOLS_SAT_ALL_DIFFERENT_H_ + +#include + +#include "ortools/sat/integer.h" +#include "ortools/sat/model.h" +#include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" + +namespace operations_research { +namespace sat { + +// This constraint forces all variables to take different values. +// It uses the matching algorithm described in Regin at AAAI1994: +// "A filtering algorithm for constraints of difference in CSPs". +// This propagator is meant to be used as a complement to an alldifferent +// decomposition: DO NOT USE WITHOUT A BINARY ALLDIFFERENT. +// Doing the filtering that the decomposition can do with an appropriate +// algorithm should be cheaper and yield more accurate explanations. +// This will fully encode variables. +std::function AllDifferentAC( + const std::vector& variables); + +class AllDifferentConstraint : PropagatorInterface { + public: + AllDifferentConstraint(std::vector variables, + IntegerEncoder* encoder, Trail* trail, + IntegerTrail* integer_trail); + + bool Propagate() final; + void RegisterWith(GenericLiteralWatcher* watcher); + + private: + // MakeAugmentingPath() is a step in Ford-Fulkerson's augmenting path + // algorithm. It changes its current internal state (see vectors below) + // to assign a value to the start vertex using an augmenting path. + // If it is not possible, it keeps variable_to_value_[start] to -1 and returns + // false, otherwise it modifies the current assignment and returns true. + // It uses value/variable_visited to mark the nodes it visits during its + // search: one can use this information to generate an explanation of failure, + // or manipulate it to create what-if scenarios without modifying successor_. + bool MakeAugmentingPath(int start); + + // Accessors to the cache of literals. + inline LiteralIndex VariableLiteralIndexOf(int x, int64 value); + inline bool VariableHasPossibleValue(int x, int64 value); + + // This caches all literals of the fully encoded variables. + // Values of a given variable are 0-indexed using offsets variable_min_value_, + // the set of all values is globally offset using offset min_all_values_. + // TODO(user): compare this encoding to a sparser hash_map encoding. + const int num_variables_; + const std::vector variables_; + int64 min_all_values_; + int64 num_all_values_; + std::vector variable_min_value_; + std::vector variable_max_value_; + std::vector> variable_literal_index_; + + // Internal state of MakeAugmentingPath(). + // value_to_variable_ and variable_to_value_ represent the current assignment; + // -1 means not assigned. Otherwise, + // variable_to_value_[var] = value <=> value_to_variable_[value] = var. + std::vector> successor_; + std::vector value_visited_; + std::vector variable_visited_; + std::vector value_to_variable_; + std::vector variable_to_value_; + std::vector prev_matching_; + std::vector visiting_; + std::vector variable_visited_from_; + + // Internal state of ComputeSCCs(). + // Variable nodes are indexed by [0, num_variables_), + // value nodes by [num_variables_, num_variables_ + num_all_values_), + // and a dummy node with index num_variables_ + num_all_values_ is added. + // The graph passed to ComputeSCCs() is the residual of the possible graph + // by the current matching, i.e. its arcs are: + // _ (var, val) if val \in dom(var) and var not matched to val, + // _ (val, var) if var matched to val, + // _ (val, dummy) if val not matched to any variable, + // _ (dummy, var) for all variables. + // In the original paper, forbidden arcs are identified by detecting that they + // are not in any alternating cycle or alternating path starting at a + // free vertex. Adding the dummy node allows to factor the alternating path + // part in the alternating cycle, and filter with only the SCC decomposition. + // When num_variables_ == num_all_values_, the dummy node is useless, + // we add it anyway to simplify the code. + std::vector> residual_graph_successors_; + std::vector component_number_; + + Trail* trail_; + IntegerTrail* integer_trail_; +}; + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_ALL_DIFFERENT_H_ diff --git a/ortools/sat/clause.cc b/ortools/sat/clause.cc index acc5cefac7..dc268eb81a 100644 --- a/ortools/sat/clause.cc +++ b/ortools/sat/clause.cc @@ -428,7 +428,7 @@ void BinaryImplicationGraph::MinimizeConflictFirst( // first UIP conflict. void BinaryImplicationGraph::MinimizeConflictFirstWithTransitiveReduction( const Trail& trail, std::vector* conflict, - SparseBitset* marked, RandomBase* random) { + SparseBitset* marked, random_engine_t* random) { SCOPED_TIME_STAT(&stats_); const LiteralIndex root_literal_index = conflict->front().NegatedIndex(); is_marked_.ClearAndResize(LiteralIndex(implications_.size())); @@ -441,8 +441,7 @@ void BinaryImplicationGraph::MinimizeConflictFirstWithTransitiveReduction( // a => b and remove b, a must be before b in direct_implications. Note that // a std::reverse() could work too. But randomization seems to work better. // Probably because it has other impact on the search tree. - std::random_shuffle(direct_implications.begin(), direct_implications.end(), - *random); + std::shuffle(direct_implications.begin(), direct_implications.end(), *random); dfs_stack_.clear(); for (const Literal l : direct_implications) { if (is_marked_[l.Index()]) { diff --git a/ortools/sat/clause.h b/ortools/sat/clause.h index b4395dccd5..39061b5aa2 100644 --- a/ortools/sat/clause.h +++ b/ortools/sat/clause.h @@ -32,12 +32,11 @@ #include "ortools/base/int_type.h" #include "ortools/base/int_type_indexed_vector.h" #include "ortools/base/hash.h" -#include "ortools/base/span.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" #include "ortools/util/bitset.h" +#include "ortools/util/random_engine.h" #include "ortools/util/stats.h" -#include "ortools/base/random.h" namespace operations_research { namespace sat { @@ -378,7 +377,7 @@ class BinaryImplicationGraph : public SatPropagator { SparseBitset* marked); void MinimizeConflictFirstWithTransitiveReduction( const Trail& trail, std::vector* c, - SparseBitset* marked, RandomBase* random); + SparseBitset* marked, random_engine_t* random); // This must only be called at decision level 0 after all the possible // propagations. It: diff --git a/ortools/sat/cp_constraints.cc b/ortools/sat/cp_constraints.cc index 981a79f200..ffc160aad6 100644 --- a/ortools/sat/cp_constraints.cc +++ b/ortools/sat/cp_constraints.cc @@ -190,22 +190,20 @@ std::function AllDifferent( std::unordered_set fixed_values; // First, we fully encode all the given integer variables. - IntegerEncoder* encoder = model->GetOrCreate(); for (const IntegerVariable var : vars) { - if (!encoder->VariableIsFullyEncoded(var)) { - const IntegerValue lb(model->Get(LowerBound(var))); - const IntegerValue ub(model->Get(UpperBound(var))); - if (lb == ub) { - fixed_values.insert(lb); - } else { - encoder->FullyEncodeVariable(var, lb, ub); - } + const IntegerValue lb(model->Get(LowerBound(var))); + const IntegerValue ub(model->Get(UpperBound(var))); + if (lb == ub) { + fixed_values.insert(lb); + } else { + model->Add(FullyEncodeVariable(var)); } } // Then we construct a mapping value -> List of literal each indicating // that a given variable takes this value. std::unordered_map> value_to_literals; + IntegerEncoder* encoder = model->GetOrCreate(); for (const IntegerVariable var : vars) { if (!encoder->VariableIsFullyEncoded(var)) continue; for (const auto& entry : encoder->FullDomainEncoding(var)) { diff --git a/ortools/sat/cp_model_checker.cc b/ortools/sat/cp_model_checker.cc index 3c01fb9ab9..15cda0f18d 100644 --- a/ortools/sat/cp_model_checker.cc +++ b/ortools/sat/cp_model_checker.cc @@ -17,7 +17,6 @@ #include #include "ortools/base/join.h" -#include "ortools/base/hash.h" #include "ortools/base/map_util.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/util/saturated_arithmetic.h" @@ -368,7 +367,7 @@ class ConstraintChecker { bool ElementConstraintIsFeasible(const CpModelProto& model, const ConstraintProto& ct) { const int index = Value(ct.element().index()); - return Value(ct.element().vars(index)) == Value(ct.element().target()); + return Value(ct.element().vars().Get(index)) == Value(ct.element().target()); } bool TableConstraintIsFeasible(const CpModelProto& model, @@ -448,6 +447,11 @@ class ConstraintChecker { bool SolutionIsFeasible(const CpModelProto& model, const std::vector& variable_values) { + if (variable_values.size() != model.variables_size()) { + VLOG(1) << "Wrong number of variables in the solution vector"; + return false; + } + // Check that all values fall in the variable domains. for (int i = 0; i < model.variables_size(); ++i) { if (!DomainInProtoContains(model.variables(i), variable_values[i])) { diff --git a/ortools/sat/cp_model_checker.h b/ortools/sat/cp_model_checker.h index 34e70e9403..ab78a29e41 100644 --- a/ortools/sat/cp_model_checker.h +++ b/ortools/sat/cp_model_checker.h @@ -14,6 +14,7 @@ #ifndef OR_TOOLS_SAT_CP_MODEL_CHECKER_H_ #define OR_TOOLS_SAT_CP_MODEL_CHECKER_H_ +#include "ortools/base/hash.h" #include "ortools/base/integral_types.h" #include "ortools/sat/cp_model.pb.h" diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index eaeb1d5e10..75a39b9873 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -21,7 +21,6 @@ #include #include "ortools/base/join.h" -#include "ortools/base/hash.h" #include "ortools/base/map_util.h" #include "ortools/base/stl_util.h" #include "ortools/sat/cp_model_checker.h" @@ -386,6 +385,14 @@ bool PresolveBoolOr(ConstraintProto* ct, PresolveContext* context) { context->SetLiteralToTrue(new_literals.Get(0)); return RemoveConstraint(ct, context); } + if (new_literals.size() == 2) { + // For consistency, we move all "implication" into half-reified bool_and. + // TODO(user): merge by enforcement literal and detect implication cycles. + context->UpdateRuleStats("bool_or: implications"); + ct->add_enforcement_literal(NegatedRef(new_literals.Get(0))); + ct->mutable_bool_and()->add_literals(new_literals.Get(1)); + return changed; + } ct->mutable_bool_or()->mutable_literals()->Swap(&new_literals); if (changed) context->UpdateRuleStats("bool_or: fixed literals"); @@ -421,9 +428,6 @@ bool PresolveBoolAnd(ConstraintProto* ct, PresolveContext* context) { } if (new_literals.empty()) return RemoveConstraint(ct, context); - if (new_literals.size() == 1) { - context->UpdateRuleStats("TODO bool_and: equality"); - } ct->mutable_bool_and()->mutable_literals()->Swap(&new_literals); if (changed) context->UpdateRuleStats("bool_and: fixed literals"); @@ -897,6 +901,8 @@ bool PresolveLinearIntoClauses(ConstraintProto* ct, PresolveContext* context) { } // Detect clauses and reified ands. + // TODO(user): split an == 1 constraint or similar into a clause and a <= 1 + // constraint? const std::vector domain = ReadDomain(arg); DCHECK(!domain.empty()); if (offset + min_coeff > domain.back().end) { @@ -1140,6 +1146,31 @@ bool PresolveTable(ConstraintProto* ct, PresolveContext* context) { return false; } +bool PresolveAllDiff(ConstraintProto* ct, PresolveContext* context) { + if (HasEnforcementLiteral(*ct)) return false; + const int size = ct->all_diff().vars_size(); + if (size == 0) { + context->UpdateRuleStats("all_diff: empty constraint"); + return RemoveConstraint(ct, context); + } + if (size == 1) { + context->UpdateRuleStats("all_diff: only one variable"); + return RemoveConstraint(ct, context); + } + + bool contains_fixed_variable = false; + for (int i = 0; i < size; ++i) { + if (context->domains[PositiveRef(ct->all_diff().vars(i))].IsFixed()) { + contains_fixed_variable = true; + break; + } + } + if (contains_fixed_variable) { + context->UpdateRuleStats("TODO all_diff: fixed variables"); + } + return false; +} + } // namespace. // ============================================================================= @@ -1260,6 +1291,9 @@ void PresolveCpModel(const CpModelProto& initial_model, case ConstraintProto::ConstraintCase::kTable: changed |= PresolveTable(ct, &context); break; + case ConstraintProto::ConstraintCase::kAllDiff: + changed |= PresolveAllDiff(ct, &context); + break; default: break; } diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index 190b7df994..216b2c0b34 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -20,6 +20,7 @@ #include "ortools/base/join.h" #include "ortools/base/stl_util.h" #include "ortools/graph/connectivity.h" +#include "ortools/sat/all_different.h" #include "ortools/sat/cp_model_checker.h" #include "ortools/sat/cp_model_presolve.h" #include "ortools/sat/cp_model_utils.h" @@ -348,7 +349,23 @@ void LoadLinearConstraint(const ConstraintProto& ct, ModelWithMapping* m) { void LoadAllDiffConstraint(const ConstraintProto& ct, ModelWithMapping* m) { const std::vector vars = m->Integers(ct.all_diff().vars()); - m->Add(AllDifferentOnBounds(vars)); + // TODO(user): Find out which alldifferent to use depending on model. + // If some domain is too large, use bounds reasoning. + IntegerTrail* integer_trail = m->GetOrCreate(); + int64 max_domain_size = 0; + for (const IntegerVariable var : vars) { + IntegerValue lb = integer_trail->LowerBound(var); + IntegerValue ub = integer_trail->UpperBound(var); + int64 domain_size = ub.value() - lb.value(); + max_domain_size = std::max(max_domain_size, domain_size); + } + + if (max_domain_size < 1024) { + m->Add(AllDifferent(vars)); + m->Add(AllDifferentAC(vars)); + } else { + m->Add(AllDifferentOnBounds(vars)); + } } void LoadIntProdConstraint(const ConstraintProto& ct, ModelWithMapping* m) { @@ -1077,9 +1094,7 @@ CpSolverResponse SolveCpModelWithoutPresolve(const CpModelProto& model_proto, return response; } - // Register the global LP constraint. - // TODO(user): Computes the connected components, and use one constraint per - // component. There is also no need for a constraint with just one equation. + // Linearize some part of the problem and register LP constraint(s). if (parameters.use_global_lp_constraint()) { AddLPConstraints(model_proto, &m); } @@ -1248,6 +1263,15 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { var_proto); } Model postsolve_model; + + // Postosolve parameters. + // TODO(user): this problem is usually trivial, but we may still want to + // impose a time limit or copy some of the parameters passed by the user. + { + SatParameters params; + params.set_use_global_lp_constraint(false); + postsolve_model.Add(operations_research::sat::NewSatParameters(params)); + } const CpSolverResponse postsolve_response = SolveCpModelWithoutPresolve(mapping_proto, &postsolve_model); CHECK_EQ(postsolve_response.status(), CpSolverStatus::MODEL_SAT); diff --git a/ortools/sat/cp_model_utils.h b/ortools/sat/cp_model_utils.h index 2f274c646f..9a109cd5b6 100644 --- a/ortools/sat/cp_model_utils.h +++ b/ortools/sat/cp_model_utils.h @@ -14,7 +14,6 @@ #ifndef OR_TOOLS_SAT_CP_MODEL_UTILS_H_ #define OR_TOOLS_SAT_CP_MODEL_UTILS_H_ -#include #include #include "ortools/base/logging.h" diff --git a/ortools/sat/cumulative.cc b/ortools/sat/cumulative.cc index ac753e5901..3e8a25ddde 100644 --- a/ortools/sat/cumulative.cc +++ b/ortools/sat/cumulative.cc @@ -42,6 +42,7 @@ std::function Cumulative( if (intervals->MaxSize(vars[i]) == 0) continue; if (intervals->MinSize(vars[i]) > 0) { + if (demands[i] == capacity) continue; if (intervals->IsOptional(vars[i])) { model->Add(ConditionalLowerOrEqual( demands[i], capacity, intervals->IsPresentLiteral(vars[i]))); diff --git a/ortools/sat/integer.cc b/ortools/sat/integer.cc index 402789c797..888acd12aa 100644 --- a/ortools/sat/integer.cc +++ b/ortools/sat/integer.cc @@ -28,12 +28,19 @@ std::vector NegationOf( return result; } -void IntegerEncoder::FullyEncodeVariable(IntegerVariable i_var, - std::vector values) { +void IntegerEncoder::FullyEncodeVariable( + IntegerVariable i_var, const std::vector& domain) { + CHECK(!VariableIsFullyEncoded(i_var)); CHECK_EQ(0, sat_solver_->CurrentDecisionLevel()); - CHECK(!values.empty()); // UNSAT problem. We don't deal with that here. + CHECK(!domain.empty()); // UNSAT problem. We don't deal with that here. - STLSortAndRemoveDuplicates(&values); + std::vector values; + for (const ClosedInterval interval : domain) { + for (IntegerValue v(interval.start); v <= interval.end; ++v) { + values.push_back(v); + CHECK_LT(values.size(), 100000) << "Domain too large for full encoding."; + } + } // TODO(user): This case is annoying, not sure yet how to best fix the // variable. There is certainly no need to create a Boolean variable, but @@ -42,74 +49,28 @@ void IntegerEncoder::FullyEncodeVariable(IntegerVariable i_var, // the caller to deal with this case. CHECK_NE(values.size(), 1); - // If the variable has already been fully encoded, we set to false the - // literals that cannot be true anymore. We also log a warning because ideally - // these intersection should happen in the presolve. - if (ContainsKey(full_encoding_index_, i_var)) { - int num_fixed = 0; - std::unordered_set to_interset(values.begin(), values.end()); - const std::vector& encoding = FullDomainEncoding(i_var); - for (const ValueLiteralPair& p : encoding) { - if (!ContainsKey(to_interset, p.value)) { - // TODO(user): also remove this entry from encoding. - ++num_fixed; - sat_solver_->AddUnitClause(p.literal.Negated()); - } - } - if (num_fixed > 0) { - LOG(WARNING) << "Domain intersection removed " << num_fixed << " values " - << "(out of " << encoding.size() << ")."; - } - return; - } - - std::vector encoding; + std::vector literals; if (values.size() == 2) { const BooleanVariable var = sat_solver_->NewBooleanVariable(); - encoding.push_back({values[0], Literal(var, true)}); - encoding.push_back({values[1], Literal(var, false)}); + literals.push_back(Literal(var, true)); + literals.push_back(Literal(var, false)); } else { - std::vector cst; - for (const IntegerValue value : values) { + for (int i = 0; i < values.size(); ++i) { const BooleanVariable var = sat_solver_->NewBooleanVariable(); - encoding.push_back({value, Literal(var, true)}); - cst.push_back(LiteralWithCoeff(Literal(var, true), Coefficient(1))); + literals.push_back(Literal(var, true)); } - CHECK(sat_solver_->AddLinearConstraint(true, sat::Coefficient(1), true, - sat::Coefficient(1), &cst)); } - - full_encoding_index_[i_var] = full_encoding_.size(); - full_encoding_.push_back(encoding); // copy because we need it below. - - // Deal with NegationOf(i_var). - // - // TODO(user): This seems a bit wasted, but it does simplify the code at a - // somehow small cost. - std::reverse(encoding.begin(), encoding.end()); - for (auto& entry : encoding) { - entry.value = -entry.value; // Reverse the value. - } - full_encoding_index_[NegationOf(i_var)] = full_encoding_.size(); - full_encoding_.push_back(std::move(encoding)); + return FullyEncodeVariableUsingGivenLiterals(i_var, literals, values); } -void IntegerEncoder::FullyEncodeVariable(IntegerVariable i_var, IntegerValue lb, - IntegerValue ub) { - // TODO(user): optimize the code if it ever become needed. - CHECK_LE(ub - lb, 10000) << "Large domain for full encoding! investigate."; - std::vector values; - for (IntegerValue value = lb; value <= ub; ++value) values.push_back(value); - return FullyEncodeVariable(i_var, std::move(values)); -} - -// TODO(user): merge the common code with FullyEncodeVariable(). void IntegerEncoder::FullyEncodeVariableUsingGivenLiterals( IntegerVariable i_var, const std::vector& literals, const std::vector& values) { CHECK(!VariableIsFullyEncoded(i_var)); + CHECK(!literals.empty()); + CHECK_NE(literals.size(), 1); - // Sort the literal. + // Sort the literals by values. std::vector encoding; std::vector cst; for (int i = 0; i < values.size(); ++i) { @@ -120,11 +81,46 @@ void IntegerEncoder::FullyEncodeVariableUsingGivenLiterals( } std::sort(encoding.begin(), encoding.end()); - // We need the <= 1 constraint, and the >= 1 part is cheap. Note that the - // solver will discard it if it is of size 2 and contains a literal and its - // negation. - CHECK(sat_solver_->AddLinearConstraint(true, sat::Coefficient(1), true, - sat::Coefficient(1), &cst)); + // Create the associated literal (<= and >=) in order (best for the + // implications between them). Note that we only create literals like this for + // value inside the domain. This is nice since these will be the only kind of + // literal pushed by Enqueue() (we look at the domain there). + for (int i = 0; i + 1 < encoding.size(); ++i) { + const IntegerLiteral i_lit = + IntegerLiteral::LowerOrEqual(i_var, encoding[i].value); + const IntegerLiteral i_lit_negated = + IntegerLiteral::GreaterOrEqual(i_var, encoding[i + 1].value); + if (i == 0) { + // Special case for the start. + HalfAssociateGivenLiteral(i_lit, encoding[0].literal); + HalfAssociateGivenLiteral(i_lit_negated, encoding[0].literal.Negated()); + } else if (i + 2 == encoding.size()) { + // Special case for the end. + HalfAssociateGivenLiteral(i_lit, encoding.back().literal.Negated()); + HalfAssociateGivenLiteral(i_lit_negated, encoding.back().literal); + } else { + // Normal case. + if (!LiteralIsAssociated(i_lit) || !LiteralIsAssociated(i_lit_negated)) { + const BooleanVariable new_var = sat_solver_->NewBooleanVariable(); + const Literal literal(new_var, true); + HalfAssociateGivenLiteral(i_lit, literal); + HalfAssociateGivenLiteral(i_lit_negated, literal.Negated()); + } + } + } + + // Now that all literals are created, wire them together using + // (X == v) <=> (X >= v) and (X <= v). + for (int i = 1; i + 1 < encoding.size(); ++i) { + const ValueLiteralPair pair = encoding[i]; + const Literal a(GetAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(i_var, pair.value))); + const Literal b( + GetAssociatedLiteral(IntegerLiteral::LowerOrEqual(i_var, pair.value))); + sat_solver_->AddBinaryClause(a, pair.literal.Negated()); + sat_solver_->AddBinaryClause(b, pair.literal.Negated()); + sat_solver_->AddProblemClause({a.Negated(), b.Negated(), pair.literal}); + } full_encoding_index_[i_var] = full_encoding_.size(); full_encoding_.push_back(encoding); // copy because we need it below. @@ -141,6 +137,9 @@ void IntegerEncoder::FullyEncodeVariableUsingGivenLiterals( full_encoding_.push_back(std::move(encoding)); } +// Note that by not inserting the literal in "order" we can in the worst case +// use twice as much implication (2 by literals) instead of only one between +// consecutive literals. void IntegerEncoder::AddImplications(IntegerLiteral i_lit, Literal literal) { if (i_lit.var >= encoding_by_var_.size()) { encoding_by_var_.resize(i_lit.var + 1); @@ -174,27 +173,9 @@ void IntegerEncoder::AddImplications(IntegerLiteral i_lit, Literal literal) { map_ref[i_lit.bound] = literal; } -void IntegerEncoder::AssociateGivenLiteral(IntegerLiteral i_lit, - Literal literal) { - // Resize reverse encoding. - const int new_size = - 1 + std::max(literal.Index(), literal.NegatedIndex()).value(); - if (new_size > reverse_encoding_.size()) reverse_encoding_.resize(new_size); - - // Associate the new literal to i_lit. - AddImplications(i_lit, literal); - reverse_encoding_[literal.Index()].push_back(i_lit); - - // Add its negation and associated it with i_lit.Negated(). - // - // TODO(user): This seems to work for optional variables, but I am not - // 100% sure why!! I think it works because these literals can only appear - // in a conflict if the presence literal of the optional variables is true. - AddImplications(i_lit.Negated(), literal.Negated()); - reverse_encoding_[literal.NegatedIndex()].push_back(i_lit.Negated()); -} Literal IntegerEncoder::CreateAssociatedLiteral(IntegerLiteral i_lit) { + CHECK(!LiteralIsAssociated(i_lit)); ++num_created_variables_; const BooleanVariable new_var = sat_solver_->NewBooleanVariable(); const Literal literal(new_var, true); @@ -202,6 +183,37 @@ Literal IntegerEncoder::CreateAssociatedLiteral(IntegerLiteral i_lit) { return literal; } +void IntegerEncoder::AssociateGivenLiteral(IntegerLiteral i_lit, + Literal literal) { + // TODO(user): convert it to a "domain compatible one". + CHECK(!LiteralIsAssociated(i_lit)); + HalfAssociateGivenLiteral(i_lit, literal); + HalfAssociateGivenLiteral(i_lit.Negated(), literal.Negated()); +} + +// TODO(user): The hard constraints we add between associated literals seems to +// work for optional variables, but I am not 100% sure why!! I think it works +// because these literals can only appear in a conflict if the presence literal +// of the optional variables is true. +void IntegerEncoder::HalfAssociateGivenLiteral(IntegerLiteral i_lit, + Literal literal) { + // Resize reverse encoding. + const int new_size = 1 + literal.Index().value(); + if (new_size > reverse_encoding_.size()) reverse_encoding_.resize(new_size); + + // Associate the new literal to i_lit. + if (!LiteralIsAssociated(i_lit)) { + AddImplications(i_lit, literal); + reverse_encoding_[literal.Index()].push_back(i_lit); + } else { + const Literal associated(GetAssociatedLiteral(i_lit)); + if (associated != literal) { + sat_solver_->AddBinaryClause(literal, associated.Negated()); + sat_solver_->AddBinaryClause(literal.Negated(), associated); + } + } +} + bool IntegerEncoder::LiteralIsAssociated(IntegerLiteral i) const { if (i.var >= encoding_by_var_.size()) return false; const std::map& encoding = @@ -218,14 +230,14 @@ LiteralIndex IntegerEncoder::GetAssociatedLiteral(IntegerLiteral i) { return result->second.Index(); } -Literal IntegerEncoder::GetOrCreateAssociatedLiteral(IntegerLiteral i) { - if (i.var < encoding_by_var_.size()) { +Literal IntegerEncoder::GetOrCreateAssociatedLiteral(IntegerLiteral i_lit) { + if (i_lit.var < encoding_by_var_.size()) { const std::map& encoding = - encoding_by_var_[IntegerVariable(i.var)]; - const auto it = encoding.find(i.bound); + encoding_by_var_[IntegerVariable(i_lit.var)]; + const auto it = encoding.find(i_lit.bound); if (it != encoding.end()) return it->second; } - return CreateAssociatedLiteral(i); + return CreateAssociatedLiteral(i_lit); } LiteralIndex IntegerEncoder::SearchForLiteralAtOrBefore( @@ -250,118 +262,26 @@ bool IntegerTrail::Propagate(Trail* trail) { CHECK_EQ(trail->CurrentDecisionLevel(), integer_decision_levels_.size()); } - // Value encoder. - // - // TODO(user): There is no need to maintain the bounds of such variable if - // they are never used in any constraint! - // - // Algorithm: - // 1/ See if new variables are fully encoded and initialize them. - // 2/ In the loop below, each time a "min" variable was assigned to false, - // update the associated variable bounds, and change the watched "min". - // This step is in O(num variables at false between the old and new min). - // - // The data structure are reversible. - watched_min_.SetLevel(trail->CurrentDecisionLevel()); - current_min_.SetLevel(trail->CurrentDecisionLevel()); + // This is used to map any integer literal out of the initial variable domain + // into one that use one of the domain value. var_to_current_lb_interval_index_.SetLevel(trail->CurrentDecisionLevel()); - if (encoder_->GetFullyEncodedVariables().size() != num_encoded_variables_) { - num_encoded_variables_ = encoder_->GetFullyEncodedVariables().size(); - - // for now this is only supported at level zero. Otherwise we need to - // inspect the trail to properly compute all the min. - // - // TODO(user): Don't rescan all the variables from scratch, we could only - // scan the new ones. But then we need a mecanism to detect the new ones. - CHECK_EQ(trail->CurrentDecisionLevel(), 0); - for (const auto& entry : encoder_->GetFullyEncodedVariables()) { - const IntegerVariable var = entry.first; - if (IsCurrentlyIgnored(var)) continue; - - // This variable was already added and will be processed below. - // Note that this is important, otherwise we may call many times - // watched_min_.Add() on the same literal. - if (ContainsKey(current_min_, var)) continue; - - const auto& encoding = encoder_->FullDomainEncoding(var); - for (int i = 0; i < encoding.size(); ++i) { - if (!trail_->Assignment().LiteralIsFalse(encoding[i].literal)) { - if (!trail_->Assignment().LiteralIsTrue(encoding[i].literal)) { - watched_min_.Add(encoding[i].literal.NegatedIndex(), var); - } - current_min_.Set(var, i); - - // No reason because we are at level zero. - if (!Enqueue(IntegerLiteral::GreaterOrEqual(var, encoding[i].value), - {}, {})) { - return false; - } - break; - } - } - } - } // Process all the "associated" literals and Enqueue() the corresponding // bounds. while (propagation_trail_index_ < trail->Index()) { const Literal literal = (*trail)[propagation_trail_index_++]; - - // Bound encoder. for (const IntegerLiteral i_lit : encoder_->GetIntegerLiterals(literal)) { if (IsCurrentlyIgnored(i_lit.Var())) continue; // The reason is simply the associated literal. if (!Enqueue(i_lit, {literal.Negated()}, {})) return false; } - - // Value encoder. - for (const IntegerVariable var : watched_min_.Values(literal.Index())) { - DCHECK(!IsOptional(var)) << "Not supported yet"; - - // A watched min value just became false. Note that because current_min_ - // is also updated by Enqueue(), it may be larger than the watched min. - const int min = current_min_.FindOrDie(var); - int i = min; - const auto& encoding = encoder_->FullDomainEncoding(var); - tmp_literals_reason_.clear(); - for (; i < encoding.size(); ++i) { - if (!trail_->Assignment().LiteralIsFalse(encoding[i].literal)) break; - tmp_literals_reason_.push_back(encoding[i].literal); - } - - // Note(user): we enforce a "== 1" on the encoding literals, but the - // clause forcing at least one of them to be true may not have propagated - // in some cases (because we loop in the integer propagators before - // calling the clause propagator again). - if (i == encoding.size()) { - return ReportConflict(tmp_literals_reason_, {LowerBoundAsLiteral(var)}); - } - - // Note that we don't need to delete the old watched min: - // - its literal has been assigned, so it will not be queried again until - // backtrack. - // - When backtracked over, it will already be set correctly. It seems - // less efficient to delete it now and add it back on backtrack. - watched_min_.Add(encoding[i].literal.NegatedIndex(), var); - if (i > min) { - // Note that we also need the fact that all smaller value are false - // for the propagation. We use the current lower bound for that. - current_min_.Set(var, i); - if (!Enqueue(IntegerLiteral::GreaterOrEqual(var, encoding[i].value), - tmp_literals_reason_, {LowerBoundAsLiteral(var)})) { - return false; - } - } - } } return true; } void IntegerTrail::Untrail(const Trail& trail, int literal_trail_index) { - watched_min_.SetLevel(trail.CurrentDecisionLevel()); - current_min_.SetLevel(trail.CurrentDecisionLevel()); var_to_current_lb_interval_index_.SetLevel(trail.CurrentDecisionLevel()); propagation_trail_index_ = std::min(propagation_trail_index_, literal_trail_index); @@ -466,6 +386,70 @@ std::vector IntegerTrail::InitialVariableDomain( } } +bool IntegerTrail::UpdateInitialDomain(IntegerVariable var, + std::vector domain) { + domain = + IntersectionOfSortedDisjointIntervals(domain, InitialVariableDomain(var)); + if (domain.empty()) return false; + + // TODO(user): A bit inefficient as this recreate a vector for no reason. + if (domain == InitialVariableDomain(var)) return true; + + CHECK(Enqueue( + IntegerLiteral::GreaterOrEqual(var, IntegerValue(domain.front().start)), + {}, {})); + CHECK(Enqueue( + IntegerLiteral::LowerOrEqual(var, IntegerValue(domain.back().end)), {}, + {})); + + // TODO(user): reuse the memory if possible. + if (ContainsKey(var_to_current_lb_interval_index_, var)) { + var_to_current_lb_interval_index_.EraseOrDie(var); + var_to_end_interval_index_.erase(var); + var_to_current_lb_interval_index_.EraseOrDie(NegationOf(var)); + var_to_end_interval_index_.erase(NegationOf(var)); + } + if (domain.size() > 1) { + var_to_current_lb_interval_index_.Set(var, all_intervals_.size()); + for (const ClosedInterval interval : domain) { + all_intervals_.push_back(interval); + } + InsertOrDie(&var_to_end_interval_index_, var, all_intervals_.size()); + + // Copy for the negated variable. + var_to_current_lb_interval_index_.Set(NegationOf(var), + all_intervals_.size()); + for (const ClosedInterval interval : ::gtl::reversed_view(domain)) { + all_intervals_.push_back({-interval.end, -interval.start}); + } + InsertOrDie(&var_to_end_interval_index_, NegationOf(var), + all_intervals_.size()); + } + + // If the variable is fully encoded, set to false excluded literals. + if (encoder_->VariableIsFullyEncoded(var)) { + int i = 0; + int num_fixed = 0; + const auto encoding = encoder_->FullDomainEncoding(var); + for (const auto pair : encoding) { + while (pair.value > domain[i].end && i < domain.size()) ++i; + if (i == domain.size() || pair.value < domain[i].start) { + // Set the literal to false; + ++num_fixed; + if (trail_->Assignment().LiteralIsTrue(pair.literal)) return false; + if (!trail_->Assignment().LiteralIsFalse(pair.literal)) { + trail_->EnqueueWithUnitReason(pair.literal.Negated()); + } + } + } + if (num_fixed > 0) { + VLOG(1) << "Domain intersection removed " << num_fixed << " values " + << "(out of " << encoding.size() << ")."; + } + } + return true; +} + IntegerVariable IntegerTrail::GetOrCreateConstantIntegerVariable( IntegerValue value) { auto insert = constant_map_.insert(std::make_pair(value, kNoIntegerVariable)); @@ -636,48 +620,6 @@ bool IntegerTrail::Enqueue(IntegerLiteral i_lit, // For the EnqueueWithSameReasonAs() mechanism. BooleanVariable first_propagated_variable = kNoBooleanVariable; - // Deal with fully encoded variable. We want to do that first because this may - // make the IntegerLiteral bound stronger. - const int min_index = FindWithDefault(current_min_, var, -1); - if (min_index >= 0) { - DCHECK(!IsOptional(var)) - << "Fully encoded optional variable are not yet supported."; - - // Recover the current min, and propagate to false all the values that - // are in [min, i_lit.value). All these literals have the same reason, so - // we use the "same reason as" mecanism. - // - // TODO(user): We could go even further if the next literals are set to - // false (but they need to be added to the reason). - const auto& encoding = encoder_->FullDomainEncoding(var); - if (i_lit.bound > encoding[min_index].value) { - int i = min_index; - for (; i < encoding.size(); ++i) { - if (i_lit.bound <= encoding[i].value) break; - const Literal literal = encoding[i].literal.Negated(); - if (!EnqueueAssociatedLiteral(literal, i_lit, literal_reason, - integer_reason, - &first_propagated_variable)) { - return false; - } - } - - if (i == encoding.size()) { - // Conflict: no possible values left. - return ReportConflict(literal_reason, integer_reason); - } else { - // We have a new min. Note that watched_min_ will be updated on the next - // call to Propagate() since we just pushed the watched literal if it - // wasn't already set to false. - current_min_.Set(var, i); - - // Adjust the bound of i_lit ! - CHECK_GE(encoding[i].value, i_lit.bound); - i_lit.bound = encoding[i].value.value(); - } - } - } - // Check if the integer variable has an empty domain. if (i_lit.bound > UpperBound(var)) { // We relax the upper bound as much as possible to still have a conflict. @@ -1121,12 +1063,15 @@ bool GenericLiteralWatcher::Propagate(Trail* trail) { } // If the propagator pushed a literal, we have two options. - // - // TODO(user): expose the parameter. Early experiments are counter - // intuitive and seems to indicate that it is better not to re-run the SAT - // propagators each time we push a literal. if (trail->Index() > old_boolean_timestamp) { - const bool run_sat_propagators_at_higher_priority = false; + // Important: for now we need to re-run the clauses propagator each time + // we push a new literal because some propagator like the arc consistent + // all diff relies on this. + // + // However, on some problem, it seems to work better to not do that. One + // possible reason is that the reason of a "natural" propagation might + // be better than one we learned. + const bool run_sat_propagators_at_higher_priority = true; if (run_sat_propagators_at_higher_priority) { // We exit in order to rerun all SAT only propagators first. Note that // since a literal was pushed we are guaranteed to be called again, diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index 874ffe6a67..ab08ca9999 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -187,46 +187,36 @@ class IntegerEncoder { return encoder; } - // This has 3 effects: - // 1/ It restricts the given variable to only take values amongst the given - // ones. - // 2/ It creates one Boolean variable per value that convey the fact that the - // var is equal to this value iff the Boolean is true. If there is only - // 2 values, then just one Boolean variable is created. For more than two - // values, a constraint is also added to enforce that exactly one Boolean - // variable is true. - // 3/ The encoding for NegationOf(var) is automatically created too. It reuses - // the same Boolean variable as the encoding of var. + // Fully encode a variable. This can be called only once. // - // Calling this more than once will take the intersection of all the given - // values arguments. However, this is not optimal because the first calls may - // creates new Boolean variables that will later be fixed, so we log a warning - // when this happen. Ideally, the intersection should be done in a presolve - // step to be as efficient as possible here. + // Important: this should really only be called with + // integer_trail_->InitialVariableDomain() which can be updated with + // integer_trail_->UpdateInitialDomain(). // - // Note(user): There is currently no relation here between - // FullyEncodeVariable() and CreateAssociatedLiteral(). However the - // IntegerTrail class will automatically link the two representations and do - // the right thing. + // TODO(user): clean this up by enforcing this programmatically. // - // Note(user): Calling this with just one value will cause a CHECK fail. One - // need to fix the IntegerVariable inside the IntegerTrail instead of calling - // this. + // This creates new Booleans variables as needed: + // 1) num_values for the literals X == value. Except when there is just + // two value in which case only one variable is created. + // 2) num_values - 3 for the literals X >= value or X <= value (using their + // negation). The -3 comes from the fact that we can reuse the equality + // literals for the two extreme points. + // + // The encoding for NegationOf(var) is automatically created too. It reuses + // the same Boolean variable as the encoding of var. + // + // Note(user): Calling this with just one value will cause a CHECK fail. + // We don't really want to create a fixed Boolean. // // TODO(user): It is currently only possible to call that at the decision - // level zero. This is Checked. + // level zero because we cannot add ternary clause in the middle of the + // search (for now). This is Checked. void FullyEncodeVariable(IntegerVariable var, - std::vector values); - void FullyEncodeVariable(IntegerVariable var, IntegerValue lb, - IntegerValue ub); + const std::vector& domain); // Similar to FullyEncodeVariable() but use the given literal for each values. // This can only be called on variable that are not fully encoded yet, This is - // checked. - // - // Note that duplicate values are supported, but exactly one literal must be - // true at the same time. The "exactly one" constraint will implicitely be - // enforced by the code in IntegerTrail. + // checked. Duplicates values are not supported. void FullyEncodeVariableUsingGivenLiterals( IntegerVariable var, const std::vector& literals, const std::vector& values); @@ -284,15 +274,20 @@ class IntegerEncoder { Literal CreateAssociatedLiteral(IntegerLiteral i_lit); void AssociateGivenLiteral(IntegerLiteral i_lit, Literal wanted); + // Same as CreateAssociatedLiteral() but safe to call if already created. + Literal GetOrCreateAssociatedLiteral(IntegerLiteral i_lit); + + // Only add the equivalence between i_lit and literal, if there is already an + // associated literal with i_lit, this make literal and this associated + // literal equivalent. + void HalfAssociateGivenLiteral(IntegerLiteral i_lit, Literal literal); + // Return true iff the given integer literal is associated. bool LiteralIsAssociated(IntegerLiteral i_lit) const; // Returns the associated literal or kNoLiteralIndex. LiteralIndex GetAssociatedLiteral(IntegerLiteral i_lit); - // Same as CreateAssociatedLiteral() but safe to call if already created. - Literal GetOrCreateAssociatedLiteral(IntegerLiteral i_lit); - // Returns the IntegerLiterals that were associated with the given Literal. const InlinedIntegerLiteralVector& GetIntegerLiterals(Literal lit) const { if (lit.Index() >= reverse_encoding_.size()) { @@ -392,6 +387,13 @@ class IntegerTrail : public SatPropagator { // propagations, but not if the domain is more complex. std::vector InitialVariableDomain(IntegerVariable var) const; + // Takes the intersection with the current initial variable domain. + // TODO(user): There is some memory inefficiency if this is called many time + // because of the underlying data structure we use. In practice, when used + // with a presolve, this is not often used, so that is fine though. + bool UpdateInitialDomain(IntegerVariable var, + std::vector domain); + // Same as AddIntegerVariable(value, value), but this is a bit more efficient // because it reuses another constant with the same value if its exist. // @@ -598,14 +600,6 @@ class IntegerTrail : public SatPropagator { // The "is_ignored" literal of the optional variables or kNoLiteralIndex. ITIVector is_ignored_literals_; - // Data used to support the propagation of fully encoded variable. We keep - // for each variable the index in encoder_.GetDomainEncoding() of the first - // literal that is not assigned to false, and call this the "min". - int64 num_encoded_variables_ = 0; - std::vector tmp_literals_reason_; - RevGrowingMultiMap watched_min_; - RevMap> current_min_; - // This is only filled for variables with a domain more complex than a single // interval of values. All intervals are stored in a vector, and we keep // indices to the current interval of the lower bound, and to the end index @@ -1097,15 +1091,8 @@ FullyEncodeVariable(IntegerVariable var) { return [=](Model* model) { IntegerEncoder* encoder = model->GetOrCreate(); if (!encoder->VariableIsFullyEncoded(var)) { - IntegerTrail* integer_trail = model->GetOrCreate(); - std::vector values; - for (const ClosedInterval interval : - integer_trail->InitialVariableDomain(var)) { - for (IntegerValue v(interval.start); v <= interval.end; ++v) { - values.push_back(v); - } - } - encoder->FullyEncodeVariable(var, values); + encoder->FullyEncodeVariable( + var, model->GetOrCreate()->InitialVariableDomain(var)); } return encoder->FullDomainEncoding(var); }; diff --git a/ortools/sat/integer_expr.cc b/ortools/sat/integer_expr.cc index b10986b178..ce341d9b61 100644 --- a/ortools/sat/integer_expr.cc +++ b/ortools/sat/integer_expr.cc @@ -13,6 +13,10 @@ #include "ortools/sat/integer_expr.h" +#include + +#include "ortools/base/stl_util.h" + namespace operations_research { namespace sat { @@ -442,5 +446,48 @@ void DivisionPropagator::RegisterWith(GenericLiteralWatcher* watcher) { watcher->WatchIntegerVariable(c_, id); } +std::function IsOneOf(IntegerVariable var, + const std::vector& selectors, + const std::vector& values) { + return [=](Model* model) { + IntegerTrail* integer_trail = model->GetOrCreate(); + IntegerEncoder* encoder = model->GetOrCreate(); + if (encoder->VariableIsFullyEncoded(var)) { + LOG(FATAL) << "TODO(fdid): Not implemented."; + } + + CHECK(!values.empty()); + CHECK_EQ(values.size(), selectors.size()); + std::vector unique_values; + std::unordered_map> value_to_selector; + for (int i = 0; i < values.size(); ++i) { + unique_values.push_back(values[i].value()); + value_to_selector[values[i].value()].push_back(selectors[i]); + } + STLSortAndRemoveDuplicates(&unique_values); + + integer_trail->UpdateInitialDomain( + var, SortedDisjointIntervalsFromValues(unique_values)); + if (unique_values.size() == 1) { + model->Add(ClauseConstraint(selectors)); + return; + } + + std::vector new_selectors; + for (const int64 v : unique_values) { + if (value_to_selector[v].size() == 1) { + new_selectors.push_back(value_to_selector[v][0]); + } else { + const Literal l(model->Add(NewBooleanVariable()), true); + model->Add(ReifiedBoolOr(value_to_selector[v], l)); + new_selectors.push_back(l); + } + } + encoder->FullyEncodeVariableUsingGivenLiterals( + var, new_selectors, + std::vector(unique_values.begin(), unique_values.end())); + }; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/integer_expr.h b/ortools/sat/integer_expr.h index e73634638a..a05f879a13 100644 --- a/ortools/sat/integer_expr.h +++ b/ortools/sat/integer_expr.h @@ -459,35 +459,9 @@ inline std::function NewMax( // Expresses the fact that an existing integer variable is equal to one of // the given values, each selected by a given literal. -inline std::function IsOneOf( - IntegerVariable var, const std::vector& selectors, - const std::vector& values) { - return [=](Model* model) { - CHECK(!values.empty()); - CHECK_EQ(values.size(), selectors.size()); - IntegerValue min_value = values[0]; - IntegerValue max_value = values[1]; - for (const IntegerValue v : values) { - min_value = std::min(min_value, v); - max_value = std::max(max_value, v); - } - IntegerTrail* integer_trail = model->GetOrCreate(); - CHECK(integer_trail->Enqueue(IntegerLiteral::GreaterOrEqual(var, min_value), - {}, {})); - CHECK(integer_trail->Enqueue(IntegerLiteral::LowerOrEqual(var, max_value), - {}, {})); - - IntegerEncoder* encoder = model->GetOrCreate(); - if (!encoder->VariableIsFullyEncoded(var)) { - encoder->FullyEncodeVariableUsingGivenLiterals(var, selectors, values); - } else { - // TODO(user): copy the sat_fz_solver code of the int element here. - // And use this function instead because the first branch will be more - // efficient). - LOG(FATAL) << "TODO(fdid): Not implemented."; - } - }; -} +std::function IsOneOf(IntegerVariable var, + const std::vector& selectors, + const std::vector& values); template void RegisterAndTransferOwnership(Model* model, T* ct) { diff --git a/ortools/sat/optimization.cc b/ortools/sat/optimization.cc index 736a1295dd..e0fec55f07 100644 --- a/ortools/sat/optimization.cc +++ b/ortools/sat/optimization.cc @@ -1126,7 +1126,7 @@ SatSolver::Status MinimizeWithCoreAndLazyEncoding( IntegerValue best_objective = integer_trail->UpperBound(objective_var); const auto process_solution = [&]() { const IntegerValue objective(model->Get(Value(objective_var))); - if (objective >= best_objective) return true; + if (objective >= best_objective && num_solutions > 0) return true; ++num_solutions; best_objective = objective; diff --git a/ortools/sat/sat_solver.cc b/ortools/sat/sat_solver.cc index 2e206d8a82..a58ee84a80 100644 --- a/ortools/sat/sat_solver.cc +++ b/ortools/sat/sat_solver.cc @@ -142,7 +142,7 @@ void SatSolver::SetParameters(const SatParameters& parameters) { parameters_ = parameters; clauses_propagator_.SetParameters(parameters); pb_constraints_.SetParameters(parameters); - random_.Reset(parameters_.random_seed()); + random_.seed(parameters_.random_seed()); InitRestart(); time_limit_ = TimeLimit::FromParameters(parameters_); dl_running_average_.Reset(parameters_.restart_running_window_size()); @@ -256,7 +256,7 @@ bool SatSolver::AddLinearConstraintInternal( SCOPED_TIME_STAT(&stats_); DCHECK(BooleanLinearExpressionIsCanonical(cst)); if (rhs < 0) return SetModelUnsat(); // Unsatisfiable constraint. - if (rhs >= max_value) return true; // Always satisfied constraint. + if (rhs >= max_value) return true; // Always satisfied constraint. // Update the weighted_sign_. // TODO(user): special case the rhs = 0 which just fix variables... @@ -1644,14 +1644,18 @@ Literal SatSolver::NextBranch() { // Choose the variable. BooleanVariable var; const double ratio = parameters_.random_branches_ratio(); - if (ratio != 0.0 && random_.RandDouble() < ratio) { + auto zero_to_one = [this]() { + return std::uniform_real_distribution()(random_); + }; + if (ratio != 0.0 && zero_to_one() < ratio) { ++counters_.num_random_branches; while (true) { // TODO(user): This may not be super efficient if almost all the // variables are assigned. - var = BooleanVariable( - (*var_ordering_.Raw())[random_.Uniform(var_ordering_.Raw()->size())] - - &queue_elements_.front()); + std::uniform_int_distribution index_dist( + 0, var_ordering_.Raw()->size() - 1); + var = BooleanVariable((*var_ordering_.Raw())[index_dist(random_)] - + &queue_elements_.front()); if (!trail_->Assignment().VariableIsAssigned(var)) break; pq_need_update_for_var_at_trail_index_.Set(trail_->Info(var).trail_index); var_ordering_.Remove(&queue_elements_[var]); @@ -1670,8 +1674,8 @@ Literal SatSolver::NextBranch() { // Choose its polarity (i.e. True of False). const double random_ratio = parameters_.random_polarity_ratio(); - if (random_ratio != 0.0 && random_.RandDouble() < random_ratio) { - return Literal(var, random_.OneIn(2)); + if (random_ratio != 0.0 && zero_to_one() < random_ratio) { + return Literal(var, std::uniform_int_distribution(0, 1)(random_)); } return Literal(var, var_use_phase_saving_[var] ? trail_->Info(var).last_polarity @@ -1693,7 +1697,7 @@ void SatSolver::ResetPolarity(BooleanVariable from) { initial_polarity = false; break; case SatParameters::POLARITY_RANDOM: - initial_polarity = random_.OneIn(2); + initial_polarity = std::uniform_int_distribution(0, 1)(random_); break; case SatParameters::POLARITY_WEIGHTED_SIGN: initial_polarity = weighted_sign_[var] > 0; @@ -1739,7 +1743,7 @@ void SatSolver::InitializeVariableOrdering() { std::reverse(variables.begin(), variables.end()); break; case SatParameters::IN_RANDOM_ORDER: - std::random_shuffle(variables.begin(), variables.end(), random_); + std::shuffle(variables.begin(), variables.end(), random_); break; } @@ -2186,7 +2190,7 @@ void SatSolver::ComputePBConflict(int max_trail_index, // The sum of the literal with level <= backjump_level must propagate. std::vector sum_for_le_level(backjump_level + 2, Coefficient(0)); std::vector max_coeff_for_ge_level(backjump_level + 2, - Coefficient(0)); + Coefficient(0)); int size = 0; Coefficient max_sum(0); for (BooleanVariable var : conflict->PossibleNonZeros()) { diff --git a/ortools/sat/sat_solver.h b/ortools/sat/sat_solver.h index c429836126..0a4306c98e 100644 --- a/ortools/sat/sat_solver.h +++ b/ortools/sat/sat_solver.h @@ -39,10 +39,10 @@ #include "ortools/sat/pb_constraint.h" #include "ortools/sat/sat_parameters.pb.h" #include "ortools/util/bitset.h" +#include "ortools/util/random_engine.h" #include "ortools/util/running_stat.h" #include "ortools/util/stats.h" #include "ortools/util/time_limit.h" -#include "ortools/base/random.h" #include "ortools/base/adjustable_priority_queue.h" namespace operations_research { @@ -893,7 +893,7 @@ class SatSolver { VariableWithSameReasonIdentifier same_reason_identifier_; // A random number generator. - mutable MTRandom random_; + mutable random_engine_t random_; // Temporary vector used by AddProblemClause(). std::vector tmp_pb_constraint_; diff --git a/ortools/sat/table.cc b/ortools/sat/table.cc index 1df9b6f811..2ea61822d6 100644 --- a/ortools/sat/table.cc +++ b/ortools/sat/table.cc @@ -24,14 +24,13 @@ namespace sat { namespace { -// Transpose the given "matrix" and transform the value to IntegerValue. -std::vector> Transpose( +// Transposes the given "matrix". +std::vector> Transpose( const std::vector>& tuples) { CHECK(!tuples.empty()); const int n = tuples.size(); const int m = tuples[0].size(); - std::vector> transpose( - m, std::vector(n)); + std::vector> transpose(m, std::vector(n)); for (int i = 0; i < n; ++i) { CHECK_EQ(m, tuples[i].size()); for (int j = 0; j < m; ++j) { @@ -53,40 +52,21 @@ std::unordered_map GetEncoding(IntegerVariable var, Model void FilterValues(IntegerVariable var, Model* model, std::unordered_set* values) { - const int64 lb = model->Get(LowerBound(var)); - const int64 ub = model->Get(UpperBound(var)); - - IntegerEncoder* encoder = model->GetOrCreate(); - const VariablesAssignment& assignment = - model->GetOrCreate()->Assignment(); - if (encoder->VariableIsFullyEncoded(var)) { - const auto encoding = GetEncoding(var, model); - for (auto it = values->begin(); it != values->end();) { - const int64 v = *it; - auto copy = it++; - if (v < lb || v > ub || !ContainsKey(encoding, IntegerValue(v))) { - values->erase(copy); - } else { - const Literal literal = FindOrDie(encoding, IntegerValue(v)); - if (assignment.LiteralIsFalse(literal)) { - values->erase(copy); - } - } - } - } else { - for (auto it = values->begin(); it != values->end();) { - const int64 v = *it; - auto copy = it++; - if (v < lb || v > ub) { - values->erase(copy); - } + std::vector domain = + model->Get()->InitialVariableDomain(var); + for (auto it = values->begin(); it != values->end();) { + const int64 v = *it; + auto copy = it++; + // TODO(user): quadratic! improve. + if (!SortedDisjointIntervalsContain(domain, v)) { + values->erase(copy); } } } // Add the implications and clauses to link one column of a table to the Literal // controling if the lines are possible or not. The column has the given values, -// and the Literal of the column variable can be retreived using the encoding +// and the Literal of the column variable can be retrieved using the encoding // map. void ProcessOneColumn(const std::vector& line_literals, const std::vector& values, @@ -168,19 +148,23 @@ std::function TableConstraint( } // Fully encode the variables using all the values appearing in the tuples. - IntegerEncoder* encoder = model->GetOrCreate(); + IntegerTrail* interger_trail = model->GetOrCreate(); std::unordered_map encoding; - const std::vector> tr_tuples = - Transpose(new_tuples); + const std::vector> tr_tuples = Transpose(new_tuples); for (int i = 0; i < n; ++i) { - const IntegerValue first = tr_tuples[i].front(); + const int64 first = tr_tuples[i].front(); if (std::all_of(tr_tuples[i].begin(), tr_tuples[i].end(), - [first](IntegerValue v) { return v == first; })) { - model->Add(Equality(vars[i], first.value())); + [first](int64 v) { return v == first; })) { + model->Add(Equality(vars[i], first)); } else { - encoder->FullyEncodeVariable(vars[i], tr_tuples[i]); + interger_trail->UpdateInitialDomain( + vars[i], SortedDisjointIntervalsFromValues(tr_tuples[i])); + model->Add(FullyEncodeVariable(vars[i])); encoding = GetEncoding(vars[i], model); - ProcessOneColumn(tuple_literals, tr_tuples[i], encoding, model); + ProcessOneColumn( + tuple_literals, + std::vector(tr_tuples[i].begin(), tr_tuples[i].end()), + encoding, model); } } }; @@ -292,7 +276,7 @@ std::function TransitionConstraint( const std::vector>& automata, int64 initial_state, const std::vector& final_states) { return [=](Model* model) { - IntegerEncoder* encoder = model->GetOrCreate(); + IntegerTrail* integer_trail = model->GetOrCreate(); const int n = vars.size(); CHECK_GT(n, 0) << "No variables in TransitionConstraint()."; @@ -311,22 +295,12 @@ std::function TransitionConstraint( // Construct a table with the possible values of each vars. std::vector> possible_values(n); - const VariablesAssignment& assignment = - model->GetOrCreate()->Assignment(); for (int time = 0; time < n; ++time) { - if (encoder->VariableIsFullyEncoded(vars[time])) { - for (const auto& entry : encoder->FullDomainEncoding(vars[time])) { - if (!assignment.LiteralIsFalse(entry.literal)) { - possible_values[time].insert(entry.value.value()); - } - } - } else { - const int64 lb = model->Get(LowerBound(vars[time])); - const int64 ub = model->Get(UpperBound(vars[time])); - for (const std::vector& transition : automata) { - if (lb <= transition[1] && transition[1] <= ub) { - possible_values[time].insert(transition[1]); - } + const auto domain = integer_trail->InitialVariableDomain(vars[time]); + for (const std::vector& transition : automata) { + // TODO(user): quadratic algo, improve! + if (SortedDisjointIntervalsContain(domain, transition[1])) { + possible_values[time].insert(transition[1]); } } } @@ -399,8 +373,12 @@ std::function TransitionConstraint( encoding.clear(); if (s.size() > 1) { - std::vector values(s.begin(), s.end()); - encoder->FullyEncodeVariable(vars[time], values); + std::vector values; + values.reserve(s.size()); + for (IntegerValue v : s) values.push_back(v.value()); + integer_trail->UpdateInitialDomain( + vars[time], SortedDisjointIntervalsFromValues(values)); + model->Add(FullyEncodeVariable(vars[time])); encoding = GetEncoding(vars[time], model); } else { // Fix vars[time] to its unique possible value. diff --git a/ortools/util/BUILD b/ortools/util/BUILD index f447972100..ce95c25c25 100644 --- a/ortools/util/BUILD +++ b/ortools/util/BUILD @@ -26,6 +26,12 @@ cc_library( deps = ["//ortools/base:map_util"], ) +cc_library( + name = "random_engine", + hdrs = ["random_engine.h"], + deps = ["//ortools/base:map_util"], +) + cc_library( name = "bitset", srcs = ["bitset.cc"], diff --git a/ortools/util/file_util.cc b/ortools/util/file_util.cc new file mode 100644 index 0000000000..a2aadd4f0d --- /dev/null +++ b/ortools/util/file_util.cc @@ -0,0 +1,75 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/util/file_util.h" + +#include "ortools/base/logging.h" +#include "ortools/base/file.h" +#include "google/protobuf/io/zero_copy_stream_impl_lite.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +namespace operations_research { + +using ::google::protobuf::Descriptor; +using ::google::protobuf::FieldDescriptor; +using ::google::protobuf::Reflection; +using ::google::protobuf::TextFormat; + +bool ReadFileToProto(const std::string& filename, google::protobuf::Message* proto) { + std::string data; + CHECK_OK(file::GetContents(filename, &data, file::Defaults())); + // Note that gzipped files are currently not supported. + // Try binary format first, then text format, then JSON, then give up. + if (proto->ParseFromString(data)) return true; + if (google::protobuf::TextFormat::ParseFromString(data, proto)) return true; + LOG(WARNING) << "Could not parse protocol buffer"; + return false; +} + +bool WriteProtoToFile(const std::string& filename, const google::protobuf::Message& proto, + ProtoWriteFormat proto_write_format, bool gzipped) { + // Note that gzipped files are currently not supported. + gzipped = false; + + std::string file_type_suffix; + std::string output_string; + google::protobuf::io::StringOutputStream stream(&output_string); + switch (proto_write_format) { + case ProtoWriteFormat::kProtoBinary: + if (!proto.SerializeToZeroCopyStream(&stream)) { + LOG(WARNING) << "Serialize to stream failed."; + return false; + } + file_type_suffix = ".bin"; + break; + case ProtoWriteFormat::kProtoText: + if (!google::protobuf::TextFormat::PrintToString(proto, &output_string)) { + LOG(WARNING) << "Printing to std::string failed."; + return false; + } + break; + } + const std::string output_filename = StrCat(filename, file_type_suffix); + VLOG(1) << "Writing " << output_string.size() << " bytes to " + << output_filename; + if (!file::SetContents(output_filename, output_string, file::Defaults()) + .ok()) { + LOG(WARNING) << "Writing to file failed."; + return false; + } + return true; +} + +} // namespace operations_research diff --git a/ortools/util/file_util.h b/ortools/util/file_util.h new file mode 100644 index 0000000000..8ab07a9572 --- /dev/null +++ b/ortools/util/file_util.h @@ -0,0 +1,123 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_UTIL_FILE_UTIL_H_ +#define OR_TOOLS_UTIL_FILE_UTIL_H_ + +#include +#include + +#include "ortools/base/file.h" +#include "ortools/base/stringpiece.h" +#include "ortools/base/recordio.h" +#include "google/protobuf/message.h" + +namespace operations_research { + +// Reads a proto from a file. Supports the following formats: binary, text, +// JSON, all of those optionally gzipped. Returns false on failure. +bool ReadFileToProto(const std::string& filename, google::protobuf::Message* proto); + +template +Proto ReadFileToProtoOrDie(const std::string& filename) { + Proto proto; + CHECK(ReadFileToProto(filename, &proto)); + return proto; +} + +enum class ProtoWriteFormat { kProtoText, kProtoBinary, kJson }; + +// Writes a proto to a file. Supports the following formats: binary, text, JSON, +// all of those optionally gzipped. Returns false on failure. +// If 'proto_write_format' is kProtoBinary, ".bin" is appended to file_name. If +// 'proto_write_format' is kJson, ".json" is appended to file_name. If 'gzipped' +// is true, ".gz" is appended to file_name. +bool WriteProtoToFile(const std::string& filename, const google::protobuf::Message& proto, + ProtoWriteFormat proto_write_format, bool gzipped); + +namespace internal { +// General method to read expected_num_records from a file. If +// expected_num_records is -1, then reads all records from the file. If not, +// dies if the file doesn't contain exactly expected_num_records. +template +std::vector ReadNumRecords(File* file, int expected_num_records) { + recordio::RecordReader reader(file); + std::vector protos; + Proto proto; + int num_read = 0; + while (num_read != expected_num_records && + reader.ReadProtocolMessage(&proto)) { + protos.push_back(proto); + ++num_read; + } + + /* CHECK(reader.AtEOF(false) && reader.Close()) */ + /* << "File '" << file->filename() */ + /* << "'was not fully read, or something went wrong when closing " */ + /* "it. Is it the right format? (RecordIO of Protocol Buffers)."; */ + + if (expected_num_records >= 0) { + CHECK_EQ(num_read, expected_num_records) + << "There were less than the expected " << expected_num_records + << " in the file."; + } + + return protos; +} + +// Ditto, taking a filename as argument. +template +std::vector ReadNumRecords(const std::string& filename, + int expected_num_records) { + // return ReadNumRecords(file::OpenOrDie(filename, "r", file::Defaults()), + // expected_num_records); + return ReadNumRecords(File::OpenOrDie(filename, "r")); +} +} // namespace internal + +// Reads all records in Proto format in 'file'. Silently does nothing if the +// file is empty. Dies if the file doesn't exist or contains something else than +// protos encoded in RecordIO format. +template +std::vector ReadAllRecordsOrDie(StringPiece filename) { + return internal::ReadNumRecords(filename, -1); +} +template +std::vector ReadAllRecordsOrDie(File* file) { + return internal::ReadNumRecords(file, -1); +} + +// Reads one record in Proto format in 'file'. Dies if the file doesn't exist, +// doesn't contain exactly one record, or contains something else than protos +// encoded in RecordIO format. +template +Proto ReadOneRecordOrDie(StringPiece filename) { + Proto p; + p.Swap(&internal::ReadNumRecords(filename, 1)[0]); + return p; +} + +// Writes all records in Proto format to 'file'. Dies if it is unable to open +// the file or write to it. +template +void WriteRecordsOrDie(const std::string filename, + const std::vector& protos) { + recordio::RecordWriter writer(File::OpenOrDie(filename, "w")); + for (const Proto& proto : protos) { + CHECK(writer.WriteProtocolMessage(proto)); + } +} + +} // namespace operations_research + +#endif // OR_TOOLS_UTIL_FILE_UTIL_H_ diff --git a/ortools/util/proto_tools.cc b/ortools/util/proto_tools.cc index 2d5118f587..4b30d2814d 100644 --- a/ortools/util/proto_tools.cc +++ b/ortools/util/proto_tools.cc @@ -13,12 +13,10 @@ #include "ortools/util/proto_tools.h" -#include "ortools/base/logging.h" -#include "ortools/base/file.h" -#include "google/protobuf/io/zero_copy_stream_impl_lite.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" +#include "ortools/base/join.h" namespace operations_research { @@ -27,51 +25,6 @@ using ::google::protobuf::FieldDescriptor; using ::google::protobuf::Reflection; using ::google::protobuf::TextFormat; -bool ReadFileToProto(const std::string& file_name, google::protobuf::Message* proto) { - std::string data; - CHECK_OK(file::GetContents(file_name, &data, file::Defaults())); - // Note that gzipped files are currently not supported. - // Try binary format first, then text format, then JSON, then give up. - if (proto->ParseFromString(data)) return true; - if (google::protobuf::TextFormat::ParseFromString(data, proto)) return true; - LOG(WARNING) << "Could not parse protocol buffer"; - return false; -} - -bool WriteProtoToFile(const std::string& file_name, const google::protobuf::Message& proto, - ProtoWriteFormat proto_write_format, bool gzipped) { - // Note that gzipped files are currently not supported. - gzipped = false; - - std::string file_type_suffix; - std::string output_string; - google::protobuf::io::StringOutputStream stream(&output_string); - switch (proto_write_format) { - case ProtoWriteFormat::kProtoBinary: - if (!proto.SerializeToZeroCopyStream(&stream)) { - LOG(WARNING) << "Serialize to stream failed."; - return false; - } - file_type_suffix = ".bin"; - break; - case ProtoWriteFormat::kProtoText: - if (!google::protobuf::TextFormat::PrintToString(proto, &output_string)) { - LOG(WARNING) << "Printing to std::string failed."; - return false; - } - break; - } - const std::string output_file_name = StrCat(file_name, file_type_suffix); - VLOG(1) << "Writing " << output_string.size() << " bytes to " - << output_file_name; - if (!file::SetContents(output_file_name, output_string, file::Defaults()) - .ok()) { - LOG(WARNING) << "Writing to file failed."; - return false; - } - return true; -} - namespace { void WriteFullProtocolMessage(const google::protobuf::Message& message, int indent_level, std::string* out) { diff --git a/ortools/util/proto_tools.h b/ortools/util/proto_tools.h index 6bc5216a23..3a4a7b5e5f 100644 --- a/ortools/util/proto_tools.h +++ b/ortools/util/proto_tools.h @@ -15,24 +15,10 @@ #define OR_TOOLS_UTIL_PROTO_TOOLS_H_ #include + #include "google/protobuf/message.h" namespace operations_research { - -enum class ProtoWriteFormat { kProtoText, kProtoBinary, kJson }; - -// Exactly like file::ReadFileToProto() but also supports GZipped files and -// JSON. -bool ReadFileToProto(const std::string& file_name, google::protobuf::Message* proto); - -// Like file::WriteProtoToFile() or file::WriteProtoToASCIIFile(), but also -// supports JSON and GZipped output. -// If 'proto_write_format' is kProtoBinary, ".bin" is appended to file_name. -// If 'proto_write_format' is kJson, ".json" is appended to file_name. -// If 'gzipped' is true, ".gz" is appended to file_name. -bool WriteProtoToFile(const std::string& file_name, const google::protobuf::Message& proto, - ProtoWriteFormat proto_write_format, bool gzipped); - // Prints a proto2 message as a std::string, it behaves like TextFormat::Print() // but also prints the default values of unset fields which is useful for // printing parameters. diff --git a/ortools/util/random_engine.h b/ortools/util/random_engine.h new file mode 100644 index 0000000000..dd91af6643 --- /dev/null +++ b/ortools/util/random_engine.h @@ -0,0 +1,27 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Defines the random engine type to use within operations_research code. + +#ifndef OR_TOOLS_UTIL_RANDOM_ENGINE_H_ +#define OR_TOOLS_UTIL_RANDOM_ENGINE_H_ + +#include + +namespace operations_research { + +using random_engine_t = std::default_random_engine; + +} // namespace operations_research + +#endif // OR_TOOLS_UTIL_RANDOM_ENGINE_H_