diff --git a/examples/cpp/BUILD b/examples/cpp/BUILD index 2f461bcd51..2b45dac924 100644 --- a/examples/cpp/BUILD +++ b/examples/cpp/BUILD @@ -293,6 +293,7 @@ cc_binary( "//ortools/base:file", "//ortools/base:strings", "//ortools/base:timer", + "//ortools/sat:cp_model_solver", "//ortools/sat:disjunctive", "//ortools/sat:integer", "//ortools/sat:intervals", @@ -370,6 +371,7 @@ cc_binary( "//ortools/base:file", "//ortools/base:strings", "//ortools/sat:cp_constraints", + "//ortools/sat:cp_model_solver", "//ortools/sat:cumulative", "//ortools/sat:disjunctive", "//ortools/sat:integer", @@ -516,6 +518,7 @@ cc_binary( "//ortools/base", "//ortools/base:file", "//ortools/base:strings", + "//ortools/sat:cp_model_solver", "//ortools/sat:disjunctive", "//ortools/sat:integer", "//ortools/sat:integer_expr", diff --git a/examples/cpp/jobshop_sat.cc b/examples/cpp/jobshop_sat.cc index 12665e3434..ac539d7af7 100644 --- a/examples/cpp/jobshop_sat.cc +++ b/examples/cpp/jobshop_sat.cc @@ -24,6 +24,7 @@ #include "ortools/base/strutil.h" #include "examples/cpp/flexible_jobshop.h" #include "examples/cpp/jobshop.h" +#include "ortools/sat/cp_model_solver.h" #include "ortools/sat/disjunctive.h" #include "ortools/sat/intervals.h" #include "ortools/sat/model.h" diff --git a/examples/cpp/rcpsp_sat.cc b/examples/cpp/rcpsp_sat.cc index 82125fb6fa..63ea7598b8 100644 --- a/examples/cpp/rcpsp_sat.cc +++ b/examples/cpp/rcpsp_sat.cc @@ -17,6 +17,7 @@ #include "ortools/base/commandlineflags.h" #include "ortools/base/logging.h" #include "ortools/base/timer.h" +#include "ortools/sat/cp_model_solver.h" #include "ortools/sat/cumulative.h" #include "ortools/sat/disjunctive.h" #include "ortools/sat/integer_expr.h" diff --git a/examples/cpp/shift_minimization_sat.cc b/examples/cpp/shift_minimization_sat.cc index e79a55865d..3a4eb6396e 100644 --- a/examples/cpp/shift_minimization_sat.cc +++ b/examples/cpp/shift_minimization_sat.cc @@ -39,6 +39,7 @@ #include "ortools/util/filelineiter.h" #include "ortools/base/split.h" #include "ortools/sat/cp_constraints.h" +#include "ortools/sat/cp_model_solver.h" #include "ortools/sat/integer_expr.h" #include "ortools/sat/model.h" #include "ortools/sat/optimization.h" diff --git a/examples/cpp/weighted_tardiness_sat.cc b/examples/cpp/weighted_tardiness_sat.cc index da0062bfb8..97e319599b 100644 --- a/examples/cpp/weighted_tardiness_sat.cc +++ b/examples/cpp/weighted_tardiness_sat.cc @@ -12,24 +12,21 @@ // limitations under the License. #include +#include #include #include "ortools/base/commandlineflags.h" #include "ortools/base/commandlineflags.h" #include "ortools/base/logging.h" -#include "ortools/base/stringprintf.h" -#include "ortools/base/strtoint.h" #include "ortools/base/timer.h" #include "google/protobuf/text_format.h" #include "ortools/base/join.h" #include "ortools/base/split.h" +#include "ortools/base/strtoint.h" #include "ortools/base/strutil.h" -#include "ortools/sat/disjunctive.h" -#include "ortools/sat/integer_expr.h" -#include "ortools/sat/intervals.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/cp_model_solver.h" #include "ortools/sat/model.h" -#include "ortools/sat/optimization.h" -#include "ortools/sat/precedences.h" #include "ortools/util/filelineiter.h" DEFINE_string(input, "examples/data/weighted_tardiness/wt40.txt", @@ -37,9 +34,6 @@ DEFINE_string(input, "examples/data/weighted_tardiness/wt40.txt", DEFINE_int32(size, 40, "Size of the problem in the wt file."); DEFINE_int32(n, 28, "1-based instance number in the wt file."); DEFINE_string(params, "", "Sat parameters in text proto format."); -DEFINE_bool(use_boolean_precedences, false, - "Whether we create Boolean variables for all the possible " - "precedences between tasks on the same machine, or not."); DEFINE_int32(upper_bound, -1, "If positive, look for a solution <= this."); namespace operations_research { @@ -86,32 +80,81 @@ void Solve(const std::vector& durations, const std::vector& due_dates, LOG(INFO) << "Trival cost bound = " << heuristic_bound; // Create the model. - Model model; - std::vector decision_vars; - std::vector tasks(num_tasks); - std::vector tardiness_vars(num_tasks); + CpModelProto cp_model; + cp_model.set_name("weighted_tardiness"); + auto new_variable = [&cp_model](int64 lb, int64 ub) { + const int index = cp_model.variables_size(); + IntegerVariableProto* var = cp_model.add_variables(); + var->add_domain(lb); + var->add_domain(ub); + return index; + }; + auto new_interval = [&cp_model](int start, int duration, int end) { + const int index = cp_model.constraints_size(); + ConstraintProto* ct = cp_model.add_constraints(); + ct->mutable_interval()->set_start(start); + ct->mutable_interval()->set_size(duration); + ct->mutable_interval()->set_end(end); + return index; + }; + + std::vector tasks_interval(num_tasks); + std::vector tasks_start(num_tasks); + std::vector tasks_duration(num_tasks); + std::vector tasks_end(num_tasks); + std::vector tardiness_vars(num_tasks); for (int i = 0; i < num_tasks; ++i) { - tasks[i] = model.Add(NewInterval(0, horizon, durations[i])); + tasks_start[i] = new_variable(0, horizon - durations[i]); + tasks_duration[i] = new_variable(durations[i], durations[i]); + tasks_end[i] = new_variable(durations[i], horizon); + tasks_interval[i] = + new_interval(tasks_start[i], tasks_duration[i], tasks_end[i]); if (due_dates[i] == 0) { - tardiness_vars[i] = model.Get(EndVar(tasks[i])); + tardiness_vars[i] = tasks_end[i]; } else { - tardiness_vars[i] = - model.Add(NewIntegerVariable(0, std::max(0, horizon - due_dates[i]))); - model.Add(LowerOrEqualWithOffset(model.Get(EndVar(tasks[i])), - tardiness_vars[i], -due_dates[i])); + tardiness_vars[i] = new_variable(0, std::max(0, horizon - due_dates[i])); + + // tardiness_vars >= end - due_date + LinearConstraintProto* arg = cp_model.add_constraints()->mutable_linear(); + arg->add_vars(tardiness_vars[i]); + arg->add_coeffs(1); + arg->add_vars(tasks_end[i]); + arg->add_coeffs(-1); + arg->add_domain(-due_dates[i]); + arg->add_domain(kint64max); } + } + + // Decision heuristic. Note that we don't instantiate all the variables. As a + // consequence, in the values returned by the solution observer for the + // non-fully instantiated variable will be the variable lower bounds after + // propagation. + { + DecisionStrategyProto* strategy = cp_model.add_search_strategy(); + for (int i = 0; i < num_tasks; ++i) strategy->add_variables(tasks_start[i]); // Experiments showed that the heuristic of choosing first the task that - // comes last (because of the NegationOf()) works a lot better. This make - // sense because these are the task with the most influence on the cost. - decision_vars.push_back(NegationOf(model.Get(StartVar(tasks[i])))); - } - if (FLAGS_use_boolean_precedences) { - model.Add(DisjunctiveWithBooleanPrecedences(tasks)); - } else { - model.Add(Disjunctive(tasks)); + // comes last works a lot better. This make sense because these are the task + // with the most influence on the cost. + strategy->set_variable_selection_strategy( + DecisionStrategyProto::CHOOSE_HIGHEST_MAX); + strategy->set_domain_reduction_strategy( + DecisionStrategyProto::SELECT_MAX_VALUE); } + // Disjunction between all the task intervals + { + ConstraintProto* ct = cp_model.add_constraints(); + NoOverlapConstraintProto* arg = ct->mutable_no_overlap(); + for (const int interval : tasks_interval) { + arg->add_intervals(interval); + } + } + + // TODO(user): We can't set an objective upper bound with the current cp_model + // interface, so we can't use heuristic or FLAGS_upper_bound here. The best is + // probably to provide a "solution hint" instead. + // // Set a known upper bound (or use the flag). This has a bigger impact than // can be expected at first: // - It avoid spending time finding not so good solution. @@ -121,12 +164,10 @@ void Solve(const std::vector& durations, const std::vector& due_dates, // // Note however than for big problem, this will drastically augment the time // to get a first feasible solution (but then the heuristic gave one to us). - const IntegerVariable objective_var = - model.Add(NewWeightedSum(weights, tardiness_vars)); - if (FLAGS_upper_bound >= 0) { - model.Add(LowerOrEqual(objective_var, FLAGS_upper_bound)); - } else { - model.Add(LowerOrEqual(objective_var, heuristic_bound)); + CpObjectiveProto* objective = cp_model.mutable_objective(); + for (int i = 0; i < num_tasks; ++i) { + objective->add_vars(tardiness_vars[i]); + objective->add_coeffs(weights[i]); } // Optional preprocessing: add precedences that don't change the optimal @@ -150,89 +191,81 @@ void Solve(const std::vector& durations, const std::vector& due_dates, } ++num_added_precedences; - model.Add(LowerOrEqual(model.Get(EndVar(tasks[i])), - model.Get(StartVar(tasks[j])))); + ConstraintProto* ct = cp_model.add_constraints(); + LinearConstraintProto* arg = ct->mutable_linear(); + arg->add_vars(tasks_start[j]); + arg->add_coeffs(1); + arg->add_vars(tasks_end[i]); + arg->add_coeffs(-1); + arg->add_domain(0); + arg->add_domain(kint64max); } } } LOG(INFO) << "Added " << num_added_precedences << " precedences that will not affect the optimal solution value."; - if (FLAGS_use_boolean_precedences) { - // We disable the lazy encoding in this case. - decision_vars.clear(); - } - // Solve it. // - // Note that we only fully instanciate the start/end and only look at the + // Note that we only fully instantiate the start/end and only look at the // lower bound for the objective and the tardiness variables. + Model model; model.Add(NewSatParameters(FLAGS_params)); - MinimizeIntegerVariableWithLinearScanAndLazyEncoding( - /*log_info=*/true, objective_var, - /*next_decision=*/ - UnassignedVarWithLowestMinAtItsMinHeuristic(decision_vars, &model), - /*feasible_solution_observer=*/ - [&](const Model& model) { - const int64 objective = model.Get(LowerBound(objective_var)); - LOG(INFO) << "Cost " << objective; + model.Add(NewFeasibleSolutionObserver([&](const std::vector& values) { + // Note that we conpute the "real" cost here and do not use the tardiness + // variables. This is because in the core based appraoch, the tardiness + // variable might be fixed before the end date, and we just have a >= + // relation. + int64 objective = 0; + for (int i = 0; i < num_tasks; ++i) { + objective += + weights[i] * std::max(0ll, values[tasks_end[i]] - due_dates[i]); + } + LOG(INFO) << "Cost " << objective; - // Debug code. - { - int64 tardiness_objective = 0; - for (int i = 0; i < num_tasks; ++i) { - tardiness_objective += - weights[i] * - std::max(0ll, model.Get(Value(model.Get(EndVar(tasks[i])))) - - due_dates[i]); - } - CHECK_EQ(objective, tardiness_objective); + // Print the current solution. + std::vector sorted_tasks(num_tasks); + std::iota(sorted_tasks.begin(), sorted_tasks.end(), 0); + std::sort(sorted_tasks.begin(), sorted_tasks.end(), [&](int v1, int v2) { + return values[tasks_start[v1]] < values[tasks_start[v2]]; + }); + std::string solution = "0"; + int end = 0; + for (const int i : sorted_tasks) { + const int64 cost = weights[i] * values[tardiness_vars[i]]; + StrAppend(&solution, "| #", i, " "); + if (cost > 0) { + // Display the cost in red. + StrAppend(&solution, "\033[1;31m(+", cost, ") \033[0m"); + } + StrAppend(&solution, "|", values[tasks_end[i]]); + CHECK_EQ(end, values[tasks_start[i]]); + end += durations[i]; + CHECK_EQ(end, values[tasks_end[i]]); + } + LOG(INFO) << "solution: " << solution; + })); - tardiness_objective = 0; - for (int i = 0; i < num_tasks; ++i) { - tardiness_objective += - weights[i] * model.Get(LowerBound(tardiness_vars[i])); - } - CHECK_EQ(objective, tardiness_objective); - } - - // Print the current solution. - std::vector sorted_tasks = tasks; - std::sort(sorted_tasks.begin(), sorted_tasks.end(), - [&model](IntervalVariable v1, IntervalVariable v2) { - return model.Get(Value(model.Get(StartVar(v1)))) < - model.Get(Value(model.Get(StartVar(v2)))); - }); - std::string solution = "0"; - int end = 0; - for (const IntervalVariable v : sorted_tasks) { - const int64 cost = weights[v.value()] * - model.Get(LowerBound(tardiness_vars[v.value()])); - solution += StringPrintf("| #%d ", v.value()); - if (cost > 0) { - // Display the cost in red. - solution += StringPrintf("\033[1;31m(+%lld) \033[0m", cost); - } - solution += - StringPrintf("|%lld", model.Get(Value(model.Get(EndVar(v))))); - CHECK_EQ(end, model.Get(Value(model.Get(StartVar(v))))); - end += durations[v.value()]; - CHECK_EQ(end, model.Get(Value(model.Get(EndVar(v))))); - } - LOG(INFO) << "solution: " << solution; - }, - &model); + LOG(INFO) << CpModelStats(cp_model); + const CpSolverResponse response = SolveCpModel(cp_model, &model); + LOG(INFO) << CpSolverResponseStats(response); } } // namespace sat +} // namespace operations_research + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags( &argc, &argv, true); + if (FLAGS_input.empty()) { + LOG(FATAL) << "Please supply a data file with --input="; + } -void LoadAndSolve() { std::vector numbers; std::vector entries; for (const std::string& line : operations_research::FileLines(FLAGS_input)) { entries = strings::Split(line, ' ', strings::SkipEmpty()); for (const std::string& entry : entries) { - numbers.push_back(atoi32(entry)); + numbers.push_back(operations_research::atoi32(entry)); } } @@ -253,15 +286,6 @@ void LoadAndSolve() { std::vector due_dates; for (int j = 0; j < FLAGS_size; ++j) due_dates.push_back(numbers[index++]); - sat::Solve(durations, due_dates, weights); -} -} // namespace operations_research - -int main(int argc, char** argv) { - gflags::ParseCommandLineFlags( &argc, &argv, true); - if (FLAGS_input.empty()) { - LOG(FATAL) << "Please supply a data file with --input="; - } - operations_research::LoadAndSolve(); + operations_research::sat::Solve(durations, due_dates, weights); return EXIT_SUCCESS; } diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index 342f6eae00..c1588b826d 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -26,6 +26,7 @@ #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/cumulative.h" #include "ortools/sat/disjunctive.h" +#include "ortools/sat/integer.h" #include "ortools/sat/intervals.h" #include "ortools/sat/linear_programming_constraint.h" #include "ortools/sat/optimization.h" @@ -1832,6 +1833,27 @@ std::function NewFeasibleSolutionObserver( }; } +std::function NewSatParameters(const std::string& params) { + return [=](Model* model) { + sat::SatParameters parameters; + if (!params.empty()) { + CHECK(google::protobuf::TextFormat::ParseFromString(params, ¶meters)) << params; + model->GetOrCreate()->SetParameters(parameters); + model->SetSingleton(TimeLimit::FromParameters(parameters)); + } + return parameters; + }; +} + +std::function NewSatParameters( + const sat::SatParameters& parameters) { + return [=](Model* model) { + model->GetOrCreate()->SetParameters(parameters); + model->SetSingleton(TimeLimit::FromParameters(parameters)); + return parameters; + }; +} + namespace { // Because we also use this function for postsolve, we call it with diff --git a/ortools/sat/cp_model_solver.h b/ortools/sat/cp_model_solver.h index 517af3f9a3..28d1e44127 100644 --- a/ortools/sat/cp_model_solver.h +++ b/ortools/sat/cp_model_solver.h @@ -15,8 +15,8 @@ #define OR_TOOLS_SAT_CP_MODEL_SOLVER_H_ #include "ortools/sat/cp_model.pb.h" -#include "ortools/sat/integer.h" #include "ortools/sat/model.h" +#include "ortools/sat/sat_parameters.pb.h" namespace operations_research { namespace sat { @@ -42,11 +42,21 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model); // search. The values will be in one to one correspondence with the variables // in the model_proto. // +// Hack: For the non-fully instantiated variables, the value will be the +// propagated lower bound. Note that this will be fixed with the TODO below. +// // TODO(user): Change the API to take the full CpSolverResponse() so we have // solve statistics and the current objective value. std::function NewFeasibleSolutionObserver( const std::function& values)>& observer); +// Allows to change the default parameters with +// model->Add(NewSatParameters(parameters_as_string_or_proto)) +// before calling SolveCpModel(). +std::function NewSatParameters(const std::string& params); +std::function NewSatParameters( + const SatParameters& parameters); + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/optimization.cc b/ortools/sat/optimization.cc index df91f2b592..5021a28247 100644 --- a/ortools/sat/optimization.cc +++ b/ortools/sat/optimization.cc @@ -431,7 +431,7 @@ SatSolver::Status SolveWithWPM1(LogBehavior log, Logger logger(log); FuMalikSymmetryBreaker symmetry; - // The curent lower_bound on the cost. + // The current lower_bound on the cost. // It will be correct after the initialization. Coefficient lower_bound(static_cast(problem.objective().offset())); Coefficient upper_bound(kint64max); @@ -1180,7 +1180,7 @@ SatSolver::Status MinimizeWithCoreAndLazyEncoding( IntegerValue objective(0); for (int i = 0; i < variables.size(); ++i) { objective += - coefficients[i] * IntegerValue(model->Get(Value(variables[i]))); + coefficients[i] * IntegerValue(model->Get(LowerBound(variables[i]))); } if (objective >= best_objective && num_solutions > 0) return true; @@ -1226,7 +1226,11 @@ SatSolver::Status MinimizeWithCoreAndLazyEncoding( // This is used by the "stratified" approach. We will only consider terms with // a weight not lower than this threshold. The threshold will decrease as the // algorithm progress. - IntegerValue stratified_threshold = kMaxIntegerValue; + IntegerValue stratified_threshold = + sat_solver->parameters().max_sat_stratification() == + SatParameters::STRATIFICATION_NONE + ? IntegerValue(0) + : kMaxIntegerValue; // TODO(user): The core is returned in the same order as the assumptions, // so we don't really need this map, we could just do a linear scan to diff --git a/ortools/sat/sat_solver.h b/ortools/sat/sat_solver.h index 6cdf1c3e05..a04a4130f5 100644 --- a/ortools/sat/sat_solver.h +++ b/ortools/sat/sat_solver.h @@ -1115,28 +1115,6 @@ inline std::function ExcludeCurrentSolutionAndBacktrack() { }; } -inline std::function NewSatParameters( - const std::string& params) { - return [=](Model* model) { - sat::SatParameters parameters; - if (!params.empty()) { - CHECK(google::protobuf::TextFormat::ParseFromString(params, ¶meters)) << params; - model->GetOrCreate()->SetParameters(parameters); - model->SetSingleton(TimeLimit::FromParameters(parameters)); - } - return parameters; - }; -} - -inline std::function NewSatParameters( - const sat::SatParameters& parameters) { - return [=](Model* model) { - model->GetOrCreate()->SetParameters(parameters); - model->SetSingleton(TimeLimit::FromParameters(parameters)); - return parameters; - }; -} - // Returns a std::string representation of a SatSolver::Status. std::string SatStatusString(SatSolver::Status status); inline std::ostream& operator<<(std::ostream& os, SatSolver::Status status) {