diff --git a/examples/cpp/BUILD b/examples/cpp/BUILD index c83a4f8a98..2f461bcd51 100644 --- a/examples/cpp/BUILD +++ b/examples/cpp/BUILD @@ -12,9 +12,9 @@ cc_binary( name = "cryptarithm", srcs = ["cryptarithm.cc"], deps = [ - "@com_google_protobuf_cc//:protobuf", "//ortools/base", "//ortools/constraint_solver:cp", + "@com_google_protobuf_cc//:protobuf", ], ) @@ -230,7 +230,7 @@ cc_binary( srcs = ["integer_programming.cc"], deps = [ "//ortools/base", - "//ortools/linear_solver:linear_solver", + "//ortools/linear_solver", ], ) @@ -309,7 +309,7 @@ cc_binary( copts = ["-DUSE_GLOP"], deps = [ "//ortools/base", - "//ortools/linear_solver:linear_solver", + "//ortools/linear_solver", "//ortools/linear_solver:linear_solver_cc_proto", ], ) @@ -319,7 +319,7 @@ cc_binary( srcs = ["linear_solver_protocol_buffers.cc"], deps = [ "//ortools/base", - "//ortools/linear_solver:linear_solver", + "//ortools/linear_solver", "//ortools/linear_solver:linear_solver_cc_proto", ], ) @@ -415,28 +415,28 @@ cc_binary( "sat_runner.cc", ], deps = [ - "@com_google_protobuf_cc//:protobuf", + "//ortools/algorithms:sparse_permutation", + "//ortools/base", + "//ortools/base:file", + "//ortools/base:random", + "//ortools/base:status", + "//ortools/base:strings", + "//ortools/base:threadpool", + "//ortools/lp_data:mps_reader", + "//ortools/lp_data:proto_utils", "//ortools/sat:boolean_problem", "//ortools/sat:boolean_problem_cc_proto", - # "//ortools/sat:cp_model_proto", - # "//ortools/sat:cp_model_solver", + "//ortools/sat:cp_model_cc_proto", + "//ortools/sat:cp_model_solver", "//ortools/sat:drat", "//ortools/sat:lp_utils", "//ortools/sat:optimization", "//ortools/sat:sat_solver", "//ortools/sat:simplification", "//ortools/sat:symmetry", - "//ortools/base", - "//ortools/base:file", - "//ortools/base:strings", - "//ortools/base:status", - "//ortools/base:random", - "//ortools/base:threadpool", - "//ortools/algorithms:sparse_permutation", - "//ortools/lp_data:mps_reader", - "//ortools/lp_data:proto_utils", "//ortools/util:filelineiter", "//ortools/util:time_limit", + "@com_google_protobuf_cc//:protobuf", ], ) @@ -470,7 +470,7 @@ cc_binary( ], deps = [ "//ortools/base", - "//ortools/linear_solver:linear_solver", + "//ortools/linear_solver", "//ortools/linear_solver:linear_solver_cc_proto", "//ortools/lp_data:mps_reader", ], @@ -491,7 +491,7 @@ cc_binary( copts = ["-DUSE_GLOP"], deps = [ "//ortools/base", - "//ortools/linear_solver:linear_solver", + "//ortools/linear_solver", ], ) @@ -499,11 +499,11 @@ cc_binary( name = "tsp", srcs = ["tsp.cc"], deps = [ - "@com_google_protobuf_cc//:protobuf", "//ortools/base", "//ortools/constraint_solver:cp", "//ortools/constraint_solver:routing", "//ortools/constraint_solver:routing_flags", + "@com_google_protobuf_cc//:protobuf", ], ) @@ -513,7 +513,6 @@ cc_binary( "weighted_tardiness_sat.cc", ], deps = [ - "@com_google_protobuf_cc//:protobuf", "//ortools/base", "//ortools/base:file", "//ortools/base:strings", @@ -526,5 +525,6 @@ cc_binary( "//ortools/sat:precedences", "//ortools/sat:sat_solver", "//ortools/util:filelineiter", + "@com_google_protobuf_cc//:protobuf", ], ) diff --git a/examples/cpp/sat_cnf_reader.h b/examples/cpp/sat_cnf_reader.h index efed690167..2e22dd5756 100644 --- a/examples/cpp/sat_cnf_reader.h +++ b/examples/cpp/sat_cnf_reader.h @@ -103,11 +103,11 @@ class SatCnfReader { return problem_name; } - int64 StringViewAtoi(const string_view& input) { + int64 StringPieceAtoi(string_view input) { // Hack: data() is not null terminated, but we do know that it points // inside a std::string where numbers are separated by " " and since atoi64 will // stop at the first invalid char, this works. - return atoi64(input.data()); + return atoi64(input.data()); // NOLINT } void ProcessNewLine(const std::string& line, LinearBooleanProblem* problem) { @@ -123,11 +123,11 @@ class SatCnfReader { if (words_[0] == "p") { if (words_[1] == "cnf" || words_[1] == "wcnf") { - num_variables_ = StringViewAtoi(words_[2]); - num_clauses_ = StringViewAtoi(words_[3]); + num_variables_ = StringPieceAtoi(words_[2]); + num_clauses_ = StringPieceAtoi(words_[3]); if (words_[1] == "wcnf") { is_wcnf_ = true; - hard_weight_ = (words_.size() > 4) ? StringViewAtoi(words_[4]) : 0; + hard_weight_ = (words_.size() > 4) ? StringPieceAtoi(words_[4]) : 0; } } else { // TODO(user): The ToString() is only required for the open source. Fix. @@ -148,7 +148,7 @@ class SatCnfReader { int64 weight = (!is_wcnf_ && interpret_cnf_as_max_sat_) ? 1 : hard_weight_; for (int i = 0; i < size; ++i) { - const int64 signed_value = StringViewAtoi(words_[i]); + const int64 signed_value = StringPieceAtoi(words_[i]); if (i == 0 && is_wcnf_) { // Mathematically, a soft clause of weight 0 can be removed. if (signed_value == 0) { diff --git a/examples/cpp/sat_runner.cc b/examples/cpp/sat_runner.cc index d81cf82e71..9681c02dc7 100644 --- a/examples/cpp/sat_runner.cc +++ b/examples/cpp/sat_runner.cc @@ -30,6 +30,8 @@ #include "ortools/base/threadpool.h" #include "ortools/algorithms/sparse_permutation.h" #include "ortools/sat/boolean_problem.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/cp_model_solver.h" #include "ortools/sat/drat.h" #include "examples/cpp/opb_reader.h" #include "ortools/sat/optimization.h" @@ -38,6 +40,7 @@ #include "ortools/sat/sat_solver.h" #include "ortools/sat/simplification.h" #include "ortools/sat/symmetry.h" +#include "ortools/util/file_util.h" #include "ortools/util/time_limit.h" #include "ortools/base/random.h" #include "ortools/base/status.h" @@ -106,11 +109,19 @@ DEFINE_bool(presolve, true, DEFINE_bool(probing, false, "If true, presolve the problem using probing."); +DEFINE_bool(use_cp_model, true, + "Whether to interpret a linear program input as a CpModelProto or " + "to read by default a CpModelProto."); + DEFINE_bool(reduce_memory_usage, false, "If true, do not keep a copy of the original problem in memory." "This reduce the memory usage, but disable the solution cheking at " "the end."); +DEFINE_string( + drat_output, "", + "If non-empty, a proof in DRAT format will be written to this file."); + namespace operations_research { namespace sat { namespace { @@ -126,7 +137,8 @@ double GetScaledTrivialBestBound(const LinearBooleanProblem& problem) { return AddOffsetAndScaleObjectiveValue(problem, best_bound); } -void LoadBooleanProblem(std::string filename, LinearBooleanProblem* problem) { +void LoadBooleanProblem(const std::string& filename, LinearBooleanProblem* problem, + CpModelProto* cp_model) { if (strings::EndsWith(filename, ".opb") || strings::EndsWith(filename, ".opb.bz2")) { OpbReader reader; @@ -145,8 +157,12 @@ void LoadBooleanProblem(std::string filename, LinearBooleanProblem* problem) { if (!reader.Load(filename, problem)) { LOG(FATAL) << "Cannot load file '" << filename << "'."; } + } else if (FLAGS_use_cp_model) { + LOG(INFO) << "Reading a CpModelProto."; + *cp_model = ReadFileToProtoOrDie(filename); } else { - file::ReadFileToProtoOrDie(filename, problem); + LOG(INFO) << "Reading a LinearBooleanProblem."; + *problem = ReadFileToProtoOrDie(filename); } } @@ -178,14 +194,20 @@ int Run() { << FLAGS_params; } - Model model; - DratWriter* drat_writer = model.GetOrCreate(); // Initialize the solver. std::unique_ptr solver(new SatSolver()); - solver->SetDratWriter(drat_writer); solver->SetParameters(parameters); + // Create a DratWriter? + std::unique_ptr drat_writer; + if (!FLAGS_drat_output.empty()) { + File* output; + CHECK_OK(file::Open(FLAGS_drat_output, "w", &output, file::Defaults())); + drat_writer.reset(new DratWriter(/*in_binary_format=*/false, output)); + solver->SetDratWriter(drat_writer.get()); + } + // The global time limit. std::unique_ptr time_limit(TimeLimit::FromParameters(parameters)); @@ -195,7 +217,21 @@ int Run() { // Read the problem. LinearBooleanProblem problem; - LoadBooleanProblem(FLAGS_input, &problem); + CpModelProto cp_model; + LoadBooleanProblem(FLAGS_input, &problem, &cp_model); + + // TODO(user): clean this hack. Ideally LinearBooleanProblem should be + // completely replaced by the more general CpModelProto. + if (cp_model.variables_size() != 0) { + Model model; + model.GetOrCreate()->SetParameters(parameters); + model.SetSingleton(std::move(time_limit)); + LOG(INFO) << CpModelStats(cp_model); + const CpSolverResponse response = SolveCpModel(cp_model, &model); + LOG(INFO) << CpSolverResponseStats(response); + exit(EXIT_SUCCESS); + } + if (FLAGS_strict_validity) { const util::Status status = ValidateBooleanProblem(problem); if (!status.ok()) { @@ -255,7 +291,8 @@ int Run() { for (int i = 0; i < generators.size(); ++i) { propagator->AddSymmetry(std::move(generators[i])); } - solver->AddPropagator(std::move(propagator)); + solver->AddPropagator(propagator.get()); + solver->TakePropagatorOwnership(std::move(propagator)); } // Optimize? @@ -292,7 +329,7 @@ int Run() { parameters.set_log_search_progress(true); solver->SetParameters(parameters); if (FLAGS_presolve) { - result = SolveWithPresolve(&solver, &solution, drat_writer); + result = SolveWithPresolve(&solver, &solution, drat_writer.get()); if (result == SatSolver::MODEL_SAT) { CHECK(IsAssignmentValid(problem, solution)); } diff --git a/examples/cpp/solve.cc b/examples/cpp/solve.cc index 91fa3864ff..d7c4a37d15 100644 --- a/examples/cpp/solve.cc +++ b/examples/cpp/solve.cc @@ -45,8 +45,8 @@ DEFINE_string(params_file, "", "If this flag is set, the --params flag is ignored."); DEFINE_string(params, "", "Solver specific parameters"); DEFINE_int64(time_limit_ms, 0, - "If stricitly positive, specifies a limit in ms on the solving" - " time."); + "If strictly positive, specifies a limit in ms on the solving " + "time. Otherwise, no time limit will be imposed."); DEFINE_string(forced_mps_format, "", "Set to force the mps format to use: free, fixed"); diff --git a/makefiles/Makefile.archive.mk b/makefiles/Makefile.archive.mk index eb38cc82c6..d7069ffbc6 100644 --- a/makefiles/Makefile.archive.mk +++ b/makefiles/Makefile.archive.mk @@ -35,6 +35,7 @@ create_dirs: $(MKDIR) temp$S$(INSTALL_DIR)$Sinclude$Sortools$Sgoogle $(MKDIR) temp$S$(INSTALL_DIR)$Sinclude$Sortools$Sgraph $(MKDIR) temp$S$(INSTALL_DIR)$Sinclude$Sortools$Slinear_solver + $(MKDIR) temp$S$(INSTALL_DIR)$Sinclude$Sortools$Slp_data $(MKDIR) temp$S$(INSTALL_DIR)$Sinclude$Sortools$Ssat $(MKDIR) temp$S$(INSTALL_DIR)$Sinclude$Sortools$Sutil $(MKDIR) temp$S$(INSTALL_DIR)$Sexamples @@ -84,6 +85,7 @@ cc_archive: cc $(COPY) ortools$Sgraph$S*.h temp$S$(INSTALL_DIR)$Sinclude$Sortools$Sgraph $(COPY) ortools$Sgen$Sortools$Sgraph$S*.h temp$S$(INSTALL_DIR)$Sinclude$Sortools$Sgraph $(COPY) ortools$Slinear_solver$S*.h temp$S$(INSTALL_DIR)$Sinclude$Sortools$Slinear_solver + $(COPY) ortools$Slp_data$S*.h temp$S$(INSTALL_DIR)$Sinclude$Sortools$Slp_data $(COPY) ortools$Sgen$Sortools$Slinear_solver$S*.pb.h temp$S$(INSTALL_DIR)$Sinclude$Sortools$Slinear_solver $(COPY) ortools$Ssat$S*.h temp$S$(INSTALL_DIR)$Sinclude$Sortools$Ssat $(COPY) ortools$Sgen$Sortools$Ssat$S*.pb.h temp$S$(INSTALL_DIR)$Sinclude$Sortools$Ssat diff --git a/makefiles/Makefile.cpp.mk b/makefiles/Makefile.cpp.mk index 4f8ae5b529..1b49da46af 100755 --- a/makefiles/Makefile.cpp.mk +++ b/makefiles/Makefile.cpp.mk @@ -102,7 +102,6 @@ FLATZINC_DEPS = \ $(SRC_DIR)/ortools/flatzinc/presolve.h \ $(SRC_DIR)/ortools/flatzinc/reporting.h \ $(SRC_DIR)/ortools/flatzinc/sat_constraint.h \ - $(SRC_DIR)/ortools/flatzinc/sat_fz_solver.h \ $(SRC_DIR)/ortools/flatzinc/solver_data.h \ $(SRC_DIR)/ortools/flatzinc/solver.h \ $(SRC_DIR)/ortools/flatzinc/solver_util.h \ @@ -206,7 +205,6 @@ FLATZINC_OBJS=\ $(OBJ_DIR)/flatzinc/presolve.$O \ $(OBJ_DIR)/flatzinc/reporting.$O \ $(OBJ_DIR)/flatzinc/sat_constraint.$O \ - $(OBJ_DIR)/flatzinc/sat_fz_solver.$O \ $(OBJ_DIR)/flatzinc/solver.$O \ $(OBJ_DIR)/flatzinc/solver_data.$O \ $(OBJ_DIR)/flatzinc/solver_util.$O @@ -251,9 +249,6 @@ $(OBJ_DIR)/flatzinc/reporting.$O: $(SRC_DIR)/ortools/flatzinc/reporting.cc $(FLA $(OBJ_DIR)/flatzinc/sat_constraint.$O: $(SRC_DIR)/ortools/flatzinc/sat_constraint.cc $(FLATZINC_DEPS) $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Sflatzinc$Ssat_constraint.cc $(OBJ_OUT)$(OBJ_DIR)$Sflatzinc$Ssat_constraint.$O -$(OBJ_DIR)/flatzinc/sat_fz_solver.$O: $(SRC_DIR)/ortools/flatzinc/sat_fz_solver.cc $(FLATZINC_DEPS) - $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Sflatzinc$Ssat_fz_solver.cc $(OBJ_OUT)$(OBJ_DIR)$Sflatzinc$Ssat_fz_solver.$O - $(OBJ_DIR)/flatzinc/solver.$O: $(SRC_DIR)/ortools/flatzinc/solver.cc $(FLATZINC_DEPS) $(CCC) $(CFLAGS) -c $(SRC_DIR)$Sortools$Sflatzinc$Ssolver.cc $(OBJ_OUT)$(OBJ_DIR)$Sflatzinc$Ssolver.$O diff --git a/ortools/bop/bop_fs.cc b/ortools/bop/bop_fs.cc index 97a24276fa..189c741e5a 100644 --- a/ortools/bop/bop_fs.cc +++ b/ortools/bop/bop_fs.cc @@ -105,7 +105,8 @@ BopOptimizerBase::Status GuidedSatFirstSolutionGenerator::SynchronizeIfNeeded( for (int i = 0; i < generators.size(); ++i) { propagator->AddSymmetry(std::move(generators[i])); } - sat_solver_->AddPropagator(std::move(propagator)); + sat_solver_->AddPropagator(propagator.get()); + sat_solver_->TakePropagatorOwnership(std::move(propagator)); } } diff --git a/ortools/bop/bop_portfolio.cc b/ortools/bop/bop_portfolio.cc index 2aa730d302..480d6e9131 100644 --- a/ortools/bop/bop_portfolio.cc +++ b/ortools/bop/bop_portfolio.cc @@ -307,7 +307,8 @@ void PortfolioOptimizer::CreateOptimizers( for (int i = 0; i < generators.size(); ++i) { propagator->AddSymmetry(std::move(generators[i])); } - sat_propagator_.AddPropagator(std::move(propagator)); + sat_propagator_.AddPropagator(propagator.get()); + sat_propagator_.TakePropagatorOwnership(std::move(propagator)); } const int max_num_optimizers = diff --git a/ortools/flatzinc/BUILD b/ortools/flatzinc/BUILD index 756c85960e..34f0af51d7 100644 --- a/ortools/flatzinc/BUILD +++ b/ortools/flatzinc/BUILD @@ -259,33 +259,6 @@ cc_library( ], ) -cc_library( - name = "sat_fz_solver", - srcs = ["sat_fz_solver.cc"], - hdrs = ["sat_fz_solver.h"], - deps = [ - ":checker", - ":logging", - ":model", - ":solver", - "//ortools/base", - "//ortools/base:map_util", - "//ortools/sat:cp_constraints", - "//ortools/sat:cumulative", - "//ortools/sat:disjunctive", - "//ortools/sat:flow_costs", - "//ortools/sat:integer", - "//ortools/sat:integer_expr", - "//ortools/sat:intervals", - "//ortools/sat:linear_programming_constraint", - "//ortools/sat:model", - "//ortools/sat:optimization", - "//ortools/sat:sat_solver", - "//ortools/sat:table", - "//ortools/util:sorted_interval_list", - ], -) - cc_library( name = "cp_model_fz_solver", srcs = ["cp_model_fz_solver.cc"], @@ -323,7 +296,6 @@ cc_binary( ":parser_lib", ":presolve", ":reporting", - ":sat_fz_solver", ":solver", ":solver_util", "//ortools/base", diff --git a/ortools/flatzinc/cp_model_fz_solver.cc b/ortools/flatzinc/cp_model_fz_solver.cc index e3a1e24dd2..cf3e39e638 100644 --- a/ortools/flatzinc/cp_model_fz_solver.cc +++ b/ortools/flatzinc/cp_model_fz_solver.cc @@ -34,7 +34,7 @@ #include "ortools/sat/sat_solver.h" #include "ortools/sat/table.h" -DEFINE_string(cp_model_solver_params, "", "SatParameters as a text proto."); +DEFINE_string(cp_sat_params, "", "SatParameters as a text proto."); namespace operations_research { namespace sat { @@ -748,13 +748,14 @@ void SolveFzWithCpModelProto(const fz::Model& fz_model, // Fill the objective. if (fz_model.objective() != nullptr) { - CpObjectiveProto* objective = m.proto.add_objectives(); + CpObjectiveProto* objective = m.proto.mutable_objective(); + objective->add_coeffs(1); if (fz_model.maximize()) { - objective->set_objective_var( - NegatedCpModelVariable(m.fz_var_to_index[fz_model.objective()])); objective->set_scaling_factor(-1); + objective->add_vars( + NegatedCpModelVariable(m.fz_var_to_index[fz_model.objective()])); } else { - objective->set_objective_var(m.fz_var_to_index[fz_model.objective()]); + objective->add_vars(m.fz_var_to_index[fz_model.objective()]); } } @@ -766,9 +767,9 @@ void SolveFzWithCpModelProto(const fz::Model& fz_model, // The order is important, we want the flag parameters to overwrite anything // set in m.parameters. sat::SatParameters flag_parameters; - CHECK(google::protobuf::TextFormat::ParseFromString(FLAGS_cp_model_solver_params, + CHECK(google::protobuf::TextFormat::ParseFromString(FLAGS_cp_sat_params, &flag_parameters)) - << FLAGS_cp_model_solver_params; + << FLAGS_cp_sat_params; m.parameters.MergeFrom(flag_parameters); sat_model.GetOrCreate()->SetParameters(m.parameters); diff --git a/ortools/flatzinc/fz.cc b/ortools/flatzinc/fz.cc index 3efa8756d8..6bbad50107 100644 --- a/ortools/flatzinc/fz.cc +++ b/ortools/flatzinc/fz.cc @@ -36,7 +36,6 @@ #include "ortools/flatzinc/parser.h" #include "ortools/flatzinc/presolve.h" #include "ortools/flatzinc/reporting.h" -#include "ortools/flatzinc/sat_fz_solver.h" #include "ortools/flatzinc/solver.h" #include "ortools/flatzinc/solver_util.h" @@ -66,15 +65,13 @@ DEFINE_bool( verbose_impact, false, "Increase verbosity of the impact based search when used in free search."); DEFINE_bool(verbose_mt, false, "Verbose Multi-Thread."); -DEFINE_bool(use_fz_sat, false, "Use the SAT/CP solver."); -DEFINE_bool(use_cp_model, false, "Use the SAT/CP solver through CpModel."); +DEFINE_bool(use_cp_sat, false, "Use the CP/SAT solver."); DEFINE_string(fz_model_name, "stdin", "Define problem name when reading from stdin."); // TODO(user): Remove when using ABCL in open-source. DECLARE_bool(log_prefix); DECLARE_bool(fz_use_sat); -DECLARE_bool(vmodule); using operations_research::ThreadPool; @@ -328,17 +325,11 @@ int main(int argc, char** argv) { operations_research::fz::ParseFlatzincModel(input, !FLAGS_read_from_stdin); - if (FLAGS_use_fz_sat || FLAGS_use_cp_model) { + if (FLAGS_use_cp_sat) { bool interrupt_solve = false; - if (FLAGS_use_fz_sat) { - operations_research::sat::SolveWithSat( - model, operations_research::fz::SingleThreadParameters(), - &interrupt_solve); - } else { - operations_research::sat::SolveFzWithCpModelProto( - model, operations_research::fz::SingleThreadParameters(), - &interrupt_solve); - } + operations_research::sat::SolveFzWithCpModelProto( + model, operations_research::fz::SingleThreadParameters(), + &interrupt_solve); } else { operations_research::fz::Solve(model); } diff --git a/ortools/flatzinc/sat_fz_solver.cc b/ortools/flatzinc/sat_fz_solver.cc deleted file mode 100644 index 6c2d7e1e46..0000000000 --- a/ortools/flatzinc/sat_fz_solver.cc +++ /dev/null @@ -1,1613 +0,0 @@ -// Copyright 2010-2014 Google -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "ortools/flatzinc/sat_fz_solver.h" - -#include -#include -#include "ortools/base/stringprintf.h" -#include "ortools/base/timer.h" -#include "ortools/base/join.h" -#include "ortools/base/map_util.h" -#include "ortools/flatzinc/checker.h" -#include "ortools/flatzinc/logging.h" -#include "ortools/sat/all_different.h" -#include "ortools/sat/cp_constraints.h" -#include "ortools/sat/cumulative.h" -#include "ortools/sat/disjunctive.h" -#include "ortools/sat/flow_costs.h" -#include "ortools/sat/integer.h" -#include "ortools/sat/integer_expr.h" -#include "ortools/sat/intervals.h" -#include "ortools/sat/linear_programming_constraint.h" -#include "ortools/sat/model.h" -#include "ortools/sat/optimization.h" -#include "ortools/sat/sat_solver.h" -#include "ortools/sat/table.h" -#include "ortools/util/sorted_interval_list.h" - -DEFINE_bool(fz_use_lp_constraint, true, - "Use LP solver glop to enforce all linear inequalities at once."); - -namespace operations_research { -namespace sat { -namespace { - -// Hold the sat::Model and the correspondance between flatzinc and sat vars. -struct SatModel { - Model model; - - // A flatzinc Boolean variable can appear in both maps if a constraint needs - // its integer representation as a 0-1 variable. Such an IntegerVariable is - // created lazily by LookupVar() when a constraint is requesting it. - std::unordered_map var_map; - std::unordered_map bool_map; - - // Utility functions to convert an fz::Argument to sat::IntegerVariable(s). - IntegerVariable LookupConstant(int64 value); - IntegerVariable LookupVar(fz::IntegerVariable* var); - IntegerVariable LookupVar(const fz::Argument& argument); - std::vector LookupVars(const fz::Argument& argument); - - // Utility functions to convert a Boolean fz::Argument to sat::Literal(s). - bool IsBoolean(fz::IntegerVariable* argument) const; - bool IsBoolean(const fz::Argument& argument) const; - Literal GetTrueLiteral(fz::IntegerVariable* var) const; - Literal GetTrueLiteral(const fz::Argument& argument) const; - std::vector GetTrueLiterals(const fz::Argument& argument) const; - std::vector GetFalseLiterals(const fz::Argument& argument) const; - - // Returns the full domain Boolean encoding of the given variable (encoding it - // if not already done). - std::vector FullEncoding(IntegerVariable v); - - // Returns the value of the given variable in the current assigment. It must - // be assigned, otherwise this will crash. - int64 Value(fz::IntegerVariable* var) const; -}; - -IntegerVariable SatModel::LookupConstant(const int64 value) { - return model.Add(ConstantIntegerVariable(value)); -} - -IntegerVariable SatModel::LookupVar(fz::IntegerVariable* var) { - CHECK(!var->domain.HasOneValue()); - if (ContainsKey(var_map, var)) return FindOrDie(var_map, var); - CHECK_EQ(var->domain.Min(), 0); - CHECK_EQ(var->domain.Max(), 1); - - // Otherwise, this must be a Boolean and we must construct the IntegerVariable - // associated with it. - const Literal lit = FindOrDie(bool_map, var); - const IntegerVariable int_var = - model.GetOrCreate()->GetIntegerView(lit); - InsertOrDie(&var_map, var, int_var); - return int_var; -} - -IntegerVariable SatModel::LookupVar(const fz::Argument& argument) { - if (argument.HasOneValue()) return LookupConstant(argument.Value()); - CHECK_EQ(argument.type, fz::Argument::INT_VAR_REF); - return LookupVar(argument.variables[0]); -} - -std::vector SatModel::LookupVars( - const fz::Argument& argument) { - std::vector result; - if (argument.type == fz::Argument::VOID_ARGUMENT) return result; - - if (argument.type == fz::Argument::INT_LIST) { - for (int64 value : argument.values) { - result.push_back(LookupConstant(value)); - } - } else { - CHECK_EQ(argument.type, fz::Argument::INT_VAR_REF_ARRAY); - for (fz::IntegerVariable* var : argument.variables) { - if (var->domain.HasOneValue()) { - result.push_back(LookupConstant(var->domain.Value())); - } else { - result.push_back(LookupVar(var)); - } - } - } - return result; -} - -std::vector SatModel::FullEncoding( - IntegerVariable var) { - return model.Add(FullyEncodeVariable(var)); -} - -bool SatModel::IsBoolean(fz::IntegerVariable* var) const { - return ContainsKey(bool_map, var); -} - -bool SatModel::IsBoolean(const fz::Argument& argument) const { - if (argument.type != fz::Argument::INT_VAR_REF) return false; - return ContainsKey(bool_map, argument.variables[0]); -} - -Literal SatModel::GetTrueLiteral(fz::IntegerVariable* var) const { - CHECK(!var->domain.HasOneValue()); - return FindOrDie(bool_map, var); -} - -Literal SatModel::GetTrueLiteral(const fz::Argument& argument) const { - CHECK(!argument.HasOneValue()); - CHECK_EQ(argument.type, fz::Argument::INT_VAR_REF); - return FindOrDie(bool_map, argument.variables[0]); -} - -std::vector SatModel::GetTrueLiterals( - const fz::Argument& argument) const { - std::vector literals; - if (argument.type == fz::Argument::VOID_ARGUMENT) return literals; - CHECK_EQ(argument.type, fz::Argument::INT_VAR_REF_ARRAY); - for (fz::IntegerVariable* var : argument.variables) { - literals.push_back(GetTrueLiteral(var)); - } - return literals; -} - -std::vector SatModel::GetFalseLiterals( - const fz::Argument& argument) const { - std::vector literals; - if (argument.type == fz::Argument::VOID_ARGUMENT) return literals; - CHECK_EQ(argument.type, fz::Argument::INT_VAR_REF_ARRAY); - for (fz::IntegerVariable* var : argument.variables) { - literals.push_back(GetTrueLiteral(var).Negated()); - } - return literals; -} - -int64 SatModel::Value(fz::IntegerVariable* var) const { - if (var->domain.HasOneValue()) { - return var->domain.Value(); - } - if (ContainsKey(bool_map, var)) { - return model.Get(sat::Value(FindOrDie(bool_map, var))); - } - return model.Get(sat::Value(FindOrDie(var_map, var))); -} - -// ============================================================================= -// Constraints extraction. -// ============================================================================= - -void ExtractBoolEq(const fz::Constraint& ct, SatModel* m) { - const Literal a = m->GetTrueLiteral(ct.arguments[0]); - const Literal b = m->GetTrueLiteral(ct.arguments[1]); - m->model.Add(Equality(a, b)); -} - -void ExtractBoolEqNeReif(bool is_eq, const fz::Constraint& ct, SatModel* m) { - const Literal a = m->GetTrueLiteral(ct.arguments[0]); - const Literal b = m->GetTrueLiteral(ct.arguments[1]); - Literal r = m->GetTrueLiteral(ct.arguments[2]); - if (!is_eq) r = r.Negated(); - - // We exclude 101, 011, 110 and 000. - m->model.Add(ClauseConstraint({a.Negated(), b, r.Negated()})); - m->model.Add(ClauseConstraint({a, b.Negated(), r.Negated()})); - m->model.Add(ClauseConstraint({a.Negated(), b.Negated(), r})); - m->model.Add(ClauseConstraint({a, b, r})); -} - -void ExtractBoolNe(const fz::Constraint& ct, SatModel* m) { - const Literal a = m->GetTrueLiteral(ct.arguments[0]); - const Literal b = m->GetTrueLiteral(ct.arguments[1]); - m->model.Add(Equality(a, b.Negated())); -} - -void ExtractBoolLe(const fz::Constraint& ct, SatModel* m) { - const Literal a = m->GetTrueLiteral(ct.arguments[0]); - const Literal b = m->GetTrueLiteral(ct.arguments[1]); - m->model.Add(Implication(a, b)); -} - -void ExtractBoolLeLtReif(bool is_le, const fz::Constraint& ct, SatModel* m) { - Literal a = m->GetTrueLiteral(ct.arguments[0]); - Literal b = m->GetTrueLiteral(ct.arguments[1]); - Literal r = m->GetTrueLiteral(ct.arguments[2]); - if (!is_le) { - // The negation of r <=> (a <= b) is not(r) <=> (a > b) - r = r.Negated(); - std::swap(a, b); - } - m->model.Add(ReifiedBoolLe(a, b, r)); -} - -void ExtractBoolClause(const fz::Constraint& ct, SatModel* m) { - std::vector positive = m->GetTrueLiterals(ct.arguments[0]); - const std::vector negative = m->GetFalseLiterals(ct.arguments[1]); - positive.insert(positive.end(), negative.begin(), negative.end()); - m->model.Add(ClauseConstraint(positive)); -} - -void ExtractArrayBoolAnd(const fz::Constraint& ct, SatModel* m) { - if (ct.arguments[1].HasOneValue()) { - CHECK_EQ(0, ct.arguments[1].Value()) << "Other case should be presolved."; - m->model.Add(ClauseConstraint(m->GetFalseLiterals(ct.arguments[0]))); - } else { - const Literal r = m->GetTrueLiteral(ct.arguments[1]); - m->model.Add(ReifiedBoolAnd(m->GetTrueLiterals(ct.arguments[0]), r)); - } -} - -void ExtractArrayBoolOr(const fz::Constraint& ct, SatModel* m) { - if (ct.arguments[1].HasOneValue()) { - CHECK_EQ(ct.arguments[1].Value(), 1) << "Other case should be presolved."; - m->model.Add(ClauseConstraint(m->GetTrueLiterals(ct.arguments[0]))); - } else { - const Literal r = m->GetTrueLiteral(ct.arguments[1]); - m->model.Add(ReifiedBoolOr(m->GetTrueLiterals(ct.arguments[0]), r)); - } -} - -void ExtractArrayBoolXor(const fz::Constraint& ct, SatModel* m) { - bool sum = false; - std::vector literals; - for (fz::IntegerVariable* var : ct.arguments[0].variables) { - if (var->domain.HasOneValue()) { - sum ^= (var->domain.Value() == 1); - } else { - literals.push_back(m->GetTrueLiteral(var)); - } - } - m->model.Add(LiteralXorIs(literals, !sum)); -} - -void ExtractIntMin(const fz::Constraint& ct, SatModel* m) { - const IntegerVariable a = m->LookupVar(ct.arguments[0]); - const IntegerVariable b = m->LookupVar(ct.arguments[1]); - const IntegerVariable c = m->LookupVar(ct.arguments[2]); - m->model.Add(IsEqualToMinOf(c, {a, b})); -} - -void ExtractArrayIntMinimum(const fz::Constraint& ct, SatModel* m) { - const IntegerVariable min = m->LookupVar(ct.arguments[0]); - const std::vector vars = m->LookupVars(ct.arguments[1]); - m->model.Add(IsEqualToMinOf(min, vars)); -} - -void ExtractArrayIntMaximum(const fz::Constraint& ct, SatModel* m) { - const IntegerVariable max = m->LookupVar(ct.arguments[0]); - const std::vector vars = m->LookupVars(ct.arguments[1]); - m->model.Add(IsEqualToMaxOf(max, vars)); -} - -void ExtractIntAbs(const fz::Constraint& ct, SatModel* m) { - const IntegerVariable v = m->LookupVar(ct.arguments[0]); - const IntegerVariable abs = m->LookupVar(ct.arguments[1]); - m->model.Add(IsEqualToMaxOf(abs, {v, NegationOf(v)})); -} - -void ExtractIntMax(const fz::Constraint& ct, SatModel* m) { - const IntegerVariable a = m->LookupVar(ct.arguments[0]); - const IntegerVariable b = m->LookupVar(ct.arguments[1]); - const IntegerVariable max = m->LookupVar(ct.arguments[2]); - m->model.Add(IsEqualToMaxOf(max, {a, b})); -} - -void ExtractIntTimes(const fz::Constraint& ct, SatModel* m) { - // TODO(user): Many constraint could be optimized in the same way. - // especially the int_eq_reif between bool and so on. - if (m->IsBoolean(ct.arguments[0]) && m->IsBoolean(ct.arguments[1]) && - m->IsBoolean(ct.arguments[2])) { - const Literal a = m->GetTrueLiteral(ct.arguments[0]); - const Literal b = m->GetTrueLiteral(ct.arguments[1]); - const Literal c = m->GetTrueLiteral(ct.arguments[2]); - m->model.Add(ReifiedBoolAnd({a, b}, c)); - return; - } - const IntegerVariable a = m->LookupVar(ct.arguments[0]); - const IntegerVariable b = m->LookupVar(ct.arguments[1]); - const IntegerVariable c = m->LookupVar(ct.arguments[2]); - m->model.Add(ProductConstraint(a, b, c)); -} - -void ExtractIntDiv(const fz::Constraint& ct, SatModel* m) { - const IntegerVariable a = m->LookupVar(ct.arguments[0]); - const IntegerVariable b = m->LookupVar(ct.arguments[1]); - const IntegerVariable c = m->LookupVar(ct.arguments[2]); - m->model.Add(DivisionConstraint(a, b, c)); -} - -void ExtractIntPlus(const fz::Constraint& ct, SatModel* m) { - const IntegerVariable a = m->LookupVar(ct.arguments[0]); - const IntegerVariable b = m->LookupVar(ct.arguments[1]); - const IntegerVariable c = m->LookupVar(ct.arguments[2]); - m->model.Add(FixedWeightedSum({a, b, c}, std::vector{1, 1, -1}, 0)); -} - -void ExtractIntEq(const fz::Constraint& ct, SatModel* m) { - // TODO(user): use the full encoding if available? - const IntegerVariable a = m->LookupVar(ct.arguments[0]); - const IntegerVariable b = m->LookupVar(ct.arguments[1]); - m->model.Add(Equality(a, b)); -} - -void ExtractIntNe(const fz::Constraint& ct, SatModel* m) { - const IntegerVariable a = m->LookupVar(ct.arguments[0]); - const IntegerVariable b = m->LookupVar(ct.arguments[1]); - IntegerEncoder* encoder = m->model.GetOrCreate(); - if (!encoder->VariableIsFullyEncoded(a) || - !encoder->VariableIsFullyEncoded(b)) { - m->model.Add(NotEqual(a, b)); - } else { - m->model.Add(AllDifferentBinary({a, b})); - } -} - -void ExtractIntLe(const fz::Constraint& ct, SatModel* m) { - if (m->IsBoolean(ct.arguments[0]) && m->IsBoolean(ct.arguments[1])) { - const Literal a = m->GetTrueLiteral(ct.arguments[0]); - const Literal b = m->GetTrueLiteral(ct.arguments[1]); - m->model.Add(Implication(a, b)); - return; - } - const IntegerVariable a = m->LookupVar(ct.arguments[0]); - const IntegerVariable b = m->LookupVar(ct.arguments[1]); - m->model.Add(LowerOrEqual(a, b)); -} - -void ExtractIntGe(const fz::Constraint& ct, SatModel* m) { - if (m->IsBoolean(ct.arguments[0]) && m->IsBoolean(ct.arguments[1])) { - const Literal a = m->GetTrueLiteral(ct.arguments[0]); - const Literal b = m->GetTrueLiteral(ct.arguments[1]); - m->model.Add(Implication(b, a)); - return; - } - const IntegerVariable a = m->LookupVar(ct.arguments[0]); - const IntegerVariable b = m->LookupVar(ct.arguments[1]); - m->model.Add(GreaterOrEqual(a, b)); -} - -void ExtractIntLeGeReif(bool is_le, const fz::Constraint& ct, SatModel* m) { - CHECK(!ct.arguments[2].HasOneValue()) << "Should be presolved."; - const Literal r = m->GetTrueLiteral(ct.arguments[2]); - - if (m->IsBoolean(ct.arguments[0]) && m->IsBoolean(ct.arguments[1])) { - Literal a = m->GetTrueLiteral(ct.arguments[0]); - Literal b = m->GetTrueLiteral(ct.arguments[1]); - if (!is_le) std::swap(a, b); - m->model.Add(ReifiedBoolLe(a, b, r)); - return; - } - - if (ct.arguments[1].HasOneValue()) { - if (ct.arguments[0].HasOneValue()) { - if (is_le) { - if (ct.arguments[0].Value() <= ct.arguments[1].Value()) { - m->model.Add(ClauseConstraint({r})); - } else { - m->model.Add(ClauseConstraint({r.Negated()})); - } - } else { - if (ct.arguments[0].Value() >= ct.arguments[1].Value()) { - m->model.Add(ClauseConstraint({r})); - } else { - m->model.Add(ClauseConstraint({r.Negated()})); - } - } - FZLOG << "Should be presolved: " << ct.DebugString() << FZENDL; - return; - } - const IntegerVariable a = m->LookupVar(ct.arguments[0]); - const IntegerValue value(ct.arguments[1].Value()); - IntegerLiteral i_lit = is_le ? IntegerLiteral::LowerOrEqual(a, value) - : IntegerLiteral::GreaterOrEqual(a, value); - m->model.Add(Equality(i_lit, r)); - } else if (ct.arguments[0].HasOneValue()) { - const IntegerValue value(ct.arguments[0].Value()); - const IntegerVariable b = m->LookupVar(ct.arguments[1]); - IntegerLiteral i_lit = is_le ? IntegerLiteral::GreaterOrEqual(b, value) - : IntegerLiteral::LowerOrEqual(b, value); - m->model.Add(Equality(i_lit, r)); - } else { - IntegerVariable a = m->LookupVar(ct.arguments[0]); - IntegerVariable b = m->LookupVar(ct.arguments[1]); - if (!is_le) std::swap(a, b); - m->model.Add(ReifiedLowerOrEqualWithOffset(a, b, 0, r)); - } -} - -void ExtractIntLt(const fz::Constraint& ct, SatModel* m) { - const IntegerVariable a = m->LookupVar(ct.arguments[0]); - const IntegerVariable b = m->LookupVar(ct.arguments[1]); - m->model.Add(LowerOrEqualWithOffset(a, b, 1)); // a + 1 <= b -} - -// TODO(user): the code can probably be shared by ExtractIntLeGeReif() and -// we can easily support Gt. -void ExtractIntLtReif(const fz::Constraint& ct, SatModel* m) { - CHECK(!ct.arguments[2].HasOneValue()) << "Should be presolved."; - const Literal is_lt = m->GetTrueLiteral(ct.arguments[2]); - - if (ct.arguments[1].HasOneValue()) { - CHECK(!ct.arguments[0].HasOneValue()) << "Should be presolved."; - const IntegerVariable a = m->LookupVar(ct.arguments[0]); - const IntegerValue value(ct.arguments[1].Value() - 1); - m->model.Add(Equality(IntegerLiteral::LowerOrEqual(a, value), is_lt)); - } else if (ct.arguments[0].HasOneValue()) { - const IntegerValue value(ct.arguments[0].Value() + 1); - const IntegerVariable b = m->LookupVar(ct.arguments[1]); - m->model.Add(Equality(IntegerLiteral::GreaterOrEqual(b, value), is_lt)); - } else { - const IntegerVariable a = m->LookupVar(ct.arguments[0]); - const IntegerVariable b = m->LookupVar(ct.arguments[1]); - m->model.Add(ReifiedLowerOrEqualWithOffset(a, b, 1, is_lt)); - } -} - -// Returns a non-empty vector if the constraint sum vars[i] * coeff[i] can be -// written as a sum of literal (eventually negating the variable) by replacing a -// variable -B by (not(B) - 1) and updates the given rhs. -// -// TODO(user): Do that in the presolve? -std::vector IsSumOfLiteral(const fz::Argument& vars, - const std::vector& coeffs, - int64* rhs, SatModel* m) { - const int n = coeffs.size(); - std::vector result; - for (int i = 0; i < n; ++i) { - if (!m->IsBoolean(vars.variables[i])) return std::vector(); - if (coeffs[i] == 1) { - result.push_back(m->GetTrueLiteral(vars.variables[i])); - } else if (coeffs[i] == -1) { - result.push_back(m->GetTrueLiteral(vars.variables[i]).Negated()); - (*rhs)++; // we replace -B by (not(B) - 1); - } else { - return std::vector(); - } - } - CHECK_GE(*rhs, 0) << "Should be presolved."; - CHECK_LE(*rhs, n) << "Should be presolved."; - return result; -} - -void AddLinearConstraintToLP(const std::vector& vars, - const std::vector& coeffs, double lb, - double ub, SatModel* m) { - LinearProgrammingConstraint* lp = - m->model.GetOrCreate(); - const LinearProgrammingConstraint::ConstraintIndex ct = - lp->CreateNewConstraint(lb, ub); - for (int i = 0; i < vars.size(); i++) { - lp->SetCoefficient(ct, vars[i], coeffs[i]); - } -} - -void ExtractIntLinEq(const fz::Constraint& ct, SatModel* m) { - const std::vector vars = m->LookupVars(ct.arguments[1]); - const std::vector& coeffs = ct.arguments[0].values; - const int64 rhs = ct.arguments[2].values[0]; - m->model.Add(FixedWeightedSum(vars, coeffs, rhs)); - - if (FLAGS_fz_use_lp_constraint) { - const double value = static_cast(rhs); - AddLinearConstraintToLP(vars, coeffs, value, value, m); - } -} - -void ExtractIntLinNe(const fz::Constraint& ct, SatModel* m) { - const std::vector vars = m->LookupVars(ct.arguments[1]); - const std::vector& coeffs = ct.arguments[0].values; - const int64 rhs = ct.arguments[2].values[0]; - m->model.Add(WeightedSumNotEqual(vars, coeffs, rhs)); -} - -void ExtractIntLinLe(const fz::Constraint& ct, SatModel* m) { - const std::vector& coeffs = ct.arguments[0].values; - const int64 rhs = ct.arguments[2].values[0]; - int64 new_rhs = rhs; - std::vector lits = - IsSumOfLiteral(ct.arguments[1], coeffs, &new_rhs, m); - if (!lits.empty() && new_rhs == coeffs.size() - 1) { - // Not all literal can be true. - for (Literal& ref : lits) ref = ref.Negated(); - m->model.Add(ClauseConstraint(lits)); - } else if (!lits.empty() && new_rhs == 0) { - // Every literal must be false. - for (const Literal l : lits) m->model.Add(ClauseConstraint({l.Negated()})); - } else { - const std::vector vars = m->LookupVars(ct.arguments[1]); - m->model.Add(WeightedSumLowerOrEqual(vars, coeffs, rhs)); - - if (FLAGS_fz_use_lp_constraint) { - AddLinearConstraintToLP(vars, coeffs, - -std::numeric_limits::infinity(), - static_cast(rhs), m); - } - } -} - -void ExtractIntLinGe(const fz::Constraint& ct, SatModel* m) { - const std::vector& coeffs = ct.arguments[0].values; - const int64 rhs = ct.arguments[2].values[0]; - int64 new_rhs = rhs; - const std::vector lits = - IsSumOfLiteral(ct.arguments[1], coeffs, &new_rhs, m); - if (!lits.empty() && new_rhs == 1) { - // Not all literal can be false. - m->model.Add(ClauseConstraint(lits)); - } else if (!lits.empty() && new_rhs == coeffs.size()) { - // Every literal must be true. - for (const Literal l : lits) m->model.Add(ClauseConstraint({l})); - } else { - const std::vector vars = m->LookupVars(ct.arguments[1]); - m->model.Add(WeightedSumGreaterOrEqual(vars, coeffs, rhs)); - - if (FLAGS_fz_use_lp_constraint) { - AddLinearConstraintToLP(vars, coeffs, static_cast(rhs), - std::numeric_limits::infinity(), m); - } - } -} - -void ExtractIntLinLeReif(const fz::Constraint& ct, SatModel* m) { - const Literal r = m->GetTrueLiteral(ct.arguments[3]); - const std::vector& coeffs = ct.arguments[0].values; - const int64 rhs = ct.arguments[2].values[0]; - int64 new_rhs = rhs; - const std::vector lits = - IsSumOfLiteral(ct.arguments[1], coeffs, &new_rhs, m); - if (!lits.empty() && new_rhs == coeffs.size() - 1) { - m->model.Add(ReifiedBoolAnd(lits, r.Negated())); - } else if (!lits.empty() && new_rhs == 0) { - m->model.Add(ReifiedBoolOr(lits, r.Negated())); - } else { - const std::vector vars = m->LookupVars(ct.arguments[1]); - m->model.Add(WeightedSumLowerOrEqualReif(r, vars, coeffs, rhs)); - } -} - -void ExtractIntLinGeReif(const fz::Constraint& ct, SatModel* m) { - const std::vector& coeffs = ct.arguments[0].values; - const int64 rhs = ct.arguments[2].values[0]; - const Literal r = m->GetTrueLiteral(ct.arguments[3]); - int64 new_rhs = rhs; - const std::vector lits = - IsSumOfLiteral(ct.arguments[1], coeffs, &new_rhs, m); - if (!lits.empty() && new_rhs == 1) { - m->model.Add(ReifiedBoolOr(lits, r)); - } else if (!lits.empty() && new_rhs == coeffs.size()) { - m->model.Add(ReifiedBoolAnd(lits, r)); - } else { - const std::vector vars = m->LookupVars(ct.arguments[1]); - m->model.Add(WeightedSumGreaterOrEqualReif(r, vars, coeffs, rhs)); - } -} - -void ExtractIntLinEqReif(const fz::Constraint& ct, SatModel* m) { - const std::vector vars = m->LookupVars(ct.arguments[1]); - const std::vector& coeffs = ct.arguments[0].values; - const int64 rhs = ct.arguments[2].values[0]; - const Literal r = m->GetTrueLiteral(ct.arguments[3]); - m->model.Add(FixedWeightedSumReif(r, vars, coeffs, rhs)); -} - -void ExtractIntLinNeReif(const fz::Constraint& ct, SatModel* m) { - const std::vector vars = m->LookupVars(ct.arguments[1]); - const std::vector& coeffs = ct.arguments[0].values; - const int64 rhs = ct.arguments[2].values[0]; - const Literal r = m->GetTrueLiteral(ct.arguments[3]); - m->model.Add(FixedWeightedSumReif(r.Negated(), vars, coeffs, rhs)); -} - -// r => (a == cte). -void ImpliesEqualityToConstant(bool reverse_implication, IntegerVariable a, - int64 cte, Literal r, SatModel* m) { - if (m->model.Get(IsFixed(a))) { - if (m->model.Get(Value(a)) == IntegerValue(cte)) { - if (reverse_implication) { - m->model.GetOrCreate()->AddUnitClause(r); - } - } else { - m->model.GetOrCreate()->AddUnitClause(r.Negated()); - } - return; - } - - // TODO(user): Simply do that all the time? - // TODO(user): No need to create a literal that is trivially true or false! - IntegerEncoder* encoder = m->model.GetOrCreate(); - if (!encoder->VariableIsFullyEncoded(a)) { - if (reverse_implication) { - m->model.Add(ReifiedInInterval(a, cte, cte, r)); - } else { - const Literal ge = encoder->GetOrCreateAssociatedLiteral( - IntegerLiteral::GreaterOrEqual(a, IntegerValue(cte))); - const Literal le = encoder->GetOrCreateAssociatedLiteral( - IntegerLiteral::LowerOrEqual(a, IntegerValue(cte))); - m->model.Add(Implication(r, ge)); - m->model.Add(Implication(r, le)); - } - return; - } - - for (const auto pair : m->FullEncoding(a)) { - if (pair.value == IntegerValue(cte)) { - // Lit is equal to pair.literal. - // - // TODO(user): We could just use the same variable for this instead of - // creating two and then making them equals. - if (reverse_implication) { - m->model.Add(Equality(r, pair.literal)); - } else { - m->model.Add(Implication(r, pair.literal)); - } - return; - } - } - - // Value is not found, the literal must be false. - m->model.GetOrCreate()->AddUnitClause(r.Negated()); -} - -// r => (a == b), and if reverse_implication is true, we have the other way -// around too. -// -// TODO(user): move this and ImpliesEqualityToConstant() under .../sat/ and unit -// test it! -void ImpliesEquality(bool reverse_implication, Literal r, IntegerVariable a, - IntegerVariable b, SatModel* m) { - if (m->model.Get(IsFixed(a))) { - ImpliesEqualityToConstant(reverse_implication, b, m->model.Get(Value(a)), r, - m); - return; - } - if (m->model.Get(IsFixed(b))) { - ImpliesEqualityToConstant(reverse_implication, a, m->model.Get(Value(b)), r, - m); - return; - } - - // TODO(user): Do that all the time? - IntegerEncoder* encoder = m->model.GetOrCreate(); - if (!encoder->VariableIsFullyEncoded(a) || - !encoder->VariableIsFullyEncoded(b)) { - if (reverse_implication) { - m->model.Add(ReifiedEquality(a, b, r)); - } else if (a != b) { - // If a == b, r can take any value. - m->model.Add(ConditionalLowerOrEqualWithOffset(a, b, 0, r)); - m->model.Add(ConditionalLowerOrEqualWithOffset(b, a, 0, r)); - } - return; - } - - std::unordered_map> by_value; - for (const auto p : m->FullEncoding(a)) { - by_value[p.value].push_back(p.literal); - } - for (const auto p : m->FullEncoding(b)) { - by_value[p.value].push_back(p.literal); - } - for (const auto& p : by_value) { - if (p.second.size() == 1) { - // This value appear in only one of the variable, so if this value is - // true then r must be false. - m->model.Add(Implication(p.second[0], r.Negated())); - } else { - CHECK_EQ(p.second.size(), 2); - const Literal a = p.second[0]; - const Literal b = p.second[1]; - // This value is common: - // - a & b => r - // - a & not(b) => not(r) - // - not(a) & b => not(r) - if (reverse_implication) { - m->model.Add(ClauseConstraint({a.Negated(), b.Negated(), r})); - } - m->model.Add(ClauseConstraint({a.Negated(), b, r.Negated()})); - m->model.Add(ClauseConstraint({a, b.Negated(), r.Negated()})); - } - } -} - -void ExtractIntEqNeReif(const fz::Constraint& ct, bool eq, SatModel* m) { - // The Eq or Ne version are the same up to the sign of the "eq" literal. - Literal is_eq = m->GetTrueLiteral(ct.arguments[2]); - if (!eq) is_eq = is_eq.Negated(); - - if (ct.arguments[0].HasOneValue()) { - ImpliesEqualityToConstant(/*reverse_implication=*/true, - m->LookupVar(ct.arguments[1]), - ct.arguments[0].Value(), is_eq, m); - return; - } - - if (ct.arguments[1].HasOneValue()) { - ImpliesEqualityToConstant(/*reverse_implication=*/true, - m->LookupVar(ct.arguments[0]), - ct.arguments[1].Value(), is_eq, m); - return; - } - - // General case. This is exercised by the grid-colouring problems. - ImpliesEquality(/*reverse_implication=*/true, is_eq, - m->LookupVar(ct.arguments[0]), m->LookupVar(ct.arguments[1]), - m); -} - -// Special case added by the presolve (not in flatzinc). We encode this as a -// table constraint. -// -// TODO(user): is this the more efficient? we could at least optimize the table -// code to not create row literals when not needed. -void ExtractArray2dIntElement(const fz::Constraint& ct, SatModel* m) { - CHECK_EQ(2, ct.arguments[0].variables.size()); - CHECK_EQ(5, ct.arguments.size()); - - // the constraint is: - // values[coeff1 * vars[0] + coeff2 * vars[1] + offset] == target. - std::vector vars = m->LookupVars(ct.arguments[0]); - const std::vector& values = ct.arguments[1].values; - const int64 coeff1 = ct.arguments[3].values[0]; - const int64 coeff2 = ct.arguments[3].values[1]; - const int64 offset = ct.arguments[4].values[0] - 1; - - std::vector> tuples; - const auto encoding1 = m->FullEncoding(vars[0]); - const auto encoding2 = m->FullEncoding(vars[1]); - for (const auto& entry1 : encoding1) { - const int64 v1 = entry1.value.value(); - for (const auto& entry2 : encoding2) { - const int64 v2 = entry2.value.value(); - const int index = coeff1 * v1 + coeff2 * v2 + offset; - CHECK_GE(index, 0); - CHECK_LT(index, values.size()); - tuples.push_back({v1, v2, values[index]}); - } - } - vars.push_back(m->LookupVar(ct.arguments[2])); - m->model.Add(TableConstraint(vars, tuples)); -} - -// TODO(user): move this logic in some model function under .../sat/ and unit -// test it! Or adapt the table constraint? this is like a table with 1 columns, -// the row literal beeing the one of ct.arguments[0]. -void ExtractArrayIntElement(const fz::Constraint& ct, SatModel* m) { - if (ct.arguments[0].type != fz::Argument::INT_VAR_REF) { - return ExtractArray2dIntElement(ct, m); - } - - std::map> value_to_literals; - { - const auto encoding = m->FullEncoding(m->LookupVar(ct.arguments[0])); - const std::vector& values = ct.arguments[1].values; - if (encoding.size() != values.size()) { - FZVLOG << "array_int_element could have been slightly presolved." - << FZENDL; - } - for (const auto literal_value : encoding) { - const int i = literal_value.value.value() - 1; // minizinc use 1-index. - CHECK_GE(i, 0); - CHECK_LT(i, values.size()); - value_to_literals[values[i]].push_back(literal_value.literal); - } - } - - std::map target_by_value; - const IntegerVariable target = m->LookupVar(ct.arguments[2]); - for (const auto p : m->FullEncoding(target)) { - target_by_value[p.value] = p.literal; - } - - for (auto entry : value_to_literals) { - // target == OR(entry.second), same as ExtractBoolOr(). - const Literal r = FindOrDie(target_by_value, IntegerValue(entry.first)); - for (const Literal literal : entry.second) { - m->model.Add(Implication(literal, r)); - } - - // Note that this clause is not striclty needed because all the other - // value of target will be false and so only the literals in entry.second - // can be true out of all the literal of the argument 0. - // TODO(user): remove? - entry.second.push_back(r.Negated()); - m->model.Add(ClauseConstraint(entry.second)); - - // We remove the entry from target_by_value to see if they all appear. - target_by_value.erase(IntegerValue(entry.first)); - } - - if (!target_by_value.empty()) { - FZLOG << "array_int_element could have been presolved." << FZENDL; - for (const auto& entry : target_by_value) { - m->model.GetOrCreate()->AddUnitClause(entry.second.Negated()); - } - } -} - -// vars[i] == t. -void ExtractArrayVarIntElement(const fz::Constraint& ct, SatModel* m) { - const std::vector vars = m->LookupVars(ct.arguments[1]); - const IntegerVariable t = m->LookupVar(ct.arguments[2]); - - CHECK(!ct.arguments[0].HasOneValue()) << "Should have been presolved."; - const IntegerVariable index_var = m->LookupVar(ct.arguments[0]); - if (m->model.Get(IsFixed(index_var))) { - // TODO(user): use the full encoding if available. - m->model.Add(Equality(vars[m->model.Get(Value(index_var)) - 1], t)); - return; - } - - const auto encoding = m->FullEncoding(index_var); - if (encoding.size() != vars.size()) { - FZVLOG << "array_var_int_element could have been slightly presolved." - << FZENDL; - } - - std::vector selectors; - std::vector possible_vars; - for (const auto literal_value : encoding) { - const int i = literal_value.value.value() - 1; // minizinc use 1-index. - CHECK_GE(i, 0); - CHECK_LT(i, vars.size()); - possible_vars.push_back(vars[i]); - selectors.push_back(literal_value.literal); - ImpliesEquality(/*reverse_implication=*/false, literal_value.literal, - vars[i], t, m); - } - - // TODO(user): make a IsOneOfVar() support the full propagation. - m->model.Add(PartialIsOneOfVar(t, possible_vars, selectors)); -} - -void ExtractRegular(const fz::Constraint& ct, SatModel* m) { - const std::vector vars = m->LookupVars(ct.arguments[0]); - const int64 num_states = ct.arguments[1].Value(); - const int64 num_values = ct.arguments[2].Value(); - - const std::vector& next = ct.arguments[3].values; - std::vector> transitions; - int count = 0; - for (int i = 1; i <= num_states; ++i) { - for (int j = 1; j <= num_values; ++j) { - transitions.push_back({i, j, next[count++]}); - } - } - - const int64 initial_state = ct.arguments[4].Value(); - - std::vector final_states; - switch (ct.arguments[5].type) { - case fz::Argument::INT_VALUE: { - final_states.push_back(ct.arguments[5].values[0]); - break; - } - case fz::Argument::INT_INTERVAL: { - for (int v = ct.arguments[5].values[0]; v <= ct.arguments[5].values[1]; - ++v) { - final_states.push_back(v); - } - break; - } - case fz::Argument::INT_LIST: { - final_states = ct.arguments[5].values; - break; - } - default: { LOG(FATAL) << "Wrong constraint " << ct.DebugString(); } - } - - m->model.Add( - TransitionConstraint(vars, transitions, initial_state, final_states)); -} - -void ExtractTableInt(const fz::Constraint& ct, SatModel* m) { - const std::vector vars = m->LookupVars(ct.arguments[0]); - const std::vector& t = ct.arguments[1].values; - const int num_vars = vars.size(); - const int num_tuples = t.size() / num_vars; - std::vector> tuples(num_tuples); - int count = 0; - for (int i = 0; i < num_tuples; ++i) { - for (int j = 0; j < num_vars; ++j) { - tuples[i].push_back(t[count++]); - } - } - m->model.Add(TableConstraint(vars, tuples)); -} - -void ExtractSetInReif(const fz::Constraint& ct, SatModel* m) { - const IntegerVariable var = m->LookupVar(ct.arguments[0]); - const Literal in_set = m->GetTrueLiteral(ct.arguments[2]); - CHECK(!ct.arguments[0].HasOneValue()) << "Should be presolved: " - << ct.DebugString(); - if (ct.arguments[1].HasOneValue()) { - FZLOG << "Could have been presolved in int_eq_reif: " << ct.DebugString() - << FZENDL; - } - if (ct.arguments[1].type == fz::Argument::INT_LIST) { - std::set values(ct.arguments[1].values.begin(), - ct.arguments[1].values.end()); - const auto encoding = m->FullEncoding(var); - for (const auto& literal_value : encoding) { - if (ContainsKey(values, literal_value.value.value())) { - m->model.Add(Implication(literal_value.literal, in_set)); - } else { - m->model.Add(Implication(literal_value.literal, in_set.Negated())); - } - } - } else if (ct.arguments[1].type == fz::Argument::INT_INTERVAL) { - m->model.Add(ReifiedInInterval(var, ct.arguments[1].values[0], - ct.arguments[1].values[1], in_set)); - } else { - LOG(FATAL) << "Argument type not supported: " << ct.arguments[1].type; - } -} - -void ExtractAllDifferentInt(const fz::Constraint& ct, SatModel* m) { - const std::vector vars = m->LookupVars(ct.arguments[0]); - IntegerEncoder* encoder = m->model.GetOrCreate(); - bool all_variables_are_encoded = true; - for (const IntegerVariable v : vars) { - if (!encoder->VariableIsFullyEncoded(v)) { - all_variables_are_encoded = false; - break; - } - } - if (all_variables_are_encoded) { - m->model.Add(AllDifferentBinary(vars)); - } else { - m->model.Add(AllDifferentOnBounds(vars)); - } -} - -void ExtractDiffN(const fz::Constraint& ct, SatModel* m) { - const std::vector x = m->LookupVars(ct.arguments[0]); - const std::vector y = m->LookupVars(ct.arguments[1]); - if (ct.arguments[2].type == fz::Argument::INT_LIST && - ct.arguments[3].type == fz::Argument::INT_LIST) { - m->model.Add(StrictNonOverlappingFixedSizeRectangles( - x, y, ct.arguments[2].values, ct.arguments[3].values)); - } else { - const std::vector dx = m->LookupVars(ct.arguments[2]); - const std::vector dy = m->LookupVars(ct.arguments[3]); - m->model.Add(StrictNonOverlappingRectangles(x, y, dx, dy)); - } -} - -void ExtractDiffNNonStrict(const fz::Constraint& ct, SatModel* m) { - const std::vector x = m->LookupVars(ct.arguments[0]); - const std::vector y = m->LookupVars(ct.arguments[1]); - if (ct.arguments[2].type == fz::Argument::INT_LIST && - ct.arguments[3].type == fz::Argument::INT_LIST) { - m->model.Add(NonOverlappingFixedSizeRectangles(x, y, ct.arguments[2].values, - ct.arguments[3].values)); - } else { - const std::vector dx = m->LookupVars(ct.arguments[2]); - const std::vector dy = m->LookupVars(ct.arguments[3]); - m->model.Add(NonOverlappingRectangles(x, y, dx, dy)); - } -} - -void ExtractCumulative(const fz::Constraint& ct, SatModel* m) { - const std::vector starts = m->LookupVars(ct.arguments[0]); - const std::vector durations = m->LookupVars(ct.arguments[1]); - const std::vector demands = m->LookupVars(ct.arguments[2]); - const IntegerVariable capacity = m->LookupVar(ct.arguments[3]); - - // Convert the couple (starts, duration) into an interval variable. - std::vector intervals; - for (int i = 0; i < starts.size(); ++i) { - intervals.push_back( - m->model.Add(NewIntervalFromStartAndSizeVars(starts[i], durations[i]))); - } - - m->model.Add(Cumulative(intervals, demands, capacity)); -} - -void ExtractCircuit(const fz::Constraint& ct, bool allow_subcircuit, - SatModel* m) { - bool found_zero = false; - bool found_size = false; - for (fz::IntegerVariable* const var : ct.arguments[0].variables) { - if (var->domain.Min() == 0) { - found_zero = true; - } - if (var->domain.Max() == ct.arguments[0].variables.size()) { - found_size = true; - } - } - // Are array 1 based or 0 based. - const int offset = found_zero && !found_size ? 0 : 1; - - const std::vector vars = m->LookupVars(ct.arguments[0]); - std::vector> graph( - vars.size(), std::vector(vars.size(), kFalseLiteralIndex)); - for (int i = 0; i < vars.size(); ++i) { - if (m->model.Get(IsFixed(vars[i]))) { - graph[i][m->model.Get(Value(vars[i])) - offset] = kTrueLiteralIndex; - } else { - const auto encoding = m->FullEncoding(vars[i]); - for (const auto& entry : encoding) { - graph[i][entry.value.value() - offset] = entry.literal.Index(); - } - } - } - m->model.Add(allow_subcircuit ? SubcircuitConstraint(graph) - : CircuitConstraint(graph)); -} - -// network_flow(arcs, balance, flow) -// network_flow_cost(arcs, balance, weight, flow, cost) -void ExtractNetworkFlow(const fz::Constraint& ct, SatModel* m) { - const bool has_cost = ct.type == "network_flow_cost"; - const std::vector flow = - m->LookupVars(ct.arguments[has_cost ? 3 : 2]); - - // First, encode the flow conservation constraints as sums for performance: - // updating balance variables is done faster locally. - const int num_nodes = ct.arguments[1].values.size(); - std::vector> flows_per_node(num_nodes); - std::vector> coeffs_per_node(num_nodes); - - const int num_arcs = ct.arguments[0].values.size() / 2; - for (int arc = 0; arc < num_arcs; arc++) { - const int tail = ct.arguments[0].values[2 * arc] - 1; - flows_per_node[tail].push_back(flow[arc]); - coeffs_per_node[tail].push_back(1); - - const int head = ct.arguments[0].values[2 * arc + 1] - 1; - flows_per_node[head].push_back(flow[arc]); - coeffs_per_node[head].push_back(-1); - } - - for (int node = 0; node < num_nodes; node++) { - m->model.Add(FixedWeightedSum(flows_per_node[node], coeffs_per_node[node], - ct.arguments[1].values[node])); - } - - if (has_cost) { - std::vector filtered_flows; - std::vector filtered_costs; - - for (int arc = 0; arc < num_arcs; arc++) { - const int64 weight = ct.arguments[2].values[arc]; - if (weight == 0) continue; - - filtered_flows.push_back(flow[arc]); - filtered_costs.push_back(weight); - } - - filtered_flows.push_back(m->LookupVar(ct.arguments[4])); - filtered_costs.push_back(-1); - - m->model.Add(FixedWeightedSum(filtered_flows, filtered_costs, 0)); - } - - // Then pass the problem to global FlowCosts constraint. - std::vector balance; - for (const int value : ct.arguments[1].values) { - balance.push_back(m->model.Add(ConstantIntegerVariable(value))); - } - - const std::vector arcs = ct.arguments[0].values; - std::vector tails; - std::vector heads; - for (int arc = 0; arc < num_arcs; arc++) { - tails.push_back(arcs[2 * arc] - 1); - heads.push_back(arcs[2 * arc + 1] - 1); - } - - std::vector> weights_per_cost_type; - if (has_cost) { - std::vector weights; - for (const int64 value : ct.arguments[2].values) { - weights.push_back(static_cast(value)); - } - weights_per_cost_type.push_back(weights); - } - - std::vector total_costs_per_cost_type; - if (has_cost) { - total_costs_per_cost_type.push_back(m->LookupVar(ct.arguments[4])); - } - - m->model.Add(FlowCostsConstraint(balance, flow, tails, heads, - weights_per_cost_type, - total_costs_per_cost_type)); -} - -// Returns false iff the constraint type is not supported. -bool ExtractConstraint(const fz::Constraint& ct, SatModel* m) { - if (ct.type == "bool_eq") { - ExtractBoolEq(ct, m); - } else if (ct.type == "bool_eq_reif") { - ExtractBoolEqNeReif(/*is_eq=*/true, ct, m); - } else if (ct.type == "bool_ne" || ct.type == "bool_not") { - ExtractBoolNe(ct, m); - } else if (ct.type == "bool_ne_reif") { - ExtractBoolEqNeReif(/*is_eq=*/false, ct, m); - } else if (ct.type == "bool_le") { - ExtractBoolLe(ct, m); - } else if (ct.type == "bool_le_reif") { - ExtractBoolLeLtReif(/*is_le=*/true, ct, m); - } else if (ct.type == "bool_lt_reif") { - ExtractBoolLeLtReif(/*is_le=*/false, ct, m); - } else if (ct.type == "bool_clause") { - ExtractBoolClause(ct, m); - } else if (ct.type == "array_bool_and") { - ExtractArrayBoolAnd(ct, m); - } else if (ct.type == "array_bool_or") { - ExtractArrayBoolOr(ct, m); - } else if (ct.type == "array_bool_xor") { - ExtractArrayBoolXor(ct, m); - } else if (ct.type == "int_min") { - ExtractIntMin(ct, m); - } else if (ct.type == "int_abs") { - ExtractIntAbs(ct, m); - } else if (ct.type == "int_max") { - ExtractIntMax(ct, m); - } else if (ct.type == "int_times") { - ExtractIntTimes(ct, m); - } else if (ct.type == "int_div") { - ExtractIntDiv(ct, m); - } else if (ct.type == "int_plus") { - ExtractIntPlus(ct, m); - } else if (ct.type == "array_int_minimum" || ct.type == "minimum_int") { - ExtractArrayIntMinimum(ct, m); - } else if (ct.type == "array_int_maximum" || ct.type == "maximum_int") { - ExtractArrayIntMaximum(ct, m); - } else if (ct.type == "array_int_element" || - ct.type == "array_bool_element") { - ExtractArrayIntElement(ct, m); - } else if (ct.type == "array_var_int_element" || - ct.type == "array_var_bool_element") { - ExtractArrayVarIntElement(ct, m); - } else if (ct.type == "all_different_int") { - ExtractAllDifferentInt(ct, m); - } else if (ct.type == "int_eq" || ct.type == "bool2int") { - ExtractIntEq(ct, m); - } else if (ct.type == "int_ne") { - ExtractIntNe(ct, m); - } else if (ct.type == "int_le") { - ExtractIntLe(ct, m); - } else if (ct.type == "int_ge") { - ExtractIntGe(ct, m); - } else if (ct.type == "int_lt") { - ExtractIntLt(ct, m); - } else if (ct.type == "int_le_reif") { - ExtractIntLeGeReif(/*is_le=*/true, ct, m); - } else if (ct.type == "int_ge_reif") { - ExtractIntLeGeReif(/*is_le=*/false, ct, m); - } else if (ct.type == "int_lt_reif") { - ExtractIntLtReif(ct, m); - } else if (ct.type == "int_eq_reif") { - ExtractIntEqNeReif(ct, /*eq=*/true, m); - } else if (ct.type == "int_ne_reif") { - ExtractIntEqNeReif(ct, /*eq=*/false, m); - } else if (ct.type == "int_lin_eq") { - ExtractIntLinEq(ct, m); - } else if (ct.type == "int_lin_ne") { - ExtractIntLinNe(ct, m); - } else if (ct.type == "int_lin_le") { - ExtractIntLinLe(ct, m); - } else if (ct.type == "int_lin_ge") { - ExtractIntLinGe(ct, m); - } else if (ct.type == "int_lin_eq_reif") { - ExtractIntLinEqReif(ct, m); - } else if (ct.type == "int_lin_ne_reif") { - ExtractIntLinNeReif(ct, m); - } else if (ct.type == "int_lin_le_reif") { - ExtractIntLinLeReif(ct, m); - } else if (ct.type == "int_lin_ge_reif") { - ExtractIntLinGeReif(ct, m); - } else if (ct.type == "circuit") { - ExtractCircuit(ct, /*allow_subcircuit=*/false, m); - } else if (ct.type == "subcircuit") { - ExtractCircuit(ct, /*allow_subcircuit=*/true, m); - } else if (ct.type == "regular") { - ExtractRegular(ct, m); - } else if (ct.type == "table_int") { - ExtractTableInt(ct, m); - } else if (ct.type == "set_in_reif") { - ExtractSetInReif(ct, m); - } else if (ct.type == "diffn") { - ExtractDiffN(ct, m); - } else if (ct.type == "diffn_nonstrict") { - ExtractDiffNNonStrict(ct, m); - } else if (ct.type == "cumulative" || ct.type == "var_cumulative" || - ct.type == "variable_cumulative" || - ct.type == "fixed_cumulative") { - ExtractCumulative(ct, m); - } else if (ct.type == "network_flow" || ct.type == "network_flow_cost") { - ExtractNetworkFlow(ct, m); - } else if (ct.type == "false_constraint") { - m->model.GetOrCreate()->NotifyThatModelIsUnsat(); - } else { - return false; - } - return true; -} - -// ============================================================================= -// SAT/CP flatzinc solver. -// ============================================================================= - -// The format is fixed in the flatzinc specification. -std::string SolutionString(const SatModel& m, - const fz::SolutionOutputSpecs& output) { - if (output.variable != nullptr) { - const int64 value = m.Value(output.variable); - if (output.display_as_boolean) { - return StringPrintf("%s = %s;", output.name.c_str(), - value == 1 ? "true" : "false"); - } else { - return StringPrintf("%s = %" GG_LL_FORMAT "d;", output.name.c_str(), - value); - } - } else { - const int bound_size = output.bounds.size(); - std::string result = - StringPrintf("%s = array%dd(", output.name.c_str(), bound_size); - for (int i = 0; i < bound_size; ++i) { - if (output.bounds[i].max_value != 0) { - result.append(StringPrintf("%" GG_LL_FORMAT "d..%" GG_LL_FORMAT "d, ", - output.bounds[i].min_value, - output.bounds[i].max_value)); - } else { - result.append("{},"); - } - } - result.append("["); - for (int i = 0; i < output.flat_variables.size(); ++i) { - const int64 value = m.Value(output.flat_variables[i]); - if (output.display_as_boolean) { - result.append(StringPrintf(value ? "true" : "false")); - } else { - StrAppend(&result, value); - } - if (i != output.flat_variables.size() - 1) { - result.append(", "); - } - } - result.append("]);"); - return result; - } - return ""; -} - -std::string CheckSolutionAndGetFzString(const fz::Model& fz_model, - const SatModel& m) { - CHECK(CheckSolution(fz_model, - [&m](fz::IntegerVariable* v) { return m.Value(v); })); - - std::string solution_string; - for (const fz::SolutionOutputSpecs& output : fz_model.output()) { - solution_string.append(SolutionString(m, output)); - solution_string.append("\n"); - } - return solution_string + "----------\n"; -} - -} // namespace - -void SolveWithSat(const fz::Model& fz_model, const fz::FlatzincParameters& p, - bool* interrupt_solve) { - // Timing. - WallTimer wall_timer; - UserTimer user_timer; - wall_timer.Start(); - user_timer.Start(); - - SatModel m; - std::unique_ptr time_limit; - if (p.time_limit_in_ms > 0) { - time_limit.reset(new TimeLimit(p.time_limit_in_ms * 1e-3)); - } else { - time_limit = TimeLimit::Infinite(); - } - time_limit->RegisterExternalBooleanAsLimit(interrupt_solve); - m.model.SetSingleton(std::move(time_limit)); - - // Process the bool_not constraints to avoid creating extra Boolean variables. - std::unordered_map not_map; - for (fz::Constraint* ct : fz_model.constraints()) { - if (ct != nullptr && ct->active && - (ct->type == "bool_not" || ct->type == "bool_ne")) { - not_map[ct->arguments[0].Var()] = ct->arguments[1].Var(); - not_map[ct->arguments[1].Var()] = ct->arguments[0].Var(); - } - } - - // Extract all the variables. - int num_constants = 0; - int num_variables_with_two_values = 0; - std::set constant_values; - std::map num_vars_per_domains; - FZLOG << "Extracting " << fz_model.variables().size() << " variables. " - << FZENDL; - int num_capped_variables = 0; - for (fz::IntegerVariable* var : fz_model.variables()) { - if (!var->active) continue; - - // Will be encoded as a constant lazily as needed. - if (var->domain.HasOneValue()) { - ++num_constants; - constant_values.insert(var->domain.Value()); - continue; - } - - const int64 safe_min = - var->domain.Min() == kint64min ? kint32min : var->domain.Min(); - const int64 safe_max = - var->domain.Max() == kint64max ? kint32max : var->domain.Max(); - if (safe_min != var->domain.Min() || safe_max != var->domain.Max()) { - num_capped_variables++; - } - - // Special case for Boolean. We don't automatically create the associated - // integer variable. It will only be created if a constraint needs to see - // the Boolean variable as an IntegerVariable - if (var->domain.Min() == 0 && var->domain.Max() == 1) { - const Literal literal = - ContainsKey(not_map, var) && ContainsKey(m.bool_map, not_map[var]) - ? m.bool_map[not_map[var]].Negated() - : Literal(m.model.Add(NewBooleanVariable()), true); - InsertOrDie(&m.bool_map, var, literal); - continue; - } - - // Create the associated sat::IntegerVariable. Note that it will be lazily - // fully-encoded by the propagators that need it, except for the variables - // with just two values because it seems more efficient to do so. - // - // TODO(user): Experiment more with proactive full-encoding. Chuffed seems - // to fully encode all variables with a small domain. - std::string domain_as_string; - bool only_two_values = false; - if (var->domain.is_interval) { - only_two_values = (safe_min + 1 == safe_max); - domain_as_string = ClosedInterval({safe_min, safe_max}).DebugString(); - InsertOrDie(&m.var_map, var, - m.model.Add(NewIntegerVariable(safe_min, safe_max))); - } else { - only_two_values = (var->domain.values.size() == 2); - const std::vector domain = - SortedDisjointIntervalsFromValues(var->domain.values); - InsertOrDie(&m.var_map, var, m.model.Add(NewIntegerVariable(domain))); - domain_as_string = IntervalsAsString(domain); - } - num_vars_per_domains[domain_as_string]++; - - if (only_two_values) { - ++num_variables_with_two_values; - m.model.Add(FullyEncodeVariable(m.LookupVar(var))); - } - } - for (const auto& entry : num_vars_per_domains) { - FZLOG << " - " << entry.second << " vars in " << entry.first << FZENDL; - } - FZLOG << " - " << num_constants << " constants in {" - << strings::Join(constant_values, ",") << "}." << FZENDL; - if (num_capped_variables > 0) { - FZLOG << " - " << num_capped_variables - << " variables have been capped to fit into [int32min .. int32max]" - << FZENDL; - } - - // Extract all the constraints. - FZLOG << "Extracting " << fz_model.constraints().size() << " constraints. " - << FZENDL; - std::set unsupported_types; - Trail* trail = m.model.GetOrCreate(); - for (fz::Constraint* ct : fz_model.constraints()) { - if (ct != nullptr && ct->active) { - const int old_num_fixed = trail->Index(); - FZVLOG << "Extracting '" << ct->type << "'." << FZENDL; - if (!ExtractConstraint(*ct, &m)) { - unsupported_types.insert(ct->type); - } - - // We propagate after each new Boolean constraint but not the integer - // ones. So we call Propagate() manually here. TODO(user): Do that - // automatically? - m.model.GetOrCreate()->Propagate(); - if (trail->Index() > old_num_fixed) { - FZVLOG << "Constraint fixed " << trail->Index() - old_num_fixed - << " Boolean variable(s): " << ct->DebugString() << FZENDL; - } - if (m.model.GetOrCreate()->IsModelUnsat()) { - FZLOG << "UNSAT during extraction (after adding '" << ct->type << "')." - << FZENDL; - break; - } - } - } - if (!unsupported_types.empty()) { - FZLOG << "There are unsupported constraints types in this model: " - << FZENDL; - for (const std::string& type : unsupported_types) { - FZLOG << " - " << type << FZENDL; - } - return; - } - - // Use LinearProgrammingConstraint only if there was a linear inequality, - // i.e. if it is already instantiated in the model. - if (FLAGS_fz_use_lp_constraint && - m.model.Get() != nullptr) { - LinearProgrammingConstraint* lp = - m.model.GetOrCreate(); - lp->RegisterWith(m.model.GetOrCreate()); - } - - // Some stats. - { - int num_bool_as_int = 0; - for (auto entry : m.bool_map) { - if (ContainsKey(m.var_map, entry.first)) ++num_bool_as_int; - } - int num_fully_encoded_variables = 0; - for (int i = 0; - i < m.model.GetOrCreate()->NumIntegerVariables(); ++i) { - if (m.model.Get()->VariableIsFullyEncoded( - IntegerVariable(i))) { - ++num_fully_encoded_variables; - } - } - // We divide by two because of the automatically created NegationOf() var. - FZLOG << "Num integer variables = " - << m.model.GetOrCreate()->NumIntegerVariables() / 2 - << " (" << num_bool_as_int << " Booleans)." << FZENDL; - FZLOG << "Num fully encoded variable = " << num_fully_encoded_variables / 2 - << FZENDL; - FZLOG << "Num initial SAT variables = " - << m.model.Get()->NumVariables() << " (" - << m.model.Get()->LiteralTrail().Index() << " fixed)." - << FZENDL; - FZLOG << "Num vars with 2 values = " << num_variables_with_two_values - << FZENDL; - FZLOG << "Num constants = " - << m.model.Get()->NumConstantVariables() << FZENDL; - FZLOG << "Num integer propagators = " - << m.model.GetOrCreate()->NumPropagators() - << FZENDL; - } - - int num_solutions = 0; - int64 best_objective = 0; - std::string solutions_string; - std::string search_status; - - // Important: we use the order of the variable from flatzinc with the - // non-defined variable first. In particular we don't want to iterate on - // m.var_map which order is randomized! - // - // TODO(user): We could restrict these if we are sure all the other variables - // will be fixed once these are fixed. - std::vector decision_vars; - for (fz::IntegerVariable* var : fz_model.variables()) { - if (!var->active || var->domain.HasOneValue()) continue; - if (var->defining_constraint != nullptr) continue; - if (ContainsKey(m.bool_map, var)) continue; - decision_vars.push_back(FindOrDie(m.var_map, var)); - } - for (fz::IntegerVariable* var : fz_model.variables()) { - if (!var->active || var->domain.HasOneValue()) continue; - if (var->defining_constraint == nullptr) continue; - if (ContainsKey(m.bool_map, var)) continue; - decision_vars.push_back(FindOrDie(m.var_map, var)); - } - - // TODO(user): deal with other search parameters. - FZLOG << "Solving..." << FZENDL; - SatSolver::Status status; - if (fz_model.objective() == nullptr) { - // Decision problem. - while (num_solutions < p.num_solutions) { - status = SolveIntegerProblemWithLazyEncoding( - /*assumptions=*/{}, - FirstUnassignedVarAtItsMinHeuristic(decision_vars, &m.model), - &m.model); - - if (status == SatSolver::MODEL_SAT) { - ++num_solutions; - FZLOG << "Solution #" << num_solutions - << " num_bool:" << m.model.Get()->NumVariables() - << FZENDL; - solutions_string += CheckSolutionAndGetFzString(fz_model, m); - if (num_solutions < p.num_solutions) { - m.model.Add(ExcludeCurrentSolutionAndBacktrack()); - } - continue; - } - - if (status == SatSolver::MODEL_UNSAT) { - if (num_solutions == 0) { - search_status = "=====UNSATISFIABLE====="; - } - break; - } - - // Limit reached. - break; - } - } else { - // Optimization problem. - const IntegerVariable objective_var = m.LookupVar(fz_model.objective()); - status = MinimizeIntegerVariableWithLinearScanAndLazyEncoding( - /*log_info=*/false, - fz_model.maximize() ? NegationOf(objective_var) : objective_var, - FirstUnassignedVarAtItsMinHeuristic(decision_vars, &m.model), - [objective_var, &num_solutions, &m, &fz_model, &best_objective, - &solutions_string](const Model& sat_model) { - num_solutions++; - best_objective = sat_model.Get(LowerBound(objective_var)); - FZLOG << "Solution #" << num_solutions << " obj:" << best_objective - << " num_bool:" << sat_model.Get()->NumVariables() - << FZENDL; - solutions_string = CheckSolutionAndGetFzString(fz_model, m); - }, - &m.model); - if (num_solutions > 0) { - search_status = "=========="; - } else { - search_status = "=====UNSATISFIABLE====="; - } - } - - if (fz_model.objective() == nullptr) { - FZLOG << "Status: " << status << FZENDL; - FZLOG << "Objective: NA" << FZENDL; - FZLOG << "Best_bound: NA" << FZENDL; - } else { - m.model.GetOrCreate()->Backtrack(0); - const IntegerVariable objective_var = m.LookupVar(fz_model.objective()); - int64 best_bound = - m.model.Get(fz_model.maximize() ? UpperBound(objective_var) - : LowerBound(objective_var)); - if (num_solutions == 0) { - FZLOG << "Status: " << status << FZENDL; - FZLOG << "Objective: NA" << FZENDL; - } else { - if (status == SatSolver::MODEL_SAT) { - FZLOG << "Status: OPTIMAL" << FZENDL; - - // We need this because even if we proved unsat, that doesn't mean we - // propagated the best bound to its current value. - best_bound = best_objective; - } else { - FZLOG << "Status: " << status << FZENDL; - } - FZLOG << "Objective: " << best_objective << FZENDL; - } - FZLOG << "Best_bound: " << best_bound; - } - FZLOG << "Booleans: " << m.model.Get()->NumVariables() << FZENDL; - FZLOG << "Conflicts: " << m.model.Get()->num_failures() << FZENDL; - FZLOG << "Branches: " << m.model.Get()->num_branches() << FZENDL; - FZLOG << "Propagations: " << m.model.Get()->num_propagations() - << FZENDL; - FZLOG << "Walltime: " << wall_timer.Get() << FZENDL; - FZLOG << "Usertime: " << user_timer.Get() << FZENDL; - FZLOG << "Deterministic_time: " - << m.model.Get()->deterministic_time() << FZENDL; - - if (status == SatSolver::LIMIT_REACHED) { - search_status = "%% LIMIT_REACHED"; - } - - // Print the solution(s). - std::cout << solutions_string; - if (!search_status.empty()) { - std::cout << search_status << std::endl; - } -} - -} // namespace sat -} // namespace operations_research diff --git a/ortools/flatzinc/sat_fz_solver.h b/ortools/flatzinc/sat_fz_solver.h deleted file mode 100644 index 2a5ed40db7..0000000000 --- a/ortools/flatzinc/sat_fz_solver.h +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2010-2014 Google -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef OR_TOOLS_FLATZINC_SAT_FZ_SOLVER_H_ -#define OR_TOOLS_FLATZINC_SAT_FZ_SOLVER_H_ - -#include "ortools/flatzinc/model.h" -#include "ortools/flatzinc/solver.h" - -namespace operations_research { -namespace sat { - -void SolveWithSat(const fz::Model& model, const fz::FlatzincParameters& p, - bool* interup_solve); - -} // namespace sat -} // namespace operations_research - -#endif // OR_TOOLS_FLATZINC_SAT_FZ_SOLVER_H_ diff --git a/ortools/lp_data/lp_data.h b/ortools/lp_data/lp_data.h index 096e276e6c..ac6771bd45 100644 --- a/ortools/lp_data/lp_data.h +++ b/ortools/lp_data/lp_data.h @@ -455,8 +455,9 @@ class LinearProgram { // Scales the problem using the given scaler. void Scale(SparseMatrixScaler* scaler); - // Scales the costs to always have a maximum cost magnitude of 1.0 and returns - // the used cost scaling factor. + // Scales the costs to always have a maximum cost magnitude of 1.0. The old + // cost of each variable can be retrieved by multiplying the new one with the + // returned factor. This also updates objective_scaling_factor(). Fractional ScaleObjective(); // Removes the given row indices from the LinearProgram. diff --git a/ortools/sat/cp_model.proto b/ortools/sat/cp_model.proto index 842da5cfe3..d6ecf8936f 100644 --- a/ortools/sat/cp_model.proto +++ b/ortools/sat/cp_model.proto @@ -226,14 +226,14 @@ message ConstraintProto { // // This is in a message because decision problems don't have any objective. message CpObjectiveProto { - // Index of the variable to minimize. - // - // For a maximization problem, one can refer to the negation of the real + // The linear terms of the objective to minimize. + // For a maximization problem, one can negate all coefficients in the // objective and set a scaling_factor to -1. - int32 objective_var = 1; + repeated int32 vars = 1; + repeated int64 coeffs = 4; // The displayed objective is always: - // scaling_factor * (Value(objective_var) + offset). + // scaling_factor * (sum(coefficients[i] * objective_vars[i]) + offset). // This is needed to have a consistent objective after presolve or when // scaling a double problem to express it with integers. // @@ -264,7 +264,7 @@ message DecisionStrategyProto { } VariableSelectionStrategy variable_selection_strategy = 2; - // Once a variable has been choosen, this enum describe what decision is taken + // Once a variable has been chosen, this enum describe what decision is taken // on its domain. // // TODO(user): extend as needed. @@ -298,10 +298,7 @@ message CpModelProto { repeated ConstraintProto constraints = 3; // The objective to minimize. Can be empty for pure decision problems. - // Note that we can have more than one objective for the cases where we want - // to optimize them in lexicographic order, or if we want to list the Pareto - // optimal solutions. - repeated CpObjectiveProto objectives = 4; + CpObjectiveProto objective = 4; // Defines the strategy that the solver should follow when the "fixed_search" // parameters is set to true. Note that this strategy is also used as an @@ -347,7 +344,7 @@ message CpSolverResponse { CpSolverStatus status = 1; // A feasible solution to the given problem. Depending on the returned status - // it may be optimal or just feasible. This is in one-to-one correspondance + // it may be optimal or just feasible. This is in one-to-one correspondence // with a CpModelProto::variables repeated field and list the values of all // the variables. repeated int64 solution = 2; diff --git a/ortools/sat/cp_model_checker.cc b/ortools/sat/cp_model_checker.cc index b0fae5f155..9d8f1b17c3 100644 --- a/ortools/sat/cp_model_checker.cc +++ b/ortools/sat/cp_model_checker.cc @@ -149,6 +149,38 @@ std::string ValidateLinearConstraint(const CpModelProto& model, return ""; } +std::string ValidateObjective(const CpModelProto& model, + const CpObjectiveProto& obj) { + // TODO(user): share the code with ValidateLinearConstraint(). + if (obj.vars_size() == 1 && obj.coeffs(0) == 1) return ""; + int64 sum_min = 0; + int64 sum_max = 0; + for (int i = 0; i < obj.vars_size(); ++i) { + const int ref = obj.vars(i); + const auto& var_proto = model.variables(PositiveRef(ref)); + const int64 min_domain = var_proto.domain(0); + const int64 max_domain = var_proto.domain(var_proto.domain_size() - 1); + const int64 coeff = RefIsPositive(ref) ? obj.coeffs(i) : -obj.coeffs(i); + const int64 prod1 = CapProd(min_domain, coeff); + const int64 prod2 = CapProd(max_domain, coeff); + + // Note that we use min/max with zero to disallow "alternative" terms and + // be sure that we cannot have an overflow if we do the computation in a + // different order. + sum_min = CapAdd(sum_min, std::min(0ll, std::min(prod1, prod2))); + sum_max = CapAdd(sum_max, std::max(0ll, std::max(prod1, prod2))); + for (const int64 v : {prod1, prod2, sum_min, sum_max}) { + // When introducing the objective variable, we use a [...] domain so we + // need to be more defensive here to make sure no overflow can happen in + // linear constraint propagator. + if (v == kint64max / 2 || v == kint64min / 2) { + return "Possible integer overflow in objective: " + obj.DebugString(); + } + } + } + return ""; +} + } // namespace std::string ValidateCpModel(const CpModelProto& model) { @@ -186,12 +218,14 @@ std::string ValidateCpModel(const CpModelProto& model) { break; } } - for (const CpObjectiveProto& objective : model.objectives()) { - const int v = objective.objective_var(); - if (!VariableReferenceIsValid(model, v)) { - return StrCat("Out of bound objective variable ", v, " : ", - objective.ShortDebugString()); + if (model.has_objective()) { + for (const int v : model.objective().vars()) { + if (!VariableReferenceIsValid(model, v)) { + return StrCat("Out of bound objective variable ", v, " : ", + model.objective().ShortDebugString()); + } } + RETURN_IF_NOT_EMPTY(ValidateObjective(model, model.objective())); } return ""; diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index 9e7132e3ef..2a4249bfef 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -167,6 +167,11 @@ struct PresolveContext { : -domains[PositiveRef(ref)].Min(); } + bool IsUniqueOrFixed(int ref) { + return var_to_constraints[PositiveRef(ref)].size() == 1 || + domains[PositiveRef(ref)].IsFixed(); + } + // Regroups fixed variables with the same value. // TODO(user): Also regroup cte and -cte? void ExploitFixedDomain(int var) { @@ -993,16 +998,18 @@ bool PresolveInterval(ConstraintProto* ct, PresolveContext* context) { bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { const int index_ref = ct->element().index(); - if (context->var_to_constraints[PositiveRef(index_ref)].size() == 1) { - context->UpdateRuleStats("TODO element: index not used elsewhere"); - } const int target_ref = ct->element().target(); - if (context->var_to_constraints[PositiveRef(target_ref)].size() == 1) { - context->UpdateRuleStats("TODO element: target not used elsewhere"); - } + const bool unique_index = context->IsUniqueOrFixed(index_ref); + const bool unique_target = context->IsUniqueOrFixed(target_ref); + // TODO(user): think about this once we do have such constraint. if (HasEnforcementLiteral(*ct)) return false; + int num_vars = 0; + bool all_constants = true; + std::unordered_set constant_set; + + bool all_included_in_target_domain = true; bool reduced_index_domain = false; std::vector infered_domain; const std::vector target_dom = @@ -1018,6 +1025,17 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { RefIsPositive(index_ref) ? i : -i); reduced_index_domain = true; } else { + ++num_vars; + if (domain.front().start == domain.back().end) { + constant_set.insert(domain.front().start); + } else { + all_constants = false; + } + if (IntersectionOfSortedDisjointIntervals( + target_dom, ComplementOfSortedDisjointIntervals(domain)) + .empty()) { + all_included_in_target_domain = false; + } infered_domain = UnionOfSortedDisjointIntervals(infered_domain, domain); } } @@ -1031,6 +1049,33 @@ bool PresolveElement(ConstraintProto* ct, PresolveContext* context) { if (context->domains[PositiveRef(target_ref)].IntersectWith(infered_domain)) { context->UpdateRuleStats("element: reduced target domain"); } + if (all_constants && unique_index) { + // This constraint is just here to reduce the domain of the target! We can + // add it to the mapping_model to reconstruct the index value during + // postsolve and get rid of it now. + context->UpdateRuleStats("element: trivial target domain reduction"); + *(context->mapping_model->add_constraints()) = *ct; + return RemoveConstraint(ct, context); + } + if (all_included_in_target_domain && unique_target) { + context->UpdateRuleStats("element: trivial index domain reduction"); + *(context->mapping_model->add_constraints()) = *ct; + return RemoveConstraint(ct, context); + } + + if (all_constants && num_vars == constant_set.size()) { + // TODO(user): We should be able to do something for simple mapping. + context->UpdateRuleStats("TODO element: one to one mapping"); + } + if (unique_target) { + context->UpdateRuleStats("TODO element: target not used elsewhere"); + } + if (context->domains[PositiveRef(index_ref)].IsFixed()) { + context->UpdateRuleStats("TODO element: fixed index."); + } else if (unique_index) { + context->UpdateRuleStats("TODO element: index not used elsewhere"); + } + return false; } @@ -1291,8 +1336,10 @@ void PresolveCpModel(const CpModelProto& initial_model, // Hack for the objective so that it is never considered to appear in only one // constraint. - for (const CpObjectiveProto& obj : initial_model.objectives()) { - context.var_to_constraints[PositiveRef(obj.objective_var())].insert(-1); + if (initial_model.has_objective()) { + for (const int obj_var : initial_model.objective().vars()) { + context.var_to_constraints[PositiveRef(obj_var)].insert(-1); + } } while (!queue.empty() && !context.is_unsat) { @@ -1421,6 +1468,88 @@ void PresolveCpModel(const CpModelProto& initial_model, return; } + // If the objective is a single variable, try to find a linear equation that + // "defines" it and expand the objective into its longer linear + // representation. + // TODO(user): Insert in main loop. + if (context.working_model->has_objective() && + context.working_model->objective().vars_size() == 1) { + CpObjectiveProto* const mutable_objective = + context.working_model->mutable_objective(); + const int initial_obj_var = mutable_objective->vars(0); + const int64 initial_coeff = mutable_objective->coeffs(0); + const double initial_offset = mutable_objective->offset(); + // TODO(user): Expand the linear equation recursively in order to have + // as much term as possible? This would also enable expanding an objective + // with multiple terms. + int expanded_linear_index = -1; + for (int ct_index = 0; ct_index < context.working_model->constraints_size(); + ++ct_index) { + const ConstraintProto& ct = context.working_model->constraints(ct_index); + // Skip everything that is not a linear equality constraint. + if (!ct.enforcement_literal().empty()) continue; + if (ct.constraint_case() != ConstraintProto::ConstraintCase::kLinear) { + continue; + } + if (ct.linear().domain().size() != 2) continue; + if (ct.linear().domain(0) != ct.linear().domain(1)) continue; + + // Find out if initial_obj_var appear in this constraint. + bool present = false; + int64 objective_coeff; + const int num_terms = ct.linear().vars_size(); + for (int i = 0; i < num_terms; ++i) { + const int ref = ct.linear().vars(i); + const int64 coeff = ct.linear().coeffs(i); + if (PositiveRef(ref) == PositiveRef(initial_obj_var)) { + CHECK(!present) << "Duplicate variables not supported"; + present = true; + objective_coeff = ref == initial_obj_var ? coeff : -coeff; + } + } + + // We use the longest equality we can find. + // TODO(user): Deal with objective_coeff with a magnitude greater than 1? + // Accept when initial_coeff divides objective_coeff. + if (present && std::abs(objective_coeff) == 1 && + num_terms > mutable_objective->vars_size() + 1) { + expanded_linear_index = ct_index; + mutable_objective->clear_coeffs(); + mutable_objective->clear_vars(); + const int64 rhs = ct.linear().domain(0); + if (rhs != 0) { + mutable_objective->set_offset(rhs * initial_coeff * objective_coeff + + initial_offset); + } + for (int i = 0; i < num_terms; ++i) { + const int ref = ct.linear().vars(i); + if (PositiveRef(ref) != PositiveRef(initial_obj_var)) { + mutable_objective->add_vars(ref); + mutable_objective->add_coeffs( + -1 * initial_coeff * ct.linear().coeffs(i) * objective_coeff); + } + } + } + } + + if (expanded_linear_index != -1) { + context.UpdateRuleStats("objective: expanded single objective"); + ConstraintProto* const ct = + context.working_model->mutable_constraints(expanded_linear_index); + // Remove the objective variable special case and make sure the new + // objective variables cannot be removed: + for (int ref : ct->linear().vars()) { + context.var_to_constraints[PositiveRef(ref)].insert(-1); + } + context.var_to_constraints[PositiveRef(initial_obj_var)].erase(-1); + + // This function will detect that the old objective is not used + // elsewhere and remove it from the equation. + PresolveLinear(ct, &context); + context.UpdateConstraintVariableUsage(expanded_linear_index); + } + } + // Remove all empty or affine constraints (they will be re-added later if // needed) in the presolved model. Note that we need to remap the interval // references. @@ -1607,12 +1736,13 @@ void ApplyVariableMapping(const std::vector& mapping, ApplyToAllLiteralIndices(f, &ct_ref); } - // Remap the objectives. - for (CpObjectiveProto& objective : *proto->mutable_objectives()) { - const int ref = objective.objective_var(); - const int image = mapping[PositiveRef(ref)]; - CHECK_GE(image, 0); - objective.set_objective_var(ref >= 0 ? image : NegatedRef(image)); + // Remap the objective variables. + if (proto->has_objective()) { + for (int& mutable_var : *proto->mutable_objective()->mutable_vars()) { + const int image = mapping[PositiveRef(mutable_var)]; + CHECK_GE(image, 0); + mutable_var = (mutable_var >= 0 ? image : NegatedRef(image)); + } } // Remap the search decision heuristic. diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index 42abeacf33..2c23def0c3 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -70,8 +70,10 @@ VariableUsage ComputeVariableUsage(const CpModelProto& model_proto) { // Add the objectives and search heuristics variables that needs to be // referenceable as integer even if they are only used as Booleans. - for (const CpObjectiveProto& objective : model_proto.objectives()) { - references.variables.insert(objective.objective_var()); + if (model_proto.has_objective()) { + for (const int obj_var : model_proto.objective().vars()) { + references.variables.insert(obj_var); + } } for (const DecisionStrategyProto& strategy : model_proto.search_strategy()) { for (const int var : strategy.variables()) { @@ -124,6 +126,16 @@ class ModelWithMapping { return integers_[PositiveRef(i)] != kNoIntegerVariable; } + // TODO(user): This does not returns true for [0,1] Integer variable that + // never appear as a literal elsewhere. This is not ideal because in + // LoadLinearConstraint() we probably still want to create the associated + // Boolean and maybe not even create the [0,1] integer variable if it is not + // used. + bool IsBoolean(int i) const { + CHECK_LT(PositiveRef(i), booleans_.size()); + return booleans_[PositiveRef(i)] != kNoBooleanVariable; + } + IntegerVariable Integer(int i) const { CHECK_LT(PositiveRef(i), integers_.size()); const IntegerVariable var = integers_[PositiveRef(i)]; @@ -146,7 +158,7 @@ class ModelWithMapping { } sat::Literal Literal(int i) const { - CHECK_LT(PositiveRef(i), integers_.size()); + CHECK_LT(PositiveRef(i), booleans_.size()); return sat::Literal(booleans_[PositiveRef(i)], RefIsPositive(i)); } @@ -215,6 +227,8 @@ class ModelWithMapping { return ContainsKey(ct_to_ignore_, ct); } + Model* model() const { return model_; } + private: void ExtractEncoding(const CpModelProto& model_proto); @@ -741,8 +755,29 @@ void LoadLinearConstraint(const ConstraintProto& ct, ModelWithMapping* m) { const int64 lb = ct.linear().domain(0); const int64 ub = ct.linear().domain(1); if (!HasEnforcementLiteral(ct)) { - if (lb != kint64min) m->Add(WeightedSumGreaterOrEqual(vars, coeffs, lb)); - if (ub != kint64max) m->Add(WeightedSumLowerOrEqual(vars, coeffs, ub)); + // Detect if there is only Booleans in order to use a more efficient + // propagator. TODO(user): we should probably also implement an + // half-reified version of this constraint. + bool all_booleans = true; + std::vector cst; + for (int i = 0; i < vars.size(); ++i) { + const int ref = ct.linear().vars(i); + if (!m->IsBoolean(ref)) { + all_booleans = false; + continue; + } + cst.push_back({m->Literal(ref), coeffs[i]}); + } + if (all_booleans) { + m->Add(BooleanLinearConstraint(lb, ub, &cst)); + } else { + if (lb != kint64min) { + m->Add(WeightedSumGreaterOrEqual(vars, coeffs, lb)); + } + if (ub != kint64max) { + m->Add(WeightedSumLowerOrEqual(vars, coeffs, ub)); + } + } } else { const Literal is_true = m->Literal(ct.enforcement_literal(0)); if (lb != kint64min) { @@ -859,8 +894,59 @@ void LoadCumulativeConstraint(const ConstraintProto& ct, ModelWithMapping* m) { m->Add(Cumulative(intervals, demands, capacity)); } +// If a variable is constant and its value appear in no other variable domains, +// then the literal encoding the index and the one encoding the target at this +// value are equivalent. +void DetectEquivalencesInElementConstraint(const ConstraintProto& ct, + ModelWithMapping* m) { + IntegerEncoder* encoder = m->GetOrCreate(); + IntegerTrail* integer_trail = m->GetOrCreate(); + + const IntegerVariable index = m->Integer(ct.element().index()); + const IntegerVariable target = m->Integer(ct.element().target()); + const std::vector vars = m->Integers(ct.element().vars()); + + if (m->Get(IsFixed(index))) return; + + std::vector union_of_non_constant_domains; + std::map constant_to_num; + for (const auto literal_value : m->Add(FullyEncodeVariable(index))) { + const int i = literal_value.value.value(); + if (m->Get(IsFixed(vars[i]))) { + const IntegerValue value(m->Get(Value(vars[i]))); + constant_to_num[value]++; + } else { + union_of_non_constant_domains = UnionOfSortedDisjointIntervals( + union_of_non_constant_domains, + integer_trail->InitialVariableDomain(vars[i])); + } + } + + // Bump the number if the constant appear in union_of_non_constant_domains. + for (const auto entry : constant_to_num) { + if (SortedDisjointIntervalsContain(union_of_non_constant_domains, + entry.first.value())) { + constant_to_num[entry.first]++; + } + } + + // Use the literal from the index encoding to encode the target at the + // "unique" values. + for (const auto literal_value : m->Add(FullyEncodeVariable(index))) { + const int i = literal_value.value.value(); + if (!m->Get(IsFixed(vars[i]))) continue; + + const IntegerValue value(m->Get(Value(vars[i]))); + if (constant_to_num[value] == 1) { + const Literal r = literal_value.literal; + encoder->AssociateToIntegerEqualValue(r, target, value); + } + } +} + // TODO(user): Be more efficient when the element().vars() are constants. -// Ideally we should avoid creating them as integer variable... +// Ideally we should avoid creating them as integer variable since we don't +// use them. void LoadElementConstraintBounds(const ConstraintProto& ct, ModelWithMapping* m) { const IntegerVariable index = m->Integer(ct.element().index()); @@ -886,11 +972,17 @@ void LoadElementConstraintBounds(const ConstraintProto& ct, selectors.push_back(literal_value.literal); const Literal r = literal_value.literal; - // TODO(user): Be more efficient if one of the two is a constant. Or handle - // that directly in the model function. if (vars[i] == target) continue; - m->Add(ConditionalLowerOrEqualWithOffset(vars[i], target, 0, r)); - m->Add(ConditionalLowerOrEqualWithOffset(target, vars[i], 0, r)); + if (m->Get(IsFixed(target))) { + const int64 value = m->Get(Value(target)); + m->Add(ImpliesInInterval(r, vars[i], value, value)); + } else if (m->Get(IsFixed(vars[i]))) { + const int64 value = m->Get(Value(vars[i])); + m->Add(ImpliesInInterval(r, target, value, value)); + } else { + m->Add(ConditionalLowerOrEqualWithOffset(vars[i], target, 0, r)); + m->Add(ConditionalLowerOrEqualWithOffset(target, vars[i], 0, r)); + } } m->Add(PartialIsOneOfVar(target, possible_vars, selectors)); } @@ -997,6 +1089,8 @@ void LoadElementConstraint(const ConstraintProto& ct, ModelWithMapping* m) { m->Get(IsFixed(variable)) || encoder->VariableIsFullyEncoded(variable); if (is_full) num_AC_variables++; } + + DetectEquivalencesInElementConstraint(ct, m); if (target_is_AC || num_AC_variables >= num_vars - 1) { LoadElementConstraintAC(ct, m); } else { @@ -1403,8 +1497,49 @@ void FillSolutionInResponse(const CpModelProto& model_proto, } } +namespace { +IntegerVariable GetOrCreateVariableEqualToSumOf( + Model* model, const std::vector>& terms) { + if (terms.empty()) return model->Add(ConstantIntegerVariable(0)); + if (terms.size() == 1 && terms.front().second == 1) { + return terms.front().first; + } + if (terms.size() == 1 && terms.front().second == -1) { + return NegationOf(terms.front().first); + } + + // Create a new variable equal to the sum, with a tight domain. + int64 sum_min = 0; + int64 sum_max = 0; + + for (const std::pair var_coeff : terms) { + const int64 min_domain = model->Get(LowerBound(var_coeff.first)); + const int64 max_domain = model->Get(UpperBound(var_coeff.first)); + const int64 coeff = var_coeff.second; + const int64 prod1 = min_domain * coeff; + const int64 prod2 = max_domain * coeff; + sum_min += std::min(prod1, prod2); + sum_max += std::max(prod1, prod2); + } + IntegerVariable new_var = model->Add(NewIntegerVariable(sum_min, sum_max)); + + // Link new variables with the linear terms. + std::vector vars; + std::vector coeffs; + for (const auto& term : terms) { + vars.push_back(term.first); + coeffs.push_back(term.second); + } + vars.push_back(new_var); + coeffs.push_back(-1); + model->Add(FixedWeightedSum(vars, coeffs, 0)); + return new_var; +} +} // namespace + // Adds one LinearProgrammingConstraint per connected component of the model. -void AddLPConstraints(const CpModelProto& model_proto, ModelWithMapping* m) { +IntegerVariable AddLPConstraints(const CpModelProto& model_proto, + ModelWithMapping* m) { const int num_constraints = model_proto.constraints().size(); const int num_variables = model_proto.variables().size(); @@ -1457,15 +1592,17 @@ void AddLPConstraints(const CpModelProto& model_proto, ModelWithMapping* m) { // Dispatch every constraint to its LinearProgrammingConstraint. std::unordered_map representative_to_lp_constraint; + std::unordered_map>> + representative_to_cp_terms; + std::vector> top_level_cp_terms; std::vector lp_constraints; - IntegerTrail* integer_trail = m->GetOrCreate(); for (int i = 0; i < num_constraints; i++) { if (constraint_has_lp_representation[i]) { const auto& ct = model_proto.constraints(i); const int id = components.GetClassRepresentative(i); if (components_to_size[id] <= 1) continue; if (!ContainsKey(representative_to_lp_constraint, id)) { - auto* lp = new LinearProgrammingConstraint(integer_trail); + auto* lp = m->model()->Create(); representative_to_lp_constraint[id] = lp; lp_constraints.push_back(lp); } @@ -1474,25 +1611,56 @@ void AddLPConstraints(const CpModelProto& model_proto, ModelWithMapping* m) { } // Add the objective. - if (model_proto.objectives_size() != 0) { - const int var = model_proto.objectives(0).objective_var(); - const int id = components.GetClassRepresentative(get_var_index(var)); - if (ContainsKey(representative_to_lp_constraint, id)) { - representative_to_lp_constraint[id]->SetObjective(m->Integer(var), true); + int num_components_containing_objective = 0; + if (model_proto.has_objective()) { + // First pass: set objective coefficients on the lp constraints, and store + // the cp terms in one vector per component. + for (int i = 0; i < model_proto.objective().coeffs_size(); ++i) { + const int var = model_proto.objective().vars(i); + const IntegerVariable cp_var = m->Integer(var); + const int64 coeff = model_proto.objective().coeffs(i); + const int id = components.GetClassRepresentative(get_var_index(var)); + if (ContainsKey(representative_to_lp_constraint, id)) { + representative_to_lp_constraint[id]->SetObjectiveCoefficient(cp_var, + coeff); + representative_to_cp_terms[id].push_back(std::make_pair(cp_var, coeff)); + } else { + // Component is too small. We still need to store the objective term. + top_level_cp_terms.push_back(std::make_pair(cp_var, coeff)); + } + } + // Second pass: Build the cp sub-objectives per component. + for (const auto& it : representative_to_cp_terms) { + const int id = it.first; + LinearProgrammingConstraint* lp = + FindOrDie(representative_to_lp_constraint, id); + const std::vector>& terms = it.second; + const IntegerVariable sub_obj_var = + GetOrCreateVariableEqualToSumOf(m->model(), terms); + top_level_cp_terms.push_back(std::make_pair(sub_obj_var, 1)); + lp->SetMainObjectiveVariable(sub_obj_var); + num_components_containing_objective++; } } - // Register LP constraints and transfer their ownership to the CP model. + const IntegerVariable main_objective_var = + GetOrCreateVariableEqualToSumOf(m->model(), top_level_cp_terms); + + // Register LP constraints. Note that this needs to be done after all the + // constraints have been added. for (auto* lp_constraint : lp_constraints) { - m->TakeOwnership(lp_constraint); lp_constraint->RegisterWith(m->GetOrCreate()); } VLOG_IF(1, !lp_constraints.empty()) << "Added " << lp_constraints.size() << " LP constraints."; + VLOG_IF(1, num_components_containing_objective > 1) + << "Objective split into " << num_components_containing_objective + << " components"; + return main_objective_var; } -// The function responsible for implementing the choosen search strategy. +// The function responsible for implementing the chosen search strategy. // // TODO(user): expose and unit-test, it seems easy to get the order wrong, and // that would not change the correctness. @@ -1590,69 +1758,22 @@ const std::function ConstructSearchStrategy( }; } -// TODO(user): Also consider linear inequality where the objective is minimized -// in the good direction. void ExtractLinearObjective(const CpModelProto& model_proto, ModelWithMapping* m, std::vector* linear_vars, std::vector* linear_coeffs) { - CHECK(!model_proto.objectives().empty()); - const CpObjectiveProto obj = model_proto.objectives(0); - const IntegerVariable objective_var = m->Integer(obj.objective_var()); - - // Default linear objective if we don't find any linear equality defining it. - *linear_vars = {objective_var}; - *linear_coeffs = {IntegerValue(1)}; - - // TODO(user): Expand the linear equation recursively in order to have - // as much term as possible? - for (const ConstraintProto& ct : model_proto.constraints()) { - // Skip everything that is not a linear equality constraint. - if (!ct.enforcement_literal().empty()) continue; - if (ct.constraint_case() != ConstraintProto::ConstraintCase::kLinear) { - continue; - } - if (ct.linear().domain().size() != 2) continue; - if (ct.linear().domain(0) != ct.linear().domain(1)) continue; - - // Find out if objective_var appear in this constraint. - bool present = false; - int64 objective_coeff; - const int num_terms = ct.linear().vars_size(); - for (int i = 0; i < num_terms; ++i) { - const int ref = ct.linear().vars(i); - const int64 coeff = ct.linear().coeffs(i); - if (PositiveRef(ref) == PositiveRef(obj.objective_var())) { - CHECK(!present) << "Duplicate variables not supported"; - present = true; - objective_coeff = (ref == obj.objective_var()) ? coeff : -coeff; - } - } - - // We use the longest equality we can find. - // TODO(user): Deal with objective_coeff with a magnitude greater than 1? - if (present && std::abs(objective_coeff) == 1 && - num_terms > linear_vars->size() + 1) { - linear_vars->clear(); - linear_coeffs->clear(); - const int64 rhs = ct.linear().domain(0); - if (rhs != 0) { - linear_vars->push_back(m->Add(NewIntegerVariable(rhs, rhs))); - linear_coeffs->push_back(IntegerValue(objective_coeff == 1 ? 1 : -1)); - } - for (int i = 0; i < num_terms; ++i) { - const int ref = ct.linear().vars(i); - if (PositiveRef(ref) != PositiveRef(obj.objective_var())) { - linear_vars->push_back(m->Integer(ref)); - const IntegerValue coeff(ct.linear().coeffs(i)); - linear_coeffs->push_back(objective_coeff == 1 ? -coeff : coeff); - } - } - } + CHECK(model_proto.has_objective()); + const CpObjectiveProto& obj = model_proto.objective(); + linear_vars->reserve(obj.vars_size()); + linear_coeffs->reserve(obj.vars_size()); + for (int i = 0; i < obj.vars_size(); ++i) { + linear_vars->push_back(m->Integer(obj.vars(i))); + linear_coeffs->push_back(IntegerValue(obj.coeffs(i))); } } CpSolverResponse SolveCpModelInternal(const CpModelProto& model_proto, + bool display_fixing_constraints, Model* model) { // Timing. WallTimer wall_timer; @@ -1667,7 +1788,7 @@ CpSolverResponse SolveCpModelInternal(const CpModelProto& model_proto, // We will add them all at once after model_proto is loaded. model->GetOrCreate()->DisableImplicationBetweenLiteral(); - // Instanciate all the needed variables. + // Instantiate all the needed variables. const VariableUsage usage = ComputeVariableUsage(model_proto); ModelWithMapping m(model_proto, usage, model); @@ -1698,10 +1819,10 @@ CpSolverResponse SolveCpModelInternal(const CpModelProto& model_proto, } // We propagate after each new Boolean constraint but not the integer - // ones. So we call Propagate() manually here. TODO(user): Do that - // automatically? + // ones. So we call Propagate() manually here. + // TODO(user): Do that automatically? model->GetOrCreate()->Propagate(); - if (trail->Index() > old_num_fixed) { + if (display_fixing_constraints && trail->Index() > old_num_fixed) { VLOG(1) << "Constraint fixed " << trail->Index() - old_num_fixed << " Boolean variable(s): " << ct.DebugString(); } @@ -1723,9 +1844,21 @@ CpSolverResponse SolveCpModelInternal(const CpModelProto& model_proto, return response; } - // Linearize some part of the problem and register LP constraint(s). + // Create an objective variable and its associated linear constraint if + // needed. + IntegerVariable objective_var = kNoIntegerVariable; + if (parameters.use_global_lp_constraint()) { - AddLPConstraints(model_proto, &m); + // Linearize some part of the problem and register LP constraint(s). + objective_var = AddLPConstraints(model_proto, &m); + } else if (model_proto.has_objective()) { + const CpObjectiveProto& obj = model_proto.objective(); + std::vector> terms; + terms.reserve(obj.vars_size()); + for (int i = 0; i < obj.vars_size(); ++i) { + terms.push_back(std::make_pair(m.Integer(obj.vars(i)), obj.coeffs(i))); + } + objective_var = GetOrCreateVariableEqualToSumOf(m.model(), terms); } model->GetOrCreate() @@ -1736,11 +1869,10 @@ CpSolverResponse SolveCpModelInternal(const CpModelProto& model_proto, if (model_proto.search_strategy().empty()) { std::vector decisions; for (const int i : usage.integers) { - if (!model_proto.objectives().empty()) { + if (model_proto.has_objective()) { // Make sure we try to fix the objective to its lowest value first. - const int obj = model_proto.objectives(0).objective_var(); - if (PositiveRef(i) == PositiveRef(obj)) { - decisions.push_back(m.Integer(obj)); + if (m.Integer(i) == NegationOf(objective_var)) { + decisions.push_back(objective_var); continue; } } @@ -1773,7 +1905,7 @@ CpSolverResponse SolveCpModelInternal(const CpModelProto& model_proto, // Solve. int num_solutions = 0; SatSolver::Status status; - if (model_proto.objectives_size() == 0) { + if (!model_proto.has_objective()) { status = SolveIntegerProblemWithLazyEncoding( /*assumptions=*/{}, next_decision, model); if (status == SatSolver::MODEL_SAT) { @@ -1781,9 +1913,7 @@ CpSolverResponse SolveCpModelInternal(const CpModelProto& model_proto, } } else { // Optimization problem. - CHECK_EQ(model_proto.objectives_size(), 1); - const CpObjectiveProto obj = model_proto.objectives(0); - const IntegerVariable objective_var = m.Integer(obj.objective_var()); + const CpObjectiveProto& obj = model_proto.objective(); const auto solution_observer = [&model_proto, &response, &num_solutions, &obj, &m, objective_var](const Model& sat_model) { @@ -1800,15 +1930,19 @@ CpSolverResponse SolveCpModelInternal(const CpModelProto& model_proto, std::vector linear_vars; std::vector linear_coeffs; ExtractLinearObjective(model_proto, &m, &linear_vars, &linear_coeffs); +#if defined(USE_CBC) || defined(USE_SCIPe) if (parameters.optimize_with_max_hs()) { status = MinimizeWithHittingSetAndLazyEncoding( VLOG_IS_ON(1), objective_var, linear_vars, linear_coeffs, next_decision, solution_observer, model); } else { +#endif // defined(USE_CBC) || defined(USE_SCIPe) status = MinimizeWithCoreAndLazyEncoding( VLOG_IS_ON(1), objective_var, linear_vars, linear_coeffs, next_decision, solution_observer, model); +#if defined(USE_CBC) || defined(USE_SCIPe) } +#endif // defined(USE_CBC) || defined(USE_SCIPe) } else { status = MinimizeIntegerVariableWithLinearScanAndLazyEncoding( /*log_info=*/false, objective_var, next_decision, solution_observer, @@ -1837,7 +1971,7 @@ CpSolverResponse SolveCpModelInternal(const CpModelProto& model_proto, break; } case SatSolver::MODEL_SAT: { - response.set_status(model_proto.objectives_size() != 0 + response.set_status(model_proto.has_objective() ? CpSolverStatus::OPTIMAL : CpSolverStatus::MODEL_SAT); break; @@ -1880,7 +2014,7 @@ CpSolverResponse SolveCpModelWithoutPresolve(const CpModelProto& model_proto, return response; } } - return SolveCpModelInternal(model_proto, model); + return SolveCpModelInternal(model_proto, true, model); } CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { @@ -1905,7 +2039,7 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { VLOG(1) << CpModelStats(presolved_proto); CpSolverResponse response = - SolveCpModelWithoutPresolve(presolved_proto, model); + SolveCpModelInternal(presolved_proto, true, model); if (response.status() != CpSolverStatus::MODEL_SAT && response.status() != CpSolverStatus::OPTIMAL) { return response; @@ -1937,7 +2071,7 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { postsolve_model.Add(operations_research::sat::NewSatParameters(params)); } const CpSolverResponse postsolve_response = - SolveCpModelInternal(mapping_proto, &postsolve_model); + SolveCpModelInternal(mapping_proto, false, &postsolve_model); CHECK_EQ(postsolve_response.status(), CpSolverStatus::MODEL_SAT); response.clear_solution(); response.clear_solution_lower_bounds(); diff --git a/ortools/sat/drat.cc b/ortools/sat/drat.cc index c07f90ccce..611577b3b4 100644 --- a/ortools/sat/drat.cc +++ b/ortools/sat/drat.cc @@ -13,13 +13,8 @@ #include "ortools/sat/drat.h" -#include "ortools/base/commandlineflags.h" #include "ortools/base/stringprintf.h" -DEFINE_string( - drat_output, "", - "If non-empty, a proof in DRAT format will be written to this file."); - namespace operations_research { namespace sat { @@ -30,16 +25,6 @@ DratWriter::~DratWriter() { } } -// static -DratWriter* DratWriter::CreateInModel(Model* model) { - if (FLAGS_drat_output.empty()) return nullptr; - File* output; - CHECK_OK(file::Open(FLAGS_drat_output, "w", &output, file::Defaults())); - DratWriter* drat_writer = new DratWriter(/*in_binary_format=*/false, output); - model->TakeOwnership(drat_writer); - return drat_writer; -} - void DratWriter::ApplyMapping( const ITIVector& mapping) { ITIVector new_mapping; diff --git a/ortools/sat/drat.h b/ortools/sat/drat.h index c76873c534..195d2539eb 100644 --- a/ortools/sat/drat.h +++ b/ortools/sat/drat.h @@ -35,10 +35,6 @@ class DratWriter { output_(output) {} ~DratWriter(); - // This tries to open the FLAGS_drat_file file and if it succeed it will - // return a non-nullptr DratWriter class. - static DratWriter* CreateInModel(Model* model); - // During the presolve step, variable get deleted and the set of non-deleted // variable is remaped in a dense set. This allows to keep track of that and // always output the DRAT clauses in term of the original variables. diff --git a/ortools/sat/integer.cc b/ortools/sat/integer.cc index 31d1fd2ae4..d644a360f2 100644 --- a/ortools/sat/integer.cc +++ b/ortools/sat/integer.cc @@ -41,14 +41,10 @@ void IntegerEncoder::FullyEncodeVariable(IntegerVariable var) { } } - // TODO(user): This case is annoying, so for now we want the caller to deal - // with it, hence the CHECK. We do not want to create a fixed Boolean - // variable, but we also do not want to complexify the API of - // FullDomainEncoding(). - CHECK_NE(values.size(), 1); - std::vector literals; - if (values.size() == 2) { + if (values.size() == 1) { + literals.push_back(GetLiteralTrue()); + } else if (values.size() == 2) { literals.push_back(GetOrCreateAssociatedLiteral( IntegerLiteral::LowerOrEqual(var, values[0]))); literals.push_back(literals.back().Negated()); @@ -70,7 +66,12 @@ void IntegerEncoder::FullyEncodeVariableUsingGivenLiterals( const std::vector& values) { CHECK(!VariableIsFullyEncoded(var)); CHECK(!literals.empty()); - CHECK_NE(literals.size(), 1); + + if (literals.size() == 1) { + full_encoding_index_[var] = full_encoding_.size(); + full_encoding_.push_back({ValueLiteralPair(values[0], literals[0])}); + return; + } // Sort the literals by values. std::vector encoding; @@ -1031,11 +1032,22 @@ void IntegerTrail::EnqueueLiteral( trail_->Enqueue(literal, propagator_id_); } -GenericLiteralWatcher::GenericLiteralWatcher( - IntegerTrail* integer_trail, RevRepository* rev_int_repository) +GenericLiteralWatcher::GenericLiteralWatcher(Model* model) : SatPropagator("GenericLiteralWatcher"), - integer_trail_(integer_trail), - rev_int_repository_(rev_int_repository) { + integer_trail_(model->GetOrCreate()) { + // TODO(user): Have a general mecanism to register "global" reversible + // classes and keep them synchronized with the search. + std::unique_ptr> rev_int_repository( + new RevRepository()); + rev_int_repository_ = rev_int_repository.get(); + model->SetSingleton(std::move(rev_int_repository)); + + // TODO(user): This propagator currently needs to be last because it is the + // only one enforcing that a fix-point is reached on the integer variables. + // Figure out a better interaction between the sat propagation loop and + // this one. + model->GetOrCreate()->AddLastPropagator(this); + integer_trail_->RegisterWatcher(&modified_vars_); queue_by_priority_.resize(2); // Because default priority is 1. } diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index 934d0dfd81..33a518be8d 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -157,11 +157,7 @@ using InlinedIntegerLiteralVector = gtl::InlinedVector; struct IntegerDomains : public ITIVector> { - static IntegerDomains* CreateInModel(Model* model) { - IntegerDomains* domains = new IntegerDomains(); - model->TakeOwnership(domains); - return domains; - } + explicit IntegerDomains(Model* model) {} }; // Each integer variable x will be associated with a set of literals encoding @@ -184,20 +180,15 @@ struct IntegerDomains // though. class IntegerEncoder { public: - IntegerEncoder(SatSolver* sat_solver, IntegerDomains* domains) - : sat_solver_(sat_solver), domains_(domains), num_created_variables_(0) {} + explicit IntegerEncoder(Model* model) + : sat_solver_(model->GetOrCreate()), + domains_(model->GetOrCreate()), + num_created_variables_(0) {} ~IntegerEncoder() { VLOG(1) << "#variables created = " << num_created_variables_; } - static IntegerEncoder* CreateInModel(Model* model) { - IntegerEncoder* encoder = new IntegerEncoder( - model->GetOrCreate(), model->GetOrCreate()); - model->TakeOwnership(encoder); - return encoder; - } - // Fully encode a variable using its current initial domain. // This can be called only once. // @@ -332,6 +323,18 @@ class IntegerEncoder { // Adds the implications: Literal(before) <= associated_lit <= Literal(after). void AddImplications(IntegerLiteral i, Literal associated_lit); + // Get the literal always set to true, make it if it does not exist. + Literal GetLiteralTrue() { + DCHECK_EQ(0, sat_solver_->CurrentDecisionLevel()); + if (literal_index_true_ == kNoLiteralIndex) { + const Literal literal_true = + Literal(sat_solver_->NewBooleanVariable(), true); + literal_index_true_ = literal_true.Index(); + sat_solver_->AddUnitClause(literal_true); + } + return Literal(literal_index_true_); + } + SatSolver* sat_solver_; IntegerDomains* domains_; @@ -359,6 +362,10 @@ class IntegerEncoder { std::unordered_map full_encoding_index_; std::vector> full_encoding_; + // A literal that is always true, convenient to encode trivial domains. + // This will be lazily created when needed. + LiteralIndex literal_index_true_ = kNoLiteralIndex; + DISALLOW_COPY_AND_ASSIGN(IntegerEncoder); }; @@ -367,23 +374,15 @@ class IntegerEncoder { // to maintain the reason for each propagation. class IntegerTrail : public SatPropagator { public: - IntegerTrail(IntegerDomains* domains, IntegerEncoder* encoder, Trail* trail) + explicit IntegerTrail(Model* model) : SatPropagator("IntegerTrail"), num_enqueues_(0), - domains_(domains), - encoder_(encoder), - trail_(trail) {} - ~IntegerTrail() final {} - - static IntegerTrail* CreateInModel(Model* model) { - IntegerDomains* domains = model->GetOrCreate(); - IntegerEncoder* encoder = model->GetOrCreate(); - Trail* trail = model->GetOrCreate(); - IntegerTrail* integer_trail = new IntegerTrail(domains, encoder, trail); - model->GetOrCreate()->AddPropagator( - std::unique_ptr(integer_trail)); - return integer_trail; + domains_(model->GetOrCreate()), + encoder_(model->GetOrCreate()), + trail_(model->GetOrCreate()) { + model->GetOrCreate()->AddPropagator(this); } + ~IntegerTrail() final {} // SatPropagator interface. These functions make sure the current bounds // information is in sync with the current solver literal trail. Any @@ -709,28 +708,9 @@ class PropagatorInterface { // TODO(user): Move this to its own file. Add unit tests! class GenericLiteralWatcher : public SatPropagator { public: - explicit GenericLiteralWatcher(IntegerTrail* trail, - RevRepository* rev_int_repository); + explicit GenericLiteralWatcher(Model* model); ~GenericLiteralWatcher() final {} - static GenericLiteralWatcher* CreateInModel(Model* model) { - // TODO(user): Have a general mecanism to register "global" reversible - // classes and keep them synchronized with the search. - std::unique_ptr> rev_int_repository( - new RevRepository()); - GenericLiteralWatcher* watcher = new GenericLiteralWatcher( - model->GetOrCreate(), rev_int_repository.get()); - model->SetSingleton(std::move(rev_int_repository)); - - // TODO(user): This propagator currently needs to be last because it is the - // only one enforcing that a fix-point is reached on the integer variables. - // Figure out a better interaction between the sat propagation loop and - // this one. - model->GetOrCreate()->AddLastPropagator( - std::unique_ptr(watcher)); - return watcher; - } - // On propagate, the registered propagators will be called if they need to // until a fixed point is reached. Propagators with low ids will tend to be // called first, but it ultimately depends on their "waking" order. @@ -955,12 +935,6 @@ class LiteralViews { public: explicit LiteralViews(Model* model) : model_(model) {} - static LiteralViews* CreateInModel(Model* model) { - LiteralViews* const views = new LiteralViews(model); - model->TakeOwnership(views); - return views; - } - IntegerVariable GetIntegerView(const Literal lit) { const LiteralIndex index = lit.Index(); @@ -1054,13 +1028,6 @@ inline std::function Equality(IntegerVariable v, int64 value) { }; } -// Associate the given literal to the given integer inequality. -inline std::function Equality(IntegerLiteral i, Literal l) { - return [=](Model* model) { - model->GetOrCreate()->AssociateToIntegerLiteral(l, i); - }; -} - // TODO(user): This is one of the rare case where it is better to use Equality() // rather than two Implications(). Maybe we should modify our internal // implementation to use half-reified encoding? that is do not propagate the @@ -1077,7 +1044,7 @@ inline std::function Implication(Literal l, IntegerLiteral i) { model->Add(ClauseConstraint({l.Negated()})); } else { // TODO(user): Double check what happen when we associate a trivially - // true or false literal. This applies to Equality() too. + // true or false literal. IntegerEncoder* encoder = model->GetOrCreate(); const Literal current = encoder->GetOrCreateAssociatedLiteral(i); model->Add(Implication(l, current)); @@ -1085,27 +1052,15 @@ inline std::function Implication(Literal l, IntegerLiteral i) { }; } -// in_interval <=> v in [lb, ub]. -inline std::function ReifiedInInterval(IntegerVariable v, - int64 lb, int64 ub, - Literal in_interval) { +// in_interval => v in [lb, ub]. +inline std::function ImpliesInInterval(Literal in_interval, + IntegerVariable v, + int64 lb, int64 ub) { return [=](Model* model) { - IntegerEncoder* encoder = model->GetOrCreate(); - const auto lb_lit = IntegerLiteral::GreaterOrEqual(v, IntegerValue(lb)); - const auto ub_lit = IntegerLiteral::LowerOrEqual(v, IntegerValue(ub)); - if (lb <= model->Get(LowerBound(v))) { - if (ub >= model->Get(UpperBound(v))) { - model->GetOrCreate()->AddUnitClause(in_interval); - } else { - model->Add(Equality(ub_lit, in_interval)); - } - } else if (ub >= model->Get(UpperBound(v))) { - model->Add(Equality(lb_lit, in_interval)); - } else { - const Literal is_ge_lb = encoder->GetOrCreateAssociatedLiteral(lb_lit); - const Literal is_le_ub = encoder->GetOrCreateAssociatedLiteral(ub_lit); - model->Add(ReifiedBoolAnd({is_ge_lb, is_le_ub}, in_interval)); - } + model->Add(Implication( + in_interval, IntegerLiteral::GreaterOrEqual(v, IntegerValue(lb)))); + model->Add(Implication(in_interval, + IntegerLiteral::LowerOrEqual(v, IntegerValue(ub)))); }; } diff --git a/ortools/sat/intervals.h b/ortools/sat/intervals.h index f2b709c17d..3e15e349f9 100644 --- a/ortools/sat/intervals.h +++ b/ortools/sat/intervals.h @@ -34,17 +34,9 @@ const IntervalVariable kNoIntervalVariable(-1); // provides many helper functions to add precedences relation between intervals. class IntervalsRepository { public: - IntervalsRepository(IntegerTrail* integer_trail, - PrecedencesPropagator* precedences) - : integer_trail_(integer_trail), precedences_(precedences) {} - - static IntervalsRepository* CreateInModel(Model* model) { - IntervalsRepository* intervals = - new IntervalsRepository(model->GetOrCreate(), - model->GetOrCreate()); - model->TakeOwnership(intervals); - return intervals; - } + explicit IntervalsRepository(Model* model) + : integer_trail_(model->GetOrCreate()), + precedences_(model->GetOrCreate()) {} // Returns the current number of intervals in the repository. // The interval will always be identified by an integer in [0, num_intervals). diff --git a/ortools/sat/linear_programming_constraint.cc b/ortools/sat/linear_programming_constraint.cc index 6a447779a5..50b11ec0dd 100644 --- a/ortools/sat/linear_programming_constraint.cc +++ b/ortools/sat/linear_programming_constraint.cc @@ -27,9 +27,16 @@ namespace sat { const double LinearProgrammingConstraint::kEpsilon = 1e-6; -LinearProgrammingConstraint::LinearProgrammingConstraint( - IntegerTrail* integer_trail) - : integer_trail_(integer_trail) { +LinearProgrammingConstraint::LinearProgrammingConstraint(Model* model) + : integer_trail_(model->GetOrCreate()) { + // TODO(user): Find a way to make GetOrCreate() construct it by + // default. + time_limit_ = model->Mutable(); + if (time_limit_ == nullptr) { + model->SetSingleton(TimeLimit::Infinite()); + time_limit_ = model->Mutable(); + } + if (!FLAGS_lp_constraint_use_dual_ray) { // The violation_sum_ variable will be the sum of constraints' violation. violation_sum_constraint_ = lp_data_.CreateNewConstraint(); @@ -74,27 +81,27 @@ void LinearProgrammingConstraint::SetCoefficient(ConstraintIndex ct, lp_data_.SetCoefficient(ct, cvar, coefficient); } -void LinearProgrammingConstraint::SetObjective(IntegerVariable ivar, - bool is_minimization) { +void LinearProgrammingConstraint::SetObjectiveCoefficient(IntegerVariable ivar, + double coeff) { CHECK(!lp_constraint_is_registered_); - CHECK(!objective_is_defined_) << "Objective was set more than once."; objective_is_defined_ = true; - objective_cp_ = ivar; - objective_lp_ = GetOrCreateMirrorVariable(ivar); - objective_is_minimization_ = is_minimization; + objective_lp_.push_back( + std::make_pair(GetOrCreateMirrorVariable(ivar), coeff)); } void LinearProgrammingConstraint::RegisterWith(GenericLiteralWatcher* watcher) { DCHECK(!lp_constraint_is_registered_); lp_constraint_is_registered_ = true; - lp_data_.Scale(&scaler_); - - // Note that we set the objective AFTER the scaling. + // Note that the order is important so that the lp objective is exactly + // lp_to_cp_objective_scale_ times the cp one. if (objective_is_defined_) { - lp_data_.SetObjectiveCoefficient(objective_lp_, 1.0); - lp_data_.SetMaximizationProblem(!objective_is_minimization_); + for (const auto& var_coeff : objective_lp_) { + lp_data_.SetObjectiveCoefficient(var_coeff.first, var_coeff.second); + } } + lp_data_.Scale(&scaler_); + lp_to_cp_objective_scale_ = lp_data_.ScaleObjective(); if (!FLAGS_lp_constraint_use_dual_ray) { // Add all the individual violation variables. Note that it is important @@ -175,15 +182,17 @@ bool LinearProgrammingConstraint::Propagate() { if (!FLAGS_lp_constraint_use_dual_ray) { if (objective_is_defined_) { - lp_data_.SetObjectiveCoefficient(objective_lp_, 0.0); + for (auto& var_coeff : objective_lp_) { + lp_data_.SetObjectiveCoefficient(var_coeff.first, 0.0); + } } lp_data_.SetObjectiveCoefficient(violation_sum_, 1.0); + lp_data_.SetObjectiveScalingFactor(1.0); lp_data_.SetVariableBounds(violation_sum_, 0.0, std::numeric_limits::infinity()); - lp_data_.SetMaximizationProblem(false); // Feasibility deductions. - const auto status = simplex_.Solve(lp_data_, TimeLimit::Infinite().get()); + const auto status = simplex_.Solve(lp_data_, time_limit_); CHECK(status.ok()) << "LinearProgrammingConstraint encountered an error: " << status.error_message(); CHECK_EQ(simplex_.GetProblemStatus(), glop::ProblemStatus::OPTIMAL) @@ -191,14 +200,14 @@ bool LinearProgrammingConstraint::Propagate() { << simplex_.GetProblemStatus(); if (simplex_.GetVariableValue(violation_sum_) > kEpsilon) { // infeasible. - FillIntegerReason(1.0); + FillReducedCostsReason(); return integer_trail_->ReportConflict(integer_reason_); } // Reduced cost strengthening for feasibility. - ReducedCostStrengtheningDeductions(1.0, 0.0); + ReducedCostStrengtheningDeductions(0.0); if (!deductions_.empty()) { - FillIntegerReason(1.0); + FillReducedCostsReason(); for (const IntegerLiteral deduction : deductions_) { if (!integer_trail_->Enqueue(deduction, {}, integer_reason_)) { return false; @@ -210,8 +219,12 @@ bool LinearProgrammingConstraint::Propagate() { lp_data_.SetVariableBounds(violation_sum_, 0.0, 0.0); lp_data_.SetObjectiveCoefficient(violation_sum_, 0.0); if (objective_is_defined_) { - lp_data_.SetObjectiveCoefficient(objective_lp_, 1.0); - lp_data_.SetMaximizationProblem(!objective_is_minimization_); + for (auto& var_coeff : objective_lp_) { + const glop::ColIndex col = var_coeff.first; + lp_data_.SetObjectiveCoefficient( + col, var_coeff.second * scaler_.col_scale(col)); + } + lp_to_cp_objective_scale_ = lp_data_.ScaleObjective(); } for (int i = 0; i < num_vars; i++) { lp_solution_[i] = GetVariableValueAtCpScale(mirror_lp_variables_[i]); @@ -222,7 +235,7 @@ bool LinearProgrammingConstraint::Propagate() { return true; } - const auto status = simplex_.Solve(lp_data_, TimeLimit::Infinite().get()); + const auto status = simplex_.Solve(lp_data_, time_limit_); CHECK(status.ok()) << "LinearProgrammingConstraint encountered an error: " << status.error_message(); @@ -235,56 +248,31 @@ bool LinearProgrammingConstraint::Propagate() { // Optimality deductions if problem has an objective. if (objective_is_defined_ && simplex_.GetProblemStatus() == glop::ProblemStatus::OPTIMAL) { - const double objective_cp_lb = - static_cast(integer_trail_->LowerBound(objective_cp_).value()); - const double objective_cp_ub = - static_cast(integer_trail_->UpperBound(objective_cp_).value()); - - // Try to filter optimal objective value. - const double objective_value = GetVariableValueAtCpScale(objective_lp_); - if (objective_is_minimization_) { - const double new_lb = std::ceil(objective_value - kEpsilon); - if (objective_cp_lb < new_lb) { - const IntegerValue new_int_lb(static_cast(new_lb)); - FillIntegerReason(1.0); - const IntegerLiteral deduction = - IntegerLiteral::GreaterOrEqual(objective_cp_, new_int_lb); - if (!integer_trail_->Enqueue(deduction, {}, integer_reason_)) { - return false; - } - } - } else { - const double new_ub = std::floor(objective_value + kEpsilon); - if (objective_cp_ub > new_ub) { - const IntegerValue new_int_ub(static_cast(new_ub)); - FillIntegerReason(-1.0); - const IntegerLiteral deduction = - IntegerLiteral::LowerOrEqual(objective_cp_, new_int_ub); - if (!integer_trail_->Enqueue(deduction, {}, integer_reason_)) { - return false; - } + // Try to filter optimal objective value. Note that GetObjectiveValue() + // already take care of the scaling so that it returns an objective in the + // CP world. + const double relaxed_optimal_objective = simplex_.GetObjectiveValue(); + const IntegerValue old_lb = integer_trail_->LowerBound(objective_cp_); + const IntegerValue new_lb( + static_cast(std::ceil(relaxed_optimal_objective - kEpsilon))); + if (old_lb < new_lb) { + FillReducedCostsReason(); + const IntegerLiteral deduction = + IntegerLiteral::GreaterOrEqual(objective_cp_, new_lb); + if (!integer_trail_->Enqueue(deduction, {}, integer_reason_)) { + return false; } } // Reduced cost strengthening. - const double objective_slack = objective_is_minimization_ - ? objective_cp_ub - objective_value - : objective_value - objective_cp_lb; - const double objective_direction = objective_is_minimization_ ? 1.0 : -1.0; - ReducedCostStrengtheningDeductions( - objective_direction, - objective_slack * scaler_.col_scale(objective_lp_)); - + const double objective_cp_ub = + static_cast(integer_trail_->UpperBound(objective_cp_).value()); + ReducedCostStrengtheningDeductions(objective_cp_ub - + relaxed_optimal_objective); if (!deductions_.empty()) { - FillIntegerReason(objective_direction); - - // Add the opposite bound of the variable used for strengthening. - const IntegerLiteral opposite_bound = - objective_is_minimization_ - ? integer_trail_->UpperBoundAsLiteral(objective_cp_) - : integer_trail_->LowerBoundAsLiteral(objective_cp_); - integer_reason_.push_back(opposite_bound); - + FillReducedCostsReason(); + integer_reason_.push_back( + integer_trail_->UpperBoundAsLiteral(objective_cp_)); for (const IntegerLiteral deduction : deductions_) { if (!integer_trail_->Enqueue(deduction, {}, integer_reason_)) { return false; @@ -301,9 +289,8 @@ bool LinearProgrammingConstraint::Propagate() { return true; } -void LinearProgrammingConstraint::FillIntegerReason(double direction) { +void LinearProgrammingConstraint::FillReducedCostsReason() { integer_reason_.clear(); - const int num_vars = integer_variables_.size(); for (int i = 0; i < num_vars; i++) { // TODO(user): try to extend the bounds that are put in the @@ -312,8 +299,7 @@ void LinearProgrammingConstraint::FillIntegerReason(double direction) { // feasible? If the violation minimum is 10 and a variable has rc 1, // then decreasing it by 9 would still leave the problem infeasible. // Using this could allow to generalize some explanations. - const double rc = - simplex_.GetReducedCost(mirror_lp_variables_[i]) * direction; + const double rc = simplex_.GetReducedCost(mirror_lp_variables_[i]); if (rc > kEpsilon) { integer_reason_.push_back( integer_trail_->LowerBoundAsLiteral(integer_variables_[i])); @@ -326,10 +312,9 @@ void LinearProgrammingConstraint::FillIntegerReason(double direction) { void LinearProgrammingConstraint::FillDualRayReason() { integer_reason_.clear(); - const int num_vars = integer_variables_.size(); for (int i = 0; i < num_vars; i++) { - // TODO(user): Like for FillIntegerReason(), the bounds could be + // TODO(user): Like for FillReducedCostsReason(), the bounds could be // extended here. Actually, the "dual ray cost updates" is the reduced cost // of an optimal solution if we were optimizing one direction of one basic // variable. The simplex_ interface would need to be slightly extended to @@ -347,14 +332,19 @@ void LinearProgrammingConstraint::FillDualRayReason() { } void LinearProgrammingConstraint::ReducedCostStrengtheningDeductions( - double direction, double lp_objective_delta) { + double cp_objective_delta) { deductions_.clear(); + // TRICKY: while simplex_.GetObjectiveValue() use the objective scaling factor + // stored in the lp_data_, all the other functions like GetReducedCost() or + // GetVariableValue() do not. + const double lp_objective_delta = + cp_objective_delta / lp_to_cp_objective_scale_; const int num_vars = integer_variables_.size(); for (int i = 0; i < num_vars; i++) { const IntegerVariable cp_var = integer_variables_[i]; const glop::ColIndex lp_var = mirror_lp_variables_[i]; - const double rc = simplex_.GetReducedCost(lp_var) * direction; + const double rc = simplex_.GetReducedCost(lp_var); const double value = simplex_.GetVariableValue(lp_var); const double lp_other_bound = value + lp_objective_delta / rc; const double cp_other_bound = lp_other_bound / scaler_.col_scale(lp_var); diff --git a/ortools/sat/linear_programming_constraint.h b/ortools/sat/linear_programming_constraint.h index 3be12a88f2..b8c256edd3 100644 --- a/ortools/sat/linear_programming_constraint.h +++ b/ortools/sat/linear_programming_constraint.h @@ -22,6 +22,7 @@ #include "ortools/sat/integer.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" +#include "ortools/util/time_limit.h" namespace operations_research { namespace sat { @@ -62,34 +63,25 @@ class LinearProgrammingConstraint : public PropagatorInterface { public: typedef glop::RowIndex ConstraintIndex; - explicit LinearProgrammingConstraint(IntegerTrail* integer_trail); - - // Creates a LinearProgrammingConstraint for templated GetOrCreate idiom. - static LinearProgrammingConstraint* CreateInModel(Model* model) { - IntegerTrail* trail = model->GetOrCreate(); - LinearProgrammingConstraint* constraint = - new LinearProgrammingConstraint(trail); - model->TakeOwnership(constraint); - return constraint; - } + explicit LinearProgrammingConstraint(Model* model); // User API, see header description. ConstraintIndex CreateNewConstraint(double lb, double ub); - void SetCoefficient(ConstraintIndex ct, IntegerVariable ivar, - double coefficient); - // TODO(user): Allow Literals to appear in linear constraints. // TODO(user): Calling SetCoefficient() twice on the same // (constraint, variable) pair will overwrite coefficients where accumulating // them might be desired, this is a common mistake, change API. + void SetCoefficient(ConstraintIndex ct, IntegerVariable ivar, + double coefficient); - // Objective may or may not be defined. It can be defined only once, - // must be exactly one IntegerVariable, and can be either - // minimized (is_minimization = true) or maximized (is_minimization = false). - // TODO(user): change API for always minimization, so that - // maximization(var) = minimization(Negation(var)). - void SetObjective(IntegerVariable ivar, bool is_minimization); + // Set the coefficient of the variable in the objective. Calling it twice will + // overwrite the previous value. + void SetObjectiveCoefficient(IntegerVariable ivar, double coeff); + + // The main objective variable should be equal to the linear sum of + // the arguments passed to SetObjectiveCoefficient(). + void SetMainObjectiveVariable(IntegerVariable ivar) { objective_cp_ = ivar; } // PropagatorInterface API. bool Propagate() override; @@ -101,22 +93,18 @@ class LinearProgrammingConstraint : public PropagatorInterface { private: // Generates a set of IntegerLiterals explaining why the best solution can not // be improved using reduced costs. This is used to generate explanations for - // both infeasibility and bounds deductions. The direction variable should be - // 1.0 if the last Solve() was a minimization, -1.0 if it was a maximization. - void FillIntegerReason(double direction); + // both infeasibility and bounds deductions. + void FillReducedCostsReason(); - // Same as FillIntegerReason() but for the case of a DUAL_UNBOUNDED problem. - // This exploit the dual ray as a reason for the primal infeasiblity. + // Same as FillReducedCostReason() but for the case of a DUAL_UNBOUNDED + // problem. This exploit the dual ray as a reason for the primal infeasiblity. void FillDualRayReason(); // Fills the deductions vector with reduced cost deductions that can be made - // from the current state of the LP solver. This should be called after - // Solve(): if the optimization was a minimization, the direction variable - // should be 1.0 and lp_objective_delta the objective's upper bound minus the - // optimal; if the optimization was a maximization, direction should be -1.0 - // and lp_objective_delta the optimal minus the objective's lower bound. - void ReducedCostStrengtheningDeductions(double direction, - double lp_objective_delta); + // from the current state of the LP solver. The given delta should be the + // difference between the cp objective upper bound and lower bound given by + // the lp. + void ReducedCostStrengtheningDeductions(double cp_objective_delta); // Gets or creates an LP variable that mirrors a CP variable. // TODO(user): only accept positive variables to prevent having different @@ -135,6 +123,7 @@ class LinearProgrammingConstraint : public PropagatorInterface { // For the scaling. glop::SparseMatrixScaler scaler_; + glop::Fractional lp_to_cp_objective_scale_; // violation_sum_ is used to simulate phase I of the simplex and be able to // do reduced cost strengthening on problem feasibility by using the sum of @@ -155,8 +144,7 @@ class LinearProgrammingConstraint : public PropagatorInterface { // then we will switch the objective between feasibility and optimization. bool objective_is_defined_ = false; IntegerVariable objective_cp_; - glop::ColIndex objective_lp_; - bool objective_is_minimization_; + std::vector> objective_lp_; // Structures for propagators. IntegerTrail* integer_trail_; @@ -170,6 +158,9 @@ class LinearProgrammingConstraint : public PropagatorInterface { // Linear constraints cannot be created or modified after this is registered. bool lp_constraint_is_registered_ = false; + + // Time limit (shared with, owned by the sat solver). + TimeLimit* time_limit_; }; } // namespace sat diff --git a/ortools/sat/lp_utils.cc b/ortools/sat/lp_utils.cc index 88f439da36..82810e9f16 100644 --- a/ortools/sat/lp_utils.cc +++ b/ortools/sat/lp_utils.cc @@ -205,44 +205,21 @@ bool ConvertMPModelProtoToCpModelProto(const MPModelProto& mp_model, // Note that here we set the scaling factor for the inverse operation of // getting the "true" objective value from the scaled one. Hence the // inverse. - auto* objective = cp_model->add_objectives(); - objective->set_offset(mp_model.objective_offset() * scaling_factor / gcd); - objective->set_scaling_factor(1.0 / scaling_factor * gcd); - objective->set_objective_var(cp_model->variables_size()); - { - auto* objective_var = cp_model->add_variables(); - objective_var->set_name("objective"); - objective_var->add_domain(-kMaxObjective); - objective_var->add_domain(kMaxObjective); - } - - // Link the objective variable with a linear constraint. - { - auto* objective_constraint = cp_model->add_constraints(); - auto* objective_arg = objective_constraint->mutable_linear(); - objective_constraint->set_name("objective"); - objective_arg->add_domain(0); - objective_arg->add_domain(0); - for (int i = 0; i < num_variables; ++i) { - const MPVariableProto& mp_var = mp_model.variable(i); - const int64 value = - static_cast( - std::round(mp_var.objective_coefficient() * scaling_factor)) / - gcd; - if (value != 0) { - objective_arg->add_vars(i); - objective_arg->add_coeffs(value); - } + auto* objective = cp_model->mutable_objective(); + const int mult = mp_model.maximize() ? -1 : 1; + objective->set_offset(mp_model.objective_offset() * scaling_factor / gcd * + mult); + objective->set_scaling_factor(1.0 / scaling_factor * gcd * mult); + for (int i = 0; i < num_variables; ++i) { + const MPVariableProto& mp_var = mp_model.variable(i); + const int64 value = + static_cast( + std::round(mp_var.objective_coefficient() * scaling_factor)) / + gcd; + if (value != 0) { + objective->add_vars(i); + objective->add_coeffs(value * mult); } - objective_arg->add_vars(objective->objective_var()); - objective_arg->add_coeffs(-1); - } - - // If the problem was a maximization one, we need to modify the objective. - if (mp_model.maximize()) { - objective->set_objective_var(-objective->objective_var() - 1); - objective->set_scaling_factor(-objective->scaling_factor()); - objective->set_offset(-objective->offset()); } } diff --git a/ortools/sat/model.h b/ortools/sat/model.h index ad4aae0223..60d8b1caec 100644 --- a/ortools/sat/model.h +++ b/ortools/sat/model.h @@ -64,8 +64,7 @@ class Model { // Returns an object of type T that is unique to this model (this is a bit // like a "local" singleton). This returns an already created instance or - // create a new one if needed using the T::CreateInModel(Model* model) - // function of the class T. + // create a new one if needed using the T(Model* model) constructor. // // This works a bit like in a dependency injection framework and allows to // really easily wire all the classes that make up a solver together. For @@ -73,25 +72,31 @@ class Model { // or both, it can depend on a Watcher class to register itself in order to // be called when needed and so on. // - // IMPORTANT: the CreateInModel() functiond shouldn't form a cycle between + // IMPORTANT: the Model* constructors function shouldn't form a cycle between // each other, otherwise this will crash the program. + // + // TODO(user): Rename to GetOrCreateSingleton(). template T* GetOrCreate() { const size_t type_id = FastTypeId(); if (!ContainsKey(singletons_, type_id)) { - // Note that it is up to CreateInModel() to call model->TakeOwnership() - // of the returned pointer. - // - // TODO(user): Always take ownership of the pointer instead. That would - // requires some cleanup, but it is probably a safer solution and would - // allow SetSingleton() to change an instance dynamically. - T* new_t = T::CreateInModel(this); + // TODO(user): directly store std::unique_ptr<> in singletons_? + T* new_t = new T(this); singletons_[type_id] = new_t; + TakeOwnership(new_t); return new_t; } return static_cast(FindOrDie(singletons_, type_id)); } + // This returns a non-singleton object owned by the model. + template + T* Create() { + T* new_t = new T(this); + TakeOwnership(new_t); + return new_t; + } + // Registers a given instance of type T as a "local singleton" for this type. // For now this CHECKs that the object was not yet created. template diff --git a/ortools/sat/no_cycle.h b/ortools/sat/no_cycle.h index c8a9331c04..f41c1623de 100644 --- a/ortools/sat/no_cycle.h +++ b/ortools/sat/no_cycle.h @@ -16,7 +16,6 @@ #include -#include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_solver.h" @@ -39,13 +38,6 @@ class NoCyclePropagator : public SatPropagator { include_propagated_arcs_in_graph_(true) {} ~NoCyclePropagator() final {} - static NoCyclePropagator* CreateInModel(Model* model) { - NoCyclePropagator* no_cycle = new NoCyclePropagator(); - model->GetOrCreate()->AddPropagator( - std::unique_ptr(no_cycle)); - return no_cycle; - } - bool Propagate(Trail* trail) final; void Untrail(const Trail& trail, int trail_index) final; gtl::Span Reason(const Trail& trail, diff --git a/ortools/sat/precedences.h b/ortools/sat/precedences.h index aff19953ff..9a91c61b37 100644 --- a/ortools/sat/precedences.h +++ b/ortools/sat/precedences.h @@ -39,27 +39,15 @@ namespace sat { // Another word is "separation logic". class PrecedencesPropagator : public SatPropagator, PropagatorInterface { public: - PrecedencesPropagator(Trail* trail, IntegerTrail* integer_trail, - GenericLiteralWatcher* watcher) + explicit PrecedencesPropagator(Model* model) : SatPropagator("PrecedencesPropagator"), - trail_(trail), - integer_trail_(integer_trail), - watcher_(watcher), - watcher_id_(watcher->Register(this)) { + trail_(model->GetOrCreate()), + integer_trail_(model->GetOrCreate()), + watcher_(model->GetOrCreate()), + watcher_id_(watcher_->Register(this)) { + model->GetOrCreate()->AddPropagator(this); integer_trail_->RegisterWatcher(&modified_vars_); - watcher->SetPropagatorPriority(watcher_id_, 0); - } - - static PrecedencesPropagator* CreateInModel(Model* model) { - PrecedencesPropagator* precedences = new PrecedencesPropagator( - model->GetOrCreate(), model->GetOrCreate(), - model->GetOrCreate()); - - // TODO(user): Find a way to have more control on the order in which - // the propagators are added. - model->GetOrCreate()->AddPropagator( - std::unique_ptr(precedences)); - return precedences; + watcher_->SetPropagatorPriority(watcher_id_, 0); } bool Propagate() final; diff --git a/ortools/sat/sat_base.h b/ortools/sat/sat_base.h index 86a1e460f3..30119ed8ea 100644 --- a/ortools/sat/sat_base.h +++ b/ortools/sat/sat_base.h @@ -224,17 +224,13 @@ struct AssignmentType { // and the information of each assignment. class Trail { public: + explicit Trail(Model* model) : Trail() {} + Trail() : num_enqueues_(0) { current_info_.trail_index = 0; current_info_.level = 0; } - static Trail* CreateInModel(Model* model) { - Trail* trail = new Trail(); - model->TakeOwnership(trail); - return trail; - } - void Resize(int num_variables); // Registers a propagator. This assigns a unique id to this propagator and diff --git a/ortools/sat/sat_solver.cc b/ortools/sat/sat_solver.cc index a58ee84a80..381296c071 100644 --- a/ortools/sat/sat_solver.cc +++ b/ortools/sat/sat_solver.cc @@ -31,13 +31,14 @@ namespace operations_research { namespace sat { -SatSolver::SatSolver() : SatSolver(new Trail()) { owned_trail_.reset(trail_); } +SatSolver::SatSolver() : SatSolver(new Model()) { owned_model_.reset(model_); } -SatSolver::SatSolver(Trail* trail) - : num_variables_(0), +SatSolver::SatSolver(Model* model) + : model_(model), + num_variables_(0), pb_constraints_(), track_binary_clauses_(false), - trail_(trail), + trail_(model->GetOrCreate()), current_decision_level_(0), last_decision_or_backtrack_trail_index_(0), assumption_level_(0), @@ -397,20 +398,20 @@ void SatSolver::AddLearnedClauseAndEnqueueUnitPropagation( } } -void SatSolver::AddPropagator(std::unique_ptr propagator) { +void SatSolver::AddPropagator(SatPropagator* propagator) { CHECK_EQ(CurrentDecisionLevel(), 0); problem_is_pure_sat_ = false; - trail_->RegisterPropagator(propagator.get()); - external_propagators_.push_back(std::move(propagator)); + trail_->RegisterPropagator(propagator); + external_propagators_.push_back(propagator); InitializePropagators(); } -void SatSolver::AddLastPropagator(std::unique_ptr propagator) { +void SatSolver::AddLastPropagator(SatPropagator* propagator) { CHECK_EQ(CurrentDecisionLevel(), 0); CHECK(last_propagator_ == nullptr); problem_is_pure_sat_ = false; - trail_->RegisterPropagator(propagator.get()); - last_propagator_ = std::move(propagator); + trail_->RegisterPropagator(propagator); + last_propagator_ = propagator; InitializePropagators(); } @@ -1538,10 +1539,10 @@ void SatSolver::InitializePropagators() { propagators_.push_back(&pb_constraints_); } for (int i = 0; i < external_propagators_.size(); ++i) { - propagators_.push_back(external_propagators_[i].get()); + propagators_.push_back(external_propagators_[i]); } if (last_propagator_ != nullptr) { - propagators_.push_back(last_propagator_.get()); + propagators_.push_back(last_propagator_); } } diff --git a/ortools/sat/sat_solver.h b/ortools/sat/sat_solver.h index 0a4306c98e..6cdf1c3e05 100644 --- a/ortools/sat/sat_solver.h +++ b/ortools/sat/sat_solver.h @@ -57,16 +57,9 @@ const int kUnsatTrailIndex = -1; class SatSolver { public: SatSolver(); - explicit SatSolver(Trail* trail); + explicit SatSolver(Model* model); ~SatSolver(); - static SatSolver* CreateInModel(Model* model) { - Trail* trail = model->GetOrCreate(); - SatSolver* solver = new SatSolver(trail); - model->TakeOwnership(solver); - return solver; - } - // Parameters management. Note that calling SetParameters() will reset the // value of many heuristics. For instance: // - The restart strategy will be reinitialized. @@ -136,8 +129,11 @@ class SatSolver { // Adds and registers the given propagator with the sat solver. Note that // during propagation, they will be called in the order they where added. - void AddPropagator(std::unique_ptr propagator); - void AddLastPropagator(std::unique_ptr propagator); + void AddPropagator(SatPropagator* propagator); + void AddLastPropagator(SatPropagator* propagator); + void TakePropagatorOwnership(std::unique_ptr propagator) { + owned_propagators_.push_back(std::move(propagator)); + } // Gives a hint so the solver tries to find a solution with the given literal // set to true. Currently this take precedence over the phase saving heuristic @@ -655,6 +651,10 @@ class SatSolver { std::string StatusString(Status status) const; std::string RunningStatisticsString() const; + // This is used by the old non-model constructor. + Model* model_; + std::unique_ptr owned_model_; + BooleanVariable num_variables_; // All the clauses managed by the solver (initial and learned). This vector @@ -686,8 +686,11 @@ class SatSolver { std::vector propagators_; // Ordered list of propagators added with AddPropagator(). - std::vector> external_propagators_; - std::unique_ptr last_propagator_; + std::vector external_propagators_; + SatPropagator* last_propagator_ = nullptr; + + // For the old, non-model interface. + std::vector> owned_propagators_; // Keep track of all binary clauses so they can be exported. bool track_binary_clauses_; @@ -696,9 +699,6 @@ class SatSolver { // The solver trail. Trail* trail_; - // This is used by the non-model constructor to properly cleanup trail_. - std::unique_ptr owned_trail_; - // Used for debugging only. See SaveDebugAssignment(). VariablesAssignment debug_assignment_; diff --git a/tools/generate_deps.sh b/tools/generate_deps.sh index db61bb8f60..1dfa766ef0 100755 --- a/tools/generate_deps.sh +++ b/tools/generate_deps.sh @@ -1,7 +1,7 @@ main_dir=$2 # List all files on ortools/$main_dir -all_cc=`ls ortools/$main_dir/*.cc` +all_cc=`ls ortools/$main_dir/*.cc | grep -v test.cc` all_h=`ls ortools/$main_dir/*.h` if ls ortools/$main_dir/*proto 1> /dev/null 2>&1; then all_proto=`ls ortools/$main_dir/*.proto`