diff --git a/makefiles/Makefile.gen.mk b/makefiles/Makefile.gen.mk index de32978a12..6ae2bf2d9a 100644 --- a/makefiles/Makefile.gen.mk +++ b/makefiles/Makefile.gen.mk @@ -3,7 +3,6 @@ BASE_DEPS = \ $(SRC_DIR)/ortools/base/basictypes.h \ $(SRC_DIR)/ortools/base/callback.h \ $(SRC_DIR)/ortools/base/casts.h \ - $(SRC_DIR)/ortools/base/commandlineflags.h \ $(SRC_DIR)/ortools/base/file.h \ $(SRC_DIR)/ortools/base/integral_types.h \ $(SRC_DIR)/ortools/base/int_type.h \ @@ -80,7 +79,6 @@ $(SRC_DIR)/ortools/base/join.h: \ $(SRC_DIR)/ortools/base/stringpiece.h $(SRC_DIR)/ortools/base/logging.h: \ - $(SRC_DIR)/ortools/base/commandlineflags.h \ $(SRC_DIR)/ortools/base/integral_types.h \ $(SRC_DIR)/ortools/base/macros.h @@ -192,11 +190,6 @@ $(OBJ_DIR)/base/join.$O: \ $(SRC_DIR)/ortools/base/stringprintf.h $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Sbase$Sjoin.cc $(OBJ_OUT)$(OBJ_DIR)$Sbase$Sjoin.$O -$(OBJ_DIR)/base/logging.$O: \ - $(SRC_DIR)/ortools/base/logging.cc \ - $(SRC_DIR)/ortools/base/logging.h - $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Sbase$Slogging.cc $(OBJ_OUT)$(OBJ_DIR)$Sbase$Slogging.$O - $(OBJ_DIR)/base/mutex.$O: \ $(SRC_DIR)/ortools/base/mutex.cc \ $(SRC_DIR)/ortools/base/mutex.h @@ -264,7 +257,6 @@ UTIL_DEPS = \ $(SRC_DIR)/ortools/base/basictypes.h \ $(SRC_DIR)/ortools/base/callback.h \ $(SRC_DIR)/ortools/base/casts.h \ - $(SRC_DIR)/ortools/base/commandlineflags.h \ $(SRC_DIR)/ortools/base/file.h \ $(SRC_DIR)/ortools/base/integral_types.h \ $(SRC_DIR)/ortools/base/int_type.h \ @@ -291,6 +283,11 @@ UTIL_LIB_OBJS = \ $(OBJ_DIR)/util/time_limit.$O \ $(OBJ_DIR)/util/xml_helper.$O +$(SRC_DIR)/ortools/util/affine_relation.h: \ + $(SRC_DIR)/ortools/base/iterator_adaptors.h \ + $(SRC_DIR)/ortools/base/logging.h \ + $(SRC_DIR)/ortools/base/macros.h + $(SRC_DIR)/ortools/util/bitset.h: \ $(SRC_DIR)/ortools/base/basictypes.h \ $(SRC_DIR)/ortools/base/integral_types.h \ @@ -514,7 +511,6 @@ LP_DATA_DEPS = \ $(SRC_DIR)/ortools/base/basictypes.h \ $(SRC_DIR)/ortools/base/callback.h \ $(SRC_DIR)/ortools/base/casts.h \ - $(SRC_DIR)/ortools/base/commandlineflags.h \ $(SRC_DIR)/ortools/base/file.h \ $(SRC_DIR)/ortools/base/integral_types.h \ $(SRC_DIR)/ortools/base/int_type.h \ @@ -739,7 +735,6 @@ GLOP_DEPS = \ $(SRC_DIR)/ortools/base/basictypes.h \ $(SRC_DIR)/ortools/base/callback.h \ $(SRC_DIR)/ortools/base/casts.h \ - $(SRC_DIR)/ortools/base/commandlineflags.h \ $(SRC_DIR)/ortools/base/file.h \ $(SRC_DIR)/ortools/base/integral_types.h \ $(SRC_DIR)/ortools/base/int_type.h \ @@ -1057,7 +1052,6 @@ GRAPH_DEPS = \ $(SRC_DIR)/ortools/base/basictypes.h \ $(SRC_DIR)/ortools/base/callback.h \ $(SRC_DIR)/ortools/base/casts.h \ - $(SRC_DIR)/ortools/base/commandlineflags.h \ $(SRC_DIR)/ortools/base/file.h \ $(SRC_DIR)/ortools/base/integral_types.h \ $(SRC_DIR)/ortools/base/int_type.h \ @@ -1281,7 +1275,6 @@ ALGORITHMS_DEPS = \ $(SRC_DIR)/ortools/base/basictypes.h \ $(SRC_DIR)/ortools/base/callback.h \ $(SRC_DIR)/ortools/base/casts.h \ - $(SRC_DIR)/ortools/base/commandlineflags.h \ $(SRC_DIR)/ortools/base/file.h \ $(SRC_DIR)/ortools/base/integral_types.h \ $(SRC_DIR)/ortools/base/int_type.h \ @@ -1375,15 +1368,6 @@ $(OBJ_DIR)/algorithms/hungarian.$O: \ $(SRC_DIR)/ortools/algorithms/hungarian.h $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Salgorithms$Shungarian.cc $(OBJ_OUT)$(OBJ_DIR)$Salgorithms$Shungarian.$O -$(OBJ_DIR)/algorithms/hungarian_test.$O: \ - $(SRC_DIR)/ortools/algorithms/hungarian_test.cc \ - $(SRC_DIR)/ortools/algorithms/hungarian.h \ - $(SRC_DIR)/ortools/base/integral_types.h \ - $(SRC_DIR)/ortools/base/macros.h \ - $(SRC_DIR)/ortools/base/map_util.h \ - $(SRC_DIR)/ortools/base/random.h - $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Salgorithms$Shungarian_test.cc $(OBJ_OUT)$(OBJ_DIR)$Salgorithms$Shungarian_test.$O - $(OBJ_DIR)/algorithms/knapsack_solver.$O: \ $(SRC_DIR)/ortools/algorithms/knapsack_solver.cc \ $(SRC_DIR)/ortools/algorithms/knapsack_solver.h \ @@ -1405,6 +1389,7 @@ SAT_DEPS = \ $(GEN_DIR)/ortools/sat/boolean_problem.pb.h \ $(SRC_DIR)/ortools/sat/clause.h \ $(SRC_DIR)/ortools/sat/cp_constraints.h \ + $(GEN_DIR)/ortools/sat/cp_model.pb.h \ $(SRC_DIR)/ortools/sat/drat.h \ $(SRC_DIR)/ortools/sat/integer_expr.h \ $(SRC_DIR)/ortools/sat/integer.h \ @@ -1421,7 +1406,6 @@ SAT_DEPS = \ $(SRC_DIR)/ortools/base/basictypes.h \ $(SRC_DIR)/ortools/base/callback.h \ $(SRC_DIR)/ortools/base/casts.h \ - $(SRC_DIR)/ortools/base/commandlineflags.h \ $(SRC_DIR)/ortools/base/file.h \ $(SRC_DIR)/ortools/base/integral_types.h \ $(SRC_DIR)/ortools/base/int_type.h \ @@ -1474,6 +1458,10 @@ SAT_LIB_OBJS = \ $(OBJ_DIR)/sat/boolean_problem.$O \ $(OBJ_DIR)/sat/clause.$O \ $(OBJ_DIR)/sat/cp_constraints.$O \ + $(OBJ_DIR)/sat/cp_model_checker.$O \ + $(OBJ_DIR)/sat/cp_model_presolve.$O \ + $(OBJ_DIR)/sat/cp_model_solver.$O \ + $(OBJ_DIR)/sat/cp_model_utils.$O \ $(OBJ_DIR)/sat/cumulative.$O \ $(OBJ_DIR)/sat/disjunctive.$O \ $(OBJ_DIR)/sat/drat.$O \ @@ -1498,6 +1486,7 @@ SAT_LIB_OBJS = \ $(OBJ_DIR)/sat/timetable_edgefinding.$O \ $(OBJ_DIR)/sat/util.$O \ $(OBJ_DIR)/sat/boolean_problem.pb.$O \ + $(OBJ_DIR)/sat/cp_model.pb.$O \ $(OBJ_DIR)/sat/sat_parameters.pb.$O $(SRC_DIR)/ortools/sat/boolean_problem.h: \ @@ -1526,6 +1515,21 @@ $(SRC_DIR)/ortools/sat/cp_constraints.h: \ $(SRC_DIR)/ortools/sat/model.h \ $(SRC_DIR)/ortools/util/sorted_interval_list.h +$(SRC_DIR)/ortools/sat/cp_model_checker.h: \ + $(GEN_DIR)/ortools/sat/cp_model.pb.h + +$(SRC_DIR)/ortools/sat/cp_model_presolve.h: \ + $(GEN_DIR)/ortools/sat/cp_model.pb.h + +$(SRC_DIR)/ortools/sat/cp_model_solver.h: \ + $(GEN_DIR)/ortools/sat/cp_model.pb.h \ + $(SRC_DIR)/ortools/sat/integer.h \ + $(SRC_DIR)/ortools/sat/model.h + +$(SRC_DIR)/ortools/sat/cp_model_utils.h: \ + $(GEN_DIR)/ortools/sat/cp_model.pb.h \ + $(SRC_DIR)/ortools/util/sorted_interval_list.h + $(SRC_DIR)/ortools/sat/cumulative.h: \ $(SRC_DIR)/ortools/sat/integer.h \ $(SRC_DIR)/ortools/sat/intervals.h \ @@ -1595,6 +1599,7 @@ $(SRC_DIR)/ortools/sat/linear_programming_constraint.h: \ $(SRC_DIR)/ortools/sat/lp_utils.h: \ $(GEN_DIR)/ortools/sat/boolean_problem.pb.h \ + $(GEN_DIR)/ortools/sat/cp_model.pb.h \ $(SRC_DIR)/ortools/sat/sat_solver.h \ $(SRC_DIR)/ortools/lp_data/lp_data.h \ $(GEN_DIR)/ortools/linear_solver/linear_solver.pb.h @@ -1726,6 +1731,54 @@ $(OBJ_DIR)/sat/cp_constraints.$O: \ $(SRC_DIR)/ortools/util/sort.h $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Ssat$Scp_constraints.cc $(OBJ_OUT)$(OBJ_DIR)$Ssat$Scp_constraints.$O +$(OBJ_DIR)/sat/cp_model_checker.$O: \ + $(SRC_DIR)/ortools/sat/cp_model_checker.cc \ + $(SRC_DIR)/ortools/sat/cp_model_checker.h \ + $(SRC_DIR)/ortools/sat/cp_model_utils.h \ + $(SRC_DIR)/ortools/base/join.h \ + $(SRC_DIR)/ortools/base/map_util.h \ + $(SRC_DIR)/ortools/util/saturated_arithmetic.h \ + $(SRC_DIR)/ortools/util/sorted_interval_list.h + $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Ssat$Scp_model_checker.cc $(OBJ_OUT)$(OBJ_DIR)$Ssat$Scp_model_checker.$O + +$(OBJ_DIR)/sat/cp_model_presolve.$O: \ + $(SRC_DIR)/ortools/sat/cp_model_presolve.cc \ + $(SRC_DIR)/ortools/sat/cp_model_checker.h \ + $(SRC_DIR)/ortools/sat/cp_model_presolve.h \ + $(SRC_DIR)/ortools/sat/cp_model_utils.h \ + $(SRC_DIR)/ortools/base/join.h \ + $(SRC_DIR)/ortools/base/map_util.h \ + $(SRC_DIR)/ortools/base/stl_util.h \ + $(SRC_DIR)/ortools/util/affine_relation.h \ + $(SRC_DIR)/ortools/util/bitset.h \ + $(SRC_DIR)/ortools/util/sorted_interval_list.h + $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Ssat$Scp_model_presolve.cc $(OBJ_OUT)$(OBJ_DIR)$Ssat$Scp_model_presolve.$O + +$(OBJ_DIR)/sat/cp_model_solver.$O: \ + $(SRC_DIR)/ortools/sat/cp_model_solver.cc \ + $(SRC_DIR)/ortools/sat/cp_model_checker.h \ + $(SRC_DIR)/ortools/sat/cp_model_presolve.h \ + $(SRC_DIR)/ortools/sat/cp_model_solver.h \ + $(SRC_DIR)/ortools/sat/cp_model_utils.h \ + $(SRC_DIR)/ortools/sat/cumulative.h \ + $(SRC_DIR)/ortools/sat/disjunctive.h \ + $(SRC_DIR)/ortools/sat/intervals.h \ + $(SRC_DIR)/ortools/sat/linear_programming_constraint.h \ + $(SRC_DIR)/ortools/sat/optimization.h \ + $(SRC_DIR)/ortools/sat/sat_solver.h \ + $(SRC_DIR)/ortools/sat/table.h \ + $(SRC_DIR)/ortools/base/join.h \ + $(SRC_DIR)/ortools/base/stl_util.h \ + $(SRC_DIR)/ortools/base/timer.h \ + $(SRC_DIR)/ortools/graph/connectivity.h + $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Ssat$Scp_model_solver.cc $(OBJ_OUT)$(OBJ_DIR)$Ssat$Scp_model_solver.$O + +$(OBJ_DIR)/sat/cp_model_utils.$O: \ + $(SRC_DIR)/ortools/sat/cp_model_utils.cc \ + $(SRC_DIR)/ortools/sat/cp_model_utils.h \ + $(SRC_DIR)/ortools/base/stl_util.h + $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Ssat$Scp_model_utils.cc $(OBJ_OUT)$(OBJ_DIR)$Ssat$Scp_model_utils.$O + $(OBJ_DIR)/sat/cumulative.$O: \ $(SRC_DIR)/ortools/sat/cumulative.cc \ $(SRC_DIR)/ortools/sat/cumulative.h \ @@ -1803,6 +1856,7 @@ $(OBJ_DIR)/sat/no_cycle.$O: \ $(OBJ_DIR)/sat/optimization.$O: \ $(SRC_DIR)/ortools/sat/optimization.cc \ $(SRC_DIR)/ortools/sat/encoding.h \ + $(SRC_DIR)/ortools/sat/integer_expr.h \ $(SRC_DIR)/ortools/sat/optimization.h \ $(SRC_DIR)/ortools/sat/util.h \ $(SRC_DIR)/ortools/base/stringprintf.h @@ -1900,6 +1954,14 @@ $(GEN_DIR)/ortools/sat/boolean_problem.pb.h: $(GEN_DIR)/ortools/sat/boolean_prob $(OBJ_DIR)/sat/boolean_problem.pb.$O: $(GEN_DIR)/ortools/sat/boolean_problem.pb.cc $(CCC) $(CFLAGS) -c $(GEN_DIR)/ortools/sat/boolean_problem.pb.cc $(OBJ_OUT)$(OBJ_DIR)$Ssat$Sboolean_problem.pb.$O +$(GEN_DIR)/ortools/sat/cp_model.pb.cc: $(SRC_DIR)/ortools/sat/cp_model.proto + $(PROTOBUF_DIR)/bin/protoc --proto_path=$(INC_DIR) --cpp_out=$(GEN_DIR) $(SRC_DIR)/ortools/sat/cp_model.proto + +$(GEN_DIR)/ortools/sat/cp_model.pb.h: $(GEN_DIR)/ortools/sat/cp_model.pb.cc + +$(OBJ_DIR)/sat/cp_model.pb.$O: $(GEN_DIR)/ortools/sat/cp_model.pb.cc + $(CCC) $(CFLAGS) -c $(GEN_DIR)/ortools/sat/cp_model.pb.cc $(OBJ_OUT)$(OBJ_DIR)$Ssat$Scp_model.pb.$O + $(GEN_DIR)/ortools/sat/sat_parameters.pb.cc: $(SRC_DIR)/ortools/sat/sat_parameters.proto $(PROTOBUF_DIR)/bin/protoc --proto_path=$(INC_DIR) --cpp_out=$(GEN_DIR) $(SRC_DIR)/ortools/sat/sat_parameters.proto @@ -1919,7 +1981,6 @@ BOP_DEPS = \ $(SRC_DIR)/ortools/base/basictypes.h \ $(SRC_DIR)/ortools/base/callback.h \ $(SRC_DIR)/ortools/base/casts.h \ - $(SRC_DIR)/ortools/base/commandlineflags.h \ $(SRC_DIR)/ortools/base/file.h \ $(SRC_DIR)/ortools/base/integral_types.h \ $(SRC_DIR)/ortools/base/int_type.h \ @@ -1959,6 +2020,7 @@ BOP_DEPS = \ $(GEN_DIR)/ortools/sat/boolean_problem.pb.h \ $(SRC_DIR)/ortools/sat/clause.h \ $(SRC_DIR)/ortools/sat/cp_constraints.h \ + $(GEN_DIR)/ortools/sat/cp_model.pb.h \ $(SRC_DIR)/ortools/sat/drat.h \ $(SRC_DIR)/ortools/sat/integer_expr.h \ $(SRC_DIR)/ortools/sat/integer.h \ @@ -2227,7 +2289,6 @@ LP_DEPS = \ $(SRC_DIR)/ortools/base/basictypes.h \ $(SRC_DIR)/ortools/base/callback.h \ $(SRC_DIR)/ortools/base/casts.h \ - $(SRC_DIR)/ortools/base/commandlineflags.h \ $(SRC_DIR)/ortools/base/file.h \ $(SRC_DIR)/ortools/base/integral_types.h \ $(SRC_DIR)/ortools/base/int_type.h \ @@ -2473,7 +2534,6 @@ CP_DEPS = \ $(SRC_DIR)/ortools/base/basictypes.h \ $(SRC_DIR)/ortools/base/callback.h \ $(SRC_DIR)/ortools/base/casts.h \ - $(SRC_DIR)/ortools/base/commandlineflags.h \ $(SRC_DIR)/ortools/base/file.h \ $(SRC_DIR)/ortools/base/integral_types.h \ $(SRC_DIR)/ortools/base/int_type.h \ @@ -2501,6 +2561,7 @@ CP_DEPS = \ $(GEN_DIR)/ortools/sat/boolean_problem.pb.h \ $(SRC_DIR)/ortools/sat/clause.h \ $(SRC_DIR)/ortools/sat/cp_constraints.h \ + $(GEN_DIR)/ortools/sat/cp_model.pb.h \ $(SRC_DIR)/ortools/sat/drat.h \ $(SRC_DIR)/ortools/sat/integer_expr.h \ $(SRC_DIR)/ortools/sat/integer.h \ @@ -3240,3 +3301,4 @@ $(GEN_DIR)/ortools/constraint_solver/solver_parameters.pb.h: $(GEN_DIR)/ortools/ $(OBJ_DIR)/constraint_solver/solver_parameters.pb.$O: $(GEN_DIR)/ortools/constraint_solver/solver_parameters.pb.cc $(CCC) $(CFLAGS) -c $(GEN_DIR)/ortools/constraint_solver/solver_parameters.pb.cc $(OBJ_OUT)$(OBJ_DIR)$Sconstraint_solver$Ssolver_parameters.pb.$O + diff --git a/ortools/algorithms/knapsack_solver.h b/ortools/algorithms/knapsack_solver.h index 194cbc53c5..28674ce8db 100644 --- a/ortools/algorithms/knapsack_solver.h +++ b/ortools/algorithms/knapsack_solver.h @@ -152,6 +152,7 @@ class KnapsackSolver { // obtained might not be optimal if the limit is reached. void set_time_limit(double time_limit_seconds) { time_limit_seconds_ = time_limit_seconds; + time_limit_.reset(new TimeLimit(time_limit_seconds_)); } private: diff --git a/ortools/base/logging.cc b/ortools/base/logging.cc deleted file mode 100644 index 4396bff436..0000000000 --- a/ortools/base/logging.cc +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2010-2014 Google -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "ortools/base/logging.h" - -DEFINE_int32(log_level, 0, "Log level (0 is the default)."); -DEFINE_bool(log_prefix, true, - "Prefix all log lines with the date, source file and line number."); - -namespace operations_research { -DateLogger::DateLogger() { -#if defined(_MSC_VER) - _tzset(); -#endif -} - -char* const DateLogger::HumanDate() { -#if defined(_MSC_VER) - _strtime_s(buffer_, sizeof(buffer_)); -#else - time_t time_value = time(NULL); - struct tm now; - localtime_r(&time_value, &now); - snprintf(buffer_, sizeof(buffer_), "%02d:%02d:%02d", now.tm_hour, - now.tm_min, now.tm_sec); -#endif - return buffer_; -} -} // namespace operations_research diff --git a/ortools/constraint_solver/routing.cc b/ortools/constraint_solver/routing.cc index 3360528477..2e9ee14417 100644 --- a/ortools/constraint_solver/routing.cc +++ b/ortools/constraint_solver/routing.cc @@ -658,6 +658,8 @@ RoutingSearchParameters RoutingModel::DefaultSearchParameters() { static const char* const kSearchParameters = "first_solution_strategy: AUTOMATIC " "use_filtered_first_solution_strategy: true " + "savings_neighbors_ratio: 0 " + "savings_add_reverse_arcs: false " "local_search_operators {" " use_relocate: true" " use_relocate_pair: true" @@ -2457,8 +2459,13 @@ void GetVehicleClasses(const RoutingModel& model, // heuristic for Vehicle Routing Problem. class SavingsBuilder : public DecisionBuilder { public: - SavingsBuilder(RoutingModel* const model, bool check_assignment) - : model_(model), check_assignment_(check_assignment) {} + SavingsBuilder(RoutingModel* const model, double savings_neighbors_ratio, + bool check_assignment) + : model_(model), + savings_neighbors_ratio_(savings_neighbors_ratio > 0 + ? std::min(savings_neighbors_ratio, 1.0) + : 1), + check_assignment_(check_assignment) {} ~SavingsBuilder() override {} Decision* Next(Solver* const solver) override { @@ -2488,10 +2495,10 @@ class SavingsBuilder : public DecisionBuilder { neighbors_.resize(nodes_number_); route_shape_parameter_ = FLAGS_savings_route_shape_parameter; - int64 savings_filter_neighbors = FLAGS_savings_filter_neighbors; + int64 savings_filter_neighbors = + std::max(1.0, model_->nodes() * savings_neighbors_ratio_); int64 savings_filter_radius = FLAGS_savings_filter_radius; - if (!savings_filter_neighbors && !savings_filter_radius) { - savings_filter_neighbors = model_->nodes(); + if (!savings_filter_radius) { savings_filter_radius = -1; } @@ -2550,6 +2557,7 @@ class SavingsBuilder : public DecisionBuilder { } RoutingModel* const model_; + const double savings_neighbors_ratio_; std::unique_ptr route_constructor_; const bool check_assignment_; std::vector dimensions_; @@ -4197,20 +4205,25 @@ void RoutingModel::CreateFirstSolutionDecisionBuilders( first_solution_decision_builders_ [FirstSolutionStrategy::BEST_INSERTION]); // Savings + const double savings_neighbors_ratio = + search_parameters.savings_neighbors_ratio(); if (search_parameters.use_filtered_first_solution_strategy()) { first_solution_filtered_decision_builders_[FirstSolutionStrategy::SAVINGS] = solver_->RevAlloc(new SavingsFilteredDecisionBuilder( - this, FLAGS_savings_filter_neighbors, + this, savings_neighbors_ratio, + search_parameters.savings_add_reverse_arcs(), GetOrCreateFeasibilityFilters())); first_solution_decision_builders_[FirstSolutionStrategy::SAVINGS] = solver_->Try(first_solution_filtered_decision_builders_ [FirstSolutionStrategy::SAVINGS], - solver_->RevAlloc(new SavingsBuilder(this, true))); + solver_->RevAlloc(new SavingsBuilder( + this, savings_neighbors_ratio, true))); } else { first_solution_decision_builders_[FirstSolutionStrategy::SAVINGS] = - solver_->RevAlloc(new SavingsBuilder(this, true)); - DecisionBuilder* savings_builder = - solver_->RevAlloc(new SavingsBuilder(this, false)); + solver_->RevAlloc( + new SavingsBuilder(this, savings_neighbors_ratio, true)); + DecisionBuilder* savings_builder = solver_->RevAlloc( + new SavingsBuilder(this, savings_neighbors_ratio, false)); first_solution_decision_builders_[FirstSolutionStrategy::SAVINGS] = solver_->Try( savings_builder, diff --git a/ortools/constraint_solver/routing.h b/ortools/constraint_solver/routing.h index d1ce0be04b..65ea9bea5f 100644 --- a/ortools/constraint_solver/routing.h +++ b/ortools/constraint_solver/routing.h @@ -2077,30 +2077,33 @@ class ComparatorCheapestAdditionFilteredDecisionBuilder // are taken into account. class SavingsFilteredDecisionBuilder : public RoutingFilteredDecisionBuilder { public: - // If savings_neighbors > 0 then for each node only its 'saving_neighbors' + // If savings_neighbors_ratio > 0 then for each node only this ratio of its // neighbors leading to the smallest arc costs are considered. + // Furthermore, if add_reverse_arcs is true, the neighborhood relationships + // are always considered symmetrically. SavingsFilteredDecisionBuilder( - RoutingModel* model, int64 saving_neighbors, - const std::vector& filters); + RoutingModel* model, double savings_neighbors_ratio, + bool add_reverse_arcs, const std::vector& filters); ~SavingsFilteredDecisionBuilder() override {} bool BuildSolution() override; private: typedef std::pair Saving; - // Computes saving values for all node pairs and vehicle cost classes. The - // saving index attached to each saving value is an index used to + // Computes saving values for all node pairs and vehicle types (see + // ComputeVehicleTypes()). + // The saving index attached to each saving value is an index used to // store and recover the node pair to which the value is linked (cf. the // index conversion methods below). - std::vector ComputeSavings() const; + std::vector ComputeSavings(); // Builds a saving from a saving value, a cost class and two nodes. - Saving BuildSaving(int64 saving, int cost_class, int before_node, + Saving BuildSaving(int64 saving, int vehicle_type, int before_node, int after_node) const { - return std::make_pair( - saving, cost_class * size_squared_ + before_node * Size() + after_node); + return std::make_pair(saving, vehicle_type * size_squared_ + + before_node * Size() + after_node); } // Returns the cost class from a saving. - int64 GetCostClassFromSaving(const Saving& saving) const { + int64 GetVehicleTypeFromSaving(const Saving& saving) const { return saving.second / size_squared_; } // Returns the "before node" from a saving. @@ -2114,8 +2117,21 @@ class SavingsFilteredDecisionBuilder : public RoutingFilteredDecisionBuilder { // Returns the saving value from a saving. int64 GetSavingValue(const Saving& saving) const { return saving.first; } - const int64 saving_neighbors_; + // Computes the vehicle type of every vehicle and stores it in + // type_index_of_vehicle_. A "vehicle type" consists of the set of vehicles + // having the same cost class and start/end nodes, therefore the same savings + // value for each arc. + // The vehicles corresponding to each vehicle type index are stored in + // vehicles_per_vehicle_type_. + void ComputeVehicleTypes(); + + const double savings_neighbors_ratio_; + const bool add_reverse_arcs_; int64 size_squared_; + std::vector type_index_of_vehicle_; + // clang-format off + std::vector > vehicles_per_vehicle_type_; + // clang-format on }; // Christofides addition heuristic. Initially created to solve TSPs, extended to diff --git a/ortools/constraint_solver/routing_flags.cc b/ortools/constraint_solver/routing_flags.cc index 759366d033..0910c8cfae 100644 --- a/ortools/constraint_solver/routing_flags.cc +++ b/ortools/constraint_solver/routing_flags.cc @@ -75,6 +75,12 @@ DEFINE_string(routing_first_solution, "", "in the code to get a full list."); DEFINE_bool(routing_use_filtered_first_solutions, true, "Use filtered version of first solution heuristics if available."); +DEFINE_double(savings_neighbors_ratio, 0, + "Ratio of neighbors to consider for each node when " + "constructing the savings."); +DEFINE_bool(savings_add_reverse_arcs, false, + "Add savings related to reverse arcs when finding the nearest " + "neighbors of the nodes."); DEFINE_bool(routing_dfs, false, "Routing: use a complete depth-first search."); DEFINE_int64(routing_optimization_step, 1, "Optimization step."); @@ -130,6 +136,8 @@ void SetFirstSolutionStrategyFromFlags(RoutingSearchParameters* parameters) { } parameters->set_use_filtered_first_solution_strategy( FLAGS_routing_use_filtered_first_solutions); + parameters->set_savings_neighbors_ratio(FLAGS_savings_neighbors_ratio); + parameters->set_savings_add_reverse_arcs(FLAGS_savings_add_reverse_arcs); } void SetLocalSearchMetaheuristicFromFlags(RoutingSearchParameters* parameters) { diff --git a/ortools/constraint_solver/routing_parameters.proto b/ortools/constraint_solver/routing_parameters.proto index 6414688427..935ff13dc8 100644 --- a/ortools/constraint_solver/routing_parameters.proto +++ b/ortools/constraint_solver/routing_parameters.proto @@ -33,6 +33,14 @@ message RoutingSearchParameters { // modified unless you know what you are doing. // Use filtered version of first solution strategy if available. bool use_filtered_first_solution_strategy = 2; + // Parameters specific to the Savings first solution heuristic. + // Ratio (between 0 and 1) of neighbors to consider for each node when + // constructing the savings. If <= 0 or greater than 1, its value is + // considered to be 1.0. + double savings_neighbors_ratio = 14; + // Add savings related to reverse arcs when finding the nearest neighbors + // of the nodes. + bool savings_add_reverse_arcs = 15; // Local search neighborhood operators used to build a solutions neighborhood. message LocalSearchNeighborhoodOperators { diff --git a/ortools/constraint_solver/routing_search.cc b/ortools/constraint_solver/routing_search.cc index f9f23c9139..68604bd055 100644 --- a/ortools/constraint_solver/routing_search.cc +++ b/ortools/constraint_solver/routing_search.cc @@ -2472,10 +2472,13 @@ void ComparatorCheapestAdditionFilteredDecisionBuilder::SortPossibleNexts( // SavingsFilteredDecisionBuilder SavingsFilteredDecisionBuilder::SavingsFilteredDecisionBuilder( - RoutingModel* model, int64 saving_neighbors, + RoutingModel* model, double savings_neighbors_ratio, bool add_reverse_arcs, const std::vector& filters) : RoutingFilteredDecisionBuilder(model, filters), - saving_neighbors_(saving_neighbors), + savings_neighbors_ratio_(savings_neighbors_ratio > 0 + ? std::min(savings_neighbors_ratio, 1.0) + : 1), + add_reverse_arcs_(add_reverse_arcs), size_squared_(0) {} bool SavingsFilteredDecisionBuilder::BuildSolution() { @@ -2485,33 +2488,44 @@ bool SavingsFilteredDecisionBuilder::BuildSolution() { const int size = model()->Size(); size_squared_ = size * size; std::vector savings = ComputeSavings(); - // Store savings for each incoming and outgoing node and by cost class. This + const int vehicle_types = vehicles_per_vehicle_type_.size(); + DCHECK_GT(vehicle_types, 0); + // Store savings for each incoming and outgoing node and by vehicle type. This // is necessary to quickly extend partial chains without scanning all savings. - const int cost_classes = model()->GetCostClassesCount(); - std::vector> in_savings(size * cost_classes); - std::vector> out_savings(size * cost_classes); + std::vector> in_savings_indices(size * vehicle_types); + std::vector> out_savings_indices(size * vehicle_types); for (int i = 0; i < savings.size(); ++i) { const Saving& saving = savings[i]; - const int cost_class_offset = GetCostClassFromSaving(saving) * size; + const int vehicle_type_offset = GetVehicleTypeFromSaving(saving) * size; const int before_node = GetBeforeNodeFromSaving(saving); - in_savings[cost_class_offset + before_node].push_back(i); + in_savings_indices[vehicle_type_offset + before_node].push_back(i); const int after_node = GetAfterNodeFromSaving(saving); - out_savings[cost_class_offset + after_node].push_back(i); + out_savings_indices[vehicle_type_offset + after_node].push_back(i); } + // For each vehicle type, sort vehicles by decreasing vehicle fixed cost. + // Vehicles with the same fixed cost are sorted by decreasing vehicle index. + std::vector fixed_cost_of_vehicle(model()->vehicles()); + for (int vehicle = 0; vehicle < model()->vehicles(); vehicle++) { + fixed_cost_of_vehicle[vehicle] = model()->GetFixedCostOfVehicle(vehicle); + } + for (int type = 0; type < vehicle_types; type++) { + std::vector& sorted_vehicles = vehicles_per_vehicle_type_[type]; + std::stable_sort(sorted_vehicles.begin(), sorted_vehicles.end(), + [&fixed_cost_of_vehicle](int v1, int v2) { + return fixed_cost_of_vehicle[v1] < + fixed_cost_of_vehicle[v2]; + }); + std::reverse(sorted_vehicles.begin(), sorted_vehicles.end()); + } + // Build routes from savings. - std::vector closed(model()->vehicles(), false); for (const Saving& saving : savings) { // First find the best saving to start a new route. - const int cost_class = GetCostClassFromSaving(saving); - int vehicle = -1; - for (int v = 0; v < model()->vehicles(); ++v) { - if (!closed[v] && - model()->GetCostClassIndexOfVehicle(v).value() == cost_class) { - vehicle = v; - break; - } - } - if (vehicle == -1) continue; + const int type = GetVehicleTypeFromSaving(saving); + std::vector& sorted_vehicles = vehicles_per_vehicle_type_[type]; + if (sorted_vehicles.empty()) continue; + int vehicle = sorted_vehicles.back(); + int before_node = GetBeforeNodeFromSaving(saving); int after_node = GetAfterNodeFromSaving(saving); if (!Contains(before_node) && !Contains(after_node)) { @@ -2522,19 +2536,49 @@ bool SavingsFilteredDecisionBuilder::BuildSolution() { SetValue(after_node, end); if (Commit()) { // Then extend the route from both ends of the partial route. - closed[vehicle] = true; + sorted_vehicles.pop_back(); int in_index = 0; int out_index = 0; - const int saving_offset = cost_class * size; - while (in_index < in_savings[saving_offset + after_node].size() && - out_index < out_savings[saving_offset + before_node].size()) { - const Saving& in_saving = - savings[in_savings[saving_offset + after_node][in_index]]; - const Saving& out_saving = - savings[out_savings[saving_offset + before_node][out_index]]; - if (GetSavingValue(in_saving) < GetSavingValue(out_saving)) { + const int saving_offset = type * size; + + while (in_index < + in_savings_indices[saving_offset + after_node].size() || + out_index < + out_savings_indices[saving_offset + before_node].size()) { + // First determine how to extend the route. + int before_before_node = -1; + int after_after_node = -1; + if (in_index < + in_savings_indices[saving_offset + after_node].size()) { + const Saving& in_saving = + savings[in_savings_indices[saving_offset + after_node] + [in_index]]; + if (out_index < + out_savings_indices[saving_offset + before_node].size()) { + const Saving& out_saving = + savings[out_savings_indices[saving_offset + before_node] + [out_index]]; + if (GetSavingValue(in_saving) < GetSavingValue(out_saving)) { + // Should extend after after_node + after_after_node = GetAfterNodeFromSaving(in_saving); + } else { + // Should extend before before_node + before_before_node = GetBeforeNodeFromSaving(out_saving); + } + } else { + // Should extend after after_node + after_after_node = GetAfterNodeFromSaving(in_saving); + } + } else { + // Should extend before before_node + before_before_node = GetBeforeNodeFromSaving( + savings[out_savings_indices[saving_offset + before_node] + [out_index]]); + } + // Extend the route + if (after_after_node != -1) { + DCHECK_EQ(before_before_node, -1); // Extending after after_node - const int after_after_node = GetAfterNodeFromSaving(in_saving); if (!Contains(after_after_node)) { SetValue(after_node, after_after_node); SetValue(after_after_node, end); @@ -2549,7 +2593,7 @@ bool SavingsFilteredDecisionBuilder::BuildSolution() { } } else { // Extending before before_node - const int before_before_node = GetBeforeNodeFromSaving(out_saving); + CHECK_GE(before_before_node, 0); if (!Contains(before_before_node)) { SetValue(start, before_before_node); SetValue(before_before_node, before_node); @@ -2571,52 +2615,127 @@ bool SavingsFilteredDecisionBuilder::BuildSolution() { return Commit(); } +void SavingsFilteredDecisionBuilder::ComputeVehicleTypes() { + type_index_of_vehicle_.clear(); + const int nodes = model()->nodes(); + const int nodes_squared = nodes * nodes; + const int vehicles = model()->vehicles(); + type_index_of_vehicle_.resize(vehicles); + + vehicles_per_vehicle_type_.clear(); + std::unordered_map type_to_type_index; + + for (int v = 0; v < vehicles; v++) { + const int start = model()->IndexToNode(model()->Start(v)).value(); + const int end = model()->IndexToNode(model()->End(v)).value(); + const int cost_class = model()->GetCostClassIndexOfVehicle(v).value(); + const int64 type = cost_class * nodes_squared + start * nodes + end; + + const auto& vehicle_type_added = type_to_type_index.insert( + std::make_pair(type, type_to_type_index.size())); + + const int index = vehicle_type_added.first->second; + + if (vehicle_type_added.second) { + // Type was not indexed yet. + DCHECK_EQ(vehicles_per_vehicle_type_.size(), index); + vehicles_per_vehicle_type_.push_back({v}); + } else { + // Type already indexed. + DCHECK_LT(index, vehicles_per_vehicle_type_.size()); + vehicles_per_vehicle_type_[index].push_back(v); + } + type_index_of_vehicle_[v] = index; + } +} + +// Computes and returns the savings related to each pair of non-start and +// non-end nodes. The savings value for an arc a-->b for a vehicle starting at +// node s and ending at node e is: +// saving = cost(s-->a-->e) + cost(s-->b-->e) - cost(s-->a-->b-->e), i.e. +// saving = cost(a-->e) + cost(s-->b) - cost(a-->b) +// The higher this saving value, the better the arc. +// Here, the value stored for the savings in the output vector is -saving, and +// the vector is therefore sorted in increasing order (the lower -saving, +// the better). std::vector -SavingsFilteredDecisionBuilder::ComputeSavings() const { +SavingsFilteredDecisionBuilder::ComputeSavings() { + ComputeVehicleTypes(); const int size = model()->Size(); - const int64 saving_neighbors = - saving_neighbors_ <= 0 ? size : saving_neighbors_; - const int num_cost_classes = model()->GetCostClassesCount(); + + const int64 saving_neighbors = std::max(1.0, size * savings_neighbors_ratio_); + + const int num_vehicle_types = vehicles_per_vehicle_type_.size(); std::vector savings; - savings.reserve(num_cost_classes * size * saving_neighbors); - std::vector class_covered(num_cost_classes, false); - for (int vehicle = 0; vehicle < model()->vehicles(); ++vehicle) { + savings.reserve(num_vehicle_types * size * saving_neighbors); + + for (int type = 0; type < num_vehicle_types; ++type) { + const std::vector& vehicles = vehicles_per_vehicle_type_[type]; + if (vehicles.empty()) { + continue; + } + const int vehicle = vehicles.front(); const int64 cost_class = model()->GetCostClassIndexOfVehicle(vehicle).value(); - if (!class_covered[cost_class]) { - class_covered[cost_class] = true; - const int64 start = model()->Start(vehicle); - const int64 end = model()->End(vehicle); - for (int before_node = 0; before_node < size; ++before_node) { - if (!Contains(before_node) && !model()->IsEnd(before_node) && - !model()->IsStart(before_node)) { - const int64 in_saving = - model()->GetArcCostForClass(before_node, end, cost_class); - std::vector> - costed_after_nodes; - costed_after_nodes.reserve(size); - for (int after_node = 0; after_node < size; ++after_node) { - if (after_node != before_node && !Contains(after_node) && - !model()->IsEnd(after_node) && !model()->IsStart(after_node)) { - costed_after_nodes.push_back( - std::make_pair(model()->GetArcCostForClass( - before_node, after_node, cost_class), - after_node)); - } + const int64 start = model()->Start(vehicle); + const int64 end = model()->End(vehicle); + const int64 fixed_cost = model()->GetFixedCostOfVehicle(vehicle); + + // TODO(user): deal with the add_reverse_arcs_ flag more efficiently. + std::vector arc_added; + if (add_reverse_arcs_) { + arc_added.resize(size * size, false); + } + for (int before_node = 0; before_node < size; ++before_node) { + if (!Contains(before_node) && !model()->IsEnd(before_node) && + !model()->IsStart(before_node)) { + const int64 in_saving = + model()->GetArcCostForClass(before_node, end, cost_class); + std::vector> + costed_after_nodes; + costed_after_nodes.reserve(size); + for (int after_node = 0; after_node < size; ++after_node) { + if (after_node != before_node && !Contains(after_node) && + !model()->IsEnd(after_node) && !model()->IsStart(after_node)) { + costed_after_nodes.push_back( + std::make_pair(model()->GetArcCostForClass( + before_node, after_node, cost_class), + after_node)); } - if (saving_neighbors < size) { - std::nth_element(costed_after_nodes.begin(), - costed_after_nodes.begin() + saving_neighbors, - costed_after_nodes.end()); - costed_after_nodes.resize(saving_neighbors); + } + if (saving_neighbors < size) { + std::nth_element(costed_after_nodes.begin(), + costed_after_nodes.begin() + saving_neighbors, + costed_after_nodes.end()); + costed_after_nodes.resize(saving_neighbors); + } + for (const auto& costed_after_node : costed_after_nodes) { + const int64 after_node = costed_after_node.second; + if (add_reverse_arcs_ && arc_added[before_node * size + after_node]) { + DCHECK(arc_added[after_node * size + before_node]); + continue; } - for (const auto& after_node : costed_after_nodes) { - const int64 saving = CapSub( - CapAdd(in_saving, model()->GetArcCostForClass( - start, after_node.second, cost_class)), - after_node.first); - savings.push_back(BuildSaving(-saving, cost_class, before_node, - after_node.second)); + + const int64 saving = + CapSub(CapAdd(in_saving, model()->GetArcCostForClass( + start, after_node, cost_class)), + CapAdd(costed_after_node.first, fixed_cost)); + savings.push_back( + BuildSaving(-saving, type, before_node, after_node)); + + if (add_reverse_arcs_) { + // Also add after->before savings. + arc_added[before_node * size + after_node] = true; + arc_added[after_node * size + before_node] = true; + const int64 second_cost = model()->GetArcCostForClass( + after_node, before_node, cost_class); + const int64 second_saving = CapSub( + CapAdd(model()->GetArcCostForClass(after_node, end, cost_class), + model()->GetArcCostForClass(start, before_node, + cost_class)), + CapAdd(second_cost, fixed_cost)); + savings.push_back( + BuildSaving(-second_saving, type, after_node, before_node)); } } } diff --git a/ortools/flatzinc/fz.cc b/ortools/flatzinc/fz.cc index 53b1873fbb..b0fecdf902 100644 --- a/ortools/flatzinc/fz.cc +++ b/ortools/flatzinc/fz.cc @@ -30,6 +30,7 @@ #include "ortools/base/timer.h" #include "ortools/base/threadpool.h" #include "ortools/base/commandlineflags.h" +#include "ortools/flatzinc/cp_model_fz_solver.h" #include "ortools/flatzinc/logging.h" #include "ortools/flatzinc/model.h" #include "ortools/flatzinc/parser.h" @@ -66,9 +67,11 @@ DEFINE_bool( "Increase verbosity of the impact based search when used in free search."); DEFINE_bool(verbose_mt, false, "Verbose Multi-Thread."); DEFINE_bool(use_fz_sat, false, "Use the SAT/CP solver."); +DEFINE_bool(use_cp_model, false, "Use the SAT/CP solver through CpModel."); DEFINE_string(fz_model_name, "stdin", "Define problem name when reading from stdin."); +// TODO(user): Remove when using ABCL in open-source. DECLARE_bool(log_prefix); DECLARE_bool(fz_use_sat); @@ -319,11 +322,17 @@ int main(int argc, char** argv) { operations_research::fz::ParseFlatzincModel(input, !FLAGS_read_from_stdin); - if (FLAGS_use_fz_sat) { + if (FLAGS_use_fz_sat || FLAGS_use_cp_model) { bool interrupt_solve = false; - operations_research::sat::SolveWithSat( - model, operations_research::fz::SingleThreadParameters(), - &interrupt_solve); + if (FLAGS_use_fz_sat) { + operations_research::sat::SolveWithSat( + model, operations_research::fz::SingleThreadParameters(), + &interrupt_solve); + } else { + operations_research::sat::SolveFzWithCpModelProto( + model, operations_research::fz::SingleThreadParameters(), + &interrupt_solve); + } } else { operations_research::fz::Solve(model); } diff --git a/ortools/linear_solver/linear_solver_natural_api.py b/ortools/linear_solver/linear_solver_natural_api.py index 417c5aa1bf..7f4850dc9f 100644 --- a/ortools/linear_solver/linear_solver_natural_api.py +++ b/ortools/linear_solver/linear_solver_natural_api.py @@ -187,7 +187,7 @@ class SumArray(LinearExpr): """Represents the sum of a list of LinearExpr.""" def __init__(self, array): - self.__array = map(CastToLinExp, array) + self.__array = [CastToLinExp(elem) for elem in array] def __str__(self): return '({})'.format(' + '.join(map(str, self.__array))) diff --git a/ortools/sat/cp_model.proto b/ortools/sat/cp_model.proto new file mode 100644 index 0000000000..2dc4fa1ab6 --- /dev/null +++ b/ortools/sat/cp_model.proto @@ -0,0 +1,364 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Proto describing a general Constraint Programming (CP) problem. + +syntax = "proto3"; + +package operations_research.sat; + +// An integer variable. +// +// It will be referred to by an int32 corresponding to its index in a +// CpModelProto variables field. +// +// Depending on the context, a reference to a variable whose domain is in [0, 1] +// can also be seen as a Boolean that will be true if the variable value is 1 +// and false if it is 0. When used in this context, the field name will always +// contain the word "literal". +// +// Negative reference (advanced usage): to simplify the creation of a model and +// for efficiency reasons, all the "literal" or "variable" fields can also +// contain a negative index. A negative index i will refer to the negation of +// the integer variable at index -i -1 or to NOT the literal at the same index. +// +// Ex: A variable index 4 will refer to the integer variable model.variables(4) +// and an index of -5 will refer to the negation of the same variable. A literal +// index 4 will refer to the logical fact that model.variable(4) == 1 and a +// literal index of -5 will refer to the logical fact model.variable(4) == 0. +message IntegerVariableProto { + // For debug/logging only. Can be empty. + string name = 1; + + // The variable domain given as a sorted list of n disjoint intervals + // [min, max] and encoded as [min_0, max_0, ..., min_{n-1}, max_{n-1}]. + // + // The most common example being just [min, max]. + // If min == max, then this is a constant variable. + // + // We have: + // - domain_size() is always even. + // - min == domain.front(); + // - max == domain.back(); + // - for all i < n : min_i <= max_i + // - for all i < n-1 : max_i + 1 < min_{i+1}. + repeated int64 domain = 2; +} + +// Argument of the constraints of the form OP(literals). +message BoolArgumentProto { + repeated int32 literals = 1; +} + +// Argument of the constraints of the form target_var = OP(vars). +message IntegerArgumentProto { + int32 target = 1; + repeated int32 vars = 2; +} + +// All variables must take different values. +message AllDifferentConstraintProto { + repeated int32 vars = 1; +} + +// The linear sum vars[i] * coeffs[i] must fall in the given domain. The domain +// has the same format as the one in IntegerVariableProto. +// +// Note that the validation code currently checks using the domain of the +// involved variables that the sum can always be computed without integer +// overflow and throws an error otherwise. +message LinearConstraintProto { + repeated int32 vars = 1; + repeated int64 coeffs = 2; // Same size as vars. + repeated int64 domain = 3; +} + +// The constraint target = vars[index]. +// This enforces that index takes one of the value in [0, vars_size()). +message ElementConstraintProto { + int32 index = 1; + int32 target = 2; + repeated int32 vars = 3; +} + +// This "special" constraint not only enforces (start + size == end) but can +// also be refered by other constraints using this "interval" concept. +message IntervalConstraintProto { + int32 start = 1; + int32 end = 2; + int32 size = 3; +} + +// All the intervals (index of IntervalConstraintProto) must be disjoint. +// Note that interval are interpreted as [start, end). This is also known as +// a disjunctive constraint in scheduling. +message NoOverlapConstraintProto { + repeated int32 intervals = 1; +} + +// The boxes defined by [start_x, end_x) * [start_y, end_y) cannot overlap. +message NoOverlap2DConstraintProto { + repeated int32 x_intervals = 1; + repeated int32 y_intervals = 2; // Same size as x_intervals. +} + +// The sum of the demands * durations of the intervals at each interval point +// cannot exceed a capacity. Note that intervals are interpreted as [start, end) +// and as such intervals like [2,3) and [3,4) do not overlap for the point of +// view of this constraint. +// +// Note: this enforce that all start <= end. +message CumulativeConstraintProto { + int32 capacity = 1; + repeated int32 intervals = 2; + repeated int32 demands = 3; // Same size as intervals. +} + +// The "next" variable of a node i represents its successor in a graph. Any +// value that fall outside [0, n = next_variables.size()) or is a self-loop +// (next[i] == i) takes the special meaning of no-successor. +// +// The circuit constraint enforces that all nodes with a valid successor must +// form a circuit, that is a single loop that goes through all of them. Note +// that to enforce that a node has a valid successor, it is sufficient to +// reduce its domain to only valid values. +message CircuitConstraintProto { + repeated int32 nexts = 2; +} + +// The values of the n-tuple formed by the given variables can only be one of +// the listed n-tuples in values. The n-tuples are encoded in a flattened way: +// [tuple0_v0, tuple0_v1, ..., tuple0_v{n-1}, tuple1_v0, ...]. +message TableConstraintProto { + repeated int32 vars = 1; + repeated int64 values = 2; + + // If true, the meaning is "negated", that is we forbid any of the given + // tuple from a feasible assignment. + bool negated = 3; +} + +// This constraint forces a sequence of variables to be accepted by an automata. +message AutomataConstraintProto { + // A state is identified by a non-negative number. It is preferable to keep + // all the states dense in says [0, num_states). The automata starts at + // starting_state and must finish in any of the final states. + int64 starting_state = 2; + repeated int64 final_states = 3; + + // List of transitions (all 3 vectors have the same size). Both tail and head + // are states, label is any variable value. No two outgoing transitions from + // the same state can have the same label. + repeated int64 transition_tail = 4; + repeated int64 transition_head = 5; + repeated int64 transition_label = 6; + + // The sequence of variables. The automata is ran for vars_size() "steps" and + // the value of vars[i] corresponds to the transition label at step i. + repeated int32 vars = 7; +} + +message ConstraintProto { + // For debug/logging only. Can be empty. + string name = 1; + + // This should contain at most one literal. If there is one, then the + // constraint must be true when this literal is true, if it is false, then it + // doesn't matter. This is also called half-reification. To have an + // equivalence between a literal and a constraint (full reification), one must + // add both a constraint (controled by a literal l) and its negation + // (controlled by the negation of l). + repeated int32 enforcement_literal = 2; + + // The actual constraint with its arguments. + oneof constraint { + BoolArgumentProto bool_or = 3; // OR(literals) is true. + BoolArgumentProto bool_and = 4; // AND(literals) is true. + BoolArgumentProto bool_xor = 5; // XOR(literals) is true. + + IntegerArgumentProto int_div = 7; // target = vars[0] / vars[1] + IntegerArgumentProto int_mod = 8; // target = vars[0] % vars[1] + IntegerArgumentProto int_max = 9; // target = MAX(vars) + IntegerArgumentProto int_min = 10; // target = MIN(vars) + IntegerArgumentProto int_prod = 11; // target = PROD(vars) + + LinearConstraintProto linear = 12; + AllDifferentConstraintProto all_diff = 13; + ElementConstraintProto element = 14; + CircuitConstraintProto circuit = 15; + TableConstraintProto table = 16; + AutomataConstraintProto automata = 17; + + // Constraints on intervals. + // + // The first constraint defines what an "interval" is and the other + // constraints use references to it. All the intervals that have an + // enforcement_literal set to false are ignored by these constraints. + // + // TODO(user): Explain what happen for intervals of size zero. Some + // constraints ignore them, other do take them into account. + IntervalConstraintProto interval = 18; + CumulativeConstraintProto cumulative = 19; + NoOverlapConstraintProto no_overlap = 20; + NoOverlap2DConstraintProto no_overlap_2d = 21; + } +} + +// Optimization objective. +// +// This is in a message because decision problems don't have any objective. +message CpObjectiveProto { + // Index of the variable to minimize. + // + // For a maximization problem, one can refer to the negation of the real + // objective and set a scaling_factor to -1. + int32 objective_var = 1; + + // The displayed objective is always: + // scaling_factor * (Value(objective_var) + offset). + // This is needed to have a consistent objective after presolve or when + // scaling a double problem to express it with integers. + // + // Note that if scaling_factor is zero, then it is assumed to be 1, so that by + // default these fields have no effect. + double offset = 2; + double scaling_factor = 3; +} + +// Define the strategy to follow when the solver needs to take a new decision. +// Note that this strategy is only defined on a subset of variables. +message DecisionStrategyProto { + // The variables to be considered for the next decision. The order matter and + // is always used as a tie-breaker after the variable selection strategy + // criteria defined below. + repeated int32 variables = 1; + + // The order in which the variables above should be considered. Note that only + // variables that are not already fixed are considered. + // + // TODO(user): extend as needed. + enum VariableSelectionStrategy { + CHOOSE_FIRST = 0; + CHOOSE_LOWEST_MIN = 1; + CHOOSE_HIGHEST_MAX = 2; + CHOOSE_MIN_DOMAIN_SIZE = 3; + CHOOSE_MAX_DOMAIN_SIZE = 4; + } + VariableSelectionStrategy variable_selection_strategy = 2; + + // Once a variable has been choosen, this enum describe what decision is taken + // on its domain. + // + // TODO(user): extend as needed. + enum DomainReductionStrategy { + SELECT_MIN_VALUE = 0; + SELECT_MAX_VALUE = 1; + SELECT_LOWER_HALF = 2; + SELECT_UPPER_HALF = 3; + } + DomainReductionStrategy domain_reduction_strategy = 3; +} + +// A constraint programming problem. +message CpModelProto { + // For debug/logging only. Can be empty. + string name = 1; + + // The associated Protos should be referred by their index in these fields. + repeated IntegerVariableProto variables = 2; + repeated ConstraintProto constraints = 3; + + // The objective to minimize. Can be empty for pure decision problems. + // Note that we can have more than one objective for the cases where we want + // to optimize them in lexicographic order, or if we want to list the Pareto + // optimal solutions. + repeated CpObjectiveProto objectives = 4; + + // Defines the strategy that the solver should follow when the "fixed_search" + // parameters is set to true. Note that this strategy is also used as an + // heuristic when we are not in fixed search. + // + // If empty, the solver will try to assign all variables to their min value in + // the order of their appearance in the variables field above. Otherwise, it + // will assign the variables in each DecisionStrategyProto according to the + // order defined there and only move to the next proto in this field once all + // variables from the previous one have been assigned. + // + // Advanced Usage: if not all variables appears, the solver will not try to + // assign the missing ones. Thus, at the end of the search, not all variables + // may be fixed and this is why the solution_lower_bounds and + // solution_upper_bounds fields in the CpSolverResponse are for. + repeated DecisionStrategyProto search_strategy = 5; +} + +// The status returned by a solver trying to solve a CpModelProto. +enum CpSolverStatus { + // The status of the model is still unknown. A search limit as been reached + // before any of the status below could be decided. + UNKNOWN = 0; + + // The given CpModelProto didn't pass the validation step. You can get a + // detailed error by calling ValidateCpModel(model_proto). + MODEL_INVALID = 1; + + // A feasible solution as been found. For an optimization problem, we still + // don't know if it is the optimal one though. + MODEL_SAT = 2; + + // The problem as been proven to be UNSAT. No feasible solution exists. + MODEL_UNSAT = 3; + + // An optimal feasible solution has been found. + OPTIMAL = 4; +} + +// The response returned by a solver trying to solve a CpModelProto. +message CpSolverResponse { + // The status of the solve. + CpSolverStatus status = 1; + + // A feasible solution to the given problem. Depending on the returned status + // it may be optimal or just feasible. This is in one-to-one correspondance + // with a CpModelProto::variables repeated field and list the values of all + // the variables. + repeated int64 solution = 2; + + // Advanced usage. + // + // If the problem has some variables that are not fixed at the end of the + // search (because of a particular search strategy in the CpModelProto) then + // this will be used instead of filling the solution above. The two fields + // will then contains the lower and upper bound of each variable as they where + // when the best "solution" was found. + repeated int64 solution_lower_bounds = 18; + repeated int64 solution_upper_bounds = 19; + + // Only make sense for an optimization problem and if solution is non-empty. + // The objective value of the returned solution. + double objective_value = 3; + + // Only make sense for an optimization problem. A proven lower-bound on the + // objective for a minimization problem, or a proven upper-bound for a + // maximization problem. + double best_objective_bound = 4; + + // Some statistics about the solve. + int64 num_booleans = 10; + int64 num_conflicts = 11; + int64 num_branches = 12; + int64 num_binary_propagations = 13; + int64 num_integer_propagations = 14; + double wall_time = 15; + double user_time = 16; + double deterministic_time = 17; +} diff --git a/ortools/sat/cp_model_checker.cc b/ortools/sat/cp_model_checker.cc new file mode 100644 index 0000000000..e1a7e3741a --- /dev/null +++ b/ortools/sat/cp_model_checker.cc @@ -0,0 +1,543 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/cp_model_checker.h" + +#include +#include + +#include "ortools/base/join.h" +#include "ortools/base/hash.h" +#include "ortools/base/map_util.h" +#include "ortools/sat/cp_model_utils.h" +#include "ortools/util/saturated_arithmetic.h" +#include "ortools/util/sorted_interval_list.h" + +namespace operations_research { +namespace sat { +namespace { + +// ============================================================================= +// CpModelProto validation. +// ============================================================================= + +// If the std::string returned by "statement" is not empty, returns it. +#define RETURN_IF_NOT_EMPTY(statement) \ + do { \ + const std::string error_message = statement; \ + if (!error_message.empty()) return error_message; \ + } while (false) + +template +bool DomainInProtoIsValid(const ProtoWithDomain& proto) { + std::vector domain; + for (int i = 0; i < proto.domain_size(); i += 2) { + domain.push_back({proto.domain(i), proto.domain(i + 1)}); + } + return IntervalsAreSortedAndDisjoint(domain); +} + +std::string ValidateIntegerVariable(const CpModelProto& model, int v) { + const IntegerVariableProto& proto = model.variables(v); + if (proto.domain_size() == 0) { + return StrCat("var #", v, + " has no domain(): ", proto.ShortDebugString()); + } + if (proto.domain_size() % 2 != 0) { + return StrCat( + "var #", v, " has an odd domain() size: ", proto.ShortDebugString()); + } + if (!DomainInProtoIsValid(proto)) { + return StrCat("var #", v, " has and invalid domain() format: ", + proto.ShortDebugString()); + } + return ""; +} + +bool VariableReferenceIsValid(const CpModelProto& model, int reference) { + return std::max(-reference - 1, reference) < model.variables_size(); +} + +bool LiteralReferenceIsValid(const CpModelProto& model, int reference) { + if (std::max(-reference - 1, reference) >= model.variables_size()) { + return false; + } + const auto& var_proto = model.variables(PositiveRef(reference)); + const int64 min_domain = var_proto.domain(0); + const int64 max_domain = var_proto.domain(var_proto.domain_size() - 1); + return min_domain >= 0 && max_domain <= 1; +} + +std::string ValidateArgumentReferencesInConstraint(const CpModelProto& model, + int c) { + const ConstraintProto& ct = model.constraints(c); + IndexReferences references; + AddReferencesUsedByConstraint(ct, &references); + for (const int v : references.variables) { + if (!VariableReferenceIsValid(model, v)) { + return StrCat("Out of bound integer variable ", v, + " in constraint #", c, " : ", ct.ShortDebugString()); + } + } + if (ct.enforcement_literal_size() > 1) { + return StrCat("More than one enforcement_literal in constraint #", c, + " : ", ct.ShortDebugString()); + } + if (ct.enforcement_literal_size() == 1) { + const int lit = ct.enforcement_literal(0); + if (!LiteralReferenceIsValid(model, lit)) { + return StrCat("Invalid enforcement literal ", lit, + " in constraint #", c, " : ", ct.ShortDebugString()); + } + } + for (const int lit : references.literals) { + if (!LiteralReferenceIsValid(model, lit)) { + return StrCat("Invalid literal ", lit, " in constraint #", c, " : ", + ct.ShortDebugString()); + } + } + for (const int i : references.intervals) { + if (i < 0 || i >= model.constraints_size()) { + return StrCat("Out of bound interval ", i, " in constraint #", c, + " : ", ct.ShortDebugString()); + } + if (model.constraints(i).constraint_case() != + ConstraintProto::ConstraintCase::kInterval) { + return StrCat( + "Interval ", i, + " does not refer to an interval constraint. Problematic constraint #", + c, " : ", ct.ShortDebugString()); + } + } + return ""; +} + +std::string ValidateLinearConstraint(const CpModelProto& model, + const ConstraintProto& ct) { + const LinearConstraintProto& arg = ct.linear(); + int64 sum_min = 0; + int64 sum_max = 0; + for (int i = 0; i < arg.vars_size(); ++i) { + const int ref = arg.vars(i); + const auto& var_proto = model.variables(PositiveRef(ref)); + const int64 min_domain = var_proto.domain(0); + const int64 max_domain = var_proto.domain(var_proto.domain_size() - 1); + const int64 coeff = RefIsPositive(ref) ? arg.coeffs(i) : -arg.coeffs(i); + const int64 prod1 = CapProd(min_domain, coeff); + const int64 prod2 = CapProd(max_domain, coeff); + + // Note that we use min/max with zero to disallow "alternative" terms and + // be sure that we cannot have an overflow if we do the computation in a + // different order. + sum_min = CapAdd(sum_min, std::min(0ll, std::min(prod1, prod2))); + sum_max = CapAdd(sum_max, std::max(0ll, std::max(prod1, prod2))); + for (const int64 v : {prod1, prod2, sum_min, sum_max}) { + if (v == kint64max || v == kint64min) { + return "Possible integer overflow in constraint: " + ct.DebugString(); + } + } + } + return ""; +} + +} // namespace + +std::string ValidateCpModel(const CpModelProto& model) { + for (int v = 0; v < model.variables_size(); ++v) { + RETURN_IF_NOT_EMPTY(ValidateIntegerVariable(model, v)); + } + for (int c = 0; c < model.constraints_size(); ++c) { + RETURN_IF_NOT_EMPTY(ValidateArgumentReferencesInConstraint(model, c)); + + // Other non-generic validations. + // TODO(user): validate all constraints. + const ConstraintProto& ct = model.constraints(c); + const ConstraintProto::ConstraintCase type = ct.constraint_case(); + switch (type) { + case ConstraintProto::ConstraintCase::kLinear: + if (!DomainInProtoIsValid(ct.linear())) { + return StrCat("Invalid domain in constraint #", c, " : ", + ct.ShortDebugString()); + } + if (ct.linear().coeffs_size() != ct.linear().vars_size()) { + return StrCat("coeffs_size() != vars_size() in constraint #", c, + " : ", ct.ShortDebugString()); + } + RETURN_IF_NOT_EMPTY(ValidateLinearConstraint(model, ct)); + break; + case ConstraintProto::ConstraintCase::kCumulative: + if (ct.cumulative().intervals_size() != + ct.cumulative().demands_size()) { + return StrCat( + "intervals_size() != demands_size() in constraint #", c, " : ", + ct.ShortDebugString()); + } + break; + default: + break; + } + } + for (const CpObjectiveProto& objective : model.objectives()) { + const int v = objective.objective_var(); + if (!VariableReferenceIsValid(model, v)) { + return StrCat("Out of bound objective variable ", v, " : ", + objective.ShortDebugString()); + } + } + + return ""; +} + +#undef RETURN_IF_NOT_EMPTY + +// ============================================================================= +// Solution Feasibility. +// ============================================================================= + +namespace { + +class ConstraintChecker { + public: + explicit ConstraintChecker(const std::vector& variable_values) + : variable_values_(variable_values) {} + + bool LiteralIsTrue(int l) const { + if (l >= 0) return variable_values_[l] != 0; + return variable_values_[-l - 1] == 0; + } + + bool LiteralIsFalse(int l) const { return !LiteralIsTrue(l); } + + int64 Value(int var) const { + if (var >= 0) return variable_values_[var]; + return -variable_values_[-var - 1]; + } + + bool BoolOrConstraintIsFeasible(const ConstraintProto& ct) { + for (const int lit : ct.bool_or().literals()) { + if (LiteralIsTrue(lit)) return true; + } + return false; + } + + bool BoolAndConstraintIsFeasible(const ConstraintProto& ct) { + for (const int lit : ct.bool_and().literals()) { + if (LiteralIsFalse(lit)) return false; + } + return true; + } + + bool BoolXorConstraintIsFeasible(const ConstraintProto& ct) { + int sum = 0; + for (const int lit : ct.bool_xor().literals()) { + sum ^= LiteralIsTrue(lit) ? 1 : 0; + } + return sum == 1; + } + + // TODO(user): deal with integer overflows. + bool LinearConstraintIsFeasible(const ConstraintProto& ct) { + int64 sum = 0; + const int num_variables = ct.linear().coeffs_size(); + for (int i = 0; i < num_variables; ++i) { + sum += Value(ct.linear().vars(i)) * ct.linear().coeffs(i); + } + return DomainInProtoContains(ct.linear(), sum); + } + + bool IntMaxConstraintIsFeasible(const ConstraintProto& ct) { + const int64 max = Value(ct.int_max().target()); + int64 actual_max = kint64min; + for (int i = 0; i < ct.int_max().vars_size(); ++i) { + actual_max = std::max(actual_max, Value(ct.int_max().vars(i))); + } + return max == actual_max; + } + + bool IntProdConstraintIsFeasible(const ConstraintProto& ct) { + const int64 prod = Value(ct.int_prod().target()); + int64 actual_prod = 1; + for (int i = 0; i < ct.int_prod().vars_size(); ++i) { + actual_prod *= Value(ct.int_prod().vars(i)); + } + return prod == actual_prod; + } + + bool IntDivConstraintIsFeasible(const ConstraintProto& ct) { + return Value(ct.int_div().target()) == + Value(ct.int_div().vars(0)) / Value(ct.int_div().vars(1)); + } + + bool IntMinConstraintIsFeasible(const ConstraintProto& ct) { + const int64 min = Value(ct.int_min().target()); + int64 actual_min = kint64max; + for (int i = 0; i < ct.int_min().vars_size(); ++i) { + actual_min = std::min(actual_min, Value(ct.int_min().vars(i))); + } + return min == actual_min; + } + + bool AllDiffConstraintIsFeasible(const ConstraintProto& ct) { + std::unordered_set values; + for (const int v : ct.all_diff().vars()) { + if (ContainsKey(values, Value(v))) return false; + values.insert(Value(v)); + } + return true; + } + + bool IntervalConstraintIsFeasible(const ConstraintProto& ct) { + const int64 size = Value(ct.interval().size()); + if (size < 0) return false; + return Value(ct.interval().start()) + size == Value(ct.interval().end()); + } + + bool NoOverlapConstraintIsFeasible(const CpModelProto& model, + const ConstraintProto& ct) { + std::vector> start_durations_pairs; + for (const int i : ct.no_overlap().intervals()) { + const IntervalConstraintProto& interval = model.constraints(i).interval(); + start_durations_pairs.push_back( + {Value(interval.start()), Value(interval.size())}); + } + std::sort(start_durations_pairs.begin(), start_durations_pairs.end()); + int64 previous_end = kint64min; + for (const auto pair : start_durations_pairs) { + if (pair.first < previous_end) return false; + previous_end = pair.first + pair.second; + } + return true; + } + + bool IntervalsAreDisjoint(const CpModelProto& model, + const IntervalConstraintProto& interval1, + const IntervalConstraintProto& interval2) { + return Value(interval1.end()) <= Value(interval2.start()) || + Value(interval2.end()) <= Value(interval1.start()); + } + + bool NoOverlap2DConstraintIsFeasible(const CpModelProto& model, + const ConstraintProto& ct) { + const auto& arg = ct.no_overlap_2d(); + const int num_intervals = arg.x_intervals_size(); + for (int i = 0; i < num_intervals; ++i) { + for (int j = i + 1; j < num_intervals; ++j) { + if (!IntervalsAreDisjoint( + model, model.constraints(arg.x_intervals(i)).interval(), + model.constraints(arg.x_intervals(j)).interval()) && + !IntervalsAreDisjoint( + model, model.constraints(arg.y_intervals(i)).interval(), + model.constraints(arg.y_intervals(j)).interval())) { + return false; + } + } + } + return true; + } + + bool CumulativeConstraintIsFeasible(const CpModelProto& model, + const ConstraintProto& ct) { + // TODO(user, fdid): Improve complexity for large durations. + const int64 capacity = Value(ct.cumulative().capacity()); + const int num_intervals = ct.cumulative().intervals_size(); + std::unordered_map usage; + for (int i = 0; i < num_intervals; ++i) { + const IntervalConstraintProto& interval = + model.constraints(ct.cumulative().intervals(i)).interval(); + const int64 start = Value(interval.start()); + const int64 duration = Value(interval.size()); + const int64 demand = Value(ct.cumulative().demands(i)); + for (int64 t = start; t < start + duration; ++t) { + usage[t] += demand; + if (usage[t] > capacity) return false; + } + } + return true; + } + + bool ElementConstraintIsFeasible(const CpModelProto& model, + const ConstraintProto& ct) { + const int index = Value(ct.element().index()); + return Value(ct.element().vars(index)) == Value(ct.element().target()); + } + + bool TableConstraintIsFeasible(const CpModelProto& model, + const ConstraintProto& ct) { + const int size = ct.table().vars_size(); + if (size == 0) return true; + for (int row_start = 0; row_start < ct.table().values_size(); + row_start += size) { + int i = 0; + while (Value(ct.table().vars(i)) == ct.table().values(row_start + i)) { + ++i; + if (i == size) return !ct.table().negated(); + } + } + return ct.table().negated(); + } + + bool AutomataConstraintIsFeasible(const CpModelProto& model, + const ConstraintProto& ct) { + // Build the transition table {tail, label} -> head. + std::unordered_map, int64> transition_map; + const int num_transitions = ct.automata().transition_tail().size(); + for (int i = 0; i < num_transitions; ++i) { + transition_map[{ct.automata().transition_tail(i), + ct.automata().transition_label(i)}] = + ct.automata().transition_head(i); + } + + // Walk the automata. + int64 current_state = ct.automata().starting_state(); + const int num_steps = ct.automata().vars_size(); + for (int i = 0; i < num_steps; ++i) { + const std::pair key = {current_state, + Value(ct.automata().vars(i))}; + CHECK(ContainsKey(transition_map, key)); + current_state = transition_map[key]; + } + + // Check we are now in a final state. + for (const int64 final : ct.automata().final_states()) { + if (current_state == final) return true; + } + return false; + } + + bool CircuitConstraintIsFeasible(const CpModelProto& model, + const ConstraintProto& ct) { + const int num_nodes = ct.circuit().nexts_size(); + int num_inactive = 0; + int last_active = 0; + for (int i = 0; i < num_nodes; ++i) { + const int value = Value(ct.circuit().nexts(i)); + if (value < 0 || value == i || value >= num_nodes) { + ++num_inactive; + } else { + last_active = i; + } + } + if (num_inactive == num_nodes) return true; + + std::vector visited(num_nodes, false); + int current = last_active; + int num_visited = 0; + while (!visited[current]) { + ++num_visited; + visited[current] = true; + current = Value(ct.circuit().nexts(current)); + } + return num_visited + num_inactive == num_nodes; + } + + private: + std::vector variable_values_; +}; + +} // namespace + +bool SolutionIsFeasible(const CpModelProto& model, + const std::vector& variable_values) { + // Check that all values fall in the variable domains. + for (int i = 0; i < model.variables_size(); ++i) { + if (!DomainInProtoContains(model.variables(i), variable_values[i])) { + LOG(ERROR) << "Variable #" << i << " has value " << variable_values[i] + << " which do not fall in its domain: " + << model.variables(i).ShortDebugString(); + return false; + } + } + + CHECK_EQ(variable_values.size(), model.variables_size()); + ConstraintChecker checker(variable_values); + + for (int c = 0; c < model.constraints_size(); ++c) { + const ConstraintProto& ct = model.constraints(c); + + // We skip optional constraints that are not present. + if (ct.enforcement_literal_size() > 0 && + checker.LiteralIsFalse(ct.enforcement_literal(0))) { + continue; + } + + bool is_feasible = true; + const ConstraintProto::ConstraintCase type = ct.constraint_case(); + switch (type) { + case ConstraintProto::ConstraintCase::kBoolOr: + is_feasible = checker.BoolOrConstraintIsFeasible(ct); + break; + case ConstraintProto::ConstraintCase::kBoolAnd: + is_feasible = checker.BoolAndConstraintIsFeasible(ct); + break; + case ConstraintProto::ConstraintCase::kBoolXor: + is_feasible = checker.BoolAndConstraintIsFeasible(ct); + break; + case ConstraintProto::ConstraintCase::kLinear: + is_feasible = checker.LinearConstraintIsFeasible(ct); + break; + case ConstraintProto::ConstraintCase::kIntProd: + is_feasible = checker.IntProdConstraintIsFeasible(ct); + break; + case ConstraintProto::ConstraintCase::kIntDiv: + is_feasible = checker.IntDivConstraintIsFeasible(ct); + break; + case ConstraintProto::ConstraintCase::kIntMin: + is_feasible = checker.IntMinConstraintIsFeasible(ct); + break; + case ConstraintProto::ConstraintCase::kIntMax: + is_feasible = checker.IntMaxConstraintIsFeasible(ct); + break; + case ConstraintProto::ConstraintCase::kAllDiff: + is_feasible = checker.AllDiffConstraintIsFeasible(ct); + break; + case ConstraintProto::ConstraintCase::kInterval: + is_feasible = checker.IntervalConstraintIsFeasible(ct); + break; + case ConstraintProto::ConstraintCase::kNoOverlap: + is_feasible = checker.NoOverlapConstraintIsFeasible(model, ct); + break; + case ConstraintProto::ConstraintCase::kNoOverlap2D: + is_feasible = checker.NoOverlap2DConstraintIsFeasible(model, ct); + break; + case ConstraintProto::ConstraintCase::kCumulative: + is_feasible = checker.CumulativeConstraintIsFeasible(model, ct); + break; + case ConstraintProto::ConstraintCase::kElement: + is_feasible = checker.ElementConstraintIsFeasible(model, ct); + break; + case ConstraintProto::ConstraintCase::kTable: + is_feasible = checker.TableConstraintIsFeasible(model, ct); + break; + case ConstraintProto::ConstraintCase::kAutomata: + is_feasible = checker.AutomataConstraintIsFeasible(model, ct); + break; + case ConstraintProto::ConstraintCase::kCircuit: + is_feasible = checker.CircuitConstraintIsFeasible(model, ct); + break; + case ConstraintProto::ConstraintCase::CONSTRAINT_NOT_SET: + // Empty constraint is always feasible. + break; + default: + LOG(FATAL) << "Unuspported constraint: " << ConstraintCaseName(type); + } + if (!is_feasible) { + LOG(ERROR) << "Failing constraint #" << c << " : " + << model.constraints(c).ShortDebugString(); + return false; + } + } + return true; +} + +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/cp_model_checker.h b/ortools/sat/cp_model_checker.h new file mode 100644 index 0000000000..34e70e9403 --- /dev/null +++ b/ortools/sat/cp_model_checker.h @@ -0,0 +1,39 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_CP_MODEL_CHECKER_H_ +#define OR_TOOLS_SAT_CP_MODEL_CHECKER_H_ + +#include "ortools/base/integral_types.h" +#include "ortools/sat/cp_model.pb.h" + +namespace operations_research { +namespace sat { + +// Verifies that the given model satisfies all the properties described in the +// proto comments. Returns an empty std::string if it is the case, otherwise fails at +// the first error and returns a human-readable description of the issue. +// +// TODO(user): Add any needed overflow validation. +std::string ValidateCpModel(const CpModelProto& model); + +// Verifies that the given variable assignment is a feasible solution of the +// given model. The values vector should be in one to one correspondance with +// the model.variables() list of variables. +bool SolutionIsFeasible(const CpModelProto& model, + const std::vector& variable_values); + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_CP_MODEL_CHECKER_H_ diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc new file mode 100644 index 0000000000..ad8fef5288 --- /dev/null +++ b/ortools/sat/cp_model_presolve.cc @@ -0,0 +1,1528 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/cp_model_presolve.h" + +#include +#include +#include +#include +#include + +#include "ortools/base/join.h" +#include "ortools/base/hash.h" +#include "ortools/base/map_util.h" +#include "ortools/base/stl_util.h" +#include "ortools/sat/cp_model_checker.h" +#include "ortools/sat/cp_model_utils.h" +#include "ortools/util/affine_relation.h" +#include "ortools/util/bitset.h" +#include "ortools/util/sorted_interval_list.h" + +namespace operations_research { +namespace sat { +namespace { + +// ============================================================================= +// Utilities. +// ============================================================================= + +// An in-memory representation of a variable domain with convenient functions. +class Domain { + public: + // This takes a pointer to an "external" SparseBitset whose position "id" will + // be set to true each time this domain changes. + Domain(const IntegerVariableProto& var_proto, int id, + SparseBitset* watcher) + : id_(id), + watcher_(watcher), + sorted_disjoint_intervals_(ReadDomain(var_proto)) {} + + bool IsEmpty() const { return sorted_disjoint_intervals_.empty(); } + int64 Min() const { return sorted_disjoint_intervals_.front().start; } + int64 Max() const { return sorted_disjoint_intervals_.back().end; } + + bool Contains(int64 value) const { + return SortedDisjointIntervalsContain(sorted_disjoint_intervals_, value); + } + + bool IsFixedTo(int64 value) const { + if (IsEmpty()) return false; + return Min() == value && Max() == value; + } + + bool IsFixed() const { + if (IsEmpty()) return false; + return Min() == Max(); + } + + // Returns true iff the domain changed. + bool IntersectWith(const std::vector& intervals) { + tmp_ = IntersectionOfSortedDisjointIntervals(sorted_disjoint_intervals_, + intervals); + if (tmp_ != sorted_disjoint_intervals_) { + watcher_->Set(id_); + sorted_disjoint_intervals_ = tmp_; + return true; + } + return false; + } + bool IntersectWith(const Domain& domain) { + return IntersectWith(domain.sorted_disjoint_intervals_); + } + + // This works in O(n). + // TODO(user): Move to O(log(n)) if needed. + void RemovePoint(int64 point) { + CHECK_NE(point, kint64min); + CHECK_NE(point, kint64max); + if (Contains(point)) { + watcher_->Set(id_); + IntersectWith({{kint64min, point - 1}, {point + 1, kint64max}}); + } + } + + void CopyToIntegerVariableProto(IntegerVariableProto* proto) const { + FillDomain(sorted_disjoint_intervals_, proto); + } + + const std::vector& InternalRepresentation() const { + return sorted_disjoint_intervals_; + } + + void NotifyChanged() { watcher_->Set(id_); } + + private: + const int id_; + SparseBitset* watcher_; + std::vector sorted_disjoint_intervals_; + std::vector tmp_; +}; + +// Returns the sorted list of variables used by a constraint. +std::vector UsedVariables(const ConstraintProto& ct) { + IndexReferences references; + AddReferencesUsedByConstraint(ct, &references); + + std::vector used_variables; + for (const int var : references.variables) { + used_variables.push_back(PositiveRef(var)); + } + for (const int lit : references.literals) { + used_variables.push_back(PositiveRef(lit)); + } + if (HasEnforcementLiteral(ct)) { + used_variables.push_back(PositiveRef(ct.enforcement_literal(0))); + } + STLSortAndRemoveDuplicates(&used_variables); + return used_variables; +} + +struct PresolveContext { + bool LiteralIsTrue(int lit) const { + if (lit >= 0) return !domains[lit].Contains(0); + return domains[NegatedRef(lit)].IsFixedTo(0); + } + bool LiteralIsFalse(int lit) const { return LiteralIsTrue(NegatedRef(lit)); } + + void SetLiteralToFalse(int lit) { + const int var = PositiveRef(lit); + if (lit >= 0) { + domains[var].IntersectWith({{0ll, 0ll}}); + } else { + domains[var].RemovePoint(0ll); + } + if (domains[var].IsEmpty()) { + is_unsat = true; + } + } + void SetLiteralToTrue(int lit) { return SetLiteralToFalse(NegatedRef(lit)); } + + void UpdateRuleStats(const std::string& name) { stats_by_rule_name[name]++; } + + void UpdateConstraintVariableUsage(int c) { + if (c >= constraint_to_vars.size()) constraint_to_vars.resize(c + 1); + const ConstraintProto& ct = working_model->constraints(c); + for (const int v : constraint_to_vars[c]) var_to_constraints[v].erase(c); + constraint_to_vars[c] = UsedVariables(ct); + for (const int v : constraint_to_vars[c]) var_to_constraints[v].insert(c); + } + + int64 MinOf(int ref) const { + return RefIsPositive(ref) ? domains[ref].Min() + : -domains[PositiveRef(ref)].Max(); + } + int64 MaxOf(int ref) const { + return RefIsPositive(ref) ? domains[ref].Max() + : -domains[PositiveRef(ref)].Min(); + } + + // Regroups fixed variables with the same value. + // TODO(user): Also regroup cte and -cte? + void ExploitFixedDomain(int var) { + CHECK(domains[var].IsFixed()); + const int min = domains[var].Min(); + if (ContainsKey(constant_to_ref, min)) { + const int representative = constant_to_ref[min]; + if (representative != var) { + affine_relations.TryAdd(var, representative, 1, 0); + var_equiv_relations.TryAdd(var, representative, 1, 0); + } + } else { + constant_to_ref[min] = var; + } + } + + // Adds the relation (ref_x = coeff * ref_y + offset) to the repository. + void AddAffineRelation(const ConstraintProto& ct, int ref_x, int ref_y, + int64 coeff, int64 offset) { + int x = PositiveRef(ref_x); + int y = PositiveRef(ref_y); + if (domains[x].IsFixed() || domains[y].IsFixed()) return; + + int64 c = RefIsPositive(ref_x) == RefIsPositive(ref_y) ? coeff : -coeff; + int64 o = RefIsPositive(ref_x) ? offset : -offset; + + // If a Boolean variable (one with domain [0, 1]) appear in this affine + // equivalence class, then we want its representative to be Boolean. Note + // that this is always possible because a Boolean variable can never be + // equal to a multiple of another if std::abs(coeff) is greater than 1 and + // if it is not fixed to zero. This is important because it allows to simply + // use the same representative for any referenced literals. + const int rep_x = affine_relations.Get(x).representative; + const int rep_y = affine_relations.Get(y).representative; + bool force = false; + if (domains[rep_y].Min() == 0 && domains[rep_y].Max() == 1) { + // We force the new representative to be rep_y. + force = true; + } else if (domains[rep_x].Min() == 0 && domains[rep_x].Max() == 1) { + // We force the new representative to be rep_x. + force = true; + std::swap(x, y); + CHECK_EQ(std::abs(coeff), 1); // Would be fixed to zero otherwise. + if (coeff == 1) o = -o; + } + + // TODO(user): can we force the rep and remove the GetAffineRelation()? + bool added = force ? affine_relations.TryAddInGivenOrder(x, y, c, o) + : affine_relations.TryAdd(x, y, c, o); + if ((c == 1 || c == -1) && o == 0) { + added |= force ? var_equiv_relations.TryAddInGivenOrder(x, y, c, o) + : var_equiv_relations.TryAdd(x, y, c, o); + } + if (added) { + // The domain didn't change, but this notification allows to re-process + // any constraint containing these variables. + domains[x].NotifyChanged(); + domains[y].NotifyChanged(); + affine_constraints.insert(&ct); + } + } + + // This makes sure that the affine relation only uses one of the + // representative from the var_equiv_relations. + AffineRelation::Relation GetAffineRelation(int var) { + CHECK(RefIsPositive(var)); + AffineRelation::Relation r = affine_relations.Get(var); + AffineRelation::Relation o = var_equiv_relations.Get(r.representative); + r.representative = o.representative; + if (o.coeff == -1) r.coeff = -r.coeff; + return r; + } + + // Returns the current domain of the given variable reference. + std::vector GetRefDomain(int ref) const { + if (RefIsPositive(ref)) return domains[ref].InternalRepresentation(); + return NegationOfSortedDisjointIntervals( + domains[PositiveRef(ref)].InternalRepresentation()); + } + + // The current domain of each variables. + std::vector domains; + + // This regroup all the affine relations between variables. Note that the + // constraints used to detect such relations will not be removed from the + // model at detection time (thus allowing proper domain propagation). However, + // if the arity of a variable becomes one, then such constraint will be + // removed. + AffineRelation affine_relations; + AffineRelation var_equiv_relations; + + // Set of constraint that implies an "affine relation". We need to mark them, + // because we can't simplify them using the relation they added. + std::unordered_set affine_constraints; + + // For each constant variable appearing in the model, we maintain a reference + // variable with the same constant value. If two variables end up having the + // same fixed value, then we can detect it using this and add a new + // equivalence relation. See TestAndExploitFixedDomain(). + std::unordered_map constant_to_ref; + + // Variable <-> constraint graph. + // The vector list is sorted and contains unique elements. + // + // Important: To properly handle the objective, var_to_constraints[objective] + // contains -1 so that if the objective appear in only one constraint, the + // constraint cannot be simplified. + // + // TODO(user): Make this private? + std::vector> constraint_to_vars; + std::vector> var_to_constraints; + + CpModelProto* working_model; + CpModelProto* mapping_model; + + // Initially false, and set to true on the first inconsistency. + bool is_unsat = false; + + // Just used to display statistics on the presolve rules that were used. + std::unordered_map stats_by_rule_name; + + // Temporary storage for PresolveLinear(). + std::vector> tmp_term_domains; + std::vector> tmp_left_domains; +}; + +// ============================================================================= +// Presolve functions. +// +// They should return false only if the constraint <-> variable graph didn't +// change. This is just an optimization, returning true is always correct. +// +// TODO(user): it migth be better to simply move all these functions to the +// PresolveContext class. +// ============================================================================= + +MUST_USE_RESULT bool RemoveConstraint(ConstraintProto* ct, + PresolveContext* context) { + ct->Clear(); + return true; +} + +MUST_USE_RESULT bool MarkConstraintAsFalse(ConstraintProto* ct, + PresolveContext* context) { + if (HasEnforcementLiteral(*ct)) { + context->SetLiteralToFalse(ct->enforcement_literal(0)); + } else { + context->is_unsat = true; + } + return RemoveConstraint(ct, context); +} + +bool PresolveEnforcementLiteral(ConstraintProto* ct, PresolveContext* context) { + if (!HasEnforcementLiteral(*ct)) return false; + const int literal = ct->enforcement_literal(0); + if (context->LiteralIsTrue(literal)) { + context->UpdateRuleStats("true enforcement literal"); + ct->clear_enforcement_literal(); + return true; + } + if (context->LiteralIsFalse(literal)) { + context->UpdateRuleStats("false enforcement literal"); + return RemoveConstraint(ct, context); + } + if (context->var_to_constraints[PositiveRef(literal)].size() == 1) { + // We can simply set it to false and ignore the constraint in this case. + context->UpdateRuleStats("enforcement literal not used"); + context->SetLiteralToFalse(literal); + return RemoveConstraint(ct, context); + } + return false; +} + +bool PresolveBoolOr(ConstraintProto* ct, PresolveContext* context) { + // Move the enforcement literal inside the clause if any. + if (HasEnforcementLiteral(*ct)) { + // Note that we do not mark this as changed though since the literal in the + // constraint are the same. + context->UpdateRuleStats("bool_or: removed enforcement literal"); + ct->mutable_bool_or()->add_literals(NegatedRef(ct->enforcement_literal(0))); + ct->clear_enforcement_literal(); + } + + // Inspects the literals and deal with fixed ones. + // + // TODO(user): detect if one literal is the negation of another in which + // case the constraint is true. Remove duplicates too. Do the same for + // the PresolveBoolAnd() function. + bool changed = false; + google::protobuf::RepeatedField new_literals; + for (const int literal : ct->bool_or().literals()) { + if (context->LiteralIsFalse(literal)) { + changed = true; + continue; + } + if (context->LiteralIsTrue(literal)) { + context->UpdateRuleStats("bool_or: always true"); + return RemoveConstraint(ct, context); + } + // We can just set the variable to true in this case since it is not + // used in any other constraint (note that we artifically bump the + // objective var usage by 1). + if (context->var_to_constraints[PositiveRef(literal)].size() == 1) { + context->SetLiteralToTrue(literal); + return RemoveConstraint(ct, context); + } + new_literals.Add(literal); + } + + if (new_literals.empty()) { + context->UpdateRuleStats("bool_or: empty"); + return MarkConstraintAsFalse(ct, context); + } + if (new_literals.size() == 1) { + context->UpdateRuleStats("bool_or: only one literal"); + context->SetLiteralToTrue(new_literals.Get(0)); + return RemoveConstraint(ct, context); + } + + ct->mutable_bool_or()->mutable_literals()->Swap(&new_literals); + if (changed) context->UpdateRuleStats("bool_or: fixed literals"); + return changed; +} + +bool PresolveBoolAnd(ConstraintProto* ct, PresolveContext* context) { + if (!HasEnforcementLiteral(*ct)) { + context->UpdateRuleStats("bool_and: non-reified."); + for (const int literal : ct->bool_and().literals()) { + context->SetLiteralToTrue(literal); + } + return RemoveConstraint(ct, context); + } + + bool changed = false; + google::protobuf::RepeatedField new_literals; + for (const int literal : ct->bool_and().literals()) { + if (context->LiteralIsFalse(literal)) { + context->UpdateRuleStats("bool_and: always false"); + return MarkConstraintAsFalse(ct, context); + } + if (context->LiteralIsTrue(literal)) { + changed = true; + continue; + } + if (context->var_to_constraints[PositiveRef(literal)].size() == 1) { + changed = true; + context->SetLiteralToTrue(literal); + continue; + } + new_literals.Add(literal); + } + + if (new_literals.empty()) return RemoveConstraint(ct, context); + if (new_literals.size() == 1) { + context->UpdateRuleStats("TODO bool_and: equality"); + } + + ct->mutable_bool_and()->mutable_literals()->Swap(&new_literals); + if (changed) context->UpdateRuleStats("bool_and: fixed literals"); + return changed; +} + +bool PresolveIntMax(ConstraintProto* ct, PresolveContext* context) { + if (ct->int_max().vars().empty()) { + return MarkConstraintAsFalse(ct, context); + } + + const int target_ref = ct->int_max().target(); + const int target_var = PositiveRef(target_ref); + + // Pass 1, compute the infered min of the target, and remove duplicates. + int64 target_min = context->MinOf(target_ref); + bool contains_target_ref = false; + std::set used_ref; + int new_size = 0; + std::string old = ct->DebugString(); + for (const int ref : ct->int_max().vars()) { + if (ref == target_ref) contains_target_ref = true; + if (ContainsKey(used_ref, ref)) continue; + if (ContainsKey(used_ref, NegatedRef(ref)) || + ref == NegatedRef(target_ref)) { + target_min = std::max(target_min, 0ll); + } + used_ref.insert(ref); + ct->mutable_int_max()->set_vars(new_size++, ref); + target_min = std::max(target_min, context->MinOf(ref)); + } + if (new_size < ct->int_max().vars_size()) { + context->UpdateRuleStats("int_max: removed dup"); + } + ct->mutable_int_max()->mutable_vars()->Truncate(new_size); + if (contains_target_ref) { + context->UpdateRuleStats("int_max: x = std::max(x, ...)"); + for (const int ref : ct->int_max().vars()) { + if (ref == target_ref) continue; + ConstraintProto* new_ct = context->working_model->add_constraints(); + *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); + auto* arg = new_ct->mutable_linear(); + arg->add_vars(target_ref); + arg->add_coeffs(1); + arg->add_vars(ref); + arg->add_coeffs(-1); + arg->add_domain(0); + arg->add_domain(kint64max); + } + return RemoveConstraint(ct, context); + } + + // Update the target domain. + bool domain_reduced = false; + if (!HasEnforcementLiteral(*ct)) { + std::vector infered_domain; + for (const int ref : ct->int_max().vars()) { + infered_domain = UnionOfSortedDisjointIntervals( + infered_domain, + IntersectionOfSortedDisjointIntervals(context->GetRefDomain(ref), + {{target_min, kint64max}})); + } + if (!RefIsPositive(target_ref)) { + infered_domain = NegationOfSortedDisjointIntervals(infered_domain); + } + domain_reduced |= + context->domains[target_var].IntersectWith(infered_domain); + } + + // Pass 2, update the argument domains. Filter them eventually. + new_size = 0; + const int size = ct->int_max().vars_size(); + const int64 target_max = context->MaxOf(target_ref); + for (const int ref : ct->int_max().vars()) { + if (!HasEnforcementLiteral(*ct)) { + if (RefIsPositive(ref)) { + domain_reduced |= context->domains[PositiveRef(ref)].IntersectWith( + {{kint64min, target_max}}); + } else { + domain_reduced |= context->domains[PositiveRef(ref)].IntersectWith( + {{-target_max, kint64max}}); + } + } + if (context->MaxOf(ref) >= target_min) { + ct->mutable_int_max()->set_vars(new_size++, ref); + } + } + if (domain_reduced) { + context->UpdateRuleStats("int_max: reduced domains"); + } + + bool modified = false; + if (new_size < size) { + context->UpdateRuleStats("int_max: removed variables"); + ct->mutable_int_max()->mutable_vars()->Truncate(new_size); + modified = true; + } + + // Note that we do that after the domains have been reduced. + // TODO(user): Even in the reified case we could do something. + // TODO(user): If the domains have holes, it is possible we will only detect + // UNSAT at postsolve time, that might be an issue. + if (new_size > 0 && !HasEnforcementLiteral(*ct) && + context->var_to_constraints[target_var].size() == 1) { + context->UpdateRuleStats("int_max: singleton target"); + *(context->mapping_model->add_constraints()) = *ct; + return RemoveConstraint(ct, context); + } + if (new_size == 1) { + // Convert to an equality. Note that we create a new constraint otherwise it + // might not be processed again. + context->UpdateRuleStats("int_max: converted to equality"); + ConstraintProto* new_ct = context->working_model->add_constraints(); + *new_ct = *ct; // copy name and potential reification. + auto* arg = new_ct->mutable_linear(); + arg->add_vars(target_ref); + arg->add_coeffs(1); + arg->add_vars(ct->int_max().vars(0)); + arg->add_coeffs(-1); + arg->add_domain(0); + arg->add_domain(0); + return RemoveConstraint(ct, context); + } + return modified; +} + +bool PresolveIntMin(ConstraintProto* ct, PresolveContext* context) { + const auto copy = ct->int_min(); + ct->mutable_int_max()->set_target(NegatedRef(copy.target())); + for (const int ref : copy.vars()) { + ct->mutable_int_max()->add_vars(NegatedRef(ref)); + } + return PresolveIntMax(ct, context); +} + +bool PresolveIntProd(ConstraintProto* ct, PresolveContext* context) { + // For now, we only presolve the case where all variable are Booleans. + const int target_ref = ct->int_prod().target(); + if (!RefIsPositive(target_ref)) return false; + for (const int var : ct->int_prod().vars()) { + if (!RefIsPositive(var)) return false; + if (context->domains[var].Min() != 0) return false; + if (context->domains[var].Max() != 1) return false; + } + + // This is a bool constraint! + context->UpdateRuleStats("int_prod: converted to reified bool_and"); + { + ConstraintProto* new_ct = context->working_model->add_constraints(); + new_ct->add_enforcement_literal(target_ref); + auto* arg = new_ct->mutable_bool_and(); + for (const int var : ct->int_prod().vars()) { + arg->add_literals(var); + } + } + { + ConstraintProto* new_ct = context->working_model->add_constraints(); + auto* arg = new_ct->mutable_bool_or(); + arg->add_literals(target_ref); + for (const int var : ct->int_prod().vars()) { + arg->add_literals(NegatedRef(var)); + } + } + return RemoveConstraint(ct, context); +} + +bool ExploitEquivalenceRelations(ConstraintProto* ct, + PresolveContext* context) { + if (ContainsKey(context->affine_constraints, ct)) return false; + bool changed = false; + + // Remap equal and negated variables to their representative. + ApplyToAllVariableIndices( + [&changed, context](int* ref) { + const int var = PositiveRef(*ref); + const AffineRelation::Relation r = + context->var_equiv_relations.Get(var); + if (r.representative != var) { + CHECK_EQ(r.offset, 0); + CHECK_EQ(std::abs(r.coeff), 1); + *ref = (r.coeff == 1) == RefIsPositive(*ref) + ? r.representative + : NegatedRef(r.representative); + changed = true; + } + }, + ct); + + // Remap literal and negated literal to their representative. + ApplyToAllLiteralIndices( + [&changed, context](int* ref) { + const int var = PositiveRef(*ref); + const AffineRelation::Relation r = context->GetAffineRelation(var); + if (r.representative != var) { + const bool is_positive = (r.offset == 0 && r.coeff == 1); + CHECK(is_positive || r.offset == 1 && r.coeff == -1 || + context->domains[var].IsFixed()); + *ref = (is_positive == RefIsPositive(*ref)) + ? r.representative + : NegatedRef(r.representative); + changed = true; + } + }, + ct); + return changed; +} + +bool PresolveLinear(ConstraintProto* ct, PresolveContext* context) { + bool var_was_removed = false; + bool var_constraint_graph_changed = false; + std::vector rhs = ReadDomain(ct->linear()); + + // First regroup the terms on the same variables and sum the fixed ones. + // Note that we use a map to sort the variables and because we expect most + // constraints to be small. + // + // TODO(user): move the map in context to reuse its memory. Add a quick pass + // to skip most of the work below if the constraint is already in canonical + // form (strictly increasing var, no-fixed var, gcd = 1). + int64 sum_of_fixed_terms = 0; + std::map var_to_coeff; + const LinearConstraintProto& arg = ct->linear(); + const bool was_affine = ContainsKey(context->affine_constraints, ct); + for (int i = 0; i < arg.vars_size(); ++i) { + const int var = PositiveRef(arg.vars(i)); + const int64 coeff = + RefIsPositive(arg.vars(i)) ? arg.coeffs(i) : -arg.coeffs(i); + if (coeff == 0) continue; + if (context->domains[var].IsFixed()) { + sum_of_fixed_terms += coeff * context->domains[var].Min(); + } else { + if (!was_affine && context->var_to_constraints[var].size() == 1) { + bool success; + const auto term_domain = PreciseMultiplicationOfSortedDisjointIntervals( + context->domains[var].InternalRepresentation(), -coeff, &success); + if (success) { + // Note that we can't do that if we loose information in the + // multiplication above because the new domain might not be as strict + // as the initial constraint otherwise. TODO(user): because of the + // addition, it might be possible to cover more cases though. + var_was_removed = true; + rhs = AdditionOfSortedDisjointIntervals(rhs, term_domain); + continue; + } + } + + if (!was_affine) { + const AffineRelation::Relation r = context->GetAffineRelation(var); + if (r.representative != var) { + var_constraint_graph_changed = true; + sum_of_fixed_terms += coeff * r.offset; + } + var_to_coeff[r.representative] += coeff * r.coeff; + if (var_to_coeff[r.representative] == 0) { + var_to_coeff.erase(r.representative); + } + } else { + var_to_coeff[var] += coeff; + if (var_to_coeff[var] == 0) var_to_coeff.erase(var); + } + } + } + if (var_was_removed) { + context->UpdateRuleStats("linear: singleton column"); + // TODO(user): we could add the constraint to mapping_model only once + // instead of adding a reduced version of it each time a new singleton + // variable appear in the same constraint later. That would work but would + // also force the postsolve to take search decisions... + *(context->mapping_model->add_constraints()) = *ct; + } + + // Compute the GCD of all coefficients. + int64 gcd = 1; + bool first_coeff = true; + for (const auto entry : var_to_coeff) { + // GCD(gcd, coeff) = GCD(coeff, gcd % coeff); + int64 coeff = std::abs(entry.second); + if (first_coeff) { + if (coeff != 0) { + first_coeff = false; + gcd = coeff; + } + continue; + } + while (coeff != 0) { + const int64 r = gcd % coeff; + gcd = coeff; + coeff = r; + } + if (gcd == 1) break; + } + if (gcd > 1) { + context->UpdateRuleStats("linear: divide by GCD"); + } + + if (var_to_coeff.size() < arg.vars_size()) { + context->UpdateRuleStats("linear: fixed or dup variables"); + var_constraint_graph_changed = true; + } + + // Rewrite the constraint in canonical form and update rhs (it will be copied + // to the constraint later). + if (sum_of_fixed_terms != 0) { + rhs = AdditionOfSortedDisjointIntervals( + rhs, {{-sum_of_fixed_terms, -sum_of_fixed_terms}}); + } + if (gcd > 1) { + rhs = InverseMultiplicationOfSortedDisjointIntervals(rhs, gcd); + } + ct->mutable_linear()->clear_vars(); + ct->mutable_linear()->clear_coeffs(); + for (const auto entry : var_to_coeff) { + CHECK(RefIsPositive(entry.first)); + ct->mutable_linear()->add_vars(entry.first); + ct->mutable_linear()->add_coeffs(entry.second / gcd); + } + + // Empty constraint? + if (ct->linear().vars().empty()) { + context->UpdateRuleStats("linear: empty"); + if (SortedDisjointIntervalsContain(rhs, 0)) { + return RemoveConstraint(ct, context); + } else { + return MarkConstraintAsFalse(ct, context); + } + } + + // Size one constraint? + if (ct->linear().vars().size() == 1 && !HasEnforcementLiteral(*ct)) { + const int64 coeff = + RefIsPositive(arg.vars(0)) ? arg.coeffs(0) : -arg.coeffs(0); + context->UpdateRuleStats("linear: size one"); + const int var = PositiveRef(arg.vars(0)); + if (coeff == 1) { + context->domains[var].IntersectWith(rhs); + } else { + DCHECK_EQ(coeff, -1); // Because of the GCD above. + context->domains[var].IntersectWith( + NegationOfSortedDisjointIntervals(rhs)); + } + return RemoveConstraint(ct, context); + } + + // Compute the implied rhs bounds from the variable ones. + const int kDomainComplexityLimit = 100; + auto& term_domains = context->tmp_term_domains; + auto& left_domains = context->tmp_left_domains; + const int num_vars = arg.vars_size(); + term_domains.resize(num_vars + 1); + left_domains.resize(num_vars + 1); + left_domains[0] = {{0, 0}}; + for (int i = 0; i < num_vars; ++i) { + const int var = PositiveRef(arg.vars(i)); + const int64 coeff = arg.coeffs(i); + const auto& domain = context->domains[var].InternalRepresentation(); + + // TODO(user): Try PreciseMultiplicationOfSortedDisjointIntervals() if + // the size is reasonable. + term_domains[i] = MultiplicationOfSortedDisjointIntervals(domain, coeff); + left_domains[i + 1] = + AdditionOfSortedDisjointIntervals(left_domains[i], term_domains[i]); + if (left_domains[i + 1].size() > kDomainComplexityLimit) { + // We take a super-set, otherwise it will be too slow. + // TODO(user): We could be smarter in how we compute this if we allow for + // more than one intervals. + left_domains[i + 1] = { + {left_domains[i + 1].front().start, left_domains[i + 1].back().end}}; + } + } + const std::vector& implied_rhs = left_domains[num_vars]; + + // Abort if intersection is empty. + const std::vector restricted_rhs = + IntersectionOfSortedDisjointIntervals(rhs, implied_rhs); + if (restricted_rhs.empty()) { + context->UpdateRuleStats("linear: infeasible"); + return MarkConstraintAsFalse(ct, context); + } + + // Relax the constraint rhs for faster propagation. + // TODO(user): add an IntersectionIsEmpty() function. + rhs.clear(); + for (const ClosedInterval i : UnionOfSortedDisjointIntervals( + restricted_rhs, ComplementOfSortedDisjointIntervals(implied_rhs))) { + if (!IntersectionOfSortedDisjointIntervals({i}, restricted_rhs).empty()) { + rhs.push_back(i); + } + } + if (rhs.size() == 1 && rhs[0].start == kint64min && rhs[0].end == kint64max) { + context->UpdateRuleStats("linear: always true"); + return RemoveConstraint(ct, context); + } + if (rhs != ReadDomain(ct->linear())) { + context->UpdateRuleStats("linear: simplified rhs"); + } + FillDomain(rhs, ct->mutable_linear()); + + // Propagate the variable bounds. + if (!HasEnforcementLiteral(*ct)) { + bool new_bounds = false; + std::vector new_domain; + std::vector right_domain = {{0, 0}}; + term_domains[num_vars] = NegationOfSortedDisjointIntervals(rhs); + for (int i = num_vars - 1; i >= 0; --i) { + right_domain = + AdditionOfSortedDisjointIntervals(right_domain, term_domains[i + 1]); + if (right_domain.size() > kDomainComplexityLimit) { + // We take a super-set, otherwise it will be too slow. + right_domain = {{right_domain.front().start, right_domain.back().end}}; + } + new_domain = InverseMultiplicationOfSortedDisjointIntervals( + AdditionOfSortedDisjointIntervals(left_domains[i], right_domain), + -arg.coeffs(i)); + if (context->domains[arg.vars(i)].IntersectWith(new_domain)) { + new_bounds = true; + } + } + if (new_bounds) { + context->UpdateRuleStats("linear: reduced variable domains"); + } + } + + // Detect affine relation. + // + // TODO(user): it might be better to first add only the affine relation with + // a coefficient of magnitude 1, and later the one with larger coeffs. + if (!was_affine && !HasEnforcementLiteral(*ct)) { + const LinearConstraintProto& arg = ct->linear(); + const int64 rhs_min = rhs.front().start; + const int64 rhs_max = rhs.back().end; + if (rhs_min == rhs_max && arg.vars_size() == 2) { + const int v1 = arg.vars(0); + const int v2 = arg.vars(1); + const int64 coeff1 = arg.coeffs(0); + const int64 coeff2 = arg.coeffs(1); + if (coeff1 == 1) { + context->AddAffineRelation(*ct, v1, v2, -coeff2, rhs_max); + } else if (coeff2 == 1) { + context->AddAffineRelation(*ct, v2, v1, -coeff1, rhs_max); + } else if (coeff1 == -1) { + context->AddAffineRelation(*ct, v1, v2, coeff2, -rhs_max); + } else if (coeff2 == -1) { + context->AddAffineRelation(*ct, v2, v1, coeff1, -rhs_max); + } + } + } + return var_constraint_graph_changed; +} + +// Convert small linear constraint involving only Booleans to clauses. +bool PresolveLinearIntoClauses(ConstraintProto* ct, PresolveContext* context) { + // TODO(user): the alternative to mark any newly created constraints might + // be better. + if (ContainsKey(context->affine_constraints, ct)) return false; + const LinearConstraintProto& arg = ct->linear(); + const int num_vars = arg.vars_size(); + int64 min_coeff = kint64max; + int64 offset = 0; + for (int i = 0; i < num_vars; ++i) { + const int var = PositiveRef(arg.vars(i)); + if (context->domains[var].Min() != 0) return false; + if (context->domains[var].Max() != 1) return false; + const int64 coeff = arg.coeffs(i); + if (coeff > 0) { + min_coeff = std::min(min_coeff, coeff); + } else { + // We replace the Boolean ref, by a ref to its negation (1 - x). + offset += coeff; + min_coeff = std::min(min_coeff, -coeff); + } + } + + // Detect clauses and reified ands. + const std::vector domain = ReadDomain(arg); + DCHECK(!domain.empty()); + if (offset + min_coeff > domain.back().end) { + // All Boolean are false if the reified literal is true. + context->UpdateRuleStats("linear: reified and"); + const auto copy = arg; + ct->mutable_bool_and()->clear_literals(); + for (int i = 0; i < num_vars; ++i) { + ct->mutable_bool_and()->add_literals( + copy.coeffs(i) > 0 ? NegatedRef(copy.vars(i)) : copy.vars(i)); + } + return PresolveBoolAnd(ct, context); + } else if (offset + min_coeff >= domain[0].start && + domain[0].end == kint64max) { + // At least one Boolean is true. + context->UpdateRuleStats("linear: clause"); + const auto copy = arg; + ct->mutable_bool_or()->clear_literals(); + for (int i = 0; i < num_vars; ++i) { + ct->mutable_bool_or()->add_literals( + copy.coeffs(i) > 0 ? copy.vars(i) : NegatedRef(copy.vars(i))); + } + return PresolveBoolOr(ct, context); + } + + // Expand small expression into clause. + if (num_vars > 3) return false; + context->UpdateRuleStats("linear: small Boolean expression"); + + // Enumerate all possible value of the Booleans and add a clause if constraint + // is false. TODO(user): the encoding could be made better in some cases. + const int max_mask = (1 << arg.vars_size()); + for (int mask = 0; mask < max_mask; ++mask) { + int64 value = 0; + for (int i = 0; i < num_vars; ++i) { + if ((mask >> i) & 1) value += arg.coeffs(i); + } + if (SortedDisjointIntervalsContain(domain, value)) continue; + + // Add a new clause to exclude this bad assignment. + ConstraintProto* new_ct = context->working_model->add_constraints(); + auto* new_arg = new_ct->mutable_bool_or(); + if (HasEnforcementLiteral(*ct)) { + new_ct->add_enforcement_literal(ct->enforcement_literal(0)); + } + for (int i = 0; i < num_vars; ++i) { + new_arg->add_literals(((mask >> i) & 1) ? NegatedRef(arg.vars(i)) + : arg.vars(i)); + } + } + + return RemoveConstraint(ct, context); +} + +bool PresolveInterval(ConstraintProto* ct, PresolveContext* context) { + // TODO(user): find a way to not care about this by extending the context API. + if (!RefIsPositive(ct->interval().start())) return false; + if (!RefIsPositive(ct->interval().end())) return false; + if (!RefIsPositive(ct->interval().size())) return false; + + Domain& start = context->domains[PositiveRef(ct->interval().start())]; + Domain& end = context->domains[PositiveRef(ct->interval().end())]; + Domain& size = context->domains[PositiveRef(ct->interval().size())]; + bool changed = false; + changed |= end.IntersectWith(AdditionOfSortedDisjointIntervals( + start.InternalRepresentation(), size.InternalRepresentation())); + changed |= start.IntersectWith(AdditionOfSortedDisjointIntervals( + end.InternalRepresentation(), + NegationOfSortedDisjointIntervals(size.InternalRepresentation()))); + changed |= size.IntersectWith(AdditionOfSortedDisjointIntervals( + end.InternalRepresentation(), + NegationOfSortedDisjointIntervals(start.InternalRepresentation()))); + if (changed) { + context->UpdateRuleStats("interval: reduced domains"); + } + + if (size.IsFixed()) { + // We add it even if the interval is optional. + // TODO(user): we must verify that all the variable of an optional interval + // do not appear in a constraint which is not reified by the same literal. + context->AddAffineRelation(*ct, ct->interval().end(), + ct->interval().start(), 1, size.Min()); + } + + // This never change the constraint-variable graph. + return false; +} + +bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { + const int index_ref = ct->element().index(); + if (context->var_to_constraints[PositiveRef(index_ref)].size() == 1) { + context->UpdateRuleStats("TODO element: index not used elsewhere"); + } + const int target_ref = ct->element().target(); + if (context->var_to_constraints[PositiveRef(target_ref)].size() == 1) { + context->UpdateRuleStats("TODO element: target not used elsewhere"); + } + + if (HasEnforcementLiteral(*ct)) return false; + + bool reduced_index_domain = false; + std::vector infered_domain; + const std::vector target_dom = + context->GetRefDomain(target_ref); + for (const ClosedInterval interval : context->GetRefDomain(index_ref)) { + for (int i = interval.start; i <= interval.end; ++i) { + CHECK_GE(i, 0); + CHECK_LE(i, ct->element().vars_size()); + const int ref = ct->element().vars(i); + const auto& domain = context->GetRefDomain(ref); + if (IntersectionOfSortedDisjointIntervals(target_dom, domain).empty()) { + context->domains[PositiveRef(index_ref)].RemovePoint( + RefIsPositive(index_ref) ? i : -i); + reduced_index_domain = true; + } else { + infered_domain = UnionOfSortedDisjointIntervals(infered_domain, domain); + } + } + } + if (reduced_index_domain) { + context->UpdateRuleStats("element: reduced index domain"); + } + if (!RefIsPositive(target_ref)) { + infered_domain = NegationOfSortedDisjointIntervals(infered_domain); + } + if (context->domains[PositiveRef(target_ref)].IntersectWith(infered_domain)) { + context->UpdateRuleStats("element: reduced target domain"); + } + return false; +} + +bool PresolveTable(ConstraintProto* ct, PresolveContext* context) { + if (HasEnforcementLiteral(*ct)) return false; + if (ct->table().negated()) return false; + if (ct->table().vars().empty()) { + context->UpdateRuleStats("table: empty constraint"); + return RemoveConstraint(ct, context); + } + + // Filter the unreachable tuples. + // + // TODO(user): this is not supper efficient. Optimize if needed. + const int num_vars = ct->table().vars_size(); + const int num_tuples = ct->table().values_size() / num_vars; + std::vector tuple(num_vars); + std::vector> new_tuples; + new_tuples.reserve(num_tuples); + std::vector> new_domains(num_vars); + for (int i = 0; i < num_tuples; ++i) { + bool delete_row = false; + std::string tmp; + for (int j = 0; j < num_vars; ++j) { + const int ref = ct->table().vars(j); + const int64 v = ct->table().values(i * num_vars + j); + tuple[j] = v; + if (!context->domains[PositiveRef(ref)].Contains( + RefIsPositive(ref) ? v : -v)) { + delete_row = true; + break; + } + } + if (delete_row) continue; + new_tuples.push_back(tuple); + for (int j = 0; j < num_vars; ++j) { + const int ref = ct->table().vars(j); + const int64 v = tuple[j]; + new_domains[j].insert(RefIsPositive(ref) ? v : -v); + } + } + STLSortAndRemoveDuplicates(&new_tuples); + + // Update the list of tuples if needed. + if (new_tuples.size() < num_tuples) { + ct->mutable_table()->clear_values(); + for (const std::vector& t : new_tuples) { + for (const int64 v : t) { + ct->mutable_table()->add_values(v); + } + } + context->UpdateRuleStats("table: removed rows"); + } + + // Filter the variable domains. + bool changed = false; + for (int j = 0; j < num_vars; ++j) { + const int ref = ct->table().vars(j); + changed |= context->domains[PositiveRef(ref)].IntersectWith( + SortedDisjointIntervalsFromValues( + std::vector(new_domains[j].begin(), new_domains[j].end()))); + } + if (changed) { + context->UpdateRuleStats("table: reduced variable domains"); + } + if (num_vars == 1) { + // Now that we properly update the domain, we can remove the constraint. + context->UpdateRuleStats("table: only one column!"); + return RemoveConstraint(ct, context); + } + + // Check that the table is not complete or just here to exclude a few tuples. + int64 prod = 1; + for (int j = 0; j < num_vars; ++j) prod *= new_domains[j].size(); + if (prod == new_tuples.size()) { + context->UpdateRuleStats("table: all tuples!"); + return RemoveConstraint(ct, context); + } + + // Convert to the negated table if we gain a lot of entries by doing so. + // Note however that currently the negated table do not propagate as much as + // it could. + if (new_tuples.size() > 0.7 * prod) { + // Enumerate all tuples. + std::vector> var_to_values(num_vars); + for (int j = 0; j < num_vars; ++j) { + var_to_values[j].assign(new_domains[j].begin(), new_domains[j].end()); + } + std::vector> all_tuples(prod); + for (int i = 0; i < prod; ++i) { + all_tuples[i].resize(num_vars); + int index = i; + for (int j = 0; j < num_vars; ++j) { + all_tuples[i][j] = var_to_values[j][index % var_to_values[j].size()]; + index /= var_to_values[j].size(); + } + } + STLSortAndRemoveDuplicates(&all_tuples); + + // Compute the complement of new_tuples. + std::vector> diff(prod - new_tuples.size()); + std::set_difference(all_tuples.begin(), all_tuples.end(), + new_tuples.begin(), new_tuples.end(), diff.begin()); + + // Negate the constraint. + ct->mutable_table()->set_negated(!ct->table().negated()); + ct->mutable_table()->clear_values(); + for (const std::vector& t : diff) { + for (const int64 v : t) ct->mutable_table()->add_values(v); + } + context->UpdateRuleStats("table: negated"); + } + return false; +} + +} // namespace. + +// ============================================================================= +// Public API. +// ============================================================================= + +// The presolve works as follow: +// +// First stage: +// We will process all active constraints until a fix point is reached. During +// this stage: +// - Variable will never be deleted, but their domain will be reduced. +// - Constraint will never be deleted (they will be marked as empty if needed). +// - New variables and new constraints can be added after the existing ones. +// - Constraints are added only when needed to the mapping_problem if they are +// needed during the postsolve. +// +// Second stage: +// - All the variables domain will be copied to the mapping_model. +// - Everything will be remapped so that only the variables appearing in some +// constraints will be kept and their index will be in [0, num_new_variables). +void PresolveCpModel(const CpModelProto& initial_model, + CpModelProto* presolved_model, CpModelProto* mapping_model, + std::vector* postsolve_mapping) { + // The list of modified domain. + SparseBitset modified_domains(initial_model.variables_size()); + + PresolveContext context; + for (int i = 0; i < initial_model.variables_size(); ++i) { + context.domains.push_back( + Domain(initial_model.variables(i), i, &modified_domains)); + if (context.domains[i].IsFixed()) context.ExploitFixedDomain(i); + } + context.working_model = presolved_model; + context.mapping_model = mapping_model; + *presolved_model = initial_model; + + // We copy the search strategy from the initial_model to mapping_model. + for (const auto& decision_strategy : initial_model.search_strategy()) { + *mapping_model->add_search_strategy() = decision_strategy; + } + + // The queue of "active" constraints, initialized to all of them. + std::vector in_queue(initial_model.constraints_size(), true); + std::deque queue(initial_model.constraints_size()); + std::iota(queue.begin(), queue.end(), 0); + + // This is used for constraint having unique variables in them (i.e. not + // appearing anywhere else) to not call the presolve more than once for this + // reason. + std::unordered_set> var_constraint_pair_already_called; + + // Initialize the constraint <-> variable graph. + context.constraint_to_vars.resize(initial_model.constraints_size()); + context.var_to_constraints.resize(initial_model.variables_size()); + for (int c = 0; c < initial_model.constraints_size(); ++c) { + context.UpdateConstraintVariableUsage(c); + } + + // Hack for the objective so that it is never considered to appear in only one + // constraint. + for (const CpObjectiveProto& obj : initial_model.objectives()) { + context.var_to_constraints[PositiveRef(obj.objective_var())].insert(-1); + } + + while (!queue.empty() && !context.is_unsat) { + while (!queue.empty() && !context.is_unsat) { + const int c = queue.front(); + in_queue[c] = false; + queue.pop_front(); + + const int old_num_constraint = context.working_model->constraints_size(); + ConstraintProto* ct = context.working_model->mutable_constraints(c); + + // Special generic presolve for reified constraint. + bool changed = PresolveEnforcementLiteral(ct, &context); + + // Generic presolve to exploit variable/literal equivalence. + changed |= ExploitEquivalenceRelations(ct, &context); + + // Because the functions below relies on proper usage stats, we need + // to update it now. + if (changed) { + context.UpdateConstraintVariableUsage(c); + changed = false; + } + + // Call the presolve function for this constraint if any. + switch (ct->constraint_case()) { + case ConstraintProto::ConstraintCase::kBoolOr: + changed |= PresolveBoolOr(ct, &context); + break; + case ConstraintProto::ConstraintCase::kBoolAnd: + changed |= PresolveBoolAnd(ct, &context); + break; + case ConstraintProto::ConstraintCase::kIntMax: + changed |= PresolveIntMax(ct, &context); + break; + case ConstraintProto::ConstraintCase::kIntMin: + changed |= PresolveIntMin(ct, &context); + break; + case ConstraintProto::ConstraintCase::kIntProd: + changed |= PresolveIntProd(ct, &context); + break; + case ConstraintProto::ConstraintCase::kLinear: + changed |= PresolveLinear(ct, &context); + if (ct->constraint_case() == + ConstraintProto::ConstraintCase::kLinear) { + changed |= PresolveLinearIntoClauses(ct, &context); + } + break; + case ConstraintProto::ConstraintCase::kInterval: + changed |= PresolveInterval(ct, &context); + break; + case ConstraintProto::ConstraintCase::kElement: + changed |= PresolveElement(ct, &context); + break; + case ConstraintProto::ConstraintCase::kTable: + changed |= PresolveTable(ct, &context); + break; + default: + break; + } + + // Update the variable <-> constraint graph if needed and add any new + // constraint to the queue of active constraint. + const int new_num_constraints = context.working_model->constraints_size(); + if (!changed) { + CHECK_EQ(new_num_constraints, old_num_constraint); + continue; + } + context.UpdateConstraintVariableUsage(c); + if (new_num_constraints > old_num_constraint) { + context.constraint_to_vars.resize(new_num_constraints); + in_queue.resize(new_num_constraints, true); + for (int c = old_num_constraint; c < new_num_constraints; ++c) { + queue.push_back(c); + context.UpdateConstraintVariableUsage(c); + } + } + } + + // Re-add to the queue constraints that have unique variables. Note that to + // not enter an infinite loop, we call each (var, constraint) pair at most + // once. + for (int v = 0; v < context.var_to_constraints.size(); ++v) { + const std::unordered_set& constraints = + context.var_to_constraints[v]; + if (constraints.size() != 1) continue; + const int c = *constraints.begin(); + if (c < 0) continue; + if (ContainsKey(var_constraint_pair_already_called, + std::pair(v, c))) { + continue; + } + var_constraint_pair_already_called.insert({v, c}); + if (!in_queue[c]) { + in_queue[c] = true; + queue.push_back(c); + } + } + + // Re-add to the queue the constraints that touch a variable that changed. + // + // TODO(user): Avoid reprocessing the constraints that changed the variables + // with the use of timestamp. + for (const int v : modified_domains.PositionsSetAtLeastOnce()) { + if (context.domains[v].IsFixed()) context.ExploitFixedDomain(v); + for (const int c : context.var_to_constraints[v]) { + if (c >= 0 && !in_queue[c]) { + in_queue[c] = true; + queue.push_back(c); + } + } + } + modified_domains.SparseClearAll(); + } + + if (context.is_unsat) { + // Set presolved_model to the simplest UNSAT problem (empty clause). + presolved_model->Clear(); + presolved_model->add_constraints()->mutable_bool_or(); + return; + } + + // Remove all empty or affine constraints (they will be re-added later if + // needed) in the presolved model. Note that we need to remap the interval + // references. + std::vector interval_mapping(presolved_model->constraints_size()); + int new_num_constraints = 0; + const int old_num_constraints = presolved_model->constraints_size(); + for (int i = 0; i < old_num_constraints; ++i) { + const auto type = presolved_model->constraints(i).constraint_case(); + if (type == ConstraintProto::ConstraintCase::CONSTRAINT_NOT_SET) continue; + + if (type == ConstraintProto::ConstraintCase::kInterval) { + interval_mapping[i] = new_num_constraints; + } else { + // TODO(user): for now we don't remove interval because they can be used + // in constraints. + ConstraintProto* ct = presolved_model->mutable_constraints(i); + if (ContainsKey(context.affine_constraints, ct)) { + ct->Clear(); + context.UpdateConstraintVariableUsage(i); + continue; + } + } + presolved_model->mutable_constraints(new_num_constraints++) + ->Swap(presolved_model->mutable_constraints(i)); + } + presolved_model->mutable_constraints()->DeleteSubrange( + new_num_constraints, old_num_constraints - new_num_constraints); + for (ConstraintProto& ct_ref : *presolved_model->mutable_constraints()) { + ApplyToAllIntervalIndices( + [&interval_mapping](int* ref) { + *ref = interval_mapping[*ref]; + DCHECK_NE(-1, *ref); + }, + &ct_ref); + } + + // Add back the affine relations to the presolved model or to the mapping + // model, depending where they are needed. + // + // TODO(user): unfortunately, for now, this duplicates the interval relations + // with a fixed size. + int num_affine_relations = 0; + for (int var = 0; var < presolved_model->variables_size(); ++var) { + const AffineRelation::Relation r = context.GetAffineRelation(var); + if (r.representative == var) continue; + + // We can get rid of this variable, only if: + // - it is not used elsewhere. + // - whatever the value of the representative, we can always find a value + // for this variable. + CpModelProto* proto; + if (context.var_to_constraints[var].empty()) { + // Make sure that domain(representative) is tight. + const auto implied = InverseMultiplicationOfSortedDisjointIntervals( + AdditionOfSortedDisjointIntervals({{-r.offset, -r.offset}}, + context.GetRefDomain(var)), + r.coeff); + if (context.domains[r.representative].IntersectWith(implied)) { + LOG(WARNING) << "Domain of " << r.representative + << " was not fully propagated using the affine relation " + << "(representative =" << r.representative << ", coeff = " + << r.coeff << ", offset = " << r.offset << ")"; + } + proto = context.mapping_model; + } else { + // This is needed for the corner cases where 2 variables in affine + // relation with the same representative are present but no one use + // the representative. This makes sure the code below will not try to + // delete the representative. + context.var_to_constraints[r.representative].insert(-1); + proto = context.working_model; + ++num_affine_relations; + } + + ConstraintProto* ct = proto->add_constraints(); + auto* arg = ct->mutable_linear(); + arg->add_vars(var); + arg->add_coeffs(1); + arg->add_vars(r.representative); + arg->add_coeffs(-r.coeff); + arg->add_domain(r.offset); + arg->add_domain(r.offset); + } + + // Update the variables domain of the presolved_model. + for (int i = 0; i < context.domains.size(); ++i) { + context.domains[i].CopyToIntegerVariableProto( + presolved_model->mutable_variables(i)); + } + + // Set the variables of the mapping_model. + mapping_model->mutable_variables()->CopyFrom(presolved_model->variables()); + + // The strategy variable indices will be remapped in ApplyVariableMapping() + // but first we use the representative of the affine relations for the + // variables that are not present anymore. + for (DecisionStrategyProto& strategy : + *presolved_model->mutable_search_strategy()) { + DecisionStrategyProto copy = strategy; + strategy.clear_variables(); + for (const int ref : copy.variables()) { + const int var = PositiveRef(ref); + if (context.var_to_constraints[var].empty()) { + const AffineRelation::Relation r = context.GetAffineRelation(var); + if (r.representative != var) { + strategy.add_variables((r.coeff == 1) == RefIsPositive(ref) + ? r.representative + : NegatedRef(r.representative)); + } + } else { + strategy.add_variables(ref); + } + } + } + + // Remove all the unused variables from the presolved model. + postsolve_mapping->clear(); + std::vector mapping(presolved_model->variables_size(), -1); + for (int i = 0; i < presolved_model->variables_size(); ++i) { + if (context.var_to_constraints[i].empty()) continue; + mapping[i] = postsolve_mapping->size(); + postsolve_mapping->push_back(i); + } + ApplyVariableMapping(mapping, presolved_model); + + // Stats and checks. + LOG(INFO) << "- " << context.affine_relations.NumRelations() + << " affine relations where detected. " << num_affine_relations + << " where kept."; + LOG(INFO) << "- " << context.var_equiv_relations.NumRelations() + << " variable equivalence relations where detected."; + std::map sorted_rules(context.stats_by_rule_name.begin(), + context.stats_by_rule_name.end()); + for (const auto& entry : sorted_rules) { + if (entry.second == 1) { + LOG(INFO) << "- rule '" << entry.first << "' was applied 1 time."; + } else { + LOG(INFO) << "- rule '" << entry.first << "' was applied " << entry.second + << " times."; + } + } + CHECK_EQ("", ValidateCpModel(*presolved_model)); + CHECK_EQ("", ValidateCpModel(*mapping_model)); +} + +void ApplyVariableMapping(const std::vector& mapping, + CpModelProto* proto) { + // Remap all the variable/literal references in the contraints. + for (ConstraintProto& ct_ref : *proto->mutable_constraints()) { + auto f = [&mapping](int* ref) { + const int image = mapping[PositiveRef(*ref)]; + CHECK_GE(image, 0); + *ref = *ref >= 0 ? image : NegatedRef(image); + }; + ApplyToAllVariableIndices(f, &ct_ref); + ApplyToAllLiteralIndices(f, &ct_ref); + } + + // Remap the objectives. + for (CpObjectiveProto& objective : *proto->mutable_objectives()) { + const int ref = objective.objective_var(); + const int image = mapping[PositiveRef(ref)]; + CHECK_GE(image, 0); + objective.set_objective_var(ref >= 0 ? image : NegatedRef(image)); + } + + // Remap the search decision heuristic. + // Note that we delete any heuristic related to a removed variable. + for (DecisionStrategyProto& strategy : *proto->mutable_search_strategy()) { + DecisionStrategyProto copy = strategy; + strategy.clear_variables(); + for (const int ref : copy.variables()) { + const int image = mapping[PositiveRef(ref)]; + if (image >= 0) { + strategy.add_variables(ref >= 0 ? image : NegatedRef(image)); + } + } + } + + // Move the variable definitions. + std::vector new_variables; + for (int i = 0; i < mapping.size(); ++i) { + const int image = mapping[i]; + if (image < 0) continue; + if (image >= new_variables.size()) { + new_variables.resize(image + 1, IntegerVariableProto()); + } + new_variables[image].Swap(proto->mutable_variables(i)); + } + proto->clear_variables(); + for (IntegerVariableProto& proto_ref : new_variables) { + proto->add_variables()->Swap(&proto_ref); + } + + // Check that all variables are used. + for (const IntegerVariableProto& v : proto->variables()) { + CHECK_GT(v.domain_size(), 0); + } +} + +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/cp_model_presolve.h b/ortools/sat/cp_model_presolve.h new file mode 100644 index 0000000000..f0cac9efa1 --- /dev/null +++ b/ortools/sat/cp_model_presolve.h @@ -0,0 +1,58 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_CP_MODEL_PRESOLVE_H_ +#define OR_TOOLS_SAT_CP_MODEL_PRESOLVE_H_ + +#include "ortools/sat/cp_model.pb.h" + +namespace operations_research { +namespace sat { + +// Presolves the given CpModelProto into presolved_model. +// +// This also creates a mapping model that encode the correspondance between the +// two problems. This works as follow: +// - The first variables of mapping_model are in one to one correspondance with +// the variables of the initial model. +// - The presolved_model variables are in one to one correspondance with the +// variable at the indices given by postsolve_mapping in the mapping model. +// - Fixing one of the two sets of variables and solving the model will assign +// the other set to a feasible solution of the other problem. Moreover, the +// objective value of these solution will be the same. Note that solving such +// problem will take little time in practice because the propagation will +// basically do all the work. +// +// Note(user): an optimization model can be transformed in a decision one if for +// instance the objective is fixed, or independent on the rest of the problem. +// +// TODO(user): Identify disconnected components and returns a vector of +// presolved model? If we go this route, it may be nicer to store the indices +// inside the model. We can add a IntegerVariableProto::initial_index; +void PresolveCpModel(const CpModelProto& initial_model, + CpModelProto* presolved_model, CpModelProto* mapping_model, + std::vector* postsolve_mapping); + +// Replaces all the instance of a variable i (and the literals referring to it) +// by mapping[i]. The definition of variables i is also moved to its new index. +// Variables with a negative mapping value are ignored and it is an error if +// such variable is referenced anywhere (this is CHECKed). +// +// The image of the mapping should be dense in [0, new_num_variables), this is +// also CHECKed. +void ApplyVariableMapping(const std::vector& mapping, CpModelProto* proto); + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_CP_MODEL_PRESOLVE_H_ diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc new file mode 100644 index 0000000000..5c1ccb99eb --- /dev/null +++ b/ortools/sat/cp_model_solver.cc @@ -0,0 +1,1250 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/cp_model_solver.h" + +#include + +#include "ortools/base/timer.h" +#include "ortools/base/join.h" +#include "ortools/base/join.h" +#include "ortools/base/stl_util.h" +#include "ortools/graph/connectivity.h" +#include "ortools/sat/cp_model_checker.h" +#include "ortools/sat/cp_model_presolve.h" +#include "ortools/sat/cp_model_utils.h" +#include "ortools/sat/cumulative.h" +#include "ortools/sat/disjunctive.h" +#include "ortools/sat/intervals.h" +#include "ortools/sat/linear_programming_constraint.h" +#include "ortools/sat/optimization.h" +#include "ortools/sat/sat_solver.h" +#include "ortools/sat/table.h" + +namespace operations_research { +namespace sat { + +namespace { + +// ============================================================================= +// Helper classes. +// ============================================================================= + +// List all the CpModelProto references used. +struct VariableUsage { + const std::vector integers; + const std::vector intervals; + const std::vector booleans; +}; + +VariableUsage ComputeVariableUsage(const CpModelProto& model_proto) { + // Since an interval is a constraint by itself, this will just list all + // the interval constraint in order. + std::vector used_intervals; + + // TODO(user): use std::vector instead of unordered_set + sort if + // efficiency become an issue. Note that we need these to be sorted. + IndexReferences references; + for (int c = 0; c < model_proto.constraints_size(); ++c) { + const ConstraintProto& ct = model_proto.constraints(c); + if (ct.constraint_case() == ConstraintProto::ConstraintCase::kInterval) { + used_intervals.push_back(c); + } + if (HasEnforcementLiteral(ct)) { + references.literals.insert(ct.enforcement_literal(0)); + } + AddReferencesUsedByConstraint(ct, &references); + } + + // Add the objectives and search heuristics variables that needs to be + // referenceable as integer even if they are only used as Booleans. + for (const CpObjectiveProto& objective : model_proto.objectives()) { + references.variables.insert(objective.objective_var()); + } + for (const DecisionStrategyProto& strategy : model_proto.search_strategy()) { + for (const int var : strategy.variables()) { + references.variables.insert(var); + } + } + + std::vector used_integers; + for (const int var : references.variables) { + used_integers.push_back(PositiveRef(var)); + } + STLSortAndRemoveDuplicates(&used_integers); + + std::vector used_booleans; + for (const int lit : references.literals) { + used_booleans.push_back(PositiveRef(lit)); + } + STLSortAndRemoveDuplicates(&used_booleans); + + return VariableUsage{used_integers, used_intervals, used_booleans}; +} + +// Holds the sat::model and the mapping between the proto indices and the +// sat::model ones. +class ModelWithMapping { + public: + ModelWithMapping(const CpModelProto& model_proto, const VariableUsage& usage, + Model* model); + + // Shortcuts for the underlying model_ functions. + template + T Add(std::function f) { + return f(model_); + } + template + T Get(std::function f) const { + return f(*model_); + } + template + T* GetOrCreate() { + return model_->GetOrCreate(); + } + template + void TakeOwnership(T* t) { + return model_->TakeOwnership(t); + } + + bool IsInteger(int i) const { + CHECK_LT(PositiveRef(i), integers_.size()); + return integers_[PositiveRef(i)] != kNoIntegerVariable; + } + + IntegerVariable Integer(int i) const { + CHECK_LT(PositiveRef(i), integers_.size()); + const IntegerVariable var = integers_[PositiveRef(i)]; + CHECK_NE(var, kNoIntegerVariable); + return RefIsPositive(i) ? var : NegationOf(var); + } + + BooleanVariable Boolean(int i) const { + CHECK_GE(i, 0); + CHECK_LT(i, booleans_.size()); + CHECK_NE(booleans_[i], kNoBooleanVariable); + return booleans_[i]; + } + + IntervalVariable Interval(int i) const { + CHECK_GE(i, 0); + CHECK_LT(i, intervals_.size()); + CHECK_NE(intervals_[i], kNoIntervalVariable); + return intervals_[i]; + } + + sat::Literal Literal(int i) const { + CHECK_LT(PositiveRef(i), integers_.size()); + return sat::Literal(booleans_[PositiveRef(i)], RefIsPositive(i)); + } + + template + std::vector Integers(const List& list) const { + std::vector result; + for (const auto i : list) result.push_back(Integer(i)); + return result; + } + + template + std::vector Literals(const ProtoIndices& indices) const { + std::vector result; + for (const int i : indices) result.push_back(ModelWithMapping::Literal(i)); + return result; + } + + template + std::vector Intervals(const ProtoIndices& indices) const { + std::vector result; + for (const int i : indices) result.push_back(Interval(i)); + return result; + } + + const IntervalsRepository& GetIntervalsRepository() const { + const IntervalsRepository* repository = model_->Get(); + return *repository; + } + + std::vector ExtractFullAssignment() const { + std::vector result; + const int num_variables = integers_.size(); + for (int i = 0; i < num_variables; ++i) { + if (integers_[i] != kNoIntegerVariable) { + if (model_->Get(LowerBound(integers_[i])) != + model_->Get(UpperBound(integers_[i]))) { + // Notify that everything is not fixed. + result.clear(); + return {}; + } + if (model_->GetOrCreate()->IsCurrentlyIgnored( + integers_[i])) { + // This variable is "ignored" so it may not be fixed, simply use + // the current lower bound. Any value in its domain should lead to + // a feasible solution. + result.push_back(model_->Get(LowerBound(integers_[i]))); + } else { + result.push_back(model_->Get(Value(integers_[i]))); + } + } else if (booleans_[i] != kNoBooleanVariable) { + result.push_back(model_->Get(Value(booleans_[i]))); + } else { + // This variable is not used anywhere, fix it to its lower_bound. + // TODO(user): maybe it is better to fix it to its lowest possible + // magnitude. + result.push_back(lower_bounds_[i]); + } + } + return result; + } + + private: + Model* model_; + + // Note that only the variables used by at leat one constraint will be + // created, the other will have a kNo[Integer,Interval,Boolean]VariableValue. + std::vector integers_; + std::vector intervals_; + std::vector booleans_; + + // Used to return a feasible solution for the unused variables. + std::vector lower_bounds_; +}; + +template +std::vector ValuesFromProto(const Values& values) { + return std::vector(values.begin(), values.end()); +} + +// Extracts all the used variables in the CpModelProto and creates a sat::Model +// representation for them. +ModelWithMapping::ModelWithMapping(const CpModelProto& model_proto, + const VariableUsage& usage, Model* sat_model) + : model_(sat_model) { + integers_.resize(model_proto.variables_size(), kNoIntegerVariable); + booleans_.resize(model_proto.variables_size(), kNoBooleanVariable); + intervals_.resize(model_proto.constraints_size(), kNoIntervalVariable); + lower_bounds_.resize(model_proto.variables_size(), 0); + + // Fills lower_bounds_, this is only used in ExtractFullAssignment(). + for (int i = 0; i < model_proto.variables_size(); ++i) { + lower_bounds_[i] = model_proto.variables(i).domain(0); + } + + // TODO(user): Detect integers that are the negation of other variable. This + // cannot be simplified by the presolve in the current proto format. + std::vector domain; + std::vector domain_is_boolean; + for (const int i : usage.integers) { + const auto& var_proto = model_proto.variables(i); + integers_[i] = Add(NewIntegerVariable(ReadDomain(var_proto))); + } + + for (const int i : usage.intervals) { + const ConstraintProto& ct = model_proto.constraints(i); + CHECK(!HasEnforcementLiteral(ct)) << "Optional interval not yet supported."; + intervals_[i] = Add(NewInterval(Integer(ct.interval().start()), + Integer(ct.interval().end()), + Integer(ct.interval().size()))); + } + + for (const int i : usage.booleans) { + booleans_[i] = Add(NewBooleanVariable()); + + // We need to fix the Boolean if the domain of the integer variable do not + // contain 0 or contains only zero! Note that this case should not appear + // once the model is presolved. + std::vector domain = + ValuesFromProto(model_proto.variables(i).domain()); + if (domain[0] == 0 && domain[1] == 0) { + // Fix to false. + Add(ClauseConstraint({sat::Literal(booleans_[i], false)})); + } else if (!DomainInProtoContains(model_proto.variables(i), 0)) { + // Fix to true. + Add(ClauseConstraint({sat::Literal(booleans_[i], true)})); + } else if (integers_[i] != kNoIntegerVariable) { + Add(ReifiedInInterval(integers_[i], 0, 0, + sat::Literal(booleans_[i], false))); + } + } +} + +// ============================================================================= +// Constraint loading functions. +// ============================================================================= + +void LoadBoolOrConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + std::vector literals = m->Literals(ct.bool_or().literals()); + if (HasEnforcementLiteral(ct)) { + literals.push_back(m->Literal(ct.enforcement_literal(0)).Negated()); + } + m->Add(ClauseConstraint(literals)); +} + +void LoadBoolAndConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + const std::vector literals = m->Literals(ct.bool_and().literals()); + if (HasEnforcementLiteral(ct)) { + const Literal is_true = m->Literal(ct.enforcement_literal(0)); + for (const Literal lit : literals) m->Add(Implication(is_true, lit)); + } else { + for (const Literal lit : literals) m->Add(ClauseConstraint({lit})); + } +} + +void LoadBoolXorConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + CHECK(!HasEnforcementLiteral(ct)) << "Not supported."; + m->Add(LiteralXorIs(m->Literals(ct.bool_xor().literals()), true)); +} + +void LoadLinearConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + const std::vector vars = m->Integers(ct.linear().vars()); + const std::vector coeffs = ValuesFromProto(ct.linear().coeffs()); + if (ct.linear().domain_size() == 2) { + const int64 lb = ct.linear().domain(0); + const int64 ub = ct.linear().domain(1); + if (!HasEnforcementLiteral(ct)) { + if (lb != kint64min) m->Add(WeightedSumGreaterOrEqual(vars, coeffs, lb)); + if (ub != kint64max) m->Add(WeightedSumLowerOrEqual(vars, coeffs, ub)); + } else { + const Literal is_true = m->Literal(ct.enforcement_literal(0)); + if (lb != kint64min) { + m->Add(ConditionalWeightedSumGreaterOrEqual(is_true, vars, coeffs, lb)); + } + if (ub != kint64max) { + m->Add(ConditionalWeightedSumLowerOrEqual(is_true, vars, coeffs, ub)); + } + } + } else { + std::vector clause; + for (int i = 0; i < ct.linear().domain_size(); i += 2) { + const int64 lb = ct.linear().domain(i); + const int64 ub = ct.linear().domain(i + 1); + const Literal literal(m->Add(NewBooleanVariable()), true); + clause.push_back(literal); + if (lb != kint64min) { + m->Add(ConditionalWeightedSumGreaterOrEqual(literal, vars, coeffs, lb)); + } + if (ub != kint64max) { + m->Add(ConditionalWeightedSumLowerOrEqual(literal, vars, coeffs, ub)); + } + } + if (HasEnforcementLiteral(ct)) { + clause.push_back(m->Literal(ct.enforcement_literal(0)).Negated()); + } + + // TODO(user): In the cases where this clause only contains two literals, + // then we could have only used one literal and its negation above. + m->Add(ClauseConstraint(clause)); + } +} + +void LoadAllDiffConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + const std::vector vars = m->Integers(ct.all_diff().vars()); + m->Add(AllDifferentOnBounds(vars)); +} + +void LoadIntProdConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + const IntegerVariable prod = m->Integer(ct.int_prod().target()); + const std::vector vars = m->Integers(ct.int_prod().vars()); + CHECK_EQ(vars.size(), 2) << "General int_prod not supported yet."; + m->Add(ProductConstraint(vars[0], vars[1], prod)); +} + +void LoadIntDivConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + const IntegerVariable div = m->Integer(ct.int_div().target()); + const std::vector vars = m->Integers(ct.int_div().vars()); + m->Add(DivisionConstraint(vars[0], vars[1], div)); +} + +void LoadIntMinConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + const IntegerVariable min = m->Integer(ct.int_min().target()); + const std::vector vars = m->Integers(ct.int_min().vars()); + m->Add(IsEqualToMinOf(min, vars)); +} + +void LoadIntMaxConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + const IntegerVariable max = m->Integer(ct.int_max().target()); + const std::vector vars = m->Integers(ct.int_max().vars()); + m->Add(IsEqualToMaxOf(max, vars)); +} + +void LoadNoOverlapConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + m->Add(Disjunctive(m->Intervals(ct.no_overlap().intervals()))); +} + +void LoadNoOverlap2dConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + const std::vector x_intervals = + m->Intervals(ct.no_overlap_2d().x_intervals()); + const std::vector y_intervals = + m->Intervals(ct.no_overlap_2d().y_intervals()); + + const IntervalsRepository& repository = m->GetIntervalsRepository(); + std::vector x; + std::vector y; + std::vector dx; + std::vector dy; + for (int i = 0; i < x_intervals.size(); ++i) { + x.push_back(repository.StartVar(x_intervals[i])); + y.push_back(repository.StartVar(y_intervals[i])); + dx.push_back(repository.SizeVar(x_intervals[i])); + dy.push_back(repository.SizeVar(y_intervals[i])); + } + m->Add(StrictNonOverlappingRectangles(x, y, dx, dy)); +} + +void LoadCumulativeConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + const std::vector intervals = + m->Intervals(ct.cumulative().intervals()); + const IntegerVariable capacity = m->Integer(ct.cumulative().capacity()); + const std::vector demands = + m->Integers(ct.cumulative().demands()); + m->Add(Cumulative(intervals, demands, capacity)); +} + +// TODO(user): Be more efficient when the element().vars() are constants. +// Ideally we should avoid creating them as integer variable... +void LoadElementConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + const IntegerVariable index = m->Integer(ct.element().index()); + const IntegerVariable target = m->Integer(ct.element().target()); + const std::vector vars = m->Integers(ct.element().vars()); + + IntegerTrail* integer_trail = m->GetOrCreate(); + if (integer_trail->LowerBound(index) == integer_trail->UpperBound(index)) { + const int64 value = integer_trail->LowerBound(index).value(); + m->Add(Equality(target, vars[value])); + return; + } + + // We always fully encode the index on an element constraint. + const auto encoding = m->Add(FullyEncodeVariable((index))); + std::vector selectors; + std::vector possible_vars; + for (const auto literal_value : encoding) { + const int i = literal_value.value.value(); + CHECK_GE(i, 0) << "Should be presolved."; + CHECK_LT(i, vars.size()) << "Should be presolved."; + possible_vars.push_back(vars[i]); + selectors.push_back(literal_value.literal); + const Literal r = literal_value.literal; + + // TODO(user): Be more efficient if one of the two is a constant. Or handle + // that directly in the model function. + if (vars[i] == target) continue; + m->Add(ConditionalLowerOrEqualWithOffset(vars[i], target, 0, r)); + m->Add(ConditionalLowerOrEqualWithOffset(target, vars[i], 0, r)); + } + m->Add(PartialIsOneOfVar(target, possible_vars, selectors)); +} + +void LoadTableConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + const std::vector vars = m->Integers(ct.table().vars()); + const std::vector values = ValuesFromProto(ct.table().values()); + const int num_vars = vars.size(); + const int num_tuples = values.size() / num_vars; + std::vector> tuples(num_tuples); + int count = 0; + for (int i = 0; i < num_tuples; ++i) { + for (int j = 0; j < num_vars; ++j) { + tuples[i].push_back(values[count++]); + } + } + if (ct.table().negated()) { + m->Add(NegatedTableConstraintWithoutFullEncoding(vars, tuples)); + } else { + m->Add(TableConstraint(vars, tuples)); + } +} + +void LoadAutomataConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + const std::vector vars = m->Integers(ct.automata().vars()); + + const int num_transitions = ct.automata().transition_tail_size(); + std::vector> transitions; + for (int i = 0; i < num_transitions; ++i) { + transitions.push_back({ct.automata().transition_tail(i), + ct.automata().transition_label(i), + ct.automata().transition_head(i)}); + } + + const int64 starting_state = ct.automata().starting_state(); + const std::vector final_states = + ValuesFromProto(ct.automata().final_states()); + m->Add(TransitionConstraint(vars, transitions, starting_state, final_states)); +} + +void LoadCircuitConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + const int num_nodes = ct.circuit().nexts_size(); + const std::vector nexts = m->Integers(ct.circuit().nexts()); + std::vector> graph( + num_nodes, std::vector(num_nodes, kFalseLiteralIndex)); + for (int i = 0; i < num_nodes; ++i) { + if (m->Get(IsFixed(nexts[i]))) { + // This is just an optimization. Note that if nexts[i] is not used in + // other places, we didn't even need to create this constant variable in + // the IntegerTrail... + graph[i][m->Get(Value(nexts[i]))] = kTrueLiteralIndex; + continue; + } else { + const auto encoding = m->Add(FullyEncodeVariable((nexts[i]))); + for (const auto& entry : encoding) { + graph[i][entry.value.value()] = entry.literal.Index(); + } + } + } + m->Add(SubcircuitConstraint(graph)); +} + +// Makes the std::string fit in one line by cutting it in the middle if necessary. +std::string Summarize(const std::string& input) { + if (input.size() < 105) return input; + const int half = 50; + return StrCat(input.substr(0, half), " ... ", + input.substr(input.size() - half, half)); +} + +} // namespace. + +// ============================================================================= +// Public API. +// ============================================================================= + +std::string CpModelStats(const CpModelProto& model_proto) { + std::map num_constraints_by_type; + for (int c = 0; c < model_proto.constraints_size(); ++c) { + num_constraints_by_type[model_proto.constraints(c).constraint_case()]++; + } + const VariableUsage usage = ComputeVariableUsage(model_proto); + + int num_constants = 0; + std::set constant_values; + std::map, int> num_vars_per_domains; + for (const IntegerVariableProto& var : model_proto.variables()) { + if (var.domain_size() == 2 && var.domain(0) == var.domain(1)) { + ++num_constants; + constant_values.insert(var.domain(0)); + } else { + num_vars_per_domains[ReadDomain(var)]++; + } + } + + std::string result; + StrAppend(&result, "Model '", model_proto.name(), "':\n"); + + for (const DecisionStrategyProto& strategy : model_proto.search_strategy()) { + StrAppend(&result, "Search strategy: on ", strategy.variables_size(), + " variables, ", + DecisionStrategyProto::VariableSelectionStrategy_Name( + strategy.variable_selection_strategy()), + ", ", + DecisionStrategyProto::DomainReductionStrategy_Name( + strategy.domain_reduction_strategy()), + "\n"); + } + + StrAppend(&result, "#Variables: ", model_proto.variables_size(), "\n"); + if (num_vars_per_domains.size() < 20) { + for (const auto& entry : num_vars_per_domains) { + const std::string temp = StrCat(" - ", entry.second, " in ", + IntervalsAsString(entry.first), "\n"); + StrAppend(&result, Summarize(temp)); + } + } else { + size_t max_complexity = 0; + int64 min = kint64max; + int64 max = kint64min; + for (const auto& entry : num_vars_per_domains) { + min = std::min(min, entry.first.front().start); + max = std::max(max, entry.first.back().end); + max_complexity = std::max(max_complexity, entry.first.size()); + } + StrAppend(&result, " - ", num_vars_per_domains.size(), + " different domains in [", min, ",", max, + "] with a largest complexity of ", max_complexity, ".\n"); + } + + if (num_constants > 0) { + const std::string temp = + StrCat(" - ", num_constants, " constants in {", + strings::Join(constant_values, ","), "} \n"); + StrAppend(&result, Summarize(temp)); + } + + StrAppend(&result, "#Booleans: ", usage.booleans.size(), "\n"); + StrAppend(&result, "#Integers: ", usage.integers.size(), "\n"); + + std::vector constraints; + for (const auto entry : num_constraints_by_type) { + constraints.push_back( + StrCat("#", ConstraintCaseName(entry.first), ": ", entry.second)); + } + std::sort(constraints.begin(), constraints.end()); + StrAppend(&result, strings::Join(constraints, "\n")); + + return result; +} + +std::string CpSolverResponseStats(const CpSolverResponse& response) { + std::string result; + StrAppend(&result, "CpSolverResponse:"); + StrAppend(&result, + "\nstatus: ", CpSolverStatus_Name(response.status())); + + // We special case the pure-decision problem for clarity. + // + // TODO(user): This test is not ideal for the corner case where the status is + // still UNKNOWN yet we already know that if there is a solution, then its + // objective is zero... + if (response.status() != CpSolverStatus::OPTIMAL && + response.objective_value() == 0 && response.best_objective_bound() == 0) { + StrAppend(&result, "\nobjective: NA"); + StrAppend(&result, "\nbest_bound: NA"); + } else { + StrAppend(&result, "\nobjective: ", response.objective_value()); + StrAppend(&result, "\nbest_bound: ", response.best_objective_bound()); + } + + StrAppend(&result, "\nbooleans: ", response.num_booleans()); + StrAppend(&result, "\nconflicts: ", response.num_conflicts()); + StrAppend(&result, "\nbranches: ", response.num_branches()); + + // TODO(user): This is probably better named "binary_propagation", but we just + // output "propagations" to be consistent with sat/analyze.sh. + StrAppend(&result, + "\npropagations: ", response.num_binary_propagations()); + StrAppend( + &result, "\ninteger_propagations: ", response.num_integer_propagations()); + StrAppend(&result, "\nwalltime: ", response.wall_time()); + StrAppend(&result, "\nusertime: ", response.user_time()); + StrAppend(&result, + "\ndeterministic_time: ", response.deterministic_time()); + StrAppend(&result, "\n"); + return result; +} + +namespace { + +double ScaleObjectiveValue(const CpObjectiveProto& proto, int64 value) { + double result = value + proto.offset(); + if (proto.scaling_factor() == 0) return result; + return proto.scaling_factor() * result; +} + +bool LoadConstraint(const ConstraintProto& ct, ModelWithMapping* m) { + switch (ct.constraint_case()) { + case ConstraintProto::ConstraintCase::CONSTRAINT_NOT_SET: + return true; + case ConstraintProto::ConstraintCase::kBoolOr: + LoadBoolOrConstraint(ct, m); + return true; + case ConstraintProto::ConstraintCase::kBoolAnd: + LoadBoolAndConstraint(ct, m); + return true; + case ConstraintProto::ConstraintCase::kBoolXor: + LoadBoolXorConstraint(ct, m); + return true; + case ConstraintProto::ConstraintProto::kLinear: + LoadLinearConstraint(ct, m); + return true; + case ConstraintProto::ConstraintProto::kAllDiff: + LoadAllDiffConstraint(ct, m); + return true; + case ConstraintProto::ConstraintProto::kIntProd: + LoadIntProdConstraint(ct, m); + return true; + case ConstraintProto::ConstraintProto::kIntDiv: + LoadIntDivConstraint(ct, m); + return true; + case ConstraintProto::ConstraintProto::kIntMin: + LoadIntMinConstraint(ct, m); + return true; + case ConstraintProto::ConstraintProto::kIntMax: + LoadIntMaxConstraint(ct, m); + return true; + case ConstraintProto::ConstraintProto::kInterval: + // Already dealt with. + return true; + case ConstraintProto::ConstraintProto::kNoOverlap: + LoadNoOverlapConstraint(ct, m); + return true; + case ConstraintProto::ConstraintProto::kNoOverlap2D: + LoadNoOverlap2dConstraint(ct, m); + return true; + case ConstraintProto::ConstraintProto::kCumulative: + LoadCumulativeConstraint(ct, m); + return true; + case ConstraintProto::ConstraintProto::kElement: + LoadElementConstraint(ct, m); + return true; + case ConstraintProto::ConstraintProto::kTable: + LoadTableConstraint(ct, m); + return true; + case ConstraintProto::ConstraintProto::kAutomata: + LoadAutomataConstraint(ct, m); + return true; + case ConstraintProto::ConstraintProto::kCircuit: + LoadCircuitConstraint(ct, m); + return true; + default: + return false; + } +} + +// TODO(user): In full generality, we could encode all the constraint as an LP. +void LoadConstraintInGlobalLp(const ConstraintProto& ct, ModelWithMapping* m, + LinearProgrammingConstraint* lp) { + const double kInfinity = std::numeric_limits::infinity(); + if (HasEnforcementLiteral(ct)) return; + if (ct.constraint_case() == ConstraintProto::ConstraintCase::kBoolOr) { + // TODO(user): Support this when the LinearProgrammingConstraint support + // SetCoefficient() with literals. + } else if (ct.constraint_case() == ConstraintProto::ConstraintCase::kIntMax) { + const int target = ct.int_max().target(); + for (const int var : ct.int_max().vars()) { + // This deal with the corner case X = std::max(X, Y, Z, ..) ! + // Note that this can be presolved into X >= Y, X >= Z, ... + if (target == var) continue; + const auto lp_constraint = lp->CreateNewConstraint(-kInfinity, 0.0); + lp->SetCoefficient(lp_constraint, m->Integer(var), 1.0); + lp->SetCoefficient(lp_constraint, m->Integer(target), -1.0); + } + } else if (ct.constraint_case() == ConstraintProto::ConstraintCase::kIntMin) { + const int target = ct.int_min().target(); + for (const int var : ct.int_min().vars()) { + if (target == var) continue; + const auto lp_constraint = lp->CreateNewConstraint(-kInfinity, 0.0); + lp->SetCoefficient(lp_constraint, m->Integer(target), 1.0); + lp->SetCoefficient(lp_constraint, m->Integer(var), -1.0); + } + } else if (ct.constraint_case() == ConstraintProto::ConstraintCase::kLinear) { + // Note that we ignore the holes in the domain... + const int64 min = ct.linear().domain(0); + const int64 max = ct.linear().domain(ct.linear().domain_size() - 1); + if (min == kint64min && max == kint64max) return; + + // This is needed in case of duplicate variables in the linear constraint. + std::unordered_map terms; + for (int i = 0; i < ct.linear().vars_size(); i++) { + terms[m->Integer(ct.linear().vars(i))] += ct.linear().coeffs(i); + } + + const double lb = (min == kint64min) ? -kInfinity : min; + const double ub = (max == kint64max) ? kInfinity : max; + const auto lp_constraint = lp->CreateNewConstraint(lb, ub); + for (const auto entry : terms) { + lp->SetCoefficient(lp_constraint, entry.first, entry.second); + } + } +} + +void FillSolutionInResponse(const CpModelProto& model_proto, + const ModelWithMapping& m, + CpSolverResponse* response) { + const std::vector solution = m.ExtractFullAssignment(); + if (!solution.empty()) { + CHECK(SolutionIsFeasible(model_proto, solution)); + response->clear_solution(); + for (const int64 value : solution) response->add_solution(value); + } else { + // Not all variables are fixed. + // We fill instead the lb/ub of each variables. + response->clear_solution_lower_bounds(); + response->clear_solution_upper_bounds(); + for (int i = 0; i < model_proto.variables_size(); ++i) { + if (m.IsInteger(i)) { + response->add_solution_lower_bounds(m.Get(LowerBound(m.Integer(i)))); + response->add_solution_upper_bounds(m.Get(UpperBound(m.Integer(i)))); + } else { + const int value = m.Get(Value(m.Boolean(i))); + response->add_solution_lower_bounds(value); + response->add_solution_upper_bounds(value); + } + } + } +} + +// Adds one LinearProgrammingConstraint per connected component of the model. +void AddLPConstraints(const CpModelProto& model_proto, ModelWithMapping* m) { + const int num_constraints = model_proto.constraints().size(); + const int num_variables = model_proto.variables().size(); + + // The bipartite graph of LP constraints might be disconnected: + // make a partition of the variables into connected components. + // Constraint nodes are indexed by [0..num_constraints), + // variable nodes by [num_constraints..num_constraints+num_variables). + // TODO(user): look into biconnected components. + ConnectedComponents components; + components.Init(num_constraints + num_variables); + std::vector constraint_has_lp_representation(num_constraints); + auto get_var_index = [num_constraints](int proto_var_index) { + return num_constraints + PositiveRef(proto_var_index); + }; + + for (int i = 0; i < num_constraints; i++) { + const auto& ct = model_proto.constraints(i); + // Skip reified constraints. + if (HasEnforcementLiteral(ct)) continue; + + constraint_has_lp_representation[i] = true; + if (ct.constraint_case() == ConstraintProto::ConstraintCase::kIntMax) { + components.AddArc(i, get_var_index(ct.int_max().target())); + for (const int var : ct.int_max().vars()) { + components.AddArc(i, get_var_index(var)); + } + } else if (ct.constraint_case() == + ConstraintProto::ConstraintCase::kIntMin) { + components.AddArc(i, get_var_index(ct.int_min().target())); + for (const int var : ct.int_min().vars()) { + components.AddArc(i, get_var_index(var)); + } + } else if (ct.constraint_case() == + ConstraintProto::ConstraintCase::kLinear) { + for (const int var : ct.linear().vars()) { + components.AddArc(i, get_var_index(var)); + } + } else { + constraint_has_lp_representation[i] = false; + } + } + + std::unordered_map components_to_size; + for (int i = 0; i < num_constraints; i++) { + if (constraint_has_lp_representation[i]) { + const int id = components.GetClassRepresentative(i); + components_to_size[id] += 1; + } + } + + // Dispatch every constraint to its LinearProgrammingConstraint. + std::unordered_map representative_to_lp_constraint; + std::vector lp_constraints; + IntegerTrail* integer_trail = m->GetOrCreate(); + for (int i = 0; i < num_constraints; i++) { + if (constraint_has_lp_representation[i]) { + const auto& ct = model_proto.constraints(i); + const int id = components.GetClassRepresentative(i); + if (components_to_size[id] <= 1) continue; + if (!ContainsKey(representative_to_lp_constraint, id)) { + auto* lp = new LinearProgrammingConstraint(integer_trail); + representative_to_lp_constraint[id] = lp; + lp_constraints.push_back(lp); + } + LoadConstraintInGlobalLp(ct, m, representative_to_lp_constraint[id]); + } + } + + // Add the objective. + if (model_proto.objectives_size() != 0) { + const int var = model_proto.objectives(0).objective_var(); + const int id = components.GetClassRepresentative(get_var_index(var)); + if (ContainsKey(representative_to_lp_constraint, id)) { + representative_to_lp_constraint[id]->SetObjective(m->Integer(var), true); + } + } + + // Register LP constraints and transfer their ownership to the CP model. + for (auto* lp_constraint : lp_constraints) { + m->TakeOwnership(lp_constraint); + lp_constraint->RegisterWith(m->GetOrCreate()); + } + + VLOG_IF(1, !lp_constraints.empty()) + << "Added " << lp_constraints.size() << " LP constraints."; +} + +// The function responsible for implementing the choosen search strategy. +// +// TODO(user): expose and unit-test, it seems easy to get the order wrong, and +// that would not change the correctness. +struct Strategy { + std::vector variables; + DecisionStrategyProto::VariableSelectionStrategy var_strategy; + DecisionStrategyProto::DomainReductionStrategy domain_strategy; +}; +const std::function ConstructSearchStrategy( + const std::vector& strategies, Model* model) { + IntegerEncoder* const integer_encoder = model->GetOrCreate(); + IntegerTrail* const integer_trail = model->GetOrCreate(); + + // Note that we copy strategies to keep the return function validity + // independently of the life of the passed vector. + return [integer_encoder, integer_trail, strategies]() { + for (const Strategy& strategy : strategies) { + IntegerVariable candidate = kNoIntegerVariable; + IntegerValue candidate_lb; + IntegerValue candidate_ub; + + // TODO(user): Improve the complexity if this becomes an issue which + // may be the case if we do a fixed_search. + for (const IntegerVariable var : strategy.variables) { + if (integer_trail->IsCurrentlyIgnored(var)) continue; + const IntegerValue lb = integer_trail->LowerBound(var); + const IntegerValue ub = integer_trail->UpperBound(var); + if (lb == ub) continue; + bool select = false; + if (candidate == kNoIntegerVariable) { + select = true; + } else { + switch (strategy.var_strategy) { + case DecisionStrategyProto::CHOOSE_FIRST: + break; + case DecisionStrategyProto::CHOOSE_LOWEST_MIN: + select = lb < candidate_lb; + break; + case DecisionStrategyProto::CHOOSE_HIGHEST_MAX: + select = ub > candidate_ub; + break; + case DecisionStrategyProto::CHOOSE_MIN_DOMAIN_SIZE: + select = (ub - lb) < (candidate_ub - candidate_lb); + break; + case DecisionStrategyProto::CHOOSE_MAX_DOMAIN_SIZE: + select = (ub - lb) > (candidate_ub - candidate_lb); + break; + default: + LOG(FATAL) << "Unknown VariableSelectionStrategy " + << strategy.var_strategy; + } + } + if (select) { + candidate = var; + candidate_lb = lb; + candidate_ub = ub; + } + } + if (candidate == kNoIntegerVariable) continue; + + IntegerLiteral literal; + switch (strategy.domain_strategy) { + case DecisionStrategyProto::SELECT_MIN_VALUE: + literal = IntegerLiteral::LowerOrEqual(candidate, candidate_lb); + break; + case DecisionStrategyProto::SELECT_MAX_VALUE: + literal = IntegerLiteral::GreaterOrEqual(candidate, candidate_ub); + break; + case DecisionStrategyProto::SELECT_LOWER_HALF: + literal = IntegerLiteral::LowerOrEqual( + candidate, candidate_lb + (candidate_ub - candidate_lb) / 2); + break; + case DecisionStrategyProto::SELECT_UPPER_HALF: + literal = IntegerLiteral::GreaterOrEqual( + candidate, candidate_ub - (candidate_ub - candidate_lb) / 2); + break; + default: + LOG(FATAL) << "Unknown DomainReductionStrategy " + << strategy.domain_strategy; + } + return integer_encoder->GetOrCreateAssociatedLiteral(literal).Index(); + } + return kNoLiteralIndex; + }; +} + +void ExtractLinearObjective(const CpModelProto& model_proto, + ModelWithMapping* m, + std::vector* linear_vars, + std::vector* linear_coeffs) { + CHECK(!model_proto.objectives().empty()); + const CpObjectiveProto obj = model_proto.objectives(0); + const IntegerVariable objective_var = m->Integer(obj.objective_var()); + + // Default linear objective if we don't find any linear equality defining it. + *linear_vars = {objective_var}; + *linear_coeffs = {IntegerValue(1)}; + + // TODO(user): Expand the linear equation recursively in order to have + // as much term as possible? + for (const ConstraintProto& ct : model_proto.constraints()) { + // Skip everything that is not a linear equality constraint. + if (!ct.enforcement_literal().empty()) continue; + if (ct.constraint_case() != ConstraintProto::ConstraintCase::kLinear) { + continue; + } + if (ct.linear().domain().size() != 2) continue; + if (ct.linear().domain(0) != ct.linear().domain(1)) continue; + + // Find out if objective_var appear in this constraint. + bool present = false; + int64 objective_coeff; + const int num_terms = ct.linear().vars_size(); + for (int i = 0; i < num_terms; ++i) { + const int ref = ct.linear().vars(i); + const int64 coeff = ct.linear().coeffs(i); + if (PositiveRef(ref) == PositiveRef(obj.objective_var())) { + CHECK(!present) << "Duplicate variables not supported"; + present = true; + objective_coeff = (ref == obj.objective_var()) ? coeff : -coeff; + } + } + + // We use the longest equality we can find. + // TODO(user): Deal with objective_coeff with a magnitude greater than 1? + if (present && std::abs(objective_coeff) == 1 && + num_terms > linear_vars->size() + 1) { + linear_vars->clear(); + linear_coeffs->clear(); + const int64 rhs = ct.linear().domain(0); + if (rhs != 0) { + linear_vars->push_back(m->Add(NewIntegerVariable(rhs, rhs))); + linear_coeffs->push_back(IntegerValue(objective_coeff == 1 ? 1 : -1)); + } + for (int i = 0; i < num_terms; ++i) { + const int ref = ct.linear().vars(i); + if (PositiveRef(ref) != PositiveRef(obj.objective_var())) { + linear_vars->push_back(m->Integer(ref)); + const IntegerValue coeff(ct.linear().coeffs(i)); + linear_coeffs->push_back(objective_coeff == 1 ? -coeff : coeff); + } + } + } + } +} + +} // namespace + +CpSolverResponse SolveCpModelWithoutPresolve(const CpModelProto& model_proto, + Model* model) { + // Timing. + WallTimer wall_timer; + UserTimer user_timer; + wall_timer.Start(); + user_timer.Start(); + + // Initialize a default invalid response. + CpSolverResponse response; + response.set_status(CpSolverStatus::MODEL_INVALID); + + // Instanciate all the needed variables. + const VariableUsage usage = ComputeVariableUsage(model_proto); + ModelWithMapping m(model_proto, usage, model); + + const SatParameters& parameters = + model->GetOrCreate()->parameters(); + + // Load the constraints. + std::set unsupported_types; + Trail* trail = model->GetOrCreate(); + for (const ConstraintProto& ct : model_proto.constraints()) { + const int old_num_fixed = trail->Index(); + if (!LoadConstraint(ct, &m)) { + unsupported_types.insert(ConstraintCaseName(ct.constraint_case())); + continue; + } + + // We propagate after each new Boolean constraint but not the integer + // ones. So we call Propagate() manually here. TODO(user): Do that + // automatically? + model->GetOrCreate()->Propagate(); + if (trail->Index() > old_num_fixed) { + VLOG(1) << "Constraint fixed " << trail->Index() - old_num_fixed + << " Boolean variable(s): " << ct.DebugString(); + } + if (model->GetOrCreate()->IsModelUnsat()) { + LOG(INFO) << "UNSAT during extraction (after adding '" + << ConstraintCaseName(ct.constraint_case()) << "'). " + << ct.DebugString(); + break; + } + } + if (!unsupported_types.empty()) { + LOG(INFO) << "There is unsuported constraints types in this model: "; + for (const std::string& type : unsupported_types) { + LOG(INFO) << " - " << type; + } + return response; + } + + // Register the global LP constraint. + // TODO(user): Computes the connected components, and use one constraint per + // component. There is also no need for a constraint with just one equation. + if (parameters.use_global_lp_constraint()) { + AddLPConstraints(model_proto, &m); + } + + // Initialize the search strategy function. + std::function next_decision; + if (model_proto.search_strategy().empty()) { + std::vector decisions; + for (const int i : usage.integers) { + decisions.push_back(m.Integer(i)); + } + next_decision = FirstUnassignedVarAtItsMinHeuristic(decisions, model); + } else { + std::vector strategies; + for (const DecisionStrategyProto& proto : model_proto.search_strategy()) { + strategies.push_back(Strategy()); + Strategy& strategy = strategies.back(); + for (const int ref : proto.variables()) { + strategy.variables.push_back(m.Integer(ref)); + } + strategy.var_strategy = proto.variable_selection_strategy(); + strategy.domain_strategy = proto.domain_reduction_strategy(); + } + next_decision = ConstructSearchStrategy(strategies, model); + } + + // Solve. + int num_solutions = 0; + SatSolver::Status status; + if (model_proto.objectives_size() == 0) { + status = SolveIntegerProblemWithLazyEncoding( + /*assumptions=*/{}, next_decision, model); + if (status == SatSolver::MODEL_SAT) { + FillSolutionInResponse(model_proto, m, &response); + } + } else { + // Optimization problem. + CHECK_EQ(model_proto.objectives_size(), 1); + const CpObjectiveProto obj = model_proto.objectives(0); + const IntegerVariable objective_var = m.Integer(obj.objective_var()); + const auto solution_observer = + [&model_proto, &response, &num_solutions, &obj, &m, + objective_var](const Model& sat_model) { + num_solutions++; + FillSolutionInResponse(model_proto, m, &response); + response.set_objective_value( + ScaleObjectiveValue(obj, sat_model.Get(Value(objective_var)))); + LOG(INFO) << "Solution #" << num_solutions + << " obj:" << response.objective_value() << " num_bool:" + << sat_model.Get()->NumVariables(); + }; + + if (parameters.optimize_with_core()) { + std::vector linear_vars; + std::vector linear_coeffs; + ExtractLinearObjective(model_proto, &m, &linear_vars, &linear_coeffs); + status = MinimizeWithCoreAndLazyEncoding( + /*log_info=*/true, objective_var, linear_vars, linear_coeffs, + next_decision, solution_observer, model); + } else { + status = MinimizeIntegerVariableWithLinearScanAndLazyEncoding( + /*log_info=*/false, objective_var, next_decision, solution_observer, + model); + } + + if (status == SatSolver::LIMIT_REACHED) { + model->GetOrCreate()->Backtrack(0); + if (num_solutions == 0) { + response.set_objective_value( + ScaleObjectiveValue(obj, model->Get(UpperBound(objective_var)))); + } + response.set_best_objective_bound( + ScaleObjectiveValue(obj, model->Get(LowerBound(objective_var)))); + } else if (status == SatSolver::MODEL_SAT) { + // Optimal! + response.set_best_objective_bound(response.objective_value()); + } + } + + // Fill response. + switch (status) { + case SatSolver::LIMIT_REACHED: { + response.set_status(num_solutions != 0 ? CpSolverStatus::MODEL_SAT + : CpSolverStatus::UNKNOWN); + break; + } + case SatSolver::MODEL_SAT: { + response.set_status(model_proto.objectives_size() != 0 + ? CpSolverStatus::OPTIMAL + : CpSolverStatus::MODEL_SAT); + break; + } + case SatSolver::MODEL_UNSAT: { + response.set_status(CpSolverStatus::MODEL_UNSAT); + break; + } + default: + LOG(FATAL) << "Unexpected SatSolver::Status " << status; + } + response.set_num_booleans(model->Get()->NumVariables()); + response.set_num_branches(model->Get()->num_branches()); + response.set_num_conflicts(model->Get()->num_failures()); + response.set_num_binary_propagations( + model->Get()->num_propagations()); + response.set_num_integer_propagations( + model->Get() == nullptr + ? 0 + : model->Get()->num_enqueues()); + response.set_wall_time(wall_timer.Get()); + response.set_user_time(user_timer.Get()); + response.set_deterministic_time( + model->Get()->deterministic_time()); + return response; +} + +CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { + // Validate model_proto. + // TODO(user): provide an option to skip this step for speed? + { + const std::string error = ValidateCpModel(model_proto); + if (!error.empty()) { + LOG(INFO) << error; + CpSolverResponse response; + response.set_status(CpSolverStatus::MODEL_INVALID); + return response; + } + } + + CpModelProto presolved_proto; + CpModelProto mapping_proto; + std::vector postsolve_mapping; + PresolveCpModel(model_proto, &presolved_proto, &mapping_proto, + &postsolve_mapping); + + LOG(INFO) << CpModelStats(presolved_proto); + + CpSolverResponse response = + SolveCpModelWithoutPresolve(presolved_proto, model); + if (response.status() != CpSolverStatus::MODEL_SAT && + response.status() != CpSolverStatus::OPTIMAL) { + return response; + } + + // Postsolve. + for (int i = 0; i < response.solution_size(); ++i) { + auto* var_proto = mapping_proto.mutable_variables(postsolve_mapping[i]); + var_proto->clear_domain(); + var_proto->add_domain(response.solution(i)); + var_proto->add_domain(response.solution(i)); + } + for (int i = 0; i < response.solution_lower_bounds_size(); ++i) { + auto* var_proto = mapping_proto.mutable_variables(postsolve_mapping[i]); + FillDomain( + IntersectionOfSortedDisjointIntervals( + ReadDomain(*var_proto), {{response.solution_lower_bounds(i), + response.solution_upper_bounds(i)}}), + var_proto); + } + Model postsolve_model; + const CpSolverResponse postsolve_response = + SolveCpModelWithoutPresolve(mapping_proto, &postsolve_model); + CHECK_EQ(postsolve_response.status(), CpSolverStatus::MODEL_SAT); + response.clear_solution(); + response.clear_solution_lower_bounds(); + response.clear_solution_upper_bounds(); + if (!postsolve_response.solution().empty()) { + for (int i = 0; i < model_proto.variables_size(); ++i) { + response.add_solution(postsolve_response.solution(i)); + } + CHECK(SolutionIsFeasible(model_proto, + std::vector(response.solution().begin(), + response.solution().end()))); + } else { + for (int i = 0; i < model_proto.variables_size(); ++i) { + response.add_solution_lower_bounds( + postsolve_response.solution_lower_bounds(i)); + response.add_solution_upper_bounds( + postsolve_response.solution_upper_bounds(i)); + } + } + return response; +} + +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/cp_model_solver.h b/ortools/sat/cp_model_solver.h new file mode 100644 index 0000000000..b82eb48ed9 --- /dev/null +++ b/ortools/sat/cp_model_solver.h @@ -0,0 +1,49 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_CP_MODEL_SOLVER_H_ +#define OR_TOOLS_SAT_CP_MODEL_SOLVER_H_ + +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/integer.h" +#include "ortools/sat/model.h" + +namespace operations_research { +namespace sat { + +// Returns a std::string with some statistics on the given CpModelProto. +std::string CpModelStats(const CpModelProto& model); + +// Returns a std::string with some statistics on the solver response. +std::string CpSolverResponseStats(const CpSolverResponse& response); + +// Solves the given CpModelProto. +// +// Note that the API takes a Model* that will be filled with the in-memory +// representation of the given CpModelProto. It is done this way so that it is +// easy to set custom parameters or time limit on the model with calls like: +// - model->SetSingleton(std::move(time_limit)); +// - model->Add(NewSatParameters(parameters_as_string)); +CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model); + +// Same as above, but do not run the CpModelProto presolver. This is exposed for +// internal usage, all client should use the version above. +// +// TODO(user): add a parameters to enable/disable the presolve. +CpSolverResponse SolveCpModelWithoutPresolve(const CpModelProto& model_proto, + Model* model); + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_CP_MODEL_SOLVER_H_ diff --git a/ortools/sat/cp_model_utils.cc b/ortools/sat/cp_model_utils.cc new file mode 100644 index 0000000000..e9ef00645b --- /dev/null +++ b/ortools/sat/cp_model_utils.cc @@ -0,0 +1,327 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/cp_model_utils.h" + +#include "ortools/base/stl_util.h" + +namespace operations_research { +namespace sat { + +namespace { + +template +void AddIndices(const IntList& indices, std::unordered_set* output) { + for (const int index : indices) output->insert(index); +} + +} // namespace + +void AddReferencesUsedByConstraint(const ConstraintProto& ct, + IndexReferences* output) { + switch (ct.constraint_case()) { + case ConstraintProto::ConstraintCase::kBoolOr: + AddIndices(ct.bool_or().literals(), &output->literals); + break; + case ConstraintProto::ConstraintCase::kBoolAnd: + AddIndices(ct.bool_and().literals(), &output->literals); + break; + case ConstraintProto::ConstraintCase::kBoolXor: + AddIndices(ct.bool_xor().literals(), &output->literals); + break; + case ConstraintProto::ConstraintCase::kIntDiv: + output->variables.insert(ct.int_div().target()); + AddIndices(ct.int_div().vars(), &output->variables); + break; + case ConstraintProto::ConstraintCase::kIntMod: + output->variables.insert(ct.int_mod().target()); + AddIndices(ct.int_mod().vars(), &output->variables); + break; + case ConstraintProto::ConstraintCase::kIntMax: + output->variables.insert(ct.int_max().target()); + AddIndices(ct.int_max().vars(), &output->variables); + break; + case ConstraintProto::ConstraintCase::kIntMin: + output->variables.insert(ct.int_min().target()); + AddIndices(ct.int_min().vars(), &output->variables); + break; + case ConstraintProto::ConstraintCase::kIntProd: + output->variables.insert(ct.int_prod().target()); + AddIndices(ct.int_prod().vars(), &output->variables); + break; + case ConstraintProto::ConstraintCase::kLinear: + AddIndices(ct.linear().vars(), &output->variables); + break; + case ConstraintProto::ConstraintCase::kAllDiff: + AddIndices(ct.all_diff().vars(), &output->variables); + break; + case ConstraintProto::ConstraintCase::kElement: + output->variables.insert(ct.element().index()); + output->variables.insert(ct.element().target()); + AddIndices(ct.element().vars(), &output->variables); + break; + case ConstraintProto::ConstraintCase::kCircuit: + AddIndices(ct.circuit().nexts(), &output->variables); + break; + case ConstraintProto::ConstraintCase::kTable: + AddIndices(ct.table().vars(), &output->variables); + break; + case ConstraintProto::ConstraintCase::kAutomata: + AddIndices(ct.automata().vars(), &output->variables); + break; + case ConstraintProto::ConstraintCase::kInterval: + output->variables.insert(ct.interval().start()); + output->variables.insert(ct.interval().end()); + output->variables.insert(ct.interval().size()); + break; + case ConstraintProto::ConstraintCase::kNoOverlap: + AddIndices(ct.no_overlap().intervals(), &output->intervals); + break; + case ConstraintProto::ConstraintCase::kNoOverlap2D: + AddIndices(ct.no_overlap_2d().x_intervals(), &output->intervals); + AddIndices(ct.no_overlap_2d().y_intervals(), &output->intervals); + break; + case ConstraintProto::ConstraintCase::kCumulative: + output->variables.insert(ct.cumulative().capacity()); + AddIndices(ct.cumulative().intervals(), &output->intervals); + AddIndices(ct.cumulative().demands(), &output->variables); + break; + case ConstraintProto::ConstraintCase::CONSTRAINT_NOT_SET: + // Empty constraint. + break; + } +} + +#define APPLY_TO_SINGULAR_FIELD(ct_name, field_name) \ + { \ + int temp = ct->mutable_##ct_name()->field_name(); \ + f(&temp); \ + ct->mutable_##ct_name()->set_##field_name(temp); \ + } + +#define APPLY_TO_REPEATED_FIELD(ct_name, field_name) \ + { \ + for (int& r : *ct->mutable_##ct_name()->mutable_##field_name()) f(&r); \ + } + +void ApplyToAllLiteralIndices(const std::function& f, + ConstraintProto* ct) { + for (int& r : *ct->mutable_enforcement_literal()) f(&r); + switch (ct->constraint_case()) { + case ConstraintProto::ConstraintCase::kBoolOr: + APPLY_TO_REPEATED_FIELD(bool_or, literals); + break; + case ConstraintProto::ConstraintCase::kBoolAnd: + APPLY_TO_REPEATED_FIELD(bool_and, literals); + break; + case ConstraintProto::ConstraintCase::kBoolXor: + APPLY_TO_REPEATED_FIELD(bool_xor, literals); + break; + case ConstraintProto::ConstraintCase::kIntDiv: + break; + case ConstraintProto::ConstraintCase::kIntMod: + break; + case ConstraintProto::ConstraintCase::kIntMax: + break; + case ConstraintProto::ConstraintCase::kIntMin: + break; + case ConstraintProto::ConstraintCase::kIntProd: + break; + case ConstraintProto::ConstraintCase::kLinear: + break; + case ConstraintProto::ConstraintCase::kAllDiff: + break; + case ConstraintProto::ConstraintCase::kElement: + break; + case ConstraintProto::ConstraintCase::kCircuit: + break; + case ConstraintProto::ConstraintCase::kTable: + break; + case ConstraintProto::ConstraintCase::kAutomata: + break; + case ConstraintProto::ConstraintCase::kInterval: + break; + case ConstraintProto::ConstraintCase::kNoOverlap: + break; + case ConstraintProto::ConstraintCase::kNoOverlap2D: + break; + case ConstraintProto::ConstraintCase::kCumulative: + break; + case ConstraintProto::ConstraintCase::CONSTRAINT_NOT_SET: + break; + } +} + +void ApplyToAllVariableIndices(const std::function& f, + ConstraintProto* ct) { + switch (ct->constraint_case()) { + case ConstraintProto::ConstraintCase::kBoolOr: + break; + case ConstraintProto::ConstraintCase::kBoolAnd: + break; + case ConstraintProto::ConstraintCase::kBoolXor: + break; + case ConstraintProto::ConstraintCase::kIntDiv: + APPLY_TO_SINGULAR_FIELD(int_div, target); + APPLY_TO_REPEATED_FIELD(int_div, vars); + break; + case ConstraintProto::ConstraintCase::kIntMod: + APPLY_TO_SINGULAR_FIELD(int_mod, target); + APPLY_TO_REPEATED_FIELD(int_mod, vars); + break; + case ConstraintProto::ConstraintCase::kIntMax: + APPLY_TO_SINGULAR_FIELD(int_max, target); + APPLY_TO_REPEATED_FIELD(int_max, vars); + break; + case ConstraintProto::ConstraintCase::kIntMin: + APPLY_TO_SINGULAR_FIELD(int_min, target); + APPLY_TO_REPEATED_FIELD(int_min, vars); + break; + case ConstraintProto::ConstraintCase::kIntProd: + APPLY_TO_SINGULAR_FIELD(int_prod, target); + APPLY_TO_REPEATED_FIELD(int_prod, vars); + break; + case ConstraintProto::ConstraintCase::kLinear: + APPLY_TO_REPEATED_FIELD(linear, vars); + break; + case ConstraintProto::ConstraintCase::kAllDiff: + APPLY_TO_REPEATED_FIELD(all_diff, vars); + break; + case ConstraintProto::ConstraintCase::kElement: + APPLY_TO_SINGULAR_FIELD(element, index); + APPLY_TO_SINGULAR_FIELD(element, target); + APPLY_TO_REPEATED_FIELD(element, vars); + break; + case ConstraintProto::ConstraintCase::kCircuit: + APPLY_TO_REPEATED_FIELD(circuit, nexts); + break; + case ConstraintProto::ConstraintCase::kTable: + APPLY_TO_REPEATED_FIELD(table, vars); + break; + case ConstraintProto::ConstraintCase::kAutomata: + APPLY_TO_REPEATED_FIELD(automata, vars); + break; + case ConstraintProto::ConstraintCase::kInterval: + APPLY_TO_SINGULAR_FIELD(interval, start); + APPLY_TO_SINGULAR_FIELD(interval, end); + APPLY_TO_SINGULAR_FIELD(interval, size); + break; + case ConstraintProto::ConstraintCase::kNoOverlap: + break; + case ConstraintProto::ConstraintCase::kNoOverlap2D: + break; + case ConstraintProto::ConstraintCase::kCumulative: + APPLY_TO_SINGULAR_FIELD(cumulative, capacity); + APPLY_TO_REPEATED_FIELD(cumulative, demands); + break; + case ConstraintProto::ConstraintCase::CONSTRAINT_NOT_SET: + break; + } +} + +void ApplyToAllIntervalIndices(const std::function& f, + ConstraintProto* ct) { + switch (ct->constraint_case()) { + case ConstraintProto::ConstraintCase::kBoolOr: + break; + case ConstraintProto::ConstraintCase::kBoolAnd: + break; + case ConstraintProto::ConstraintCase::kBoolXor: + break; + case ConstraintProto::ConstraintCase::kIntDiv: + break; + case ConstraintProto::ConstraintCase::kIntMod: + break; + case ConstraintProto::ConstraintCase::kIntMax: + break; + case ConstraintProto::ConstraintCase::kIntMin: + break; + case ConstraintProto::ConstraintCase::kIntProd: + break; + case ConstraintProto::ConstraintCase::kLinear: + break; + case ConstraintProto::ConstraintCase::kAllDiff: + break; + case ConstraintProto::ConstraintCase::kElement: + break; + case ConstraintProto::ConstraintCase::kCircuit: + break; + case ConstraintProto::ConstraintCase::kTable: + break; + case ConstraintProto::ConstraintCase::kAutomata: + break; + case ConstraintProto::ConstraintCase::kInterval: + break; + case ConstraintProto::ConstraintCase::kNoOverlap: + APPLY_TO_REPEATED_FIELD(no_overlap, intervals); + break; + case ConstraintProto::ConstraintCase::kNoOverlap2D: + APPLY_TO_REPEATED_FIELD(no_overlap_2d, x_intervals); + APPLY_TO_REPEATED_FIELD(no_overlap_2d, y_intervals); + break; + case ConstraintProto::ConstraintCase::kCumulative: + APPLY_TO_REPEATED_FIELD(cumulative, intervals); + break; + case ConstraintProto::ConstraintCase::CONSTRAINT_NOT_SET: + break; + } +} + +#undef APPLY_TO_SINGULAR_FIELD +#undef APPLY_TO_REPEATED_FIELD + +std::string ConstraintCaseName(ConstraintProto::ConstraintCase constraint_case) { + switch (constraint_case) { + case ConstraintProto::ConstraintCase::kBoolOr: + return "kBoolOr"; + case ConstraintProto::ConstraintCase::kBoolAnd: + return "kBoolAnd"; + case ConstraintProto::ConstraintCase::kBoolXor: + return "kBoolXor"; + case ConstraintProto::ConstraintCase::kIntDiv: + return "kIntDiv"; + case ConstraintProto::ConstraintCase::kIntMod: + return "kIntMod"; + case ConstraintProto::ConstraintCase::kIntMax: + return "kIntMax"; + case ConstraintProto::ConstraintCase::kIntMin: + return "kIntMin"; + case ConstraintProto::ConstraintCase::kIntProd: + return "kIntProd"; + case ConstraintProto::ConstraintCase::kLinear: + return "kLinear"; + case ConstraintProto::ConstraintCase::kAllDiff: + return "kAllDiff"; + case ConstraintProto::ConstraintCase::kElement: + return "kElement"; + case ConstraintProto::ConstraintCase::kCircuit: + return "kCircuit"; + case ConstraintProto::ConstraintCase::kTable: + return "kTable"; + case ConstraintProto::ConstraintCase::kAutomata: + return "kAutomata"; + case ConstraintProto::ConstraintCase::kInterval: + return "kInterval"; + case ConstraintProto::ConstraintCase::kNoOverlap: + return "kNoOverlap"; + case ConstraintProto::ConstraintCase::kNoOverlap2D: + return "kNoOverlap2D"; + case ConstraintProto::ConstraintCase::kCumulative: + return "kCumulative"; + case ConstraintProto::ConstraintCase::CONSTRAINT_NOT_SET: + return "kEmpty"; + } +} + +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/cp_model_utils.h b/ortools/sat/cp_model_utils.h new file mode 100644 index 0000000000..9a109cd5b6 --- /dev/null +++ b/ortools/sat/cp_model_utils.h @@ -0,0 +1,116 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_SAT_CP_MODEL_UTILS_H_ +#define OR_TOOLS_SAT_CP_MODEL_UTILS_H_ + +#include + +#include "ortools/base/logging.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/util/sorted_interval_list.h" + +namespace operations_research { +namespace sat { + +// Small utility functions to deal with negative variable/literal references. +inline int NegatedRef(int ref) { return -ref - 1; } +inline int PositiveRef(int ref) { return std::max(ref, NegatedRef(ref)); } +inline bool RefIsPositive(int ref) { return ref >= 0; } + +// Small utility functions to deal with half-reified constraints. +inline bool HasEnforcementLiteral(const ConstraintProto& ct) { + return !ct.enforcement_literal().empty(); +} +inline int EnforcementLiteral(const ConstraintProto& ct) { + return ct.enforcement_literal(0); +} + +// Collects all the references used by a constraint. This function is used in a +// few places to have a "generic" code dealing with constraints. Note that the +// enforcement_literal is NOT counted here. +// +// TODO(user): replace this by constant version of the Apply...() functions? +struct IndexReferences { + std::unordered_set variables; + std::unordered_set literals; + std::unordered_set intervals; +}; +void AddReferencesUsedByConstraint(const ConstraintProto& ct, + IndexReferences* output); + +// Applies the given function to all variables/literals/intervals indices of the +// constraint. This function is used in a few places to have a "generic" code +// dealing with constraints. +void ApplyToAllVariableIndices(const std::function& function, + ConstraintProto* ct); +void ApplyToAllLiteralIndices(const std::function& function, + ConstraintProto* ct); +void ApplyToAllIntervalIndices(const std::function& function, + ConstraintProto* ct); + +// Returns the name of the ConstraintProto::ConstraintCase oneof enum. +// Note(user): There is no such function in the proto API as of 16/01/2017. +std::string ConstraintCaseName(ConstraintProto::ConstraintCase constraint_case); + +// Returns true if a proto.domain() contain the given value. +// The domain is expected to be encoded as a sorted disjoint interval list. +template +bool DomainInProtoContains(const ProtoWithDomain& proto, int64 value) { + for (int i = 0; i < proto.domain_size(); i += 2) { + if (value >= proto.domain(i) && value <= proto.domain(i + 1)) return true; + } + return false; +} + +// Sets the domain field of a proto from a sorted interval list. +template +void FillDomain(const std::vector& domain, + ProtoWithDomain* proto) { + proto->clear_domain(); + CHECK(IntervalsAreSortedAndDisjoint(domain)); + for (const ClosedInterval& interval : domain) { + proto->add_domain(interval.start); + proto->add_domain(interval.end); + } +} + +// Extract a sorted interval list from the domain field of a proto. +template +std::vector ReadDomain(const ProtoWithDomain& proto) { + std::vector result; + for (int i = 0; i < proto.domain_size(); i += 2) { + result.push_back({proto.domain(i), proto.domain(i + 1)}); + } + CHECK(IntervalsAreSortedAndDisjoint(result)); + return result; +} + +// Returns the list of values in a given domain. +// This will fail if the domain contains more than one millions values. +template +std::vector AllValuesInDomain(const ProtoWithDomain& proto) { + std::vector result; + for (int i = 0; i < proto.domain_size(); i += 2) { + for (int64 v = proto.domain(i); v <= proto.domain(i + 1); ++v) { + CHECK_LE(result.size(), 1e6); + result.push_back(v); + } + } + return result; +} + +} // namespace sat +} // namespace operations_research + +#endif // OR_TOOLS_SAT_CP_MODEL_UTILS_H_ diff --git a/ortools/sat/disjunctive.cc b/ortools/sat/disjunctive.cc index 808ca7f18c..5932d7463a 100644 --- a/ortools/sat/disjunctive.cc +++ b/ortools/sat/disjunctive.cc @@ -217,7 +217,6 @@ bool DisjunctiveOverloadChecker::Propagate() { num_events_++; } const int num_events = num_events_; - start_event_is_present_.assign(num_events, false); theta_tree_.Reset(num_events); // Introduce events by nondecreasing end_max, check for overloads. @@ -229,7 +228,6 @@ bool DisjunctiveOverloadChecker::Propagate() { { const int current_event = task_to_start_event_[current_task]; const bool is_present = helper_->IsPresent(current_task); - start_event_is_present_[current_event] = is_present; // TODO(user): consider reducing max available duration. const IntegerValue energy_max = helper_->DurationMin(current_task); const IntegerValue energy_min = is_present ? energy_max : IntegerValue(0); @@ -250,7 +248,7 @@ bool DisjunctiveOverloadChecker::Propagate() { theta_tree_.GetEnvelopeOf(critical_event) - 1; for (int event = critical_event; event < num_events; event++) { - if (start_event_is_present_[event]) { + if (theta_tree_.EnergyMin(event) > 0) { const int task = start_event_to_task_[event]; helper_->AddPresenceReason(task); helper_->AddDurationMinReason(task); @@ -280,7 +278,7 @@ bool DisjunctiveOverloadChecker::Propagate() { helper_->DurationMin(optional_task) - 1; for (int event = critical_event; event < num_events; event++) { - if (start_event_is_present_[event]) { + if (theta_tree_.EnergyMin(event) > 0) { const int task = start_event_to_task_[event]; helper_->AddPresenceReason(task); helper_->AddDurationMinReason(task); diff --git a/ortools/sat/disjunctive.h b/ortools/sat/disjunctive.h index 636fe5ef49..e1d0f294d5 100644 --- a/ortools/sat/disjunctive.h +++ b/ortools/sat/disjunctive.h @@ -136,7 +136,6 @@ class DisjunctiveOverloadChecker : public PropagatorInterface { std::vector task_to_start_event_; std::vector start_event_to_task_; std::vector start_event_time_; - std::vector start_event_is_present_; }; class DisjunctiveDetectablePrecedences : public PropagatorInterface { diff --git a/ortools/sat/integer.cc b/ortools/sat/integer.cc index 8ff54b9609..6086ee7bcc 100644 --- a/ortools/sat/integer.cc +++ b/ortools/sat/integer.cc @@ -202,6 +202,15 @@ bool IntegerEncoder::LiteralIsAssociated(IntegerLiteral i) const { return encoding.find(i.bound) != encoding.end(); } +LiteralIndex IntegerEncoder::GetAssociatedLiteral(IntegerLiteral i) { + if (i.var >= encoding_by_var_.size()) return kNoLiteralIndex; + const std::map& encoding = + encoding_by_var_[IntegerVariable(i.var)]; + const auto result = encoding.find(i.bound); + if (result == encoding.end()) return kNoLiteralIndex; + return result->second.Index(); +} + Literal IntegerEncoder::GetOrCreateAssociatedLiteral(IntegerLiteral i) { if (i.var < encoding_by_var_.size()) { const std::map& encoding = @@ -829,67 +838,109 @@ void IntegerTrail::MergeReasonInto(const std::vector& literals, return MergeReasonIntoInternal(output); } +// This will expand the reason of the IntegerLiteral already in tmp_queue_ until +// everything is explained in term of Literal. void IntegerTrail::MergeReasonIntoInternal(std::vector* output) const { - tmp_trail_indices_.clear(); - tmp_var_to_highest_explained_trail_index_.resize(vars_.size(), 0); - DCHECK(std::all_of(tmp_var_to_highest_explained_trail_index_.begin(), - tmp_var_to_highest_explained_trail_index_.end(), + // All relevant trail indices will be >= vars_.size(), so we can safely use + // zero to means that no literal refering to this variable is in the queue. + tmp_var_to_trail_index_in_queue_.resize(vars_.size(), 0); + DCHECK(std::all_of(tmp_var_to_trail_index_in_queue_.begin(), + tmp_var_to_trail_index_in_queue_.end(), [](int v) { return v == 0; })); - // This implement an iterative DFS on a DAG. Each time a node from the - // tmp_queue_ is expanded, we change its sign, so that when we go back - // (equivalent to the return of the recursive call), we can detect that this - // node was already expanded. - // - // To detect nodes from which we already performed the full DFS exploration, - // we use tmp_var_to_highest_explained_trail_index_. - // - // TODO(user): The order in which each trail_index is expanded will change - // how much of the reason is "minimized". Investigate if some order are better - // than other. - while (!tmp_queue_.empty()) { - const bool already_expored = tmp_queue_.back() < 0; - const int trail_index = std::abs(tmp_queue_.back()); + // During the algorithm execution, all the queue entries that do not match the + // content of tmp_var_to_trail_index_in_queue_[] will be ignored. + for (const int trail_index : tmp_queue_) { const TrailEntry& entry = integer_trail_[trail_index]; + tmp_var_to_trail_index_in_queue_[entry.var] = + std::max(tmp_var_to_trail_index_in_queue_[entry.var], trail_index); + } - // Since we already have an explanation for a larger bound (ex: x>=4) we - // don't need to add the explanation for a lower one (ex: x>=2). - if (trail_index <= tmp_var_to_highest_explained_trail_index_[entry.var]) { - tmp_queue_.pop_back(); + // We manage our heap by hand so that we can range iterate over it above, and + // this initial heapify is faster. + std::make_heap(tmp_queue_.begin(), tmp_queue_.end()); + + // We process the entries by highest trail_index first. The content of the + // queue will always be a valid reason for the literals we already added to + // the output. + tmp_to_clear_.clear(); + while (!tmp_queue_.empty()) { + const int trail_index = tmp_queue_.front(); + const TrailEntry& entry = integer_trail_[trail_index]; + std::pop_heap(tmp_queue_.begin(), tmp_queue_.end()); + tmp_queue_.pop_back(); + + // Skip any stale queue entry. Amongst all the entry refering to a given + // variable, only the latest added to the queue is valid and we detect it + // using its trail index. + if (tmp_var_to_trail_index_in_queue_[entry.var] != trail_index) { continue; } - DCHECK_GT(trail_index, 0); - if (already_expored) { - // We are in the "return" of the DFS recursive call. - DCHECK_GT(trail_index, - tmp_var_to_highest_explained_trail_index_[entry.var]); - tmp_var_to_highest_explained_trail_index_[entry.var] = trail_index; - tmp_trail_indices_.push_back(trail_index); - tmp_queue_.pop_back(); - } else { - // We make "recursive calls" from this node. - tmp_queue_.back() = -trail_index; - for (const IntegerLiteral lit : Dependencies(trail_index)) { - // Extract the next_trail_index from the returned literal, we can break - // as soon as we get a negative next_trail_index. See the encoding in - // Dependencies(). - const int next_trail_index = -lit.var; - if (next_trail_index < 0) break; - const TrailEntry& next_entry = integer_trail_[next_trail_index]; - if (next_trail_index > - tmp_var_to_highest_explained_trail_index_[next_entry.var]) { - tmp_queue_.push_back(next_trail_index); - } + // If this entry has an associated literal, then we use it as a reason + // instead of the stored reason. If later this literal needs to be + // explained, then the associated literal will be expanded with the stored + // reason. + { + const LiteralIndex associated_lit = + encoder_->GetAssociatedLiteral(IntegerLiteral::GreaterOrEqual( + IntegerVariable(entry.var), entry.bound)); + if (associated_lit != kNoLiteralIndex) { + output->push_back(Literal(associated_lit).Negated()); + + // Ignore any entries of the queue refering to this variable and make + // sure no such entry are added later. + tmp_to_clear_.push_back(entry.var); + tmp_var_to_trail_index_in_queue_[entry.var] = kint32max; + continue; } } + + // Process this entry. Note that if any of the next expansion include the + // variable entry.var in their reason, we must process it again because we + // cannot easily detect if it was needed to infer the current entry. + // + // Important: the queue might already contains entries refering to the same + // variable. The code act like if we deleted all of them at this point, we + // just do that lazily. tmp_var_to_trail_index_in_queue_[var] will + // only refer to newly added entries. + AppendLiteralsReason(trail_index, output); + tmp_var_to_trail_index_in_queue_[entry.var] = 0; + + // TODO(user): we could speed up Dependencies() by using the indices stored + // in tmp_var_to_trail_index_in_queue_ instead of redoing + // FindLowestTrailIndexThatExplainBound() from the latest trail index. + bool has_dependency = false; + for (const IntegerLiteral lit : Dependencies(trail_index)) { + // Extract the next_trail_index from the returned literal, we can break + // as soon as we get a negative next_trail_index. See the encoding in + // Dependencies(). + const int next_trail_index = -lit.var; + if (next_trail_index < 0) break; + const TrailEntry& next_entry = integer_trail_[next_trail_index]; + has_dependency = true; + + // Only add literals that are not "implied" by the ones already present. + // For instance, do not add (x >= 4) if we already have (x >= 7). This + // translate into only adding a trail index if it is larger than the one + // in the queue refering to the same variable. + if (next_trail_index > tmp_var_to_trail_index_in_queue_[next_entry.var]) { + tmp_var_to_trail_index_in_queue_[next_entry.var] = next_trail_index; + tmp_queue_.push_back(next_trail_index); + std::push_heap(tmp_queue_.begin(), tmp_queue_.end()); + } + } + + // Special case for a "leaf", we will never need this variable again. + if (!has_dependency) { + tmp_to_clear_.push_back(entry.var); + tmp_var_to_trail_index_in_queue_[entry.var] = kint32max; + } } - // Cleanup + output the reason. - for (const int trail_index : tmp_trail_indices_) { - const TrailEntry& entry = integer_trail_[trail_index]; - tmp_var_to_highest_explained_trail_index_[entry.var] = 0; - AppendLiteralsReason(trail_index, output); + // clean-up. + for (const int var : tmp_to_clear_) { + tmp_var_to_trail_index_in_queue_[var] = 0; } STLSortAndRemoveDuplicates(output); } diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index 29896deec0..f30fe92fe8 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -285,6 +285,9 @@ class IntegerEncoder { // Return true iff the given integer literal is associated. bool LiteralIsAssociated(IntegerLiteral i_lit) const; + // Returns the associated literal or kNoLiteralIndex. + LiteralIndex GetAssociatedLiteral(IntegerLiteral i_lit); + // Same as CreateAssociatedLiteral() but safe to call if already created. Literal GetOrCreateAssociatedLiteral(IntegerLiteral i_lit); @@ -618,8 +621,8 @@ class IntegerTrail : public SatPropagator { // Temporary data used by MergeReasonInto(). mutable std::vector tmp_queue_; - mutable std::vector tmp_trail_indices_; - mutable std::vector tmp_var_to_highest_explained_trail_index_; + mutable std::vector tmp_to_clear_; + mutable std::vector tmp_var_to_trail_index_in_queue_; // For EnqueueLiteral(), we store a special TrailEntry to recover the reason // lazily. This vector indicates the correspondance between a literal that diff --git a/ortools/sat/intervals.h b/ortools/sat/intervals.h index a6bcd808b3..f2b709c17d 100644 --- a/ortools/sat/intervals.h +++ b/ortools/sat/intervals.h @@ -347,6 +347,18 @@ inline std::function SizeVar( }; } +inline std::function MinSize(IntervalVariable v) { + return [=](const Model& model) { + return model.Get()->MinSize(v).value(); + }; +} + +inline std::function MaxSize(IntervalVariable v) { + return [=](const Model& model) { + return model.Get()->MaxSize(v).value(); + }; +} + inline std::function IsOptional(IntervalVariable v) { return [=](const Model& model) { return model.Get()->IsOptional(v); @@ -409,6 +421,19 @@ inline std::function NewOptionalInterval( }; } +inline std::function +NewOptionalIntervalWithVariableSize(int64 min_start, int64 max_end, + int64 min_size, int64 max_size, + Literal is_present) { + return [=](Model* model) { + return model->GetOrCreate()->CreateInterval( + model->Add(NewIntegerVariable(min_start, max_end)), + model->Add(NewIntegerVariable(min_start, max_end)), + model->Add(NewIntegerVariable(min_size, max_size)), IntegerValue(0), + is_present.Index()); + }; +} + inline std::function NewIntervalFromStartAndSizeVars( IntegerVariable start, IntegerVariable size) { return [=](Model* model) { diff --git a/ortools/sat/lp_utils.cc b/ortools/sat/lp_utils.cc index ec54ba8384..906e97a3ad 100644 --- a/ortools/sat/lp_utils.cc +++ b/ortools/sat/lp_utils.cc @@ -30,6 +30,232 @@ using operations_research::MPConstraintProto; using operations_research::MPModelProto; using operations_research::MPVariableProto; +bool ConvertMPModelProtoToCpModelProto(const MPModelProto& mp_model, + CpModelProto* cp_model) { + const double kInfinity = std::numeric_limits::infinity(); + CHECK(cp_model != nullptr); + cp_model->Clear(); + cp_model->set_name(mp_model.name()); + + // To make sure we cannot have integer overflow, we use this bound for any + // unbounded variable. + // + // TODO(user): This could be made larger if needed, so be smarter if we have + // MIP problem that we cannot "convert" because of this. Note however than we + // cannot go that much further because we need to make sure we will not run + // into overflow if we add a big linear combination of such variables. It + // should always be possible for an user to scale its problem so that all + // relevant quantities are under a billion. A LP/MIP solver have a similar + // condition in disguise because problem with a difference of more than 6 + // magnitude between the variable values will likely run into numeric trouble. + const int64 kMaxVariableBound = 1ll << 30; + + // Add the variables. + const int num_variables = mp_model.variable_size(); + for (int i = 0; i < num_variables; ++i) { + const MPVariableProto& mp_var = mp_model.variable(i); + IntegerVariableProto* cp_var = cp_model->add_variables(); + cp_var->set_name(mp_var.name()); + + // Note that we must process the lower bound first. + for (const bool lower : {true, false}) { + const double bound = lower ? mp_var.lower_bound() : mp_var.upper_bound(); + if (std::abs(bound) == kInfinity) { + cp_var->add_domain(lower ? -kMaxVariableBound : kMaxVariableBound); + continue; + } + + // Reject larger bound than kMaxVariableBound. We also reject the equality + // so that after the solve, we can detect if one of our "artificial" + // bounds that we add on unbounded variable is restricting the objective. + if (std::floor(std::abs(bound)) >= kMaxVariableBound) { + LOG(ERROR) << "Large bound : " << bound; + return false; + } + + if (mp_var.is_integer()) { + // Note that the cast is "perfect" because we forbid large values. + cp_var->add_domain( + static_cast(lower ? std::ceil(bound) : std::floor(bound))); + } else { + // Continuous variable. We reject non-integer bounds. + // We do nothing if the domain is really small though. + // + // TODO(user): scale the domain. + if (bound != std::round(bound)) { + LOG(ERROR) << "Non-integer bound not supported: " << bound; + return false; + } + cp_var->add_domain(static_cast(bound)); + } + } + } + + // Variables needed to scale the double coefficients into int64. + double max_relative_coeff_error = 0.0; + double max_scaled_sum_error = 0.0; + double max_scaling_factor = 0.0; + double relative_coeff_error = 0.0; + double scaled_sum_error = 0.0; + double scaling_factor = 0.0; + std::vector coefficients; + std::vector lower_bounds; + std::vector upper_bounds; + + // Add the constraints. We scale each of them individually. + for (const MPConstraintProto& mp_constraint : mp_model.constraint()) { + auto* constraint = cp_model->add_constraints(); + constraint->set_name(mp_constraint.name()); + auto* arg = constraint->mutable_linear(); + + // First scale the coefficients of the constraints so that the constraint + // sum can always be computed without integer overflow. + coefficients.clear(); + lower_bounds.clear(); + upper_bounds.clear(); + const int num_coeffs = mp_constraint.coefficient_size(); + for (int i = 0; i < num_coeffs; ++i) { + coefficients.push_back(mp_constraint.coefficient(i)); + const auto& var_proto = cp_model->variables(mp_constraint.var_index(i)); + lower_bounds.push_back(var_proto.domain(0)); + upper_bounds.push_back(var_proto.domain(var_proto.domain_size() - 1)); + } + + // TODO(user): we could use kint64max directly here if our constraint + // propagation code was a bit more careful about integer overflow. + GetBestScalingOfDoublesToInt64(coefficients, lower_bounds, upper_bounds, + kint64max / 2, &scaling_factor, + &relative_coeff_error, &scaled_sum_error); + const int64 gcd = ComputeGcdOfRoundedDoubles(coefficients, scaling_factor); + max_relative_coeff_error = + std::max(relative_coeff_error, max_relative_coeff_error); + max_scaling_factor = std::max(scaling_factor / gcd, max_scaling_factor); + + for (int i = 0; i < num_coeffs; ++i) { + const double scaled_value = mp_constraint.coefficient(i) * scaling_factor; + const int64 value = static_cast(std::round(scaled_value)) / gcd; + if (value != 0) { + arg->add_vars(mp_constraint.var_index(i)); + arg->add_coeffs(value); + } + } + max_scaled_sum_error = + std::max(max_scaled_sum_error, scaled_sum_error / scaling_factor); + + // Add the constraint bounds. Because we are sure the scaled constraint fit + // on an int64, if the scaled bounds are too large, the constraint is either + // always true or always false. + const Fractional lb = mp_constraint.lower_bound(); + const Fractional scaled_lb = + std::round(lb * scaling_factor - scaled_sum_error); + if (lb == -kInfinity || scaled_lb <= kint64min) { + arg->add_domain(kint64min); + } else { + arg->add_domain(static_cast(scaled_lb) / gcd); + } + const Fractional ub = mp_constraint.upper_bound(); + const Fractional scaled_ub = + std::round(ub * scaling_factor + scaled_sum_error); + if (ub == kInfinity || scaled_ub >= kint64max) { + arg->add_domain(kint64max); + } else { + arg->add_domain(static_cast(scaled_ub) / gcd); + } + + // TODO(user): checks feasibility (contains zero) or support that in the + // solver! + if (arg->vars_size() == 0) constraint->Clear(); + } + + // Display the error/scaling without taking into account the objective first. + LOG(INFO) << "Maximum constraint coefficient relative error: " + << max_relative_coeff_error; + LOG(INFO) << "Maximum constraint worst-case sum absolute error: " + << max_scaled_sum_error; + LOG(INFO) << "Maximum constraint scaling factor: " << max_scaling_factor; + + // Add the objective. We use kint64max / 2 because the objective_var will + // also be added to the objective constraint. + const int64 kMaxObjective = kint64max / 2; + coefficients.clear(); + lower_bounds.clear(); + upper_bounds.clear(); + for (int i = 0; i < num_variables; ++i) { + const MPVariableProto& mp_var = mp_model.variable(i); + if (mp_var.objective_coefficient() == 0.0) continue; + coefficients.push_back(mp_var.objective_coefficient()); + const auto& var_proto = cp_model->variables(i); + lower_bounds.push_back(var_proto.domain(0)); + upper_bounds.push_back(var_proto.domain(var_proto.domain_size() - 1)); + } + if (!coefficients.empty()) { + GetBestScalingOfDoublesToInt64(coefficients, lower_bounds, upper_bounds, + kMaxObjective, &scaling_factor, + &relative_coeff_error, &scaled_sum_error); + const int64 gcd = ComputeGcdOfRoundedDoubles(coefficients, scaling_factor); + max_relative_coeff_error = + std::max(relative_coeff_error, max_relative_coeff_error); + + // Display the objective error/scaling. + LOG(INFO) << "objective coefficient relative error: " + << relative_coeff_error; + LOG(INFO) << "objective worst-case absolute error: " + << scaled_sum_error / scaling_factor; + LOG(INFO) << "objective scaling factor: " << scaling_factor / gcd; + + // Note that here we set the scaling factor for the inverse operation of + // getting the "true" objective value from the scaled one. Hence the + // inverse. + auto* objective = cp_model->add_objectives(); + objective->set_offset(mp_model.objective_offset() * scaling_factor / gcd); + objective->set_scaling_factor(1.0 / scaling_factor * gcd); + objective->set_objective_var(cp_model->variables_size()); + { + auto* objective_var = cp_model->add_variables(); + objective_var->set_name("objective"); + objective_var->add_domain(-kMaxObjective); + objective_var->add_domain(kMaxObjective); + } + + // Link the objective variable with a linear constraint. + { + auto* objective_constraint = cp_model->add_constraints(); + auto* objective_arg = objective_constraint->mutable_linear(); + objective_constraint->set_name("objective"); + objective_arg->add_domain(mp_model.maximize() ? 0 : kint64min); + objective_arg->add_domain(mp_model.maximize() ? kint64max : 0); + for (int i = 0; i < num_variables; ++i) { + const MPVariableProto& mp_var = mp_model.variable(i); + const int64 value = + static_cast( + std::round(mp_var.objective_coefficient() * scaling_factor)) / + gcd; + if (value != 0) { + objective_arg->add_vars(i); + objective_arg->add_coeffs(value); + } + } + objective_arg->add_vars(objective->objective_var()); + objective_arg->add_coeffs(-1); + } + + // If the problem was a maximization one, we need to modify the objective. + if (mp_model.maximize()) { + objective->set_objective_var(-objective->objective_var() - 1); + objective->set_scaling_factor(-objective->scaling_factor()); + } + } + + // Test the precision of the conversion. + const double kRelativeTolerance = 1e-4; + if (max_relative_coeff_error > kRelativeTolerance) { + LOG(WARNING) << "The relative error during double -> int64 conversion " + << "is too high!"; + return false; + } + return true; +} + bool ConvertBinaryMPModelProtoToBooleanProblem(const MPModelProto& mp_model, LinearBooleanProblem* problem) { CHECK(problem != nullptr); diff --git a/ortools/sat/lp_utils.h b/ortools/sat/lp_utils.h index d845a1e742..d7d6bea9e9 100644 --- a/ortools/sat/lp_utils.h +++ b/ortools/sat/lp_utils.h @@ -19,11 +19,27 @@ #include "ortools/linear_solver/linear_solver.pb.h" #include "ortools/lp_data/lp_data.h" #include "ortools/sat/boolean_problem.pb.h" +#include "ortools/sat/cp_model.pb.h" #include "ortools/sat/sat_solver.h" namespace operations_research { namespace sat { +// Converts a MIP problem to a CpModel. Returns false if the coefficients +// couldn't be converted to integers with a good enough precision. +// +// Caveats: +// - We do not support bound larger than or equal to 2^30. +// - We cap unbounded variable at 2^30. +// - Non-integer variable must have integer bounds. +// - We do not scale the variable bounds, so by assuming that a non-integer +// variable is integer, we may change the problem significantly if the +// domain is small (like [0.0, 1.0]). +// +// TODO(user): Try to remove some of the restrictions. +bool ConvertMPModelProtoToCpModelProto(const MPModelProto& mp_model, + CpModelProto* cp_model); + // Converts an integer program with only binary variables to a Boolean // optimization problem. Returns false if the problem didn't contains only // binary integer variable, or if the coefficients couldn't be converted to diff --git a/ortools/sat/table.cc b/ortools/sat/table.cc index 5ed57c376f..1df9b6f811 100644 --- a/ortools/sat/table.cc +++ b/ortools/sat/table.cc @@ -13,6 +13,7 @@ #include "ortools/sat/table.h" +#include #include #include "ortools/base/map_util.h" @@ -185,6 +186,61 @@ std::function TableConstraint( }; } +std::function NegatedTableConstraint( + const std::vector& vars, + const std::vector>& tuples) { + return [=](Model* model) { + const int n = vars.size(); + std::vector> mapping(n); + for (int i = 0; i < n; ++i) { + for (const auto pair : model->Add(FullyEncodeVariable(vars[i]))) { + mapping[i][pair.value.value()] = pair.literal; + } + } + + // For each tuple, forbid the variables values to be this tuple. + std::vector clause(n); + for (const std::vector& tuple : tuples) { + bool add_tuple = true; + for (int i = 0; i < n; ++i) { + if (ContainsKey(mapping[i], tuple[i])) { + clause[i] = FindOrDie(mapping[i], tuple[i]).Negated(); + } else { + add_tuple = false; + break; + } + } + if (add_tuple) model->Add(ClauseConstraint(clause)); + } + }; +} + +std::function NegatedTableConstraintWithoutFullEncoding( + const std::vector& vars, + const std::vector>& tuples) { + return [=](Model* model) { + const int n = vars.size(); + IntegerEncoder* encoder = model->GetOrCreate(); + std::vector clause; + for (const std::vector& tuple : tuples) { + clause.clear(); + for (int i = 0; i < n; ++i) { + const int64 value = tuple[i]; + if (value > model->Get(LowerBound(vars[i]))) { + clause.push_back(encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::LowerOrEqual(vars[i], IntegerValue(value - 1)))); + } + if (value < model->Get(UpperBound(vars[i]))) { + clause.push_back(encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(vars[i], + IntegerValue(value + 1)))); + } + } + model->Add(ClauseConstraint(clause)); + } + }; +} + std::function LiteralTableConstraint( const std::vector>& literal_tuples, const std::vector& line_literals) { diff --git a/ortools/sat/table.h b/ortools/sat/table.h index 6171b00e61..6bdc0555c9 100644 --- a/ortools/sat/table.h +++ b/ortools/sat/table.h @@ -27,6 +27,20 @@ std::function TableConstraint( const std::vector& vars, const std::vector>& tuples); +// Enforces that none of the given tuple appear. TODO(user): we could propagate +// more than what we currently do which is simply adding one clause per tuples. +std::function NegatedTableConstraint( + const std::vector& vars, + const std::vector>& tuples); + +// Same as NegatedTableConstraint() but uses a different literal encoding. +// That is, instead of fully encoding the variables and having literal like +// (x != 4) in the clause(s), we use instead two literals: (x < 4) V (x > 4). +// This can be better for variable with large domains. +std::function NegatedTableConstraintWithoutFullEncoding( + const std::vector& vars, + const std::vector>& tuples); + // Enforces that exactly one literal in line_literals is true, and that // all literals in the corresponding line of the literal_tuples matrix are true. // This constraint assumes that exactly one literal per column of the diff --git a/ortools/sat/theta_tree.cc b/ortools/sat/theta_tree.cc index f15ffda867..afb543d40b 100644 --- a/ortools/sat/theta_tree.cc +++ b/ortools/sat/theta_tree.cc @@ -18,42 +18,40 @@ namespace sat { ThetaLambdaTree::ThetaLambdaTree() {} -// Make a tree using the first num_events events of the vectors. void ThetaLambdaTree::Reset(int num_events) { // Use 2^k leaves in the tree, with 2^k >= std::max(num_events, 2). num_events_ = num_events; for (num_leaves_ = 2; num_leaves_ < num_events; num_leaves_ <<= 1) { } - const int num_nodes = 2 * num_leaves_; - tree_energy_min_.assign(num_nodes, IntegerValue(0)); - tree_energy_opt_.assign(num_nodes, IntegerValue(0)); + + // We will never access any indices larger than this because we require the + // event to exist when we call any of the GetEvent*() functions. + const int num_nodes = num_leaves_ + num_events_ + (num_events_ & 1); tree_envelope_.assign(num_nodes, kMinIntegerValue); tree_envelope_opt_.assign(num_nodes, kMinIntegerValue); + tree_sum_of_energy_min_.assign(num_nodes, IntegerValue(0)); + tree_max_of_energy_delta_.assign(num_nodes, IntegerValue(0)); } void ThetaLambdaTree::AddOrUpdateEvent(int event, IntegerValue initial_envelope, IntegerValue energy_min, IntegerValue energy_max) { - DCHECK_LE(0, event); - DCHECK_LT(event, num_events_); DCHECK_LE(0, energy_min); DCHECK_LE(energy_min, energy_max); - const int node = num_leaves_ + event; + const int node = GetLeaf(event); tree_envelope_[node] = initial_envelope + energy_min; - tree_energy_min_[node] = energy_min; tree_envelope_opt_[node] = initial_envelope + energy_max; - tree_energy_opt_[node] = energy_max; + tree_sum_of_energy_min_[node] = energy_min; + tree_max_of_energy_delta_[node] = energy_max - energy_min; RefreshNode(node); } void ThetaLambdaTree::RemoveEvent(int event) { - DCHECK_LE(0, event); - DCHECK_LT(event, num_events_); - const int node = num_leaves_ + event; + const int node = GetLeaf(event); tree_envelope_[node] = kMinIntegerValue; - tree_energy_min_[node] = IntegerValue(0); tree_envelope_opt_[node] = kMinIntegerValue; - tree_energy_opt_[node] = IntegerValue(0); + tree_sum_of_energy_min_[node] = IntegerValue(0); + tree_max_of_energy_delta_[node] = IntegerValue(0); RefreshNode(node); } @@ -65,7 +63,9 @@ IntegerValue ThetaLambdaTree::GetOptionalEnvelope() const { int ThetaLambdaTree::GetMaxEventWithEnvelopeGreaterThan( IntegerValue target_envelope) const { DCHECK_LT(target_envelope, tree_envelope_[1]); - return GetMaxLeafWithEnvelopeGreaterThan(1, target_envelope) - num_leaves_; + IntegerValue unused; + return GetMaxLeafWithEnvelopeGreaterThan(1, target_envelope, &unused) - + num_leaves_; } void ThetaLambdaTree::GetEventsWithOptionalEnvelopeGreaterThan( @@ -80,12 +80,10 @@ void ThetaLambdaTree::GetEventsWithOptionalEnvelopeGreaterThan( } IntegerValue ThetaLambdaTree::GetEnvelopeOf(int event) const { - DCHECK_LE(0, event); - DCHECK_LT(event, num_events_); - IntegerValue env = tree_envelope_[event + num_leaves_]; + IntegerValue env = tree_envelope_[GetLeaf(event)]; for (int node = event + num_leaves_; node > 1; node >>= 1) { const int right = node | 1; - if (right != node) env += tree_energy_min_[right]; + if (right != node) env += tree_sum_of_energy_min_[right]; } return env; } @@ -95,56 +93,53 @@ void ThetaLambdaTree::RefreshNode(int node) { const int right = node | 1; const int left = right ^ 1; node >>= 1; - tree_energy_min_[node] = tree_energy_min_[left] + tree_energy_min_[right]; - tree_envelope_[node] = std::max( - tree_envelope_[right], tree_envelope_[left] + tree_energy_min_[right]); - tree_energy_opt_[node] = - std::max(tree_energy_min_[left] + tree_energy_opt_[right], - tree_energy_min_[right] + tree_energy_opt_[left]); - tree_envelope_opt_[node] = - std::max(tree_envelope_opt_[left] + tree_energy_min_[right], - tree_envelope_[left] + tree_energy_opt_[right]); - tree_envelope_opt_[node] = - std::max(tree_envelope_opt_[node], tree_envelope_opt_[right]); + const IntegerValue energy_right = tree_sum_of_energy_min_[right]; + tree_sum_of_energy_min_[node] = + tree_sum_of_energy_min_[left] + energy_right; + tree_max_of_energy_delta_[node] = std::max(tree_max_of_energy_delta_[right], + tree_max_of_energy_delta_[left]); + tree_envelope_[node] = + std::max(tree_envelope_[right], tree_envelope_[left] + energy_right); + tree_envelope_opt_[node] = std::max( + tree_envelope_opt_[right], + energy_right + + std::max(tree_envelope_opt_[left], + tree_envelope_[left] + tree_max_of_energy_delta_[right])); } while (node > 1); } int ThetaLambdaTree::GetMaxLeafWithEnvelopeGreaterThan( - int node, IntegerValue target_envelope) const { + int node, IntegerValue target_envelope, IntegerValue* extra) const { DCHECK_LT(target_envelope, tree_envelope_[node]); while (node < num_leaves_) { const int left = node << 1; const int right = left | 1; + DCHECK_LT(right, tree_envelope_.size()); if (target_envelope < tree_envelope_[right]) { node = right; } else { - target_envelope -= tree_energy_min_[right]; + target_envelope -= tree_sum_of_energy_min_[right]; node = left; } } + *extra = tree_envelope_[node] - target_envelope; return node; } -int ThetaLambdaTree::GetMaxLeafWithOptionalEnergyGreaterThan( - int node, IntegerValue node_available_energy, - IntegerValue* available_energy) const { - DCHECK_LT(node_available_energy, tree_energy_opt_[node]); +int ThetaLambdaTree::GetLeafWithMaxEnergyDelta(int node) const { + const IntegerValue delta_node = tree_max_of_energy_delta_[node]; while (node < num_leaves_) { const int left = node << 1; const int right = left | 1; - - const IntegerValue available_energy_right = - node_available_energy - tree_energy_min_[left]; - if (available_energy_right < tree_energy_opt_[right]) { - node_available_energy = available_energy_right; + DCHECK_LT(right, tree_envelope_.size()); + if (tree_max_of_energy_delta_[right] == delta_node) { node = right; - } else { // available_energy_left < tree_energy_opt_[left] - node_available_energy -= tree_energy_min_[right]; + } else { + DCHECK_EQ(tree_max_of_energy_delta_[left], delta_node); node = left; } } - *available_energy = node_available_energy; return node; } @@ -156,25 +151,31 @@ void ThetaLambdaTree::GetLeavesWithOptionalEnvelopeGreaterThan( while (node < num_leaves_) { const int left = node << 1; const int right = left | 1; + DCHECK_LT(right, tree_envelope_.size()); if (target_envelope < tree_envelope_opt_[right]) { node = right; - } else if (target_envelope < - tree_envelope_[left] + tree_energy_opt_[right]) { - *critical_leaf = - GetMaxLeafWithEnvelopeGreaterThan(left, tree_envelope_[left] - 1); - *optional_leaf = GetMaxLeafWithOptionalEnergyGreaterThan( - right, target_envelope - tree_envelope_[left], available_energy); - return; - } else { // < tree_envelope_opt_[left] + tree_energy_min_[right] - target_envelope -= tree_energy_min_[right]; - node = left; + } else { + const IntegerValue opt_energy_right = + tree_sum_of_energy_min_[right] + tree_max_of_energy_delta_[right]; + if (target_envelope < tree_envelope_[left] + opt_energy_right) { + *optional_leaf = GetLeafWithMaxEnergyDelta(right); + IntegerValue extra; + *critical_leaf = GetMaxLeafWithEnvelopeGreaterThan( + left, target_envelope - opt_energy_right, &extra); + *available_energy = tree_sum_of_energy_min_[*optional_leaf] + + tree_max_of_energy_delta_[*optional_leaf] - extra; + return; + } else { // < tree_envelope_opt_[left] + tree_sum_of_energy_min_[right] + target_envelope -= tree_sum_of_energy_min_[right]; + node = left; + } } } *critical_leaf = node; *optional_leaf = node; *available_energy = - target_envelope - tree_envelope_[node] + tree_energy_min_[node]; + target_envelope - tree_envelope_[node] + tree_sum_of_energy_min_[node]; } } // namespace sat diff --git a/ortools/sat/theta_tree.h b/ortools/sat/theta_tree.h index e04370e4b7..443fece2c1 100644 --- a/ortools/sat/theta_tree.h +++ b/ortools/sat/theta_tree.h @@ -52,30 +52,24 @@ namespace sat { // that can be present or absent, and present events come with an // initial_envelope, a minimal and a maximal energy. // All nodes maintain values on the set of present events under them: -// _ energy_min(node) = sum_{leaf \in leaves(node)} energy_min(leaf) +// _ sum_energy_min(node) = sum_{leaf \in leaves(node)} energy_min(leaf) // _ envelope(node) = // max_{leaf \in leaves(node)} // initial_envelope(leaf) + // sum_{leaf' \in leaves(node), leaf' >= leaf} energy_min(leaf'). // // Thus, the envelope of a leaf representing an event, when present, is -// initial_envelope(event) + energy_min(event). +// initial_envelope(event) + sum_energy_min(event). // -// envelope_opt and energy_opt are similar, but represent the maximum value -// a node could have if one leaf took its maximum energy: -// _ energy_opt(node) = sum_{leaf \in leaves(node)} energy_min(leaf) -// + max_{leaf \in leaves(node)} -// energy_max(leaf) - energy_min(leaf) +// We also maintain envelope_opt with is the maximum envelope a node could take +// if at most one of the event where at its maximum energy. +// _ energy_delta(leaf) = energy_max(leaf) - energy_min(leaf) +// _ max_energy_delta(node) = max_{leaf \in leaves(node)} energy_delta(leaf) // _ envelope_opt(node) = // max_{leaf \in leaves(node)} // initial_envelope(leaf) + // sum_{leaf' \in leaves(node), leaf' >= leaf} energy_min(leaf') + -// max_{leaf_opt \in leaves(node)} -// . (energy_max(leaf_opt) - energy_min(leaf_opt)) -// max_{leaf_opt \in leaves(node)} -// initial_envelope(leaf_opt) + -// energy_max(leaf_opt) - energy_min(leaf_opt) + -// sum_{leaf' \in leaves(node), leaf' >= leaf_opt} energy_min(leaf') +// max_{leaf' \in leaves(node), leaf' >= leaf} energy_delta(leaf'); // // Most articles using theta-tree variants hack Vilim's original theta tree // for the disjunctive resource constraint by manipulating envelope and @@ -134,43 +128,49 @@ class ThetaLambdaTree { // Computes a pair of events (critical_event, optional_event) such that // if optional_event was at its maximum energy, the envelope of critical_event // would be greater than target_envelope. - // This assumes that such a pair exists, i.e. GetOptionalEnvelope() - // should be greater than target_envelope. - // More formally, this finds events such that - // initial_envelope(critical_event) + - // sum_{event' >= critical_event} energy_min(event') + - // max_{optional_event >= critical_event} - // (energy_max(optional_event) - energy_min(optional_event)) - // > target envelope. + // + // This assumes that such a pair exists, i.e. GetOptionalEnvelope() should be + // greater than target_envelope. More formally, this finds events such that: + // initial_envelope(critical_event) + + // sum_{event' >= critical_event} energy_min(event') + + // max_{optional_event >= critical_event} energy_delta(optional_event) + // > target envelope. + // // For efficiency reasons, this also fills available_energy with the maximum - // value such that the optional envelope of the pair would be target_envelope, - // i.e. target_envelope - GetEnvelopeOf(event) + energy_min(optional_event). + // energy the optional task can take such that the optional envelope of the + // pair would be target_envelope, i.e. + // target_envelope - GetEnvelopeOf(event) + energy_min(optional_event). + // // This operation is O(log n). void GetEventsWithOptionalEnvelopeGreaterThan( IntegerValue target_envelope, int* critical_event, int* optional_event, IntegerValue* available_energy) const; + // Getters. + IntegerValue EnergyMin(int event) const { + return tree_sum_of_energy_min_[GetLeaf(event)]; + } + private: + // Returns the index of the leaf associated with the given event. + int GetLeaf(int event) const { + DCHECK_LE(0, event); + DCHECK_LT(event, num_events_); + return num_leaves_ + event; + } + // Propagates the change of leaf energies and envelopes towards the root. void RefreshNode(int leaf); // Finds the maximum leaf under node such that // initial_envelope(leaf) + sum_{leaf' >= leaf} energy_min(leaf') - // > target_envelope. - int GetMaxLeafWithEnvelopeGreaterThan(int node, - IntegerValue target_envelope) const; + // > target_envelope. + // Fills extra with the difference. + int GetMaxLeafWithEnvelopeGreaterThan(int node, IntegerValue target_envelope, + IntegerValue* extra) const; - // Returns the maximum leaf under node whose optional energy would overload - // node. - // Finds the maximum leaf under node such that - // sum_{leaf' under node} energy_min(leaf') + - // energy_max(leaf) - energy_min(leaf) > node_available_energy. - // available_energy will be the energy available for this leaf, - // i.e. node_available_energy - sum_{leaf' under node} energy_min(leaf') + - // energy_min(leaf). - int GetMaxLeafWithOptionalEnergyGreaterThan( - int node, IntegerValue node_available_energy, - IntegerValue* available_energy) const; + // Returns the leaf with maximum energy delta under node. + int GetLeafWithMaxEnergyDelta(int node) const; // Finds the leaves and energy relevant for // GetEventsWithOptionalEnvelopeGreaterThan(). @@ -187,9 +187,9 @@ class ThetaLambdaTree { // Envelopes and energies of nodes. std::vector tree_envelope_; - std::vector tree_energy_min_; std::vector tree_envelope_opt_; - std::vector tree_energy_opt_; + std::vector tree_sum_of_energy_min_; + std::vector tree_max_of_energy_delta_; }; } // namespace sat diff --git a/ortools/util/affine_relation.h b/ortools/util/affine_relation.h new file mode 100644 index 0000000000..b16f492643 --- /dev/null +++ b/ortools/util/affine_relation.h @@ -0,0 +1,210 @@ +// Copyright 2010-2014 Google +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OR_TOOLS_UTIL_AFFINE_RELATION_H_ +#define OR_TOOLS_UTIL_AFFINE_RELATION_H_ + +#include + +#include "ortools/base/logging.h" +#include "ortools/base/macros.h" +#include "ortools/base/port.h" +#include "ortools/base/iterator_adaptors.h" + +namespace operations_research { + +// Union-Find algorithm to maintain "representative" for relations of the form: +// x = coeff * y + offset, where "coeff" and "offset" are integers. The variable +// x and y are represented by non-negative integer indices. The idea is to +// express variables in affine relation using as little different variables as +// possible (the representatives). +// +// IMPORTANT: If there are relations with std::abs(coeff) != 1, then some +// relations might be ignored. See TryAdd() for more details. +// +// TODO(user): it might be possible to do something fancier and drop less +// relations if all the affine relations are given before hand. +class AffineRelation { + public: + AffineRelation() : num_relations_(0) {} + + // Returns the number of relations added to the class and not ignored. + int NumRelations() const { return num_relations_; } + + // Adds the relation x = coeff * y + offset to the class. + // Returns true if it wasn't ignored. + // + // This relation will only be taken into account if the representative of x + // and the one of y are different and if the relation can be transformed into + // a similar relation with integer coefficient between the two + // representatives. + // + // That is, given that: + // - x = coeff_x * representative_x + offset_x + // - y = coeff_y * representative_y + offset_y + // we have: + // coeff_x * representative_x + offset_x = + // coeff * coeff_y * representative_y + coeff * offset_y + offset. + // Which can be simplified with the introduction of new variables to: + // coeff_x * representative_x = new_coeff * representative_y + new_offset. + // And we can merge the two if: + // - new_coeff and new_offset are divisible by coeff_x. + // - OR coeff_x and new_offset are divisible by new_coeff. + // + // Checked preconditions: x >=0, y >= 0, coeff != 0 and x != y. + // + // IMPORTANT: we do not check for integer overflow, but that could be added + // if it is needed. + bool TryAdd(int x, int y, int64 coeff, int64 offset); + + // Always try to make y the representative of x. + bool TryAddInGivenOrder(int x, int y, int64 coeff, int64 offset); + + // Returns a valid relation of the form x = coeff * representative + offset. + // Note that this can return x = x. Non-const because of path-compression. + struct Relation { + int representative; + int64 coeff; + int64 offset; + // TUPLE_DEFINE_STRUCT(Relation, (ctor, ostream, eq), (int, representative), + // (int64, coeff), (int64, offset)); + }; + Relation Get(int x); + + private: + void IncreaseSizeOfMemberVectors(int new_size) { + if (new_size <= representative_.size()) return; + for (int i = representative_.size(); i < new_size; ++i) { + representative_.push_back(i); + } + offset_.resize(new_size, 0); + coeff_.resize(new_size, 1); + size_.resize(new_size, 1); + } + + void CompressPath(int x) { + DCHECK_GE(x, 0); + DCHECK_LT(x, representative_.size()); + tmp_path_.clear(); + int parent = x; + while (parent != representative_[parent]) { + tmp_path_.push_back(parent); + parent = representative_[parent]; + } + for (const int var : ::gtl::reversed_view(tmp_path_)) { + const int old_parent = representative_[var]; + offset_[var] += coeff_[var] * offset_[old_parent]; + coeff_[var] *= coeff_[old_parent]; + representative_[var] = parent; + } + } + + int num_relations_; + + // The equivalence class representative for each variable index. + std::vector representative_; + + // The offset and coefficient such that + // variable[index] = coeff * variable[representative_[index]] + offset; + std::vector coeff_; + std::vector offset_; + + // The size of each representative "tree", used to get a good complexity when + // we have the choice of which tree to merge into the other. + // + // TODO(user): Using a "rank" might be faster, but because we sometimes + // need to merge the bad subtree into the better one, it is trickier to + // maintain than in the classic union-find algorihtm. + std::vector size_; + + // Used by CompressPath() to maintain the coeff/offset during compression. + std::vector tmp_path_; +}; + +inline bool AffineRelation::TryAdd(int x, int y, int64 coeff, int64 offset) { + CHECK_NE(coeff, 0); + CHECK_NE(x, y); + CHECK_GE(x, 0); + CHECK_GE(y, 0); + IncreaseSizeOfMemberVectors(std::max(x, y) + 1); + CompressPath(x); + CompressPath(y); + const int rep_x = representative_[x]; + const int rep_y = representative_[y]; + if (rep_x == rep_y) return false; + + // TODO(user): It should be possible to optimize this code block a bit, for + // instance depending on the magnitude of new_coeff vs coeff_x, we may already + // know that one of the two merge is not possible. + const int64 coeff_x = coeff_[x]; + const int64 new_coeff = coeff * coeff_[y]; + const int64 new_offset = coeff * offset_[y] + offset - offset_[x]; + const bool condition1 = + (new_coeff % coeff_x == 0) && (new_offset % coeff_x == 0); + const bool condition2 = + (coeff_x % new_coeff == 0) && (new_offset % new_coeff == 0); + if (condition1 && (!condition2 || size_[x] <= size_[y])) { + representative_[rep_x] = rep_y; + size_[rep_y] += size_[rep_x]; + coeff_[rep_x] = new_coeff / coeff_x; + offset_[rep_x] = new_offset / coeff_x; + } else if (condition2) { + representative_[rep_y] = rep_x; + size_[rep_x] += size_[rep_y]; + coeff_[rep_y] = coeff_x / new_coeff; + offset_[rep_y] = -new_offset / new_coeff; + } else { + return false; + } + ++num_relations_; + return true; +} + +// TODO(user): Find a clean way to share the code with the function above +// once we are sure this is needed. Another option would have been to provide +// some kind of prefered representative to this class. +inline bool AffineRelation::TryAddInGivenOrder(int x, int y, int64 coeff, + int64 offset) { + CHECK_NE(coeff, 0); + CHECK_NE(x, y); + CHECK_GE(x, 0); + CHECK_GE(y, 0); + IncreaseSizeOfMemberVectors(std::max(x, y) + 1); + CompressPath(x); + CompressPath(y); + const int rep_x = representative_[x]; + const int rep_y = representative_[y]; + if (rep_x == rep_y) return false; + const int64 coeff_x = coeff_[x]; + const int64 new_coeff = coeff * coeff_[y]; + const int64 new_offset = coeff * offset_[y] + offset - offset_[x]; + if ((new_coeff % coeff_x == 0) && (new_offset % coeff_x == 0)) { + representative_[rep_x] = rep_y; + size_[rep_y] += size_[rep_x]; + coeff_[rep_x] = new_coeff / coeff_x; + offset_[rep_x] = new_offset / coeff_x; + ++num_relations_; + return true; + } + return false; +} + +inline AffineRelation::Relation AffineRelation::Get(int x) { + if (x >= representative_.size() || representative_[x] == x) return {x, 1, 0}; + CompressPath(x); + return {representative_[x], coeff_[x], offset_[x]}; +} + +} // namespace operations_research + +#endif // OR_TOOLS_UTIL_AFFINE_RELATION_H_