From 845c252ab40127e2a39c133ab7c2b0afc610e84a Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Mon, 13 Feb 2023 06:45:58 -0800 Subject: [PATCH] [CP-SAT] fix 2 bugs: scheduling lns, multiple_workers with presolve off --- ortools/sat/cp_model_lns.cc | 78 ++++++++++++++-------------------- ortools/sat/cp_model_solver.cc | 6 +-- 2 files changed, 35 insertions(+), 49 deletions(-) diff --git a/ortools/sat/cp_model_lns.cc b/ortools/sat/cp_model_lns.cc index c51f8646fb..73d757d1ee 100644 --- a/ortools/sat/cp_model_lns.cc +++ b/ortools/sat/cp_model_lns.cc @@ -556,9 +556,14 @@ struct Demand { return std::tie(start, height, end) < std::tie(other.start, other.height, other.end); } + + std::string DebugString() const { + return absl::StrCat("{i=", interval_index, " span=[", start, ",", end, "]", + " d=", height, "}"); + } }; -void InsertPrecedencesFromSortedListOfDemands( +void InsertPrecedencesFromSortedListOfNonOverlapingIntervals( const std::vector& demands, absl::flat_hash_set>* precedences) { for (int i = 0; i + 1 < demands.size(); ++i) { @@ -568,6 +573,16 @@ void InsertPrecedencesFromSortedListOfDemands( } } +bool IsPresent(const ConstraintProto& interval_ct, + const CpSolverResponse& initial_solution) { + if (interval_ct.enforcement_literal().size() != 1) return true; + + const int enforcement_ref = interval_ct.enforcement_literal(0); + const int enforcement_var = PositiveRef(enforcement_ref); + const int64_t value = initial_solution.solution(enforcement_var); + return RefIsPositive(enforcement_ref) == (value == 1); +} + void InsertNoOverlapPrecedences( const absl::flat_hash_set& ignored_intervals, const CpSolverResponse& initial_solution, const CpModelProto& model_proto, @@ -580,27 +595,20 @@ void InsertNoOverlapPrecedences( if (ignored_intervals.contains(interval_index)) continue; const ConstraintProto& interval_ct = model_proto.constraints(interval_index); - // We only look at intervals that are performed in the solution. The - // unperformed intervals should be automatically freed during the generation - // phase. - if (interval_ct.enforcement_literal().size() == 1) { - const int enforcement_ref = interval_ct.enforcement_literal(0); - const int enforcement_var = PositiveRef(enforcement_ref); - const int value = initial_solution.solution(enforcement_var); - if (RefIsPositive(enforcement_ref) == (value == 0)) { - continue; - } - } + if (!IsPresent(interval_ct, initial_solution)) continue; const int64_t start_value = GetLinearExpressionValue( interval_ct.interval().start(), initial_solution); const int64_t end_value = GetLinearExpressionValue( - interval_ct.interval().size(), initial_solution); + interval_ct.interval().end(), initial_solution); + DCHECK_LE(start_value, end_value); demands.push_back({interval_index, start_value, end_value, 1}); } + // TODO(user): We actually only need interval_index, start. + // No need to fill the other fields here. std::sort(demands.begin(), demands.end()); - InsertPrecedencesFromSortedListOfDemands(demands, precedences); + InsertPrecedencesFromSortedListOfNonOverlapingIntervals(demands, precedences); } void ProcessDemandListFromCumulativeConstraint( @@ -628,7 +636,8 @@ void ProcessDemandListFromCumulativeConstraint( DCHECK_GT(sum_of_min_two_capacities, 1); if (sum_of_min_two_capacities > capacity) { - InsertPrecedencesFromSortedListOfDemands(demands, precedences); + InsertPrecedencesFromSortedListOfNonOverlapingIntervals(demands, + precedences); return; } @@ -719,20 +728,9 @@ void InsertCumulativePrecedences( for (int i = 0; i < cumulative.intervals().size(); ++i) { const int interval_index = cumulative.intervals(i); if (ignored_intervals.contains(interval_index)) continue; - const ConstraintProto& interval_ct = model_proto.constraints(interval_index); - // We only look at intervals that are performed in the solution. The - // unperformed intervals should be automatically freed during the generation - // phase. - if (interval_ct.enforcement_literal().size() == 1) { - const int enforcement_ref = interval_ct.enforcement_literal(0); - const int enforcement_var = PositiveRef(enforcement_ref); - const int value = initial_solution.solution(enforcement_var); - if (RefIsPositive(enforcement_ref) == (value == 0)) { - continue; - } - } + if (!IsPresent(interval_ct, initial_solution)) continue; const int64_t start_value = GetLinearExpressionValue( interval_ct.interval().start(), initial_solution); @@ -779,10 +777,11 @@ void InsertRectanglePredecences( const std::vector& rectangles, absl::flat_hash_set>* precedences) { // TODO(user): Refine set of interesting points. - absl::flat_hash_set interesting_points; + std::vector interesting_points; for (const Rectangle& r : rectangles) { - interesting_points.insert(r.y_end - 1); + interesting_points.push_back(r.y_end - 1); } + gtl::STLSortAndRemoveDuplicates(&interesting_points); std::vector demands; for (const int64_t t : interesting_points) { demands.clear(); @@ -791,7 +790,8 @@ void InsertRectanglePredecences( demands.push_back({r.interval_index, r.x_start, r.x_end, 1}); } std::sort(demands.begin(), demands.end()); - InsertPrecedencesFromSortedListOfDemands(demands, precedences); + InsertPrecedencesFromSortedListOfNonOverlapingIntervals(demands, + precedences); } } @@ -811,27 +811,13 @@ void InsertNoOverlap2dPrecedences( if (ignored_intervals.contains(x_interval_index)) continue; const ConstraintProto& x_interval_ct = model_proto.constraints(x_interval_index); - if (x_interval_ct.enforcement_literal().size() == 1) { - const int enforcement_ref = x_interval_ct.enforcement_literal(0); - const int enforcement_var = PositiveRef(enforcement_ref); - const int value = initial_solution.solution(enforcement_var); - if (RefIsPositive(enforcement_ref) == (value == 0)) { - continue; - } - } + if (!IsPresent(x_interval_ct, initial_solution)) continue; const int y_interval_index = no_overlap_2d.y_intervals(i); if (ignored_intervals.contains(y_interval_index)) continue; const ConstraintProto& y_interval_ct = model_proto.constraints(y_interval_index); - if (y_interval_ct.enforcement_literal().size() == 1) { - const int enforcement_ref = y_interval_ct.enforcement_literal(0); - const int enforcement_var = PositiveRef(enforcement_ref); - const int value = initial_solution.solution(enforcement_var); - if (RefIsPositive(enforcement_ref) == (value == 0)) { - continue; - } - } + if (!IsPresent(y_interval_ct, initial_solution)) continue; const int64_t x_start_value = GetLinearExpressionValue( x_interval_ct.interval().start(), initial_solution); diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index 487485e548..fbe4b5b7ef 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -2651,6 +2651,7 @@ class LnsSolver : public SubSolver { SatParameters local_params(parameters_); local_params.set_max_deterministic_time(data.deterministic_limit); local_params.set_stop_after_first_solution(false); + local_params.set_cp_model_presolve(true); local_params.set_log_search_progress(false); local_params.set_cp_model_probing_level(0); local_params.set_symmetry_level(0); @@ -2756,9 +2757,8 @@ class LnsSolver : public SubSolver { data.status = local_response.status(); // TODO(user): we actually do not need to postsolve if the solution is // not going to be used... - if (local_params.cp_model_presolve() && - (data.status == CpSolverStatus::OPTIMAL || - data.status == CpSolverStatus::FEASIBLE)) { + if (data.status == CpSolverStatus::OPTIMAL || + data.status == CpSolverStatus::FEASIBLE) { PostsolveResponseWrapper( local_params, helper_->ModelProto().variables_size(), mapping_proto, postsolve_mapping, &solution_values);