diff --git a/ortools/base/inlined_vector.h b/ortools/base/inlined_vector.h index dc70fb04e4..ad6f0413c7 100644 --- a/ortools/base/inlined_vector.h +++ b/ortools/base/inlined_vector.h @@ -134,6 +134,15 @@ class InlinedVector { u_.data[kSize - 1] = 0; } + template + void assign( + InputIterator range_start, InputIterator range_end, + typename std::enable_if::value>::type* = + NULL) { + clear(); + AppendRange(range_start, range_end); + } + // Return the ith element // REQUIRES: 0 <= i < size() const value_type& at(size_t i) const { diff --git a/ortools/base/join.cc b/ortools/base/join.cc index d4556f982d..ab63f8651a 100644 --- a/ortools/base/join.cc +++ b/ortools/base/join.cc @@ -11,9 +11,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "ortools/base/join.h" #include #include "ortools/base/basictypes.h" -#include "ortools/base/join.h" #include "ortools/base/string_view.h" #include "ortools/base/stringprintf.h" diff --git a/ortools/base/map_util.h b/ortools/base/map_util.h index 9b52aa18ad..385d8a1094 100644 --- a/ortools/base/map_util.h +++ b/ortools/base/map_util.h @@ -161,6 +161,26 @@ const typename Collection::value_type::second_type& FindOrDie( return it->second; } +// Same as FindOrDie above, but doesn't log the key on failure. +template +const typename Collection::value_type::second_type& FindOrDieNoPrint( + const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + CHECK(it != collection.end()) << "Map key not found"; + return it->second; +} + +// Same as above, but returns a non-const reference. +template +typename Collection::value_type::second_type& FindOrDieNoPrint( + Collection& collection, // NOLINT + const typename Collection::value_type::first_type& key) { + typename Collection::iterator it = collection.find(key); + CHECK(it != collection.end()) << "Map key not found"; + return it->second; +} + // Lookup a key in a map or std::unordered_map, insert it if it is not present. // Returns a reference to the value associated with the key. template diff --git a/ortools/bop/integral_solver.cc b/ortools/bop/integral_solver.cc index 39f6bda5da..57476bac83 100644 --- a/ortools/bop/integral_solver.cc +++ b/ortools/bop/integral_solver.cc @@ -617,36 +617,32 @@ void IntegralProblemConverter::ConvertAllConstraints( linear_problem.constraint_lower_bounds()[row]; if (lower_bound != -kInfinity) { const Fractional offset_lower_bound = lower_bound - offset; - if (offset_lower_bound * scaling_factor > - static_cast(kint64max)) { + const double offset_scaled_lower_bound = + round(offset_lower_bound * scaling_factor - bound_error); + if (offset_scaled_lower_bound >= static_cast(kint64max)) { LOG(WARNING) << "A constraint is trivially unsatisfiable."; return; } - if (offset_lower_bound * scaling_factor > - -static_cast(kint64max)) { + if (offset_scaled_lower_bound > -static_cast(kint64max)) { // Otherwise, the constraint is not needed. constraint->set_lower_bound( - static_cast( - round(offset_lower_bound * scaling_factor - bound_error)) / - gcd); + static_cast(offset_scaled_lower_bound) / gcd); } } const Fractional upper_bound = linear_problem.constraint_upper_bounds()[row]; if (upper_bound != kInfinity) { const Fractional offset_upper_bound = upper_bound - offset; - if (offset_upper_bound * scaling_factor < - -static_cast(kint64max)) { + const double offset_scaled_upper_bound = + round(offset_upper_bound * scaling_factor + bound_error); + if (offset_scaled_upper_bound <= -static_cast(kint64max)) { LOG(WARNING) << "A constraint is trivially unsatisfiable."; return; } - if (offset_upper_bound * scaling_factor < - static_cast(kint64max)) { + if (offset_scaled_upper_bound < static_cast(kint64max)) { // Otherwise, the constraint is not needed. constraint->set_upper_bound( - static_cast( - round(offset_upper_bound * scaling_factor + bound_error)) / - gcd); + static_cast(offset_scaled_upper_bound) / gcd); } } } diff --git a/ortools/flatzinc/cp_model_fz_solver.cc b/ortools/flatzinc/cp_model_fz_solver.cc index aafd60d073..e3a1e24dd2 100644 --- a/ortools/flatzinc/cp_model_fz_solver.cc +++ b/ortools/flatzinc/cp_model_fz_solver.cc @@ -423,6 +423,50 @@ void CpModelProtoWithMapping::FillConstraint(const fz::Constraint& fz_ct, } ++index; } + } else if (fz_ct.type == "inverse") { + auto* arg = ct->mutable_inverse(); + + const auto direct_variables = LookupVars(fz_ct.arguments[0]); + const auto inverse_variables = LookupVars(fz_ct.arguments[1]); + + const int num_variables = + std::min(direct_variables.size(), inverse_variables.size()); + + // Try to auto-detect if it is zero or one based. + bool found_zero = false; + bool found_size = false; + for (fz::IntegerVariable* const var : fz_ct.arguments[0].variables) { + if (var->domain.Min() == 0) found_zero = true; + if (var->domain.Max() == num_variables) found_size = true; + } + for (fz::IntegerVariable* const var : fz_ct.arguments[1].variables) { + if (var->domain.Min() == 0) found_zero = true; + if (var->domain.Max() == num_variables) found_size = true; + } + + // Add a dummy constant variable at zero if the indexing is one based. + const bool is_one_based = !found_zero || found_size; + const int offset = is_one_based ? 1 : 0; + + if (is_one_based) arg->add_f_direct(LookupConstant(0)); + for (const int var : direct_variables) { + arg->add_f_direct(var); + // Intersect domains with offset + [0, num_variables). + FillDomain(IntersectionOfSortedDisjointIntervals( + ReadDomain(proto.variables(var)), + {{offset, num_variables - 1 + offset}}), + proto.mutable_variables(var)); + } + + if (is_one_based) arg->add_f_inverse(LookupConstant(0)); + for (const int var : inverse_variables) { + arg->add_f_inverse(var); + // Intersect domains with offset + [0, num_variables). + FillDomain(IntersectionOfSortedDisjointIntervals( + ReadDomain(proto.variables(var)), + {{offset, num_variables - 1 + offset}}), + proto.mutable_variables(var)); + } } else if (fz_ct.type == "cumulative") { const std::vector starts = LookupVars(fz_ct.arguments[0]); const std::vector durations = LookupVars(fz_ct.arguments[1]); diff --git a/ortools/flatzinc/parser.yy b/ortools/flatzinc/parser.yy index 0ff3644ad1..e58f50ddd1 100644 --- a/ortools/flatzinc/parser.yy +++ b/ortools/flatzinc/parser.yy @@ -36,7 +36,7 @@ typedef operations_research::fz::LexerInfo YYSTYPE; // Code in the implementation file. %code { // MOE:begin_strip -#include "absl/ortools/base/string_view_utils.h" +#include "ortools/base/string_view_utils.h" // MOE:end_strip #include "ortools/flatzinc/parser_util.cc" diff --git a/ortools/flatzinc/sat_fz_solver.cc b/ortools/flatzinc/sat_fz_solver.cc index b09025421b..6c2d7e1e46 100644 --- a/ortools/flatzinc/sat_fz_solver.cc +++ b/ortools/flatzinc/sat_fz_solver.cc @@ -89,10 +89,9 @@ IntegerVariable SatModel::LookupVar(fz::IntegerVariable* var) { // Otherwise, this must be a Boolean and we must construct the IntegerVariable // associated with it. const Literal lit = FindOrDie(bool_map, var); - const IntegerVariable int_var = model.Add(NewIntegerVariable(0, 1)); + const IntegerVariable int_var = + model.GetOrCreate()->GetIntegerView(lit); InsertOrDie(&var_map, var, int_var); - model.GetOrCreate()->FullyEncodeVariableUsingGivenLiterals( - int_var, {lit.Negated(), lit}, {IntegerValue(0), IntegerValue(1)}); return int_var; } diff --git a/ortools/glop/BUILD b/ortools/glop/BUILD index f4b202f81b..e6cbdfbdcb 100644 --- a/ortools/glop/BUILD +++ b/ortools/glop/BUILD @@ -77,6 +77,7 @@ cc_library( "//ortools/util:file_util", "//ortools/util:fp_utils", "//ortools/util:iterators", + "//ortools/util:random_engine", "//ortools/util:stats", "//ortools/util:time_limit", ], diff --git a/ortools/linear_solver/linear_solver.cc b/ortools/linear_solver/linear_solver.cc index e5a2c6d36d..c7dea9550c 100644 --- a/ortools/linear_solver/linear_solver.cc +++ b/ortools/linear_solver/linear_solver.cc @@ -265,7 +265,11 @@ double MPObjective::BestBound() const { double MPVariable::solution_value() const { if (!interface_->CheckSolutionIsSynchronizedAndExists()) return 0.0; - return integer_ ? round(solution_value_) : solution_value_; + // If the underlying solver supports integer variables, and this is an integer + // variable, we round the solution value (i.e., clients usually expect precise + // integer values for integer variables). + return (integer_ && interface_->IsMIP()) ? round(solution_value_) + : solution_value_; } double MPVariable::unrounded_solution_value() const { @@ -901,6 +905,13 @@ bool MPSolver::HasInfeasibleConstraints() const { return hasInfeasibleConstraints; } +bool MPSolver::HasIntegerVariables() const { + for (const MPVariable* const variable : variables_) { + if (variable->integer()) return true; + } + return false; +} + MPSolver::ResultStatus MPSolver::Solve() { MPSolverParameters default_param; return Solve(default_param); @@ -1055,7 +1066,7 @@ bool MPSolver::VerifySolution(double tolerance, bool log_errors) const { } } // Check integrality. - if (var.integer()) { + if (IsMIP() && var.integer()) { if (fabs(value - round(value)) > tolerance) { ++num_errors; max_observed_error = @@ -1065,6 +1076,12 @@ bool MPSolver::VerifySolution(double tolerance, bool log_errors) const { } } } + if (!IsMIP() && HasIntegerVariables()) { + LOG_IF(INFO, log_errors) << "Skipped variable integrality check, because " + << "a continuous relaxation of the model was " + << "solved (i.e., the selected solver does not " + << "support integer variables)."; + } // Verify constraints. const std::vector activities = ComputeConstraintActivities(); diff --git a/ortools/linear_solver/linear_solver.h b/ortools/linear_solver/linear_solver.h index fb5dc7e3f8..d8ba4c62ca 100644 --- a/ortools/linear_solver/linear_solver.h +++ b/ortools/linear_solver/linear_solver.h @@ -565,6 +565,7 @@ class MPSolver { friend class MPSolverInterface; friend class GLOPInterface; friend class BopInterface; + friend class SatInterface; friend class KnapsackInterface; // Debugging: verify that the given MPVariable* belongs to this solver. @@ -580,6 +581,9 @@ class MPSolver { // Returns true if the model has constraints with lower bound > upper bound. bool HasInfeasibleConstraints() const; + // Returns true if the model has at least 1 integer variable. + bool HasIntegerVariables() const; + // The name of the linear programming problem. const std::string name_; @@ -726,6 +730,7 @@ class MPObjective { friend class CplexInterface; friend class GLOPInterface; friend class BopInterface; + friend class SatInterface; friend class KnapsackInterface; // Constructor. An objective points to a single MPSolverInterface @@ -801,6 +806,7 @@ class MPVariable { friend class GLOPInterface; friend class MPVariableSolutionValueTest; friend class BopInterface; + friend class SatInterface; friend class KnapsackInterface; // Constructor. A variable points to a single MPSolverInterface that @@ -901,6 +907,7 @@ class MPConstraint { friend class CplexInterface; friend class GLOPInterface; friend class BopInterface; + friend class SatInterface; friend class KnapsackInterface; // Constructor. A constraint points to a single MPSolverInterface diff --git a/ortools/linear_solver/linear_solver.proto b/ortools/linear_solver/linear_solver.proto index 54cb0f4d62..cf0177325c 100644 --- a/ortools/linear_solver/linear_solver.proto +++ b/ortools/linear_solver/linear_solver.proto @@ -227,6 +227,7 @@ message MPModelRequest { GUROBI_MIXED_INTEGER_PROGRAMMING = 7; // Commercial, needs a valid license. CPLEX_MIXED_INTEGER_PROGRAMMING = 11; // Commercial, needs a valid license. BOP_INTEGER_PROGRAMMING = 12; + SAT_INTEGER_PROGRAMMING = 14; KNAPSACK_MIXED_INTEGER_PROGRAMMING = 13; } diff --git a/ortools/lp_data/lp_data.cc b/ortools/lp_data/lp_data.cc index 02ca21f3b6..1910ce046b 100644 --- a/ortools/lp_data/lp_data.cc +++ b/ortools/lp_data/lp_data.cc @@ -27,6 +27,7 @@ #include "ortools/lp_data/lp_print_utils.h" #include "ortools/lp_data/lp_utils.h" #include "ortools/lp_data/matrix_utils.h" +#include "ortools/lp_data/permutation.h" namespace operations_research { namespace glop { @@ -817,6 +818,58 @@ void LinearProgram::PopulateFromLinearProgram( first_slack_variable_ = linear_program.first_slack_variable_; } +void LinearProgram::PopulateFromPermutedLinearProgram( + const LinearProgram& lp, const RowPermutation& row_permutation, + const ColumnPermutation& col_permutation) { + DCHECK(lp.IsCleanedUp()); + DCHECK_EQ(row_permutation.size(), lp.num_constraints()); + DCHECK_EQ(col_permutation.size(), lp.num_variables()); + DCHECK_EQ(lp.GetFirstSlackVariable(), kInvalidCol); + Clear(); + + // Populate matrix coefficients. + ColumnPermutation inverse_col_permutation; + inverse_col_permutation.PopulateFromInverse(col_permutation); + matrix_.PopulateFromPermutedMatrix(lp.matrix_, row_permutation, + inverse_col_permutation); + ClearTransposeMatrix(); + + // Populate constraints. + ApplyPermutation(row_permutation, lp.constraint_lower_bounds(), + &constraint_lower_bounds_); + ApplyPermutation(row_permutation, lp.constraint_upper_bounds(), + &constraint_upper_bounds_); + + // Populate variables. + ApplyPermutation(col_permutation, lp.objective_coefficients(), + &objective_coefficients_); + ApplyPermutation(col_permutation, lp.variable_lower_bounds(), + &variable_lower_bounds_); + ApplyPermutation(col_permutation, lp.variable_upper_bounds(), + &variable_upper_bounds_); + ApplyPermutation(col_permutation, lp.variable_types(), &variable_types_); + integer_variables_list_is_consistent_ = false; + + // There is no vector based accessor to names, because they may be created + // on the fly. + constraint_names_.resize(lp.num_constraints()); + for (RowIndex old_row(0); old_row < lp.num_constraints(); ++old_row) { + const RowIndex new_row = row_permutation[old_row]; + constraint_names_[new_row] = lp.constraint_names_[old_row]; + } + variable_names_.resize(lp.num_variables()); + for (ColIndex old_col(0); old_col < lp.num_variables(); ++old_col) { + const ColIndex new_col = col_permutation[old_col]; + variable_names_[new_col] = lp.variable_names_[old_col]; + } + + // Populate singular fields. + maximize_ = lp.maximize_; + objective_offset_ = lp.objective_offset_; + objective_scaling_factor_ = lp.objective_scaling_factor_; + name_ = lp.name_; +} + void LinearProgram::PopulateFromLinearProgramVariables( const LinearProgram& linear_program) { matrix_.PopulateFromZero(RowIndex(0), linear_program.num_variables()); @@ -1268,6 +1321,57 @@ bool LinearProgram::IsInEquationForm() const { IsRightMostSquareMatrixIdentity(matrix_); } +bool LinearProgram::BoundsOfIntegerVariablesAreInteger( + Fractional tolerance) const { + for (const ColIndex col : IntegerVariablesList()) { + if ((IsFinite(variable_lower_bounds_[col]) && + !IsIntegerWithinTolerance(variable_lower_bounds_[col], tolerance)) || + (IsFinite(variable_upper_bounds_[col]) && + !IsIntegerWithinTolerance(variable_upper_bounds_[col], tolerance))) { + VLOG(1) << "Bounds of variable " << col.value() << " are non-integer (" + << variable_lower_bounds_[col] << ", " + << variable_upper_bounds_[col] << ")."; + return false; + } + } + return true; +} + +bool LinearProgram::BoundsOfIntegerConstraintsAreInteger( + Fractional tolerance) const { + // Using transpose for this is faster (complexity = O(number of non zeros in + // matrix)) than directly iterating through entries (complexity = O(number of + // constraints * number of variables)). + const SparseMatrix& transpose = GetTransposeSparseMatrix(); + for (RowIndex row = RowIndex(0); row < num_constraints(); ++row) { + bool integer_constraint = true; + for (const SparseColumn::Entry var : transpose.column(RowToColIndex(row))) { + if (!IsVariableInteger(RowToColIndex(var.row()))) { + integer_constraint = false; + break; + } + if (!IsIntegerWithinTolerance(var.coefficient(), tolerance)) { + integer_constraint = false; + break; + } + } + if (integer_constraint) { + if ((IsFinite(constraint_lower_bounds_[row]) && + !IsIntegerWithinTolerance(constraint_lower_bounds_[row], + tolerance)) || + (IsFinite(constraint_upper_bounds_[row]) && + !IsIntegerWithinTolerance(constraint_upper_bounds_[row], + tolerance))) { + VLOG(1) << "Bounds of constraint " << row.value() + << " are non-integer (" << constraint_lower_bounds_[row] << ", " + << constraint_upper_bounds_[row] << ")."; + return false; + } + } + } + return true; +} + // -------------------------------------------------------- // ProblemSolution // -------------------------------------------------------- diff --git a/ortools/lp_data/lp_data.h b/ortools/lp_data/lp_data.h index 2c6dd6ed6a..096e276e6c 100644 --- a/ortools/lp_data/lp_data.h +++ b/ortools/lp_data/lp_data.h @@ -107,14 +107,6 @@ class LinearProgram { // Set the type of the variable. void SetVariableType(ColIndex col, VariableType type); - // Records the fact that the variable at column col must only take integer - // values. - // Note(user): For the time being, this is not handled. The continuous - // relaxation of the problem (with integrality constraints removed) is solved - // instead. - // TODO(user): Improve the support of integer variables. - void SetVariableIntegrality(ColIndex col, bool is_integer); - // Returns whether the variable at column col is constrained to be integer. bool IsVariableInteger(ColIndex col) const; @@ -413,6 +405,13 @@ class LinearProgram { // Populates the calling object with the given LinearProgram. void PopulateFromLinearProgram(const LinearProgram& linear_program); + // Populates the calling object with the given LinearProgram while permuting + // variables and constraints. This is useful mainly for testing to generate + // a model with the same optimal objective value. + void PopulateFromPermutedLinearProgram( + const LinearProgram& lp, const RowPermutation& row_permutation, + const ColumnPermutation& col_permutation); + // Populates the calling object with the variables of the given LinearProgram. // The function preserves the bounds, the integrality, the names of the // variables and their objective coefficients. No constraints are copied (the @@ -484,6 +483,14 @@ class LinearProgram { // of the literature. bool IsInEquationForm() const; + // Returns true if all integer variables in the linear program have strictly + // integer bounds. + bool BoundsOfIntegerVariablesAreInteger(Fractional tolerance) const; + + // Returns true if all integer constraints in the linear program have strictly + // integer bounds. + bool BoundsOfIntegerConstraintsAreInteger(Fractional tolerance) const; + private: // A helper function that updates the vectors integer_variables_list_, // binary_variables_list_, and non_binary_variables_list_. diff --git a/ortools/lp_data/lp_types.h b/ortools/lp_data/lp_types.h index a78d293677..c8c570e494 100644 --- a/ortools/lp_data/lp_types.h +++ b/ortools/lp_data/lp_types.h @@ -246,8 +246,7 @@ ConstraintStatus VariableToConstraintStatus(VariableStatus status); // to use the index type for the size. // // TODO(user): This should probably move into ITIVector, but note that this -// version is more strict and does not allow any other size types nor a resize() -// or creation with a default value. +// version is more strict and does not allow any other size types. template class StrictITIVector : public ITIVector { public: @@ -260,14 +259,21 @@ class StrictITIVector : public ITIVector { : ParentType(init_list.begin(), init_list.end()) {} #endif StrictITIVector() : ParentType() {} + explicit StrictITIVector(IntType size) : ParentType(size.value()) {} StrictITIVector(IntType size, const T& v) : ParentType(size.value(), v) {} template StrictITIVector(InputIteratorType first, InputIteratorType last) : ParentType(first, last) {} + + void resize(IntType size) { ParentType::resize(size.value()); } void resize(IntType size, const T& v) { ParentType::resize(size.value(), v); } + void assign(IntType size, const T& v) { ParentType::assign(size.value(), v); } + IntType size() const { return IntType(ParentType::size()); } + IntType capacity() const { return IntType(ParentType::capacity()); } + // Since calls to resize() must use a default value, we introduce a new // function for convenience to reduce the size of a vector. void resize_down(IntType size) { diff --git a/ortools/lp_data/sparse.cc b/ortools/lp_data/sparse.cc index 2b60b5c9ab..eeff1792ec 100644 --- a/ortools/lp_data/sparse.cc +++ b/ortools/lp_data/sparse.cc @@ -279,7 +279,7 @@ void SparseMatrix::DeleteColumns(const DenseBooleanRow& columns_to_delete) { ++new_index; } } - columns_.resize_down(new_index); + columns_.resize(new_index); } void SparseMatrix::DeleteRows(RowIndex new_num_rows, diff --git a/ortools/sat/BUILD b/ortools/sat/BUILD index 9d0e6a6478..4316ba8ae2 100644 --- a/ortools/sat/BUILD +++ b/ortools/sat/BUILD @@ -57,8 +57,8 @@ cc_library( ":cp_model_utils", "//ortools/base", "//ortools/base:hash", - "//ortools/base:strings", "//ortools/base:map_util", + "//ortools/base:strings", "//ortools/util:saturated_arithmetic", "//ortools/util:sorted_interval_list", ], @@ -71,9 +71,9 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":all_different", + ":cp_model_cc_proto", ":cp_model_checker", ":cp_model_presolve", - ":cp_model_cc_proto", ":cp_model_utils", ":cumulative", ":disjunctive", @@ -85,9 +85,9 @@ cc_library( ":sat_solver", ":table", "//ortools/base", - "//ortools/base:strings", "//ortools/base:hash", "//ortools/base:stl_util", + "//ortools/base:strings", "//ortools/graph:connectivity", ], ) @@ -98,14 +98,14 @@ cc_library( hdrs = ["cp_model_presolve.h"], visibility = ["//visibility:public"], deps = [ - ":cp_model_checker", ":cp_model_cc_proto", + ":cp_model_checker", ":cp_model_utils", "//ortools/base", "//ortools/base:hash", - "//ortools/base:strings", "//ortools/base:map_util", "//ortools/base:stl_util", + "//ortools/base:strings", "//ortools/util:affine_relation", "//ortools/util:bitset", "//ortools/util:sorted_interval_list", @@ -184,6 +184,7 @@ cc_library( "//ortools/base:inlined_vector", "//ortools/base:int_type", "//ortools/base:int_type_indexed_vector", + "//ortools/base:random", "//ortools/base:span", "//ortools/base:stl_util", "//ortools/base:strings", @@ -191,7 +192,6 @@ cc_library( "//ortools/util:random_engine", "//ortools/util:stats", "//ortools/util:time_limit", - "//ortools/base:random", ], ) @@ -330,12 +330,13 @@ cc_library( name = "all_different", srcs = ["all_different.cc"], hdrs = ["all_different.h"], - deps = [ + deps = [ ":integer", ":model", ":sat_base", ":sat_solver", "//ortools/base:strongly_connected_components", + "//ortools/util:sort", ], ) @@ -354,6 +355,7 @@ cc_library( hdrs = ["disjunctive.h"], copts = [W_FLOAT_CONVERSION], deps = [ + ":all_different", ":cp_constraints", ":integer", ":intervals", diff --git a/ortools/sat/all_different.cc b/ortools/sat/all_different.cc index 2704f3cab9..fff9c94937 100644 --- a/ortools/sat/all_different.cc +++ b/ortools/sat/all_different.cc @@ -117,8 +117,7 @@ AllDifferentConstraint::AllDifferentConstraint( // Force full encoding if not already done. if (!encoder->VariableIsFullyEncoded(variables_[x])) { - encoder->FullyEncodeVariable( - variables_[x], integer_trail_->InitialVariableDomain(variables_[x])); + encoder->FullyEncodeVariable(variables_[x]); } // Fill cache with literals, default value is kFalseLiteralIndex. diff --git a/ortools/sat/cp_model.proto b/ortools/sat/cp_model.proto index 49bcf181f0..842da5cfe3 100644 --- a/ortools/sat/cp_model.proto +++ b/ortools/sat/cp_model.proto @@ -148,6 +148,13 @@ message TableConstraintProto { bool negated = 3; } +// The two arrays of variable each represent a function, the second is the +// inverse of the first: f_direct[i] == j <=> f_inverse[j] == i. +message InverseConstraintProto { + repeated int32 f_direct = 1; + repeated int32 f_inverse = 2; +} + // This constraint forces a sequence of variables to be accepted by an automata. message AutomataConstraintProto { // A state is identified by a non-negative number. It is preferable to keep @@ -198,6 +205,7 @@ message ConstraintProto { CircuitConstraintProto circuit = 15; TableConstraintProto table = 16; AutomataConstraintProto automata = 17; + InverseConstraintProto inverse = 18; // Constraints on intervals. // @@ -207,10 +215,10 @@ message ConstraintProto { // // TODO(user): Explain what happen for intervals of size zero. Some // constraints ignore them, other do take them into account. - IntervalConstraintProto interval = 18; - CumulativeConstraintProto cumulative = 19; + IntervalConstraintProto interval = 19; NoOverlapConstraintProto no_overlap = 20; NoOverlap2DConstraintProto no_overlap_2d = 21; + CumulativeConstraintProto cumulative = 22; } } diff --git a/ortools/sat/cp_model_checker.cc b/ortools/sat/cp_model_checker.cc index 52de016194..db8c8a4baf 100644 --- a/ortools/sat/cp_model_checker.cc +++ b/ortools/sat/cp_model_checker.cc @@ -368,7 +368,7 @@ class ConstraintChecker { const ConstraintProto& ct) { const int index = Value(ct.element().index()); return Value(ct.element().vars().Get(index)) == - Value(ct.element().target()); + Value(ct.element().target()); } bool TableConstraintIsFeasible(const CpModelProto& model, @@ -440,6 +440,19 @@ class ConstraintChecker { return num_visited + num_inactive == num_nodes; } + bool InverseConstraintIsFeasible(const CpModelProto& model, + const ConstraintProto& ct) { + const int num_variables = ct.inverse().f_direct_size(); + if (num_variables != ct.inverse().f_inverse_size()) return false; + // Check that f_inverse(f_direct(i)) == i; this is sufficient. + for (int i = 0; i < num_variables; i++) { + const int fi = Value(ct.inverse().f_direct(i)); + if (fi < 0 || num_variables <= fi) return false; + if (i != Value(ct.inverse().f_inverse(fi))) return false; + } + return true; + } + private: std::vector variable_values_; }; @@ -529,6 +542,9 @@ bool SolutionIsFeasible(const CpModelProto& model, case ConstraintProto::ConstraintCase::kCircuit: is_feasible = checker.CircuitConstraintIsFeasible(model, ct); break; + case ConstraintProto::ConstraintCase::kInverse: + is_feasible = checker.InverseConstraintIsFeasible(model, ct); + break; case ConstraintProto::ConstraintCase::CONSTRAINT_NOT_SET: // Empty constraint is always feasible. break; diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index b7ccb3920e..42abeacf33 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -31,6 +31,7 @@ #include "ortools/sat/optimization.h" #include "ortools/sat/sat_solver.h" #include "ortools/sat/table.h" +#include "ortools/util/saturated_arithmetic.h" namespace operations_research { namespace sat { @@ -207,7 +208,16 @@ class ModelWithMapping { return result; } + // Returns true if we should not load this constraint. This is mainly used to + // skip constraints that correspond to a basic encoding detected by + // ExtractEncoding(). + bool IgnoreConstraint(const ConstraintProto* ct) const { + return ContainsKey(ct_to_ignore_, ct); + } + private: + void ExtractEncoding(const CpModelProto& model_proto); + Model* model_; // Note that only the variables used by at leat one constraint will be @@ -218,6 +228,10 @@ class ModelWithMapping { // Used to return a feasible solution for the unused variables. std::vector lower_bounds_; + + // Set of constraints to ignore because they where already dealt with by + // ExtractEncoding(). + std::unordered_set ct_to_ignore_; }; template @@ -225,6 +239,188 @@ std::vector ValuesFromProto(const Values& values) { return std::vector(values.begin(), values.end()); } +// Returns the size of the given domain capped to int64max. +int64 DomainSize(const std::vector& domain) { + int64 size = 0; + for (const ClosedInterval interval : domain) { + size += operations_research::CapAdd( + 1, operations_research::CapSub(interval.end, interval.start)); + } + return size; +} + +// The logic assumes that the linear constraints have been presolved, so that +// equality with a domain bound have been converted to <= or >= and so that we +// never have any trivial inequalities. +void ModelWithMapping::ExtractEncoding(const CpModelProto& model_proto) { + IntegerEncoder* encoder = GetOrCreate(); + + // Detection of literal equivalent to (i_var == value). We collect all the + // half-reified constraint lit => equality or lit => inequality for a given + // variable, and we will later sort them to detect equivalence. + struct EqualityDetectionHelper { + const ConstraintProto* ct; + sat::Literal literal; + int64 value; + bool is_equality; // false if != instead. + + bool operator<(const EqualityDetectionHelper& o) const { + if (literal.Variable() == o.literal.Variable()) { + if (value == o.value) return is_equality && !o.is_equality; + return value < o.value; + } + return literal.Variable() < o.literal.Variable(); + } + }; + std::vector> var_to_equalities( + model_proto.variables_size()); + + // Detection of literal equivalent to (i_var >= bound). We also collect + // all the half-refied part and we will sort the vector for detection of the + // equivalence. + struct InequalityDetectionHelper { + const ConstraintProto* ct; + sat::Literal literal; + IntegerLiteral i_lit; + + bool operator<(const InequalityDetectionHelper& o) const { + if (literal.Variable() == o.literal.Variable()) { + return i_lit.Var() < o.i_lit.Var(); + } + return literal.Variable() < o.literal.Variable(); + } + }; + std::vector inequalities; + + // Loop over all contraints and fill var_to_equalities and inequalities. + for (const ConstraintProto& ct : model_proto.constraints()) { + // For now, we only look at linear constraints with one term and an + // enforcement literal. + if (ct.enforcement_literal().empty()) continue; + if (ct.constraint_case() != ConstraintProto::ConstraintCase::kLinear) { + continue; + } + if (ct.linear().vars_size() != 1) continue; + + const sat::Literal enforcement_literal = Literal(ct.enforcement_literal(0)); + const int ref = ct.linear().vars(0); + const int var = PositiveRef(ref); + const auto rhs = InverseMultiplicationOfSortedDisjointIntervals( + ReadDomain(ct.linear()), + ct.linear().coeffs(0) * (RefIsPositive(ref) ? 1 : -1)); + + // Detect enforcement_literal => (var >= value or var <= value). + if (rhs.size() == 1) { + // We relax by 1 because we may take the negation of the rhs above. + if (rhs[0].end >= kint64max - 1) { + inequalities.push_back({&ct, enforcement_literal, + IntegerLiteral::GreaterOrEqual( + Integer(var), IntegerValue(rhs[0].start))}); + } else if (rhs[0].start <= kint64min + 1) { + inequalities.push_back({&ct, enforcement_literal, + IntegerLiteral::LowerOrEqual( + Integer(var), IntegerValue(rhs[0].end))}); + } + } + + // Detect enforcement_literal => (var == value or var != value). + // + // Note that for domain with 2 values like [0, 1], we will detect both == 0 + // and != 1. Similarly, for a domain in [min, max], we should both detect + // (== min) and (<= min), and both detect (== max) and (>= max). + const auto domain = ReadDomain(model_proto.variables(var)); + { + const auto inter = IntersectionOfSortedDisjointIntervals(domain, rhs); + if (inter.size() == 1 && inter[0].start == inter[0].end) { + var_to_equalities[var].push_back( + {&ct, enforcement_literal, inter[0].start, true}); + } + } + { + const auto inter = IntersectionOfSortedDisjointIntervals( + domain, ComplementOfSortedDisjointIntervals(rhs)); + if (inter.size() == 1 && inter[0].start == inter[0].end) { + var_to_equalities[var].push_back( + {&ct, enforcement_literal, inter[0].start, false}); + } + } + } + + // Detect Literal <=> X >= value + int num_inequalities = 0; + std::sort(inequalities.begin(), inequalities.end()); + for (int i = 0; i + 1 < inequalities.size(); i++) { + if (inequalities[i].literal != inequalities[i + 1].literal.Negated()) { + continue; + } + const auto pair_a = encoder->Canonicalize(inequalities[i].i_lit); + const auto pair_b = encoder->Canonicalize(inequalities[i + 1].i_lit); + if (pair_a.first == pair_b.second) { + ++num_inequalities; + encoder->AssociateToIntegerLiteral(inequalities[i].literal, + inequalities[i].i_lit); + ct_to_ignore_.insert(inequalities[i].ct); + ct_to_ignore_.insert(inequalities[i + 1].ct); + } + } + if (!inequalities.empty()) { + VLOG(1) << num_inequalities << " literals associated to VAR >= value (cts: " + << inequalities.size() << ")"; + } + + // Detect Literal <=> X == value and fully encoded variables. + int num_constraints = 0; + int num_equalities = 0; + int num_fully_encoded = 0; + int num_partially_encoded = 0; + for (int i = 0; i < var_to_equalities.size(); ++i) { + std::vector& encoding = var_to_equalities[i]; + std::sort(encoding.begin(), encoding.end()); + if (encoding.empty()) continue; + num_constraints += encoding.size(); + + std::unordered_set values; + for (int j = 0; j + 1 < encoding.size(); j++) { + if ((encoding[j].value != encoding[j + 1].value) || + (encoding[j].literal != encoding[j + 1].literal.Negated()) || + (encoding[j].is_equality != true) || + (encoding[j + 1].is_equality != false)) { + continue; + } + + ++num_equalities; + encoder->AssociateToIntegerEqualValue(encoding[j].literal, integers_[i], + IntegerValue(encoding[j].value)); + ct_to_ignore_.insert(encoding[j].ct); + ct_to_ignore_.insert(encoding[j + 1].ct); + values.insert(encoding[j].value); + } + + // Detect fully encoded variables and mark them as such. + // + // TODO(user): Also fully encode variable that are almost fully encoded. + const std::vector domain = + ReadDomain(model_proto.variables(i)); + if (DomainSize(domain) == values.size()) { + ++num_fully_encoded; + encoder->FullyEncodeVariable(integers_[i]); + } else { + ++num_partially_encoded; + } + } + if (num_constraints > 0) { + VLOG(1) << num_equalities + << " literals associated to VAR == value (cts: " << num_constraints + << ")"; + } + if (num_fully_encoded > 0) { + VLOG(1) << "num_fully_encoded_variables: " << num_fully_encoded; + } + if (num_partially_encoded > 0) { + VLOG(1) << "num_partially_encoded_variables: " << num_partially_encoded; + } +} + // Extracts all the used variables in the CpModelProto and creates a sat::Model // representation for them. ModelWithMapping::ModelWithMapping(const CpModelProto& model_proto, @@ -240,10 +436,6 @@ ModelWithMapping::ModelWithMapping(const CpModelProto& model_proto, lower_bounds_[i] = model_proto.variables(i).domain(0); } - // TODO(user): Detect integers that are the negation of other variable. This - // cannot be simplified by the presolve in the current proto format. - std::vector domain; - std::vector domain_is_boolean; for (const int i : usage.integers) { const auto& var_proto = model_proto.variables(i); integers_[i] = Add(NewIntegerVariable(ReadDomain(var_proto))); @@ -259,23 +451,260 @@ ModelWithMapping::ModelWithMapping(const CpModelProto& model_proto, for (const int i : usage.booleans) { booleans_[i] = Add(NewBooleanVariable()); - - // We need to fix the Boolean if the domain of the integer variable do not - // contain 0 or contains only zero! Note that this case should not appear - // once the model is presolved. - std::vector domain = - ValuesFromProto(model_proto.variables(i).domain()); - if (domain[0] == 0 && domain[1] == 0) { + const auto domain = ReadDomain(model_proto.variables(i)); + CHECK_EQ(domain.size(), 1); + if (domain[0].start == 0 && domain[0].end == 0) { // Fix to false. Add(ClauseConstraint({sat::Literal(booleans_[i], false)})); - } else if (!DomainInProtoContains(model_proto.variables(i), 0)) { + } else if (domain[0].start == 1 && domain[0].end == 1) { // Fix to true. Add(ClauseConstraint({sat::Literal(booleans_[i], true)})); } else if (integers_[i] != kNoIntegerVariable) { - Add(ReifiedInInterval(integers_[i], 0, 0, - sat::Literal(booleans_[i], false))); + // Associate with corresponding integer variable. + const sat::Literal lit(booleans_[i], true); + GetOrCreate()->FullyEncodeVariableUsingGivenLiterals( + integers_[i], {lit.Negated(), lit}, + {IntegerValue(0), IntegerValue(1)}); } } + + ModelWithMapping::ExtractEncoding(model_proto); +} + +// ============================================================================= +// A class that detects when variables should be fully encoded by computing a +// fixed point. +// ============================================================================= + +// This class is designed to be used over a ModelWithMapping, it will ask the +// underlying Model to fully encode IntegerVariables of the model using +// constraint processors PropagateConstraintXXX(), until no such processor wants +// to fully encode a variable. The workflow is to call PropagateFullEncoding() +// on a set of constraints, then ComputeFixedPoint() to launch the fixed point +// computation. +class FullEncodingFixedPointComputer { + public: + explicit FullEncodingFixedPointComputer(ModelWithMapping* model) + : model_(model), integer_encoder_(model->GetOrCreate()) {} + + // We only add to the propagation queue variable that are fully encoded. + // Note that if a variable was already added once, we never add it again. + void ComputeFixedPoint() { + // Make sure all fully encoded variables of interest are in the queue. + for (int v = 0; v < variable_watchers_.size(); v++) { + if (!variable_watchers_[v].empty() && IsFullyEncoded(v)) { + AddVariableToPropagationQueue(v); + } + } + // Propagate until no additional variable can be fully encoded. + while (!variables_to_propagate_.empty()) { + const int variable = variables_to_propagate_.back(); + variables_to_propagate_.pop_back(); + for (const ConstraintProto* ct : variable_watchers_[variable]) { + if (ContainsKey(constraint_is_finished_, ct)) continue; + const bool finished = PropagateFullEncoding(ct); + if (finished) constraint_is_finished_.insert(ct); + } + } + } + + // Return true if the constraint is finished encoding what its wants. + bool PropagateFullEncoding(const ConstraintProto* ct) { + switch (ct->constraint_case()) { + case ConstraintProto::ConstraintProto::kElement: + return PropagateElement(ct); + case ConstraintProto::ConstraintProto::kTable: + return PropagateTable(ct); + case ConstraintProto::ConstraintProto::kAutomata: + return PropagateAutomata(ct); + case ConstraintProto::ConstraintProto::kCircuit: + return PropagateCircuit(ct); + case ConstraintProto::ConstraintProto::kInverse: + return PropagateInverse(ct); + case ConstraintProto::ConstraintProto::kLinear: + return PropagateLinear(ct); + default: + return true; + } + } + + private: + // Constraint ct is interested by (full-encoding) state of variable. + void Register(const ConstraintProto* ct, int variable) { + variable = PositiveRef(variable); + if (!ContainsKey(constraint_is_registered_, ct)) { + constraint_is_registered_.insert(ct); + } + if (variable_watchers_.size() <= variable) { + variable_watchers_.resize(variable + 1); + variable_was_added_in_to_propagate_.resize(variable + 1); + } + variable_watchers_[variable].push_back(ct); + } + + void AddVariableToPropagationQueue(int variable) { + variable = PositiveRef(variable); + if (variable_was_added_in_to_propagate_.size() <= variable) { + variable_watchers_.resize(variable + 1); + variable_was_added_in_to_propagate_.resize(variable + 1); + } + if (!variable_was_added_in_to_propagate_[variable]) { + variable_was_added_in_to_propagate_[variable] = true; + variables_to_propagate_.push_back(variable); + } + } + + // Note that we always consider a fixed variable to be fully encoded here. + const bool IsFullyEncoded(int v) { + const IntegerVariable variable = model_->Integer(v); + return model_->Get(IsFixed(variable)) || + integer_encoder_->VariableIsFullyEncoded(variable); + } + + void FullyEncode(int v) { + v = PositiveRef(v); + const IntegerVariable variable = model_->Integer(v); + if (!model_->Get(IsFixed(variable))) { + model_->Add(FullyEncodeVariable(variable)); + } + AddVariableToPropagationQueue(v); + } + + bool PropagateElement(const ConstraintProto* ct); + bool PropagateTable(const ConstraintProto* ct); + bool PropagateAutomata(const ConstraintProto* ct); + bool PropagateCircuit(const ConstraintProto* ct); + bool PropagateInverse(const ConstraintProto* ct); + bool PropagateLinear(const ConstraintProto* ct); + + ModelWithMapping* model_; + IntegerEncoder* integer_encoder_; + + std::vector variable_was_added_in_to_propagate_; + std::vector variables_to_propagate_; + std::vector> variable_watchers_; + + std::unordered_set constraint_is_finished_; + std::unordered_set constraint_is_registered_; +}; + +bool FullEncodingFixedPointComputer::PropagateElement( + const ConstraintProto* ct) { + // Index must always be full encoded. + FullyEncode(ct->element().index()); + + // If target is a constant or fully encoded, variables must be fully encoded. + const int target = ct->element().target(); + if (IsFullyEncoded(target)) { + for (const int v : ct->element().vars()) FullyEncode(v); + } + + // If all non-target variables are fully encoded, target must be too. + bool all_variables_are_fully_encoded = true; + for (const int v : ct->element().vars()) { + if (v == target) continue; + if (!IsFullyEncoded(v)) { + all_variables_are_fully_encoded = false; + break; + } + } + if (all_variables_are_fully_encoded) { + if (!IsFullyEncoded(target)) FullyEncode(target); + return true; + } + + // If some variables are not fully encoded, register on those. + if (!ContainsKey(constraint_is_registered_, ct)) { + for (const int v : ct->element().vars()) Register(ct, v); + Register(ct, target); + } + return false; +} + +// If a constraint uses its variables in a symbolic (vs. numeric) manner, +// always encode its variables. +bool FullEncodingFixedPointComputer::PropagateTable(const ConstraintProto* ct) { + if (ct->table().negated()) return true; + for (const int variable : ct->table().vars()) { + FullyEncode(variable); + } + return true; +} + +bool FullEncodingFixedPointComputer::PropagateAutomata( + const ConstraintProto* ct) { + for (const int variable : ct->automata().vars()) { + FullyEncode(variable); + } + return true; +} + +bool FullEncodingFixedPointComputer::PropagateCircuit( + const ConstraintProto* ct) { + for (const int variable : ct->circuit().nexts()) { + FullyEncode(variable); + } + return true; +} + +bool FullEncodingFixedPointComputer::PropagateInverse( + const ConstraintProto* ct) { + for (const int variable : ct->inverse().f_direct()) { + FullyEncode(variable); + } + for (const int variable : ct->inverse().f_inverse()) { + FullyEncode(variable); + } + return true; +} + +bool FullEncodingFixedPointComputer::PropagateLinear( + const ConstraintProto* ct) { + // Only act when the constraint is an equality. + if (ct->linear().domain(0) != ct->linear().domain(1)) return true; + + // If some domain is too large, abort; + if (!ContainsKey(constraint_is_registered_, ct)) { + for (const int v : ct->linear().vars()) { + const IntegerVariable var = model_->Integer(v); + IntegerTrail* integer_trail = model_->GetOrCreate(); + const IntegerValue lb = integer_trail->LowerBound(var); + const IntegerValue ub = integer_trail->UpperBound(var); + if (ub - lb > 1024) return true; // Arbitrary limit value. + } + } + + if (HasEnforcementLiteral(*ct)) { + // Fully encode x in half-reified equality b => x == constant. + const auto& vars = ct->linear().vars(); + if (vars.size() == 1) FullyEncode(vars.Get(0)); + return true; + } else { + // If all variables but one are fully encoded, + // force the last one to be fully encoded. + int variable_not_fully_encoded; + int num_fully_encoded = 0; + for (const int var : ct->linear().vars()) { + if (IsFullyEncoded(var)) + num_fully_encoded++; + else + variable_not_fully_encoded = var; + } + const int num_vars = ct->linear().vars_size(); + if (num_fully_encoded == num_vars - 1) { + FullyEncode(variable_not_fully_encoded); + return true; + } + if (num_fully_encoded == num_vars) return true; + + // Register on remaining variables if not already done. + if (!ContainsKey(constraint_is_registered_, ct)) { + for (const int var : ct->linear().vars()) { + if (!IsFullyEncoded(var)) Register(ct, var); + } + } + return false; + } } // ============================================================================= @@ -349,18 +778,22 @@ void LoadLinearConstraint(const ConstraintProto& ct, ModelWithMapping* m) { void LoadAllDiffConstraint(const ConstraintProto& ct, ModelWithMapping* m) { const std::vector vars = m->Integers(ct.all_diff().vars()); - // TODO(user): Find out which alldifferent to use depending on model. - // If some domain is too large, use bounds reasoning. + // If all variables are fully encoded and domains are not too large, use + // arc-consistent reasoning. Otherwise, use bounds-consistent reasoning. IntegerTrail* integer_trail = m->GetOrCreate(); + IntegerEncoder* encoder = m->GetOrCreate(); + int num_fully_encoded = 0; int64 max_domain_size = 0; - for (const IntegerVariable var : vars) { - IntegerValue lb = integer_trail->LowerBound(var); - IntegerValue ub = integer_trail->UpperBound(var); + for (const IntegerVariable variable : vars) { + if (encoder->VariableIsFullyEncoded(variable)) num_fully_encoded++; + + IntegerValue lb = integer_trail->LowerBound(variable); + IntegerValue ub = integer_trail->UpperBound(variable); int64 domain_size = ub.value() - lb.value(); max_domain_size = std::max(max_domain_size, domain_size); } - if (max_domain_size < 1024) { + if (num_fully_encoded == vars.size() && max_domain_size < 1024) { m->Add(AllDifferentBinary(vars)); m->Add(AllDifferentAC(vars)); } else { @@ -428,13 +861,14 @@ void LoadCumulativeConstraint(const ConstraintProto& ct, ModelWithMapping* m) { // TODO(user): Be more efficient when the element().vars() are constants. // Ideally we should avoid creating them as integer variable... -void LoadElementConstraint(const ConstraintProto& ct, ModelWithMapping* m) { +void LoadElementConstraintBounds(const ConstraintProto& ct, + ModelWithMapping* m) { const IntegerVariable index = m->Integer(ct.element().index()); const IntegerVariable target = m->Integer(ct.element().target()); const std::vector vars = m->Integers(ct.element().vars()); IntegerTrail* integer_trail = m->GetOrCreate(); - if (integer_trail->LowerBound(index) == integer_trail->UpperBound(index)) { + if (m->Get(IsFixed(index))) { const int64 value = integer_trail->LowerBound(index).value(); m->Add(Equality(target, vars[value])); return; @@ -461,6 +895,115 @@ void LoadElementConstraint(const ConstraintProto& ct, ModelWithMapping* m) { m->Add(PartialIsOneOfVar(target, possible_vars, selectors)); } +// Arc-Consistent encoding of the element constraint as SAT clauses. +// The constraint enforces vars[index] == target. +// +// The AC propagation can be decomposed in three rules: +// Rule 1: dom(index) == i => dom(vars[i]) == dom(target). +// Rule 2: dom(target) \subseteq \Union_{i \in dom(index)} dom(vars[i]). +// Rule 3: dom(index) \subseteq { i | |dom(vars[i]) \inter dom(target)| > 0 }. +// +// We encode this in a way similar to the table constraint, except that the +// set of admissible tuples is not explicit. +// First, we add Booleans selected[i][value] <=> (index == i /\ vars[i] == +// value). Rules 1 and 2 are enforced by target == value <=> \Or_{i} +// selected[i][value]. Rule 3 is enforced by index == i <=> \Or_{value} +// selected[i][value]. +void LoadElementConstraintAC(const ConstraintProto& ct, ModelWithMapping* m) { + const IntegerVariable index = m->Integer(ct.element().index()); + const IntegerVariable target = m->Integer(ct.element().target()); + const std::vector vars = m->Integers(ct.element().vars()); + + IntegerTrail* integer_trail = m->GetOrCreate(); + if (m->Get(IsFixed(index))) { + const int64 value = integer_trail->LowerBound(index).value(); + m->Add(Equality(target, vars[value])); + return; + } + + // Make map target_value -> literal. + if (m->Get(IsFixed(target))) { + return LoadElementConstraintBounds(ct, m); + } + std::unordered_map target_map; + const auto target_encoding = m->Add(FullyEncodeVariable(target)); + for (const auto literal_value : target_encoding) { + target_map[literal_value.value] = literal_value.literal; + } + + // For i \in index and value in vars[i], make (index == i /\ vars[i] == value) + // literals and store them by value in vectors. + std::unordered_map> value_to_literals; + const auto index_encoding = m->Add(FullyEncodeVariable(index)); + for (const auto literal_value : index_encoding) { + const int i = literal_value.value.value(); + const Literal i_lit = literal_value.literal; + + // Special case where vars[i] == value /\ i_lit is actually i_lit. + if (m->Get(IsFixed(vars[i]))) { + value_to_literals[integer_trail->LowerBound(vars[i])].push_back(i_lit); + continue; + } + + const auto var_encoding = m->Add(FullyEncodeVariable(vars[i])); + std::vector var_selected_literals; + for (const auto var_literal_value : var_encoding) { + const IntegerValue value = var_literal_value.value; + const Literal var_is_value = var_literal_value.literal; + + if (!ContainsKey(target_map, value)) { + // No need to add to value_to_literals, selected[i][value] is always + // false. + m->Add(Implication(i_lit, var_is_value.Negated())); + continue; + } + + const Literal var_is_value_and_selected = + Literal(m->Add(NewBooleanVariable()), true); + m->Add(ReifiedBoolAnd({i_lit, var_is_value}, var_is_value_and_selected)); + value_to_literals[value].push_back(var_is_value_and_selected); + var_selected_literals.push_back(var_is_value_and_selected); + } + // index == i <=> \Or_{value} selected[i][value]. + m->Add(ReifiedBoolOr(var_selected_literals, i_lit)); + } + + // target == value <=> \Or_{i \in index} (vars[i] == value /\ index == i). + for (const auto& entry : target_map) { + const IntegerValue value = entry.first; + const Literal target_is_value = entry.second; + + if (!ContainsKey(value_to_literals, value)) { + m->Add(ClauseConstraint({target_is_value.Negated()})); + } else { + m->Add(ReifiedBoolOr(value_to_literals[value], target_is_value)); + } + } +} + +void LoadElementConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + IntegerEncoder* encoder = m->GetOrCreate(); + + const int target = ct.element().target(); + const IntegerVariable target_var = m->Integer(target); + const bool target_is_AC = m->Get(IsFixed(target_var)) || + encoder->VariableIsFullyEncoded(target_var); + + int num_AC_variables = 0; + const int num_vars = ct.element().vars().size(); + for (const int v : ct.element().vars()) { + IntegerVariable variable = m->Integer(v); + const bool is_full = + m->Get(IsFixed(variable)) || encoder->VariableIsFullyEncoded(variable); + if (is_full) num_AC_variables++; + } + if (target_is_AC || num_AC_variables >= num_vars - 1) { + LoadElementConstraintAC(ct, m); + } else { + LoadElementConstraintBounds(ct, m); + } +} + void LoadTableConstraint(const ConstraintProto& ct, ModelWithMapping* m) { const std::vector vars = m->Integers(ct.table().vars()); const std::vector values = ValuesFromProto(ct.table().values()); @@ -510,7 +1053,7 @@ void LoadCircuitConstraint(const ConstraintProto& ct, ModelWithMapping* m) { graph[i][m->Get(Value(nexts[i]))] = kTrueLiteralIndex; continue; } else { - const auto encoding = m->Add(FullyEncodeVariable((nexts[i]))); + const auto encoding = m->Add(FullyEncodeVariable(nexts[i])); for (const auto& entry : encoding) { graph[i][entry.value.value()] = entry.literal.Index(); } @@ -519,6 +1062,70 @@ void LoadCircuitConstraint(const ConstraintProto& ct, ModelWithMapping* m) { m->Add(SubcircuitConstraint(graph)); } +void LoadInverseConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + // Fully encode both arrays of variables, encode the constraint using Boolean + // equalities: f_direct[i] == j <=> f_inverse[j] == i. + const int num_variables = ct.inverse().f_direct_size(); + CHECK_EQ(num_variables, ct.inverse().f_inverse_size()); + const std::vector direct = + m->Integers(ct.inverse().f_direct()); + const std::vector inverse = + m->Integers(ct.inverse().f_inverse()); + + // Fill LiteralIndex matrices. + std::vector> matrix_direct( + num_variables, + std::vector(num_variables, kFalseLiteralIndex)); + + std::vector> matrix_inverse( + num_variables, + std::vector(num_variables, kFalseLiteralIndex)); + + auto fill_matrix = [&m](std::vector>& matrix, + const std::vector& variables) { + const int num_variables = variables.size(); + for (int i = 0; i < num_variables; i++) { + if (m->Get(IsFixed(variables[i]))) { + matrix[i][m->Get(Value(variables[i]))] = kTrueLiteralIndex; + } else { + const auto encoding = m->Add(FullyEncodeVariable(variables[i])); + for (const auto literal_value : encoding) { + matrix[i][literal_value.value.value()] = + literal_value.literal.Index(); + } + } + } + }; + + fill_matrix(matrix_direct, direct); + fill_matrix(matrix_inverse, inverse); + + // matrix_direct should be the transpose of matrix_inverse. + for (int i = 0; i < num_variables; i++) { + for (int j = 0; j < num_variables; j++) { + LiteralIndex l_ij = matrix_direct[i][j]; + LiteralIndex l_ji = matrix_inverse[j][i]; + if (l_ij >= 0 && l_ji >= 0) { + // l_ij <=> l_ji. + m->Add(ClauseConstraint({Literal(l_ij), Literal(l_ji).Negated()})); + m->Add(ClauseConstraint({Literal(l_ij).Negated(), Literal(l_ji)})); + } else if (l_ij < 0 && l_ji < 0) { + // Problem infeasible if l_ij != l_ji, otherwise nothing to add. + if (l_ij != l_ji) { + m->Add(ClauseConstraint({})); + return; + } + } else { + // One of the LiteralIndex is fixed, let it be l_ij. + if (l_ij > l_ji) std::swap(l_ij, l_ji); + const Literal lit = Literal(l_ji); + m->Add(ClauseConstraint( + {l_ij == kFalseLiteralIndex ? lit.Negated() : lit})); + } + } + } +} + // Makes the std::string fit in one line by cutting it in the middle if necessary. std::string Summarize(const std::string& input) { if (input.size() < 105) return input; @@ -715,6 +1322,9 @@ bool LoadConstraint(const ConstraintProto& ct, ModelWithMapping* m) { case ConstraintProto::ConstraintProto::kCircuit: LoadCircuitConstraint(ct, m); return true; + case ConstraintProto::ConstraintProto::kInverse: + LoadInverseConstraint(ct, m); + return true; default: return false; } @@ -980,6 +1590,8 @@ const std::function ConstructSearchStrategy( }; } +// TODO(user): Also consider linear inequality where the objective is minimized +// in the good direction. void ExtractLinearObjective(const CpModelProto& model_proto, ModelWithMapping* m, std::vector* linear_vars, @@ -1040,10 +1652,8 @@ void ExtractLinearObjective(const CpModelProto& model_proto, } } -} // namespace - -CpSolverResponse SolveCpModelWithoutPresolve(const CpModelProto& model_proto, - Model* model) { +CpSolverResponse SolveCpModelInternal(const CpModelProto& model_proto, + Model* model) { // Timing. WallTimer wall_timer; UserTimer user_timer; @@ -1054,6 +1664,9 @@ CpSolverResponse SolveCpModelWithoutPresolve(const CpModelProto& model_proto, CpSolverResponse response; response.set_status(CpSolverStatus::MODEL_INVALID); + // We will add them all at once after model_proto is loaded. + model->GetOrCreate()->DisableImplicationBetweenLiteral(); + // Instanciate all the needed variables. const VariableUsage usage = ComputeVariableUsage(model_proto); ModelWithMapping m(model_proto, usage, model); @@ -1061,10 +1674,23 @@ CpSolverResponse SolveCpModelWithoutPresolve(const CpModelProto& model_proto, const SatParameters& parameters = model->GetOrCreate()->parameters(); + // Force some variables to be fully encoded. + FullEncodingFixedPointComputer fixpoint(&m); + for (const ConstraintProto& ct : model_proto.constraints()) { + fixpoint.PropagateFullEncoding(&ct); + } + fixpoint.ComputeFixedPoint(); + // Load the constraints. std::set unsupported_types; Trail* trail = model->GetOrCreate(); + int num_ignored_constraints = 0; for (const ConstraintProto& ct : model_proto.constraints()) { + if (m.IgnoreConstraint(&ct)) { + ++num_ignored_constraints; + continue; + } + const int old_num_fixed = trail->Index(); if (!LoadConstraint(ct, &m)) { unsupported_types.insert(ConstraintCaseName(ct.constraint_case())); @@ -1086,6 +1712,9 @@ CpSolverResponse SolveCpModelWithoutPresolve(const CpModelProto& model_proto, break; } } + if (num_ignored_constraints > 0) { + VLOG(1) << num_ignored_constraints << " constraints where skipped."; + } if (!unsupported_types.empty()) { VLOG(1) << "There is unsuported constraints types in this model: "; for (const std::string& type : unsupported_types) { @@ -1099,11 +1728,22 @@ CpSolverResponse SolveCpModelWithoutPresolve(const CpModelProto& model_proto, AddLPConstraints(model_proto, &m); } + model->GetOrCreate() + ->AddAllImplicationsBetweenAssociatedLiterals(); + // Initialize the search strategy function. std::function next_decision; if (model_proto.search_strategy().empty()) { std::vector decisions; for (const int i : usage.integers) { + if (!model_proto.objectives().empty()) { + // Make sure we try to fix the objective to its lowest value first. + const int obj = model_proto.objectives(0).objective_var(); + if (PositiveRef(i) == PositiveRef(obj)) { + decisions.push_back(m.Integer(obj)); + continue; + } + } decisions.push_back(m.Integer(i)); } next_decision = FirstUnassignedVarAtItsMinHeuristic(decisions, model); @@ -1160,9 +1800,15 @@ CpSolverResponse SolveCpModelWithoutPresolve(const CpModelProto& model_proto, std::vector linear_vars; std::vector linear_coeffs; ExtractLinearObjective(model_proto, &m, &linear_vars, &linear_coeffs); - status = MinimizeWithCoreAndLazyEncoding( - VLOG_IS_ON(1), objective_var, linear_vars, linear_coeffs, - next_decision, solution_observer, model); + if (parameters.optimize_with_max_hs()) { + status = MinimizeWithHittingSetAndLazyEncoding( + VLOG_IS_ON(1), objective_var, linear_vars, linear_coeffs, + next_decision, solution_observer, model); + } else { + status = MinimizeWithCoreAndLazyEncoding( + VLOG_IS_ON(1), objective_var, linear_vars, linear_coeffs, + next_decision, solution_observer, model); + } } else { status = MinimizeIntegerVariableWithLinearScanAndLazyEncoding( /*log_info=*/false, objective_var, next_decision, solution_observer, @@ -1219,6 +1865,24 @@ CpSolverResponse SolveCpModelWithoutPresolve(const CpModelProto& model_proto, return response; } +} // namespace + +CpSolverResponse SolveCpModelWithoutPresolve(const CpModelProto& model_proto, + Model* model) { + // Validate model_proto. + // TODO(user): provide an option to skip this step for speed? + { + const std::string error = ValidateCpModel(model_proto); + if (!error.empty()) { + VLOG(1) << error; + CpSolverResponse response; + response.set_status(CpSolverStatus::MODEL_INVALID); + return response; + } + } + return SolveCpModelInternal(model_proto, model); +} + CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { // Validate model_proto. // TODO(user): provide an option to skip this step for speed? @@ -1273,7 +1937,7 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { postsolve_model.Add(operations_research::sat::NewSatParameters(params)); } const CpSolverResponse postsolve_response = - SolveCpModelWithoutPresolve(mapping_proto, &postsolve_model); + SolveCpModelInternal(mapping_proto, &postsolve_model); CHECK_EQ(postsolve_response.status(), CpSolverStatus::MODEL_SAT); response.clear_solution(); response.clear_solution_lower_bounds(); diff --git a/ortools/sat/cp_model_utils.cc b/ortools/sat/cp_model_utils.cc index e9ef00645b..52e21b4796 100644 --- a/ortools/sat/cp_model_utils.cc +++ b/ortools/sat/cp_model_utils.cc @@ -73,6 +73,10 @@ void AddReferencesUsedByConstraint(const ConstraintProto& ct, case ConstraintProto::ConstraintCase::kCircuit: AddIndices(ct.circuit().nexts(), &output->variables); break; + case ConstraintProto::ConstraintCase::kInverse: + AddIndices(ct.inverse().f_direct(), &output->variables); + AddIndices(ct.inverse().f_inverse(), &output->variables); + break; case ConstraintProto::ConstraintCase::kTable: AddIndices(ct.table().vars(), &output->variables); break; @@ -145,6 +149,8 @@ void ApplyToAllLiteralIndices(const std::function& f, break; case ConstraintProto::ConstraintCase::kCircuit: break; + case ConstraintProto::ConstraintCase::kInverse: + break; case ConstraintProto::ConstraintCase::kTable: break; case ConstraintProto::ConstraintCase::kAutomata: @@ -205,6 +211,10 @@ void ApplyToAllVariableIndices(const std::function& f, case ConstraintProto::ConstraintCase::kCircuit: APPLY_TO_REPEATED_FIELD(circuit, nexts); break; + case ConstraintProto::ConstraintCase::kInverse: + APPLY_TO_REPEATED_FIELD(inverse, f_direct); + APPLY_TO_REPEATED_FIELD(inverse, f_inverse); + break; case ConstraintProto::ConstraintCase::kTable: APPLY_TO_REPEATED_FIELD(table, vars); break; @@ -256,6 +266,8 @@ void ApplyToAllIntervalIndices(const std::function& f, break; case ConstraintProto::ConstraintCase::kCircuit: break; + case ConstraintProto::ConstraintCase::kInverse: + break; case ConstraintProto::ConstraintCase::kTable: break; case ConstraintProto::ConstraintCase::kAutomata: @@ -306,6 +318,8 @@ std::string ConstraintCaseName(ConstraintProto::ConstraintCase constraint_case) return "kElement"; case ConstraintProto::ConstraintCase::kCircuit: return "kCircuit"; + case ConstraintProto::ConstraintCase::kInverse: + return "kInverse"; case ConstraintProto::ConstraintCase::kTable: return "kTable"; case ConstraintProto::ConstraintCase::kAutomata: diff --git a/ortools/sat/cumulative.cc b/ortools/sat/cumulative.cc index 3e8a25ddde..8f59277284 100644 --- a/ortools/sat/cumulative.cc +++ b/ortools/sat/cumulative.cc @@ -54,7 +54,7 @@ std::function Cumulative( // At this point, we know that the duration variable is not fixed. const Literal size_condition = - encoder->CreateAssociatedLiteral(IntegerLiteral::GreaterOrEqual( + encoder->GetOrCreateAssociatedLiteral(IntegerLiteral::GreaterOrEqual( intervals->SizeVar(vars[i]), IntegerValue(1))); if (intervals->IsOptional(vars[i])) { @@ -203,11 +203,11 @@ std::function CumulativeTimeDecomposition( } // Task t overlaps time. - consume_condition.push_back(encoder->CreateAssociatedLiteral( + consume_condition.push_back(encoder->GetOrCreateAssociatedLiteral( IntegerLiteral::LowerOrEqual(start_vars[t], IntegerValue(time)))); - consume_condition.push_back( - encoder->CreateAssociatedLiteral(IntegerLiteral::GreaterOrEqual( - end_vars[t], IntegerValue(time + 1)))); + consume_condition.push_back(encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(end_vars[t], + IntegerValue(time + 1)))); model->Add(ReifiedBoolAnd(consume_condition, consume)); diff --git a/ortools/sat/integer.cc b/ortools/sat/integer.cc index 888acd12aa..31d1fd2ae4 100644 --- a/ortools/sat/integer.cc +++ b/ortools/sat/integer.cc @@ -28,56 +28,55 @@ std::vector NegationOf( return result; } -void IntegerEncoder::FullyEncodeVariable( - IntegerVariable i_var, const std::vector& domain) { - CHECK(!VariableIsFullyEncoded(i_var)); +void IntegerEncoder::FullyEncodeVariable(IntegerVariable var) { + CHECK(!VariableIsFullyEncoded(var)); CHECK_EQ(0, sat_solver_->CurrentDecisionLevel()); - CHECK(!domain.empty()); // UNSAT problem. We don't deal with that here. + CHECK(!(*domains_)[var].empty()); // UNSAT. We don't deal with that here. std::vector values; - for (const ClosedInterval interval : domain) { + for (const ClosedInterval interval : (*domains_)[var]) { for (IntegerValue v(interval.start); v <= interval.end; ++v) { values.push_back(v); CHECK_LT(values.size(), 100000) << "Domain too large for full encoding."; } } - // TODO(user): This case is annoying, not sure yet how to best fix the - // variable. There is certainly no need to create a Boolean variable, but - // one needs to talk to IntegerTrail to fix the variable and we don't want - // the encoder to depend on this. So for now we fail here and it is up to - // the caller to deal with this case. + // TODO(user): This case is annoying, so for now we want the caller to deal + // with it, hence the CHECK. We do not want to create a fixed Boolean + // variable, but we also do not want to complexify the API of + // FullDomainEncoding(). CHECK_NE(values.size(), 1); std::vector literals; if (values.size() == 2) { - const BooleanVariable var = sat_solver_->NewBooleanVariable(); - literals.push_back(Literal(var, true)); - literals.push_back(Literal(var, false)); + literals.push_back(GetOrCreateAssociatedLiteral( + IntegerLiteral::LowerOrEqual(var, values[0]))); + literals.push_back(literals.back().Negated()); } else { for (int i = 0; i < values.size(); ++i) { - const BooleanVariable var = sat_solver_->NewBooleanVariable(); - literals.push_back(Literal(var, true)); + const std::pair key{var, values[i]}; + if (ContainsKey(equality_to_associated_literal_, key)) { + literals.push_back(equality_to_associated_literal_[key]); + } else { + literals.push_back(Literal(sat_solver_->NewBooleanVariable(), true)); + } } } - return FullyEncodeVariableUsingGivenLiterals(i_var, literals, values); + return FullyEncodeVariableUsingGivenLiterals(var, literals, values); } void IntegerEncoder::FullyEncodeVariableUsingGivenLiterals( - IntegerVariable i_var, const std::vector& literals, + IntegerVariable var, const std::vector& literals, const std::vector& values) { - CHECK(!VariableIsFullyEncoded(i_var)); + CHECK(!VariableIsFullyEncoded(var)); CHECK(!literals.empty()); CHECK_NE(literals.size(), 1); // Sort the literals by values. std::vector encoding; - std::vector cst; + encoding.reserve(values.size()); for (int i = 0; i < values.size(); ++i) { - const Literal literal = literals[i]; - const IntegerValue value = values[i]; - encoding.push_back({value, literal}); - cst.push_back(LiteralWithCoeff(literal, Coefficient(1))); + encoding.push_back({values[i], literals[i]}); } std::sort(encoding.begin(), encoding.end()); @@ -87,9 +86,9 @@ void IntegerEncoder::FullyEncodeVariableUsingGivenLiterals( // literal pushed by Enqueue() (we look at the domain there). for (int i = 0; i + 1 < encoding.size(); ++i) { const IntegerLiteral i_lit = - IntegerLiteral::LowerOrEqual(i_var, encoding[i].value); + IntegerLiteral::LowerOrEqual(var, encoding[i].value); const IntegerLiteral i_lit_negated = - IntegerLiteral::GreaterOrEqual(i_var, encoding[i + 1].value); + IntegerLiteral::GreaterOrEqual(var, encoding[i + 1].value); if (i == 0) { // Special case for the start. HalfAssociateGivenLiteral(i_lit, encoding[0].literal); @@ -101,31 +100,26 @@ void IntegerEncoder::FullyEncodeVariableUsingGivenLiterals( } else { // Normal case. if (!LiteralIsAssociated(i_lit) || !LiteralIsAssociated(i_lit_negated)) { - const BooleanVariable new_var = sat_solver_->NewBooleanVariable(); - const Literal literal(new_var, true); - HalfAssociateGivenLiteral(i_lit, literal); - HalfAssociateGivenLiteral(i_lit_negated, literal.Negated()); + const BooleanVariable b = sat_solver_->NewBooleanVariable(); + HalfAssociateGivenLiteral(i_lit, Literal(b, true)); + HalfAssociateGivenLiteral(i_lit_negated, Literal(b, false)); } } } // Now that all literals are created, wire them together using // (X == v) <=> (X >= v) and (X <= v). + // + // TODO(user): this is currently in O(n^2) which is potentially bad even if + // we do it only once per variable. for (int i = 1; i + 1 < encoding.size(); ++i) { - const ValueLiteralPair pair = encoding[i]; - const Literal a(GetAssociatedLiteral( - IntegerLiteral::GreaterOrEqual(i_var, pair.value))); - const Literal b( - GetAssociatedLiteral(IntegerLiteral::LowerOrEqual(i_var, pair.value))); - sat_solver_->AddBinaryClause(a, pair.literal.Negated()); - sat_solver_->AddBinaryClause(b, pair.literal.Negated()); - sat_solver_->AddProblemClause({a.Negated(), b.Negated(), pair.literal}); + AssociateToIntegerEqualValue(encoding[i].literal, var, encoding[i].value); } - full_encoding_index_[i_var] = full_encoding_.size(); + full_encoding_index_[var] = full_encoding_.size(); full_encoding_.push_back(encoding); // copy because we need it below. - // Deal with NegationOf(i_var). + // Deal with NegationOf(var). // // TODO(user): This seems a bit wasted, but it does simplify the code at a // somehow small cost. @@ -133,7 +127,7 @@ void IntegerEncoder::FullyEncodeVariableUsingGivenLiterals( for (auto& entry : encoding) { entry.value = -entry.value; // Reverse the value. } - full_encoding_index_[NegationOf(i_var)] = full_encoding_.size(); + full_encoding_index_[NegationOf(var)] = full_encoding_.size(); full_encoding_.push_back(std::move(encoding)); } @@ -149,23 +143,25 @@ void IntegerEncoder::AddImplications(IntegerLiteral i_lit, Literal literal) { encoding_by_var_[IntegerVariable(i_lit.var)]; CHECK(!ContainsKey(map_ref, i_lit.bound)); - auto after_it = map_ref.lower_bound(i_lit.bound); - if (after_it != map_ref.end()) { - // Literal(after) => literal - if (sat_solver_->CurrentDecisionLevel() == 0) { - sat_solver_->AddBinaryClause(after_it->second.Negated(), literal); - } else { - sat_solver_->AddBinaryClauseDuringSearch(after_it->second.Negated(), - literal); + if (add_implications_) { + auto after_it = map_ref.lower_bound(i_lit.bound); + if (after_it != map_ref.end()) { + // Literal(after) => literal + if (sat_solver_->CurrentDecisionLevel() == 0) { + sat_solver_->AddBinaryClause(after_it->second.Negated(), literal); + } else { + sat_solver_->AddBinaryClauseDuringSearch(after_it->second.Negated(), + literal); + } } - } - if (after_it != map_ref.begin()) { - // literal => Literal(before) - if (sat_solver_->CurrentDecisionLevel() == 0) { - sat_solver_->AddBinaryClause(literal.Negated(), (--after_it)->second); - } else { - sat_solver_->AddBinaryClauseDuringSearch(literal.Negated(), - (--after_it)->second); + if (after_it != map_ref.begin()) { + // literal => Literal(before) + if (sat_solver_->CurrentDecisionLevel() == 0) { + sat_solver_->AddBinaryClause(literal.Negated(), (--after_it)->second); + } else { + sat_solver_->AddBinaryClauseDuringSearch(literal.Negated(), + (--after_it)->second); + } } } @@ -173,22 +169,112 @@ void IntegerEncoder::AddImplications(IntegerLiteral i_lit, Literal literal) { map_ref[i_lit.bound] = literal; } +void IntegerEncoder::AddAllImplicationsBetweenAssociatedLiterals() { + CHECK_EQ(0, sat_solver_->CurrentDecisionLevel()); + add_implications_ = true; + for (const std::map& encoding : encoding_by_var_) { + LiteralIndex previous = kNoLiteralIndex; + for (const auto value_literal : encoding) { + const Literal lit = value_literal.second; + if (previous != kNoLiteralIndex) { + // lit => previous. + sat_solver_->AddBinaryClause(lit.Negated(), Literal(previous)); + } + previous = lit.Index(); + } + } +} + +std::pair IntegerEncoder::Canonicalize( + IntegerLiteral i_lit) const { + const IntegerVariable var(i_lit.var); + IntegerValue after(i_lit.bound); + IntegerValue before(i_lit.bound - 1); + CHECK_GE(before, (*domains_)[var].front().start); + CHECK_LE(after, (*domains_)[var].back().end); + int64 previous = kint64min; + for (const ClosedInterval& interval : (*domains_)[var]) { + if (before > previous && before < interval.start) before = previous; + if (after > previous && after < interval.start) after = interval.start; + if (after <= interval.end) break; + previous = interval.end; + } + return {IntegerLiteral::GreaterOrEqual(var, after), + IntegerLiteral::LowerOrEqual(var, before)}; +} + +Literal IntegerEncoder::GetOrCreateAssociatedLiteral(IntegerLiteral i_lit) { + const IntegerLiteral new_lit = Canonicalize(i_lit).first; + if (new_lit.var < encoding_by_var_.size()) { + const std::map& encoding = + encoding_by_var_[IntegerVariable(new_lit.var)]; + const auto it = encoding.find(new_lit.bound); + if (it != encoding.end()) return it->second; + } -Literal IntegerEncoder::CreateAssociatedLiteral(IntegerLiteral i_lit) { - CHECK(!LiteralIsAssociated(i_lit)); ++num_created_variables_; const BooleanVariable new_var = sat_solver_->NewBooleanVariable(); const Literal literal(new_var, true); - AssociateGivenLiteral(i_lit, literal); + AssociateToIntegerLiteral(literal, new_lit); return literal; } -void IntegerEncoder::AssociateGivenLiteral(IntegerLiteral i_lit, - Literal literal) { - // TODO(user): convert it to a "domain compatible one". - CHECK(!LiteralIsAssociated(i_lit)); - HalfAssociateGivenLiteral(i_lit, literal); - HalfAssociateGivenLiteral(i_lit.Negated(), literal.Negated()); +void IntegerEncoder::AssociateToIntegerLiteral(Literal literal, + IntegerLiteral i_lit) { + const auto& domain = (*domains_)[i_lit.Var()]; + if (i_lit.Bound() <= domain.front().start) { + sat_solver_->AddUnitClause(literal); + } else if (i_lit.Bound() > domain.back().end) { + sat_solver_->AddUnitClause(literal.Negated()); + } else { + const auto pair = Canonicalize(i_lit); + HalfAssociateGivenLiteral(pair.first, literal); + HalfAssociateGivenLiteral(pair.second, literal.Negated()); + } +} + +void IntegerEncoder::AssociateToIntegerEqualValue(Literal literal, + IntegerVariable var, + IntegerValue value) { + const auto& domain = (*domains_)[var]; + const std::pair key{var, value}; + if (!SortedDisjointIntervalsContain(domain, value.value())) { + sat_solver_->AddUnitClause(literal.Negated()); + } else if (value == domain.front().start && value == domain.back().end) { + sat_solver_->AddUnitClause(literal); // fixed variable. + } else if (value == domain.front().start) { + AssociateToIntegerLiteral(literal, + IntegerLiteral::LowerOrEqual(var, value)); + if (!ContainsKey(equality_to_associated_literal_, key)) { + equality_to_associated_literal_[key] = literal; + } + } else if (value == domain.back().end) { + AssociateToIntegerLiteral(literal, + IntegerLiteral::GreaterOrEqual(var, value)); + if (!ContainsKey(equality_to_associated_literal_, key)) { + equality_to_associated_literal_[key] = literal; + } + } else { + // If this key is already associated, make the two literals equal. + if (ContainsKey(equality_to_associated_literal_, key)) { + const Literal representative = equality_to_associated_literal_[key]; + if (representative != literal) { + sat_solver_->AddBinaryClause(literal, representative.Negated()); + sat_solver_->AddBinaryClause(literal.Negated(), representative); + } + return; + } + equality_to_associated_literal_[key] = literal; + + // (var == value) <=> (var >= value) and (var <= value). + const Literal a(GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, value))); + const Literal b( + GetOrCreateAssociatedLiteral(IntegerLiteral::LowerOrEqual(var, value))); + sat_solver_->AddBinaryClause(a, literal.Negated()); + sat_solver_->AddBinaryClause(b, literal.Negated()); + sat_solver_->AddProblemClause({a.Negated(), b.Negated(), literal}); + } } // TODO(user): The hard constraints we add between associated literals seems to @@ -230,16 +316,6 @@ LiteralIndex IntegerEncoder::GetAssociatedLiteral(IntegerLiteral i) { return result->second.Index(); } -Literal IntegerEncoder::GetOrCreateAssociatedLiteral(IntegerLiteral i_lit) { - if (i_lit.var < encoding_by_var_.size()) { - const std::map& encoding = - encoding_by_var_[IntegerVariable(i_lit.var)]; - const auto it = encoding.find(i_lit.bound); - if (it != encoding.end()) return it->second; - } - return CreateAssociatedLiteral(i_lit); -} - LiteralIndex IntegerEncoder::SearchForLiteralAtOrBefore( IntegerLiteral i) const { // We take the element before the upper_bound() which is either the encoding @@ -325,6 +401,7 @@ IntegerVariable IntegerTrail::AddIntegerVariable(IntegerValue lower_bound, is_ignored_literals_.push_back(kNoLiteralIndex); vars_.push_back({lower_bound, static_cast(integer_trail_.size())}); integer_trail_.push_back({lower_bound, i.value()}); + domains_->push_back({{lower_bound.value(), upper_bound.value()}}); // TODO(user): the is_ignored_literals_ Booleans are currently always the same // for a variable and its negation. So it may be better not to store it twice @@ -333,6 +410,7 @@ IntegerVariable IntegerTrail::AddIntegerVariable(IntegerValue lower_bound, is_ignored_literals_.push_back(kNoLiteralIndex); vars_.push_back({-upper_bound, static_cast(integer_trail_.size())}); integer_trail_.push_back({-upper_bound, NegationOf(i).value()}); + domains_->push_back({{-upper_bound.value(), -lower_bound.value()}}); for (SparseBitset* w : watchers_) { w->Resize(NumIntegerVariables()); @@ -346,25 +424,7 @@ IntegerVariable IntegerTrail::AddIntegerVariable( CHECK(IntervalsAreSortedAndDisjoint(domain)); const IntegerVariable var = AddIntegerVariable( IntegerValue(domain.front().start), IntegerValue(domain.back().end)); - - // We only stores the vector of closed intervals for "complex" domain. - if (domain.size() > 1) { - var_to_current_lb_interval_index_.Set(var, all_intervals_.size()); - for (const ClosedInterval interval : domain) { - all_intervals_.push_back(interval); - } - InsertOrDie(&var_to_end_interval_index_, var, all_intervals_.size()); - - // Copy for the negated variable. - var_to_current_lb_interval_index_.Set(NegationOf(var), - all_intervals_.size()); - for (const ClosedInterval interval : ::gtl::reversed_view(domain)) { - all_intervals_.push_back({-interval.end, -interval.start}); - } - InsertOrDie(&var_to_end_interval_index_, NegationOf(var), - all_intervals_.size()); - } - + CHECK(UpdateInitialDomain(var, domain)); return var; } @@ -388,12 +448,21 @@ std::vector IntegerTrail::InitialVariableDomain( bool IntegerTrail::UpdateInitialDomain(IntegerVariable var, std::vector domain) { - domain = - IntersectionOfSortedDisjointIntervals(domain, InitialVariableDomain(var)); + // TODO(user): A bit inefficient as this recreate a vector for no reason. + const std::vector old_domain = InitialVariableDomain(var); + if (old_domain == domain) return true; + + domain = IntersectionOfSortedDisjointIntervals(domain, old_domain); if (domain.empty()) return false; - // TODO(user): A bit inefficient as this recreate a vector for no reason. - if (domain == InitialVariableDomain(var)) return true; + // TODO(user): we currently keep the domain in domains_ but also in + // all_intervals_ for domains with holes. Try to consolidate both structures. + (*domains_)[var].assign(domain.begin(), domain.end()); + { + std::vector temp = + NegationOfSortedDisjointIntervals(domain); + (*domains_)[NegationOf(var)].assign(temp.begin(), temp.end()); + } CHECK(Enqueue( IntegerLiteral::GreaterOrEqual(var, IntegerValue(domain.front().start)), @@ -432,7 +501,7 @@ bool IntegerTrail::UpdateInitialDomain(IntegerVariable var, int num_fixed = 0; const auto encoding = encoder_->FullDomainEncoding(var); for (const auto pair : encoding) { - while (pair.value > domain[i].end && i < domain.size()) ++i; + while (i < domain.size() && pair.value > domain[i].end) ++i; if (i == domain.size() || pair.value < domain[i].start) { // Set the literal to false; ++num_fixed; @@ -597,7 +666,11 @@ bool IntegerTrail::Enqueue(IntegerLiteral i_lit, const IntegerVariable var(i_lit.var); // If the domain of var is not a single intervals and i_lit.bound fall into a - // "hole", we increase it to the next possible value. + // "hole", we increase it to the next possible value. This ensure that we + // never Enqueue() non-canonical literals. See also Canonicalize(). + // + // Note: The literals in the reason are not necessarily canonical, but then + // we always map these to enqueued literals during conflict resolution. { int interval_index = FindWithDefault(var_to_current_lb_interval_index_, var, -1); @@ -672,6 +745,13 @@ bool IntegerTrail::Enqueue(IntegerLiteral i_lit, } // Enqueue the strongest associated Boolean literal implied by this one. + // Because we linked all such literal with implications, all the one before + // will be propagated by the SAT solver. + // + // TODO(user): It might be simply better and more efficient to simply enqueue + // all of them here. We have also more liberty to choose the explanation we + // want. A drawback might be that the implications might not be used in the + // binary conflict minimization algo. const LiteralIndex literal_index = encoder_->SearchForLiteralAtOrBefore(i_lit); if (literal_index != kNoLiteralIndex) { diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index ab08ca9999..934d0dfd81 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -153,6 +153,17 @@ inline std::ostream& operator<<(std::ostream& os, IntegerLiteral i_lit) { using InlinedIntegerLiteralVector = gtl::InlinedVector; +// A singleton that holds the INITIAL integer variable domains. +struct IntegerDomains + : public ITIVector> { + static IntegerDomains* CreateInModel(Model* model) { + IntegerDomains* domains = new IntegerDomains(); + model->TakeOwnership(domains); + return domains; + } +}; + // Each integer variable x will be associated with a set of literals encoding // (x >= v) for some values of v. This class maintains the relationship between // the integer variables and such literals which can be created by a call to @@ -173,27 +184,22 @@ using InlinedIntegerLiteralVector = gtl::InlinedVector; // though. class IntegerEncoder { public: - explicit IntegerEncoder(SatSolver* sat_solver) - : sat_solver_(sat_solver), num_created_variables_(0) {} + IntegerEncoder(SatSolver* sat_solver, IntegerDomains* domains) + : sat_solver_(sat_solver), domains_(domains), num_created_variables_(0) {} ~IntegerEncoder() { VLOG(1) << "#variables created = " << num_created_variables_; } static IntegerEncoder* CreateInModel(Model* model) { - SatSolver* sat_solver = model->GetOrCreate(); - IntegerEncoder* encoder = new IntegerEncoder(sat_solver); + IntegerEncoder* encoder = new IntegerEncoder( + model->GetOrCreate(), model->GetOrCreate()); model->TakeOwnership(encoder); return encoder; } - // Fully encode a variable. This can be called only once. - // - // Important: this should really only be called with - // integer_trail_->InitialVariableDomain() which can be updated with - // integer_trail_->UpdateInitialDomain(). - // - // TODO(user): clean this up by enforcing this programmatically. + // Fully encode a variable using its current initial domain. + // This can be called only once. // // This creates new Booleans variables as needed: // 1) num_values for the literals X == value. Except when there is just @@ -205,14 +211,13 @@ class IntegerEncoder { // The encoding for NegationOf(var) is automatically created too. It reuses // the same Boolean variable as the encoding of var. // - // Note(user): Calling this with just one value will cause a CHECK fail. - // We don't really want to create a fixed Boolean. + // Note(user): Calling this on fixed variables will cause a CHECK fail. We + // don't really want to create a fixed Boolean. // // TODO(user): It is currently only possible to call that at the decision // level zero because we cannot add ternary clause in the middle of the // search (for now). This is Checked. - void FullyEncodeVariable(IntegerVariable var, - const std::vector& domain); + void FullyEncodeVariable(IntegerVariable var); // Similar to FullyEncodeVariable() but use the given literal for each values. // This can only be called on variable that are not fully encoded yet, This is @@ -249,45 +254,56 @@ class IntegerEncoder { return ContainsKey(full_encoding_index_, var); } - // Returns the set of variable encoded as the keys in a map. The map values - // only have an internal meaning. The set of encoded variables is returned - // with this "weird" api for efficiency. - const std::unordered_map& GetFullyEncodedVariables() const { - return full_encoding_index_; - } + // Returns the "canonical" (i_lit, negation of i_lit) pair. This mainly + // deal with domain with initial hole like [1,2][5,6] so that if one ask + // for x <= 3, this get canonicalized in the pair (x <= 2, x >= 5). + // + // Note that it is an error to call this with a literal that is trivially true + // or trivially false according to the initial variable domain. This is + // CHECKed to make sure we don't create wasteful literal. + // + // TODO(user): This is linear in the domain "complexity", we can do better if + // needed. + std::pair Canonicalize( + IntegerLiteral i_lit) const; - // Creates a new Boolean variable 'var' such that + // Returns, after creating it if needed, a Boolean literal such that: // - if true, then the IntegerLiteral is true. // - if false, then the negated IntegerLiteral is true. // - // Returns Literal(var, true). + // Note that this "canonicalize" the given literal first. // // This add the proper implications with the two "neighbor" literals of this // one if they exist. This is the "list encoding" in: Thibaut Feydy, Peter J. // Stuckey, "Lazy Clause Generation Reengineered", CP 2009. - // - // It is an error to call this with an already created literal. This is - // Checked. - // - // The second version use the given literal instead of creating a new - // variable. - Literal CreateAssociatedLiteral(IntegerLiteral i_lit); - void AssociateGivenLiteral(IntegerLiteral i_lit, Literal wanted); - - // Same as CreateAssociatedLiteral() but safe to call if already created. Literal GetOrCreateAssociatedLiteral(IntegerLiteral i_lit); - // Only add the equivalence between i_lit and literal, if there is already an - // associated literal with i_lit, this make literal and this associated - // literal equivalent. - void HalfAssociateGivenLiteral(IntegerLiteral i_lit, Literal literal); + // Associates the Boolean literal to (X >= bound) or (X == value). If a + // literal was already associated to this fact, this will add an equality + // constraints between both literals. If the fact is trivially true or false, + // this will fix the given literal. + void AssociateToIntegerLiteral(Literal literal, IntegerLiteral i_lit); + void AssociateToIntegerEqualValue(Literal literal, IntegerVariable var, + IntegerValue value); - // Return true iff the given integer literal is associated. + // Returns true iff the given integer literal is associated. The second + // version returns the associated literal or kNoLiteralIndex. Note that none + // of these function call Canonicalize() first for speed, so it is possible + // that this returns false even though GetOrCreateAssociatedLiteral() would + // not create a new literal. bool LiteralIsAssociated(IntegerLiteral i_lit) const; - - // Returns the associated literal or kNoLiteralIndex. LiteralIndex GetAssociatedLiteral(IntegerLiteral i_lit); + // Advanced usage. It is more efficient to create the associated literals in + // order, but it might be anoying to do so. Instead, you can first call + // DisableImplicationBetweenLiteral() and when you are done creating all the + // associated literals, you can call (only at level zero) + // AddAllImplicationsBetweenAssociatedLiterals() which will also turn back on + // the implications between literals for the one that will be added + // afterwards. + void DisableImplicationBetweenLiteral() { add_implications_ = false; } + void AddAllImplicationsBetweenAssociatedLiterals(); + // Returns the IntegerLiterals that were associated with the given Literal. const InlinedIntegerLiteralVector& GetIntegerLiterals(Literal lit) const { if (lit.Index() >= reverse_encoding_.size()) { @@ -307,12 +323,20 @@ class IntegerEncoder { LiteralIndex SearchForLiteralAtOrBefore(IntegerLiteral i) const; private: + // Only add the equivalence between i_lit and literal, if there is already an + // associated literal with i_lit, this make literal and this associated + // literal equivalent. + void HalfAssociateGivenLiteral(IntegerLiteral i_lit, Literal literal); + // Adds the new associated_lit to encoding_by_var_. // Adds the implications: Literal(before) <= associated_lit <= Literal(after). void AddImplications(IntegerLiteral i, Literal associated_lit); SatSolver* sat_solver_; - int64 num_created_variables_; + IntegerDomains* domains_; + + bool add_implications_ = true; + int64 num_created_variables_ = 0; // We keep all the literals associated to an Integer variable in a map ordered // by bound (so we can properly add implications between the literals @@ -323,6 +347,12 @@ class IntegerEncoder { const InlinedIntegerLiteralVector empty_integer_literal_vector_; ITIVector reverse_encoding_; + // Mapping (variable == value) -> associated literal. Note that even if + // there is more than one literal associated to the same fact, we just keep + // the first one that was added. + std::unordered_map, Literal> + equality_to_associated_literal_; + // Full domain encoding. The map contains the index in full_encoding_ of // the fully encoded variable. Each entry in full_encoding_ is sorted by // IntegerValue and contains the encoding of one IntegerVariable. @@ -337,17 +367,19 @@ class IntegerEncoder { // to maintain the reason for each propagation. class IntegerTrail : public SatPropagator { public: - IntegerTrail(IntegerEncoder* encoder, Trail* trail) + IntegerTrail(IntegerDomains* domains, IntegerEncoder* encoder, Trail* trail) : SatPropagator("IntegerTrail"), num_enqueues_(0), + domains_(domains), encoder_(encoder), trail_(trail) {} ~IntegerTrail() final {} static IntegerTrail* CreateInModel(Model* model) { + IntegerDomains* domains = model->GetOrCreate(); IntegerEncoder* encoder = model->GetOrCreate(); Trail* trail = model->GetOrCreate(); - IntegerTrail* integer_trail = new IntegerTrail(encoder, trail); + IntegerTrail* integer_trail = new IntegerTrail(domains, encoder, trail); model->GetOrCreate()->AddPropagator( std::unique_ptr(integer_trail)); return integer_trail; @@ -629,6 +661,7 @@ class IntegerTrail : public SatPropagator { std::vector*> watchers_; + IntegerDomains* domains_; IntegerEncoder* encoder_; Trail* trail_; @@ -1024,13 +1057,7 @@ inline std::function Equality(IntegerVariable v, int64 value) { // Associate the given literal to the given integer inequality. inline std::function Equality(IntegerLiteral i, Literal l) { return [=](Model* model) { - IntegerEncoder* encoder = model->GetOrCreate(); - if (encoder->LiteralIsAssociated(i)) { - const Literal current = encoder->GetOrCreateAssociatedLiteral(i); - model->Add(Equality(current, l)); - } else { - encoder->AssociateGivenLiteral(i, l); - } + model->GetOrCreate()->AssociateToIntegerLiteral(l, i); }; } @@ -1091,8 +1118,7 @@ FullyEncodeVariable(IntegerVariable var) { return [=](Model* model) { IntegerEncoder* encoder = model->GetOrCreate(); if (!encoder->VariableIsFullyEncoded(var)) { - encoder->FullyEncodeVariable( - var, model->GetOrCreate()->InitialVariableDomain(var)); + encoder->FullyEncodeVariable(var); } return encoder->FullDomainEncoding(var); }; diff --git a/ortools/sat/lp_utils.cc b/ortools/sat/lp_utils.cc index 906e97a3ad..88f439da36 100644 --- a/ortools/sat/lp_utils.cc +++ b/ortools/sat/lp_utils.cc @@ -168,11 +168,11 @@ bool ConvertMPModelProtoToCpModelProto(const MPModelProto& mp_model, } // Display the error/scaling without taking into account the objective first. - LOG(INFO) << "Maximum constraint coefficient relative error: " - << max_relative_coeff_error; - LOG(INFO) << "Maximum constraint worst-case sum absolute error: " - << max_scaled_sum_error; - LOG(INFO) << "Maximum constraint scaling factor: " << max_scaling_factor; + VLOG(1) << "Maximum constraint coefficient relative error: " + << max_relative_coeff_error; + VLOG(1) << "Maximum constraint worst-case sum absolute error: " + << max_scaled_sum_error; + VLOG(1) << "Maximum constraint scaling factor: " << max_scaling_factor; // Add the objective. We use kint64max / 2 because the objective_var will // also be added to the objective constraint. @@ -188,7 +188,7 @@ bool ConvertMPModelProtoToCpModelProto(const MPModelProto& mp_model, lower_bounds.push_back(var_proto.domain(0)); upper_bounds.push_back(var_proto.domain(var_proto.domain_size() - 1)); } - if (!coefficients.empty()) { + if (!coefficients.empty() || mp_model.objective_offset() != 0.0) { GetBestScalingOfDoublesToInt64(coefficients, lower_bounds, upper_bounds, kMaxObjective, &scaling_factor, &relative_coeff_error, &scaled_sum_error); @@ -197,11 +197,10 @@ bool ConvertMPModelProtoToCpModelProto(const MPModelProto& mp_model, std::max(relative_coeff_error, max_relative_coeff_error); // Display the objective error/scaling. - LOG(INFO) << "objective coefficient relative error: " - << relative_coeff_error; - LOG(INFO) << "objective worst-case absolute error: " - << scaled_sum_error / scaling_factor; - LOG(INFO) << "objective scaling factor: " << scaling_factor / gcd; + VLOG(1) << "objective coefficient relative error: " << relative_coeff_error; + VLOG(1) << "objective worst-case absolute error: " + << scaled_sum_error / scaling_factor; + VLOG(1) << "objective scaling factor: " << scaling_factor / gcd; // Note that here we set the scaling factor for the inverse operation of // getting the "true" objective value from the scaled one. Hence the @@ -222,8 +221,8 @@ bool ConvertMPModelProtoToCpModelProto(const MPModelProto& mp_model, auto* objective_constraint = cp_model->add_constraints(); auto* objective_arg = objective_constraint->mutable_linear(); objective_constraint->set_name("objective"); - objective_arg->add_domain(mp_model.maximize() ? 0 : kint64min); - objective_arg->add_domain(mp_model.maximize() ? kint64max : 0); + objective_arg->add_domain(0); + objective_arg->add_domain(0); for (int i = 0; i < num_variables; ++i) { const MPVariableProto& mp_var = mp_model.variable(i); const int64 value = @@ -243,6 +242,7 @@ bool ConvertMPModelProtoToCpModelProto(const MPModelProto& mp_model, if (mp_model.maximize()) { objective->set_objective_var(-objective->objective_var() - 1); objective->set_scaling_factor(-objective->scaling_factor()); + objective->set_offset(-objective->offset()); } } diff --git a/ortools/sat/lp_utils.h b/ortools/sat/lp_utils.h index d7d6bea9e9..085d8a954a 100644 --- a/ortools/sat/lp_utils.h +++ b/ortools/sat/lp_utils.h @@ -60,7 +60,7 @@ int FixVariablesFromSat(const SatSolver& solver, glop::LinearProgram* lp); // polarity choices. The variable must have the same index in the solved lp // problem and in SAT for this to make sense. // -// Returns false if a problem occured while trying to solve the lp. +// Returns false if a problem occurred while trying to solve the lp. bool SolveLpAndUseSolutionForSatAssignmentPreference( const glop::LinearProgram& lp, SatSolver* sat_solver, double max_time_in_seconds); diff --git a/ortools/sat/optimization.cc b/ortools/sat/optimization.cc index e0fec55f07..bfe1dc5639 100644 --- a/ortools/sat/optimization.cc +++ b/ortools/sat/optimization.cc @@ -18,6 +18,10 @@ #include "ortools/base/stringprintf.h" #include "google/protobuf/descriptor.h" +#if defined(USE_CBC) || defined(USE_SCIP) +#include "ortools/linear_solver/linear_solver.h" +#include "ortools/linear_solver/linear_solver.pb.h" +#endif // defined(USE_CBC) || defined(USE_SCIP) #include "ortools/sat/encoding.h" #include "ortools/sat/integer_expr.h" #include "ortools/sat/util.h" @@ -1108,6 +1112,52 @@ SatSolver::Status MinimizeIntegerVariableWithLinearScanAndLazyEncoding( return result; } +namespace { + +// If the given model is unsat under the given assumptions, returns one or more +// non-overlapping set of assumptions, each set making the problem infeasible on +// its own (the cores). +// +// The returned status can be either: +// - ASSUMPTIONS_UNSAT if the set of returned core perfectly cover the given +// assumptions, in this case, we don't bother trying to find a SAT solution +// with no assumptions. +// - MODEL_SAT if after finding zero or more core we have a solution. +// - LIMIT_REACHED if we reached the time-limit before one of the two status +// above could be decided. +SatSolver::Status FindCores(std::vector assumptions, + const std::function& next_decision, + Model* model, + std::vector>* cores) { + cores->clear(); + SatSolver* sat_solver = model->GetOrCreate(); + do { + const SatSolver::Status result = + SolveIntegerProblemWithLazyEncoding(assumptions, next_decision, model); + if (result != SatSolver::ASSUMPTIONS_UNSAT) return result; + std::vector core = sat_solver->GetLastIncompatibleDecisions(); + if (sat_solver->parameters().minimize_core()) { + MinimizeCore(sat_solver, &core); + } + CHECK(!core.empty()); + cores->push_back(core); + if (!sat_solver->parameters().find_multiple_cores()) break; + + // Remove from assumptions all the one in the core and see if we can find + // another core. + int new_size = 0; + std::set temp(core.begin(), core.end()); + for (int i = 0; i < assumptions.size(); ++i) { + if (ContainsKey(temp, assumptions[i])) continue; + assumptions[new_size++] = assumptions[i]; + } + assumptions.resize(new_size); + } while (!assumptions.empty()); + return SatSolver::ASSUMPTIONS_UNSAT; +} + +} // namespace + SatSolver::Status MinimizeWithCoreAndLazyEncoding( bool log_info, IntegerVariable objective_var, const std::vector& variables, @@ -1152,9 +1202,7 @@ SatSolver::Status MinimizeWithCoreAndLazyEncoding( struct ObjectiveTerm { IntegerVariable var; IntegerValue weight; - - // These fields are only used for logging/debugging. - int depth; + int depth; // only for logging/debugging. IntegerValue old_var_lb; }; std::vector terms; @@ -1245,84 +1293,119 @@ SatSolver::Status MinimizeWithCoreAndLazyEncoding( } // Solve under the assumptions. - result = - SolveIntegerProblemWithLazyEncoding(assumptions, next_decision, model); + std::vector> cores; + result = FindCores(assumptions, next_decision, model, &cores); if (result == SatSolver::MODEL_SAT) { - if (!process_solution()) { - result = SatSolver::MODEL_UNSAT; - break; + process_solution(); + if (cores.empty()) { + // If not all assumptions where taken, continue with a lower stratified + // bound. Otherwise we have an optimal solution. + stratified_threshold = next_stratified_threshold; + if (stratified_threshold == 0) break; + --iter; // "false" iteration, the lower bound does not increase. + continue; } - - // If not all assumptions where taken, continue with a lower stratified - // bound. Otherwise we have an optimal solution. - stratified_threshold = next_stratified_threshold; - if (stratified_threshold == 0) break; - --iter; // "false" iteration, the lower bound does not increase. - continue; + } else if (result != SatSolver::ASSUMPTIONS_UNSAT) { + break; } - if (result != SatSolver::ASSUMPTIONS_UNSAT) break; - - // We have a new core. - std::vector core = sat_solver->GetLastIncompatibleDecisions(); - if (sat_solver->parameters().minimize_core()) { - MinimizeCore(sat_solver, &core); - } - CHECK(!core.empty()); - - // This just increase the lower-bound of the corresponding node, which - // should already be done by the solver. - if (core.size() == 1) continue; sat_solver->Backtrack(0); sat_solver->SetAssumptionLevel(0); + for (const std::vector& core : cores) { + // This just increase the lower-bound of the corresponding node, which + // should already be done by the solver. + if (core.size() == 1) continue; - // Compute the min weight of all the terms in the core. The lower bound will - // be increased by that much because at least one assumption in the core - // must be true. This is also why we can start at 1 for new_var_lb. - IntegerValue min_weight = kMaxIntegerValue; - IntegerValue max_weight(0); - IntegerValue new_var_lb(1); - IntegerValue new_var_ub(0); - int new_depth = 0; - for (const Literal lit : core) { - const int index = FindOrDie(assumption_to_term_index, lit.Index()); - min_weight = std::min(min_weight, terms[index].weight); - max_weight = std::max(max_weight, terms[index].weight); - new_depth = std::max(new_depth, terms[index].depth + 1); - new_var_lb += integer_trail->LowerBound(terms[index].var); - new_var_ub += integer_trail->UpperBound(terms[index].var); - CHECK_EQ(terms[index].old_var_lb, - integer_trail->LowerBound(terms[index].var)); - } - max_depth = std::max(max_depth, new_depth); - if (log_info) { - LOG(INFO) << StringPrintf( - " core:%zu weight:[%lld,%lld] domain:[%lld,%lld] depth:%d", - core.size(), min_weight.value(), max_weight.value(), - new_var_lb.value(), new_var_ub.value(), new_depth); - } - - // We will "transfer" min_weight from all the variables of the core - // to a new variable. - const IntegerVariable new_var = - model->Add(NewIntegerVariable(new_var_lb.value(), new_var_ub.value())); - terms.push_back({new_var, min_weight, new_depth}); - - // Sum variables in the core <= new_var. - // TODO(user): Experiment with FixedWeightedSum() instead. - { - std::vector constraint_vars; - std::vector constraint_coeffs; + // Compute the min weight of all the terms in the core. The lower bound + // will be increased by that much because at least one assumption in the + // core must be true. This is also why we can start at 1 for new_var_lb. + bool ignore_this_core = false; + IntegerValue min_weight = kMaxIntegerValue; + IntegerValue max_weight(0); + IntegerValue new_var_lb(1); + IntegerValue new_var_ub(0); + int new_depth = 0; for (const Literal lit : core) { const int index = FindOrDie(assumption_to_term_index, lit.Index()); - terms[index].weight -= min_weight; - constraint_vars.push_back(terms[index].var); - constraint_coeffs.push_back(1); + min_weight = std::min(min_weight, terms[index].weight); + max_weight = std::max(max_weight, terms[index].weight); + new_depth = std::max(new_depth, terms[index].depth + 1); + new_var_lb += integer_trail->LowerBound(terms[index].var); + new_var_ub += integer_trail->UpperBound(terms[index].var); + + // When this happen, the core is now trivially "minimized" by the new + // bound on this variable, so there is no point in adding it. + if (terms[index].old_var_lb < + integer_trail->LowerBound(terms[index].var)) { + ignore_this_core = true; + break; + } + } + if (ignore_this_core) continue; + + max_depth = std::max(max_depth, new_depth); + if (log_info) { + LOG(INFO) << StringPrintf( + " core:%zu weight:[%lld,%lld] domain:[%lld,%lld] depth:%d", + core.size(), min_weight.value(), max_weight.value(), + new_var_lb.value(), new_var_ub.value(), new_depth); + } + + // We will "transfer" min_weight from all the variables of the core + // to a new variable. + const IntegerVariable new_var = model->Add( + NewIntegerVariable(new_var_lb.value(), new_var_ub.value())); + terms.push_back({new_var, min_weight, new_depth}); + + // Sum variables in the core <= new_var. + // TODO(user): Experiment with FixedWeightedSum() instead. + { + std::vector constraint_vars; + std::vector constraint_coeffs; + for (const Literal lit : core) { + const int index = FindOrDie(assumption_to_term_index, lit.Index()); + terms[index].weight -= min_weight; + constraint_vars.push_back(terms[index].var); + constraint_coeffs.push_back(1); + } + constraint_vars.push_back(new_var); + constraint_coeffs.push_back(-1); + model->Add( + WeightedSumLowerOrEqual(constraint_vars, constraint_coeffs, 0)); + } + + // Find out the true lower bound of new_var. This is called "cover + // optimization" in the max-SAT literature. + // + // TODO(user): Do more experiments to decide if this is better. This + // approach kind of mix the basic linear-scan one with the core based + // approach. + if (/* DISABLES CODE */ (false)) { + IntegerValue best = new_var_ub; + + // Simple linear scan algorithm to find the optimal of new_var. + while (best > new_var_lb) { + const Literal a = integer_encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::LowerOrEqual(new_var, best - 1)); + result = + SolveIntegerProblemWithLazyEncoding({a}, next_decision, model); + if (result != SatSolver::MODEL_SAT) break; + best = integer_trail->LowerBound(new_var); + if (!process_solution()) { + result = SatSolver::MODEL_UNSAT; + break; + } + } + if (result == SatSolver::ASSUMPTIONS_UNSAT) { + sat_solver->Backtrack(0); + sat_solver->SetAssumptionLevel(0); + if (!integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual(new_var, best), {}, {})) { + result = SatSolver::MODEL_UNSAT; + break; + } + } } - constraint_vars.push_back(new_var); - constraint_coeffs.push_back(-1); - model->Add( - WeightedSumLowerOrEqual(constraint_vars, constraint_coeffs, 0)); } // Re-express the objective with the new terms. @@ -1340,35 +1423,217 @@ SatSolver::Status MinimizeWithCoreAndLazyEncoding( constraint_coeffs.push_back(-1); model->Add(FixedWeightedSum(constraint_vars, constraint_coeffs, 0)); } + } - // Find out the true lower bound of new_var. This is called "cover - // optimization" in the max-SAT literature. + // Returns MODEL_SAT if we found the optimal. + return num_solutions > 0 && result == SatSolver::MODEL_UNSAT + ? SatSolver::MODEL_SAT + : result; +} + +#if defined(USE_CBC) || defined(USE_SCIP) +// TODO(user): take the MPModelRequest or MPModelProto directly, so that we can +// have initial constraints! +// +// TODO(user): remove code duplication with MinimizeWithCoreAndLazyEncoding(); +SatSolver::Status MinimizeWithHittingSetAndLazyEncoding( + bool log_info, IntegerVariable objective_var, + std::vector variables, + std::vector coefficients, + const std::function& next_decision, + const std::function& feasible_solution_observer, + Model* model) { + SatSolver* sat_solver = model->GetOrCreate(); + IntegerTrail* integer_trail = model->GetOrCreate(); + IntegerEncoder* integer_encoder = model->GetOrCreate(); + + // This will be called each time a feasible solution is found. + int num_solutions = 0; + IntegerValue best_objective = integer_trail->UpperBound(objective_var); + const auto process_solution = [&]() { + const IntegerValue objective(model->Get(Value(objective_var))); + if (objective >= best_objective) return; + ++num_solutions; + best_objective = objective; + if (feasible_solution_observer != nullptr) { + feasible_solution_observer(*model); + } + }; + + // This is the "generalized" hitting set problem we will solve. Each time + // we find a core, a new constraint will be added to this problem. + MPModelRequest request; + #if defined(USE_SCIP) + request.set_solver_type(MPModelRequest::SCIP_MIXED_INTEGER_PROGRAMMING); + #else // USE_CBC + request.set_solver_type(MPModelRequest::CBC_MIXED_INTEGER_PROGRAMMING); + #endif // USE_CBC or USE_SCIP + + MPModelProto& hs_model = *request.mutable_model(); + const int num_variables = variables.size(); + for (int i = 0; i < num_variables; ++i) { + if (coefficients[i] < 0) { + variables[i] = NegationOf(variables[i]); + coefficients[i] = -coefficients[i]; + } + const IntegerVariable var = variables[i]; + MPVariableProto* var_proto = hs_model.add_variable(); + var_proto->set_lower_bound(integer_trail->LowerBound(var).value()); + var_proto->set_upper_bound(integer_trail->UpperBound(var).value()); + var_proto->set_objective_coefficient(coefficients[i].value()); + var_proto->set_is_integer(true); + } + + // The MIP solver. + #if defined(USE_SCIP) + MPSolver solver("HS solver", MPSolver::SCIP_MIXED_INTEGER_PROGRAMMING); + #else // USE_CBC + MPSolver solver("HS solver", MPSolver::CBC_MIXED_INTEGER_PROGRAMMING); + #endif // USE_CBC or USE_SCIP + MPSolutionResponse response; + + // This is used by the "stratified" approach. We will only consider terms with + // a weight not lower than this threshold. The threshold will decrease as the + // algorithm progress. + IntegerValue stratified_threshold = kMaxIntegerValue; + + // TODO(user): The core is returned in the same order as the assumptions, + // so we don't really need this map, we could just do a linear scan to + // recover which node are part of the core. + std::map assumption_to_index; + + // New Booleans variable in the MIP model to represent X >= cte. + std::map, int> created_var; + + // Start the algorithm. + SatSolver::Status result; + for (int iter = 0;; ++iter) { + // TODO(user): Even though we keep the same solver, currently the solve is + // not really done incrementally. It might be hard to improve though. // - // TODO(user): Do more experiments to decide if this is better. This - // approach kind of mix the basic linear-scan one with the core based - // approach. - if (/* DISABLES CODE */ (false)) { - IntegerValue best = new_var_ub; + // TODO(user): deal with time limit. + solver.SolveWithProto(request, &response); + CHECK_EQ(response.status(), MPSolverResponseStatus::MPSOLVER_OPTIMAL); + if (log_info) { + LOG(INFO) << "constraints: " << hs_model.constraint_size() + << " variables: " << hs_model.variable_size() + << " mip_lower_bound: " << response.objective_value() + << " strat: " << stratified_threshold; + } - // Simple linear scan algorithm to find the optimal of new_var. - while (best > new_var_lb) { - const Literal a = integer_encoder->GetOrCreateAssociatedLiteral( - IntegerLiteral::LowerOrEqual(new_var, best - 1)); - result = SolveIntegerProblemWithLazyEncoding({a}, next_decision, model); - if (result != SatSolver::MODEL_SAT) break; - best = integer_trail->LowerBound(new_var); - if (!process_solution()) { - result = SatSolver::MODEL_UNSAT; - break; - } + // Update the objective lower bound with our current bound. + // + // Note(user): This is not needed for correctness, but it might cause + // more propagation and is nice to have for reporting/logging purpose. + if (!integer_trail->Enqueue( + IntegerLiteral::GreaterOrEqual( + objective_var, + IntegerValue(static_cast(response.objective_value()))), + {}, {})) { + result = SatSolver::MODEL_UNSAT; + break; + } + + sat_solver->Backtrack(0); + sat_solver->SetAssumptionLevel(0); + std::vector assumptions; + assumption_to_index.clear(); + IntegerValue next_stratified_threshold(0); + for (int i = 0; i < num_variables; ++i) { + const IntegerValue hs_value( + static_cast(response.variable_value(i))); + if (hs_value == integer_trail->UpperBound(variables[i])) continue; + + // Only consider the terms above the threshold. + if (coefficients[i] < stratified_threshold) { + next_stratified_threshold = + std::max(next_stratified_threshold, coefficients[i]); + } else { + assumptions.push_back(integer_encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::LowerOrEqual(variables[i], hs_value))); + InsertOrDie(&assumption_to_index, assumptions.back().Index(), i); } - if (result == SatSolver::ASSUMPTIONS_UNSAT) { - sat_solver->Backtrack(0); - sat_solver->SetAssumptionLevel(0); - if (!integer_trail->Enqueue( - IntegerLiteral::GreaterOrEqual(new_var, best), {}, {})) { - result = SatSolver::MODEL_UNSAT; - break; + } + + // No assumptions with the current stratified_threshold? use the new one. + if (assumptions.empty()) { + if (next_stratified_threshold > 0) { + CHECK_LT(next_stratified_threshold, stratified_threshold); + stratified_threshold = next_stratified_threshold; + --iter; // "false" iteration, the lower bound does not increase. + continue; + } else { + result = + num_solutions > 0 ? SatSolver::MODEL_SAT : SatSolver::MODEL_UNSAT; + break; + } + } + + // TODO(user): we could also randomly shuffle the assumptions to find more + // cores for only one MIP solve. + std::vector> cores; + result = FindCores(assumptions, next_decision, model, &cores); + if (result == SatSolver::MODEL_SAT) { + process_solution(); + if (cores.empty()) { + // If not all assumptions where taken, continue with a lower stratified + // bound. Otherwise we have an optimal solution. + stratified_threshold = next_stratified_threshold; + if (stratified_threshold == 0) break; + --iter; // "false" iteration, the lower bound does not increase. + continue; + } + } else if (result != SatSolver::ASSUMPTIONS_UNSAT) { + break; + } + + sat_solver->Backtrack(0); + sat_solver->SetAssumptionLevel(0); + for (const std::vector& core : cores) { + if (core.size() == 1) { + const int index = FindOrDie(assumption_to_index, core.front().Index()); + hs_model.mutable_variable(index)->set_lower_bound( + integer_trail->LowerBound(variables[index]).value()); + continue; + } + + // Add the corresponding constraint to hs_model. + MPConstraintProto* ct = hs_model.add_constraint(); + ct->set_lower_bound(1.0); + for (const Literal lit : core) { + const int index = FindOrDie(assumption_to_index, lit.Index()); + const double lb = integer_trail->LowerBound(variables[index]).value(); + const double hs_value = response.variable_value(index); + if (hs_value == lb) { + ct->add_var_index(index); + ct->add_coefficient(1.0); + ct->set_lower_bound(ct->lower_bound() + lb); + } else { + // TODO(user): if we have just one variable whose hs_value is not at + // its lower bound, then we can add a cut that remove the current + // solution (on the core) without the need to introduce this extra + // variable. + std::pair key = {index, hs_value}; + if (!ContainsKey(created_var, key)) { + const int new_var_index = hs_model.variable_size(); + created_var[key] = new_var_index; + + MPVariableProto* new_var = hs_model.add_variable(); + new_var->set_lower_bound(0); + new_var->set_upper_bound(1); + new_var->set_is_integer(true); + + // (new_var == 1) => x > hs_value. + // (x - lb) - (hs_value - lb + 1) * new_var >= 0. + MPConstraintProto* implication = hs_model.add_constraint(); + implication->set_lower_bound(lb); + implication->add_var_index(index); + implication->add_coefficient(1.0); + implication->add_var_index(new_var_index); + implication->add_coefficient(lb - hs_value - 1); + } + ct->add_var_index(FindOrDieNoPrint(created_var, key)); + ct->add_coefficient(1.0); } } } @@ -1379,6 +1644,7 @@ SatSolver::Status MinimizeWithCoreAndLazyEncoding( ? SatSolver::MODEL_SAT : result; } +#endif // defined(USE_CBC) || defined(USE_SCIP) SatSolver::Status MinimizeWeightedLiteralSumWithCoreAndLazyEncoding( bool log_info, const std::vector& literals, diff --git a/ortools/sat/optimization.h b/ortools/sat/optimization.h index 32a4fecbfc..27ead2ad82 100644 --- a/ortools/sat/optimization.h +++ b/ortools/sat/optimization.h @@ -136,7 +136,7 @@ SatSolver::Status MinimizeIntegerVariableWithLinearScanAndLazyEncoding( // sum of the given variables using the given coefficients. // // TODO(user): It is not needed to have objective_var and the linear objective -// constraint encoded in the model. Remove this preconditions in order to +// constraint encoded in the model. Remove this precondition in order to // improve the solving time. SatSolver::Status MinimizeWithCoreAndLazyEncoding( bool log_info, IntegerVariable objective_var, @@ -146,6 +146,30 @@ SatSolver::Status MinimizeWithCoreAndLazyEncoding( const std::function& feasible_solution_observer, Model* model); +#if defined(USE_CBC) || defined(USE_SCIP) +// Generalization of the max-HS algorithm (HS stands for Hitting Set). This is +// similar to MinimizeWithCoreAndLazyEncoding() but it uses an hybrid approach +// with a MIP solver to handle the discovered infeasibility cores. +// +// See, Jessica Davies and Fahiem Bacchus, "Solving MAXSAT by Solving a +// Sequence of Simpler SAT Instances", +// http://www.cs.toronto.edu/~jdavies/daviesCP11.pdf" +// +// Note that an implementation of this approach won the 2016 max-SAT competition +// on the industrial category, see +// http://maxsat.ia.udl.cat/results/#wpms-industrial +// +// TODO(user): This function brings dependency to the SCIP MIP solver which is +// quite big, maybe we should find a way not to do that. +SatSolver::Status MinimizeWithHittingSetAndLazyEncoding( + bool log_info, IntegerVariable objective_var, + std::vector variables, + std::vector coefficients, + const std::function& next_decision, + const std::function& feasible_solution_observer, + Model* model); +#endif // defined(USE_CBC) || defined(USE_SCIP) + // Similar to MinimizeIntegerVariableWithLinearScanAndLazyEncoding() but use // a core based approach. Note that this require the objective to be given as // a weighted sum of literals diff --git a/ortools/sat/sat_base.h b/ortools/sat/sat_base.h index 2c652b21ed..86a1e460f3 100644 --- a/ortools/sat/sat_base.h +++ b/ortools/sat/sat_base.h @@ -16,8 +16,8 @@ #ifndef OR_TOOLS_SAT_SAT_BASE_H_ #define OR_TOOLS_SAT_SAT_BASE_H_ -#include #include +#include #include #include #include diff --git a/ortools/sat/sat_parameters.proto b/ortools/sat/sat_parameters.proto index 4c4b716b9e..5e06d1baee 100644 --- a/ortools/sat/sat_parameters.proto +++ b/ortools/sat/sat_parameters.proto @@ -18,7 +18,7 @@ package operations_research.sat; // Contains the definitions for all the sat algorithm parameters and their // default values. // -// NEXT TAG: 84 +// NEXT TAG: 86 message SatParameters { // ========================================================================== // Branching and polarity @@ -387,6 +387,10 @@ message SatParameters { // Whether we use a simple heuristic to try to minimize an UNSAT core. optional bool minimize_core = 50 [default = true]; + // Wheter we try to find more independent cores for a given set of assumptions + // in the core based max-SAT algorithms. + optional bool find_multiple_cores = 84 [default = true]; + // In what order do we add the assumptions in a core-based max-sat algorithm enum MaxSatAssumptionOrder { DEFAULT_ASSUMPTION_ORDER = 0; @@ -485,4 +489,12 @@ message SatParameters { // use a core-based approach (like in max-SAT) when we try to increase the // lower bound instead. optional bool optimize_with_core = 83 [default = false]; + + // This has no effect if optimize_with_core is false. If true, use a different + // core-based algorithm similar to the max-HS algo for max-SAT. This is a + // hybrid MIP/CP approach and it uses a MIP solver in addition to the CP/SAT + // one. This is also related to the PhD work of tobyodavies@ + // "Automatic Logic-Based Benders Decomposition with MiniZinc" + // http://aaai.org/ocs/index.php/AAAI/AAAI17/paper/view/14489 + optional bool optimize_with_max_hs = 85 [default = false]; } diff --git a/ortools/sat/table.cc b/ortools/sat/table.cc index 2ea61822d6..0e23aa1fb9 100644 --- a/ortools/sat/table.cc +++ b/ortools/sat/table.cc @@ -208,19 +208,29 @@ std::function NegatedTableConstraintWithoutFullEncoding( std::vector clause; for (const std::vector& tuple : tuples) { clause.clear(); + bool add = true; for (int i = 0; i < n; ++i) { const int64 value = tuple[i]; - if (value > model->Get(LowerBound(vars[i]))) { + const int64 lb = model->Get(LowerBound(vars[i])); + const int64 ub = model->Get(UpperBound(vars[i])); + CHECK_LT(lb, ub); + // TODO(user): test the full initial domain instead of just checking + // the bounds. + if (value < lb || value > ub) { + add = false; + break; + } + if (value > lb) { clause.push_back(encoder->GetOrCreateAssociatedLiteral( IntegerLiteral::LowerOrEqual(vars[i], IntegerValue(value - 1)))); } - if (value < model->Get(UpperBound(vars[i]))) { + if (value < ub) { clause.push_back(encoder->GetOrCreateAssociatedLiteral( IntegerLiteral::GreaterOrEqual(vars[i], IntegerValue(value + 1)))); } } - model->Add(ClauseConstraint(clause)); + if (add) model->Add(ClauseConstraint(clause)); } }; } diff --git a/ortools/util/BUILD b/ortools/util/BUILD index d719fa2dff..ce5fd9941c 100644 --- a/ortools/util/BUILD +++ b/ortools/util/BUILD @@ -135,6 +135,7 @@ cc_library( deps = [ ":saturated_arithmetic", "//ortools/base", + "//ortools/base:span", "//ortools/base:strings", ], ) diff --git a/ortools/util/sorted_interval_list.cc b/ortools/util/sorted_interval_list.cc index 0c8f782529..804f9c2e45 100644 --- a/ortools/util/sorted_interval_list.cc +++ b/ortools/util/sorted_interval_list.cc @@ -75,8 +75,8 @@ bool IntervalsAreSortedAndDisjoint( return true; } -bool SortedDisjointIntervalsContain( - const std::vector& intervals, int64 value) { +bool SortedDisjointIntervalsContain(gtl::Span intervals, + int64 value) { for (const ClosedInterval& interval : intervals) { if (interval.start <= value && interval.end >= value) return true; } diff --git a/ortools/util/sorted_interval_list.h b/ortools/util/sorted_interval_list.h index 8f4d06f53a..511acf26b9 100644 --- a/ortools/util/sorted_interval_list.h +++ b/ortools/util/sorted_interval_list.h @@ -20,6 +20,7 @@ #include #include "ortools/base/integral_types.h" +#include "ortools/base/span.h" namespace operations_research { @@ -70,8 +71,8 @@ bool IntervalsAreSortedAndDisjoint( // // TODO(user): This works in O(n), but could be made to work in O(log n) for // long list of intervals. -bool SortedDisjointIntervalsContain( - const std::vector& intervals, int64 value); +bool SortedDisjointIntervalsContain(gtl::Span intervals, + int64 value); // Returns the intersection of two lists of sorted disjoint intervals in a // sorted disjoint interval form.