From 2d54341470b1bca27862066a2be7d655f4d5b10e Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Thu, 28 Mar 2019 21:23:28 +0100 Subject: [PATCH] more improvements to diffn --- ortools/sat/cp_model_lns.cc | 20 +++++ ortools/sat/cp_model_lns.h | 19 +++++ ortools/sat/cp_model_solver.cc | 136 +++++++++++++++++++-------------- ortools/sat/cumulative.cc | 11 +-- ortools/sat/cumulative.h | 6 +- ortools/sat/diffn.cc | 133 ++++++++++++++------------------ ortools/sat/diffn.h | 66 ++++++++++------ ortools/sat/intervals.h | 3 + 8 files changed, 233 insertions(+), 161 deletions(-) diff --git a/ortools/sat/cp_model_lns.cc b/ortools/sat/cp_model_lns.cc index c2f45f7e5e..8ad55fe450 100644 --- a/ortools/sat/cp_model_lns.cc +++ b/ortools/sat/cp_model_lns.cc @@ -133,6 +133,26 @@ Neighborhood NeighborhoodGeneratorHelper::RelaxGivenVariables( return FixGivenVariables(initial_solution, fixed_variables); } +double NeighborhoodGenerator::GetUCBScore(int64 total_num_calls) const { + DCHECK_GE(total_num_calls, num_calls_); + if (num_calls_ <= 10) return std::numeric_limits::infinity(); + return current_average_ + sqrt((2 * log(total_num_calls)) / num_calls_); +} + +void NeighborhoodGenerator::AddSolveData(double objective_diff, + double deterministic_time) { + double gain_per_time_unit = objective_diff / (1.0 + deterministic_time); + // TODO(user): Add more data. + // TODO(user): Weight more recent data. + num_calls_++; + // degrade the current average to forget old learnings. + if (num_calls_ <= 100) { + current_average_ += (gain_per_time_unit - current_average_) / num_calls_; + } else { + current_average_ = 0.9 * current_average_ + 0.1 * gain_per_time_unit; + } +} + namespace { void GetRandomSubset(int seed, double relative_size, std::vector* base) { diff --git a/ortools/sat/cp_model_lns.h b/ortools/sat/cp_model_lns.h index 0f615aa44c..a8e0b0e622 100644 --- a/ortools/sat/cp_model_lns.h +++ b/ortools/sat/cp_model_lns.h @@ -133,9 +133,28 @@ class NeighborhoodGenerator { // Returns a short description of the generator. std::string name() const { return name_; } + // Uses UCB1 algorithm to compute the score (Multi armed bandit problem). + // Details are at + // https://lilianweng.github.io/lil-log/2018/01/23/the-multi-armed-bandit-problem-and-its-solutions.html. + // 'total_num_calls' should be the sum of calls across all generators part of + // the multi armed bandit problem. + // If the generator is called less than 10 times then the method returns + // inifinity as score in order to get more data about the generator + // performance. + double GetUCBScore(int64 total_num_calls) const; + + // Updates the records using the current improvement in objective for the + // generator. + void AddSolveData(double objective_diff, double deterministic_time); + + // Number of times this generator is called. + int64 num_calls() const { return num_calls_; } + protected: const NeighborhoodGeneratorHelper& helper_; const std::string name_; + int64 num_calls_ = 0; + double current_average_ = 0.0; }; // Pick a random subset of variables. diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index 1df00fb811..32b564ff5f 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -1964,6 +1965,7 @@ CpSolverResponse SolveCpModelWithLNS( int num_no_progress = 0; const int num_threads = std::max(1, parameters->lns_num_threads()); + int64 total_num_calls = 0; OptimizeWithLNS( num_threads, [&]() { @@ -2019,10 +2021,9 @@ CpSolverResponse SolveCpModelWithLNS( } } } - AdaptiveParameterValue& difficulty = - difficulties[seed % generators.size()]; - const double saved_difficulty = difficulty.value(); const int selected_generator = seed % generators.size(); + AdaptiveParameterValue& difficulty = difficulties[selected_generator]; + const double saved_difficulty = difficulty.value(); Neighborhood neighborhood = generators[selected_generator]->Generate( response, num_workers * seed + worker_id, saved_difficulty); CpModelProto& local_problem = neighborhood.cp_model; @@ -2071,63 +2072,82 @@ CpSolverResponse SolveCpModelWithLNS( } const bool neighborhood_is_reduced = neighborhood.is_reduced; - return - [neighborhood_is_reduced, &num_no_progress, &model_proto, &response, - &difficulty, local_response, &observer, limit, solution_info]() { - // TODO(user): This is not ideal in multithread because even - // though the saved_difficulty will be the same for all thread, we - // will Increase()/Decrease() the difficuty sequentially more than - // once. - if (local_response.status() == CpSolverStatus::OPTIMAL || - local_response.status() == CpSolverStatus::INFEASIBLE) { - if (neighborhood_is_reduced) { - difficulty.Increase(); - } else { - // We solved the full model here. - response = local_response; - } - } else { - difficulty.Decrease(); + return [neighborhood_is_reduced, &num_no_progress, &model_proto, + &response, &difficulty, local_response, &observer, limit, + solution_info, &generators, selected_generator, + &total_num_calls]() { + // TODO(user): This is not ideal in multithread because even though + // the saved_difficulty will be the same for all thread, we will + // Increase()/Decrease() the difficuty sequentially more than once. + if (local_response.status() == CpSolverStatus::OPTIMAL || + local_response.status() == CpSolverStatus::INFEASIBLE) { + if (neighborhood_is_reduced) { + difficulty.Increase(); + } else { + // We solved the full model here. + response = local_response; + } + } else { + difficulty.Decrease(); + } + // Update the generator record. + double objective_diff = 0.0; + if (local_response.status() == CpSolverStatus::OPTIMAL || + local_response.status() == CpSolverStatus::FEASIBLE) { + objective_diff = std::abs(local_response.objective_value() - + response.objective_value()); + } + total_num_calls++; + generators[selected_generator]->AddSolveData( + objective_diff, local_response.deterministic_time()); + VLOG(2) + << generators[selected_generator]->name() + << ": [difficulty: " << difficulty.value() + << ", deterministic time: " << local_response.deterministic_time() + << ", status: " << CpSolverStatus_Name(local_response.status()) + << ", num calls: " << generators[selected_generator]->num_calls() + << ", UCB1 Score: " + << generators[selected_generator]->GetUCBScore(total_num_calls) + << "]"; + if (local_response.status() == CpSolverStatus::FEASIBLE || + local_response.status() == CpSolverStatus::OPTIMAL) { + // If the objective are the same, we override the solution, + // otherwise we just ignore this local solution and increment + // num_no_progress. + double coeff = model_proto.objective().scaling_factor(); + if (coeff == 0.0) coeff = 1.0; + if (local_response.objective_value() * coeff >= + response.objective_value() * coeff) { + if (local_response.objective_value() * coeff > + response.objective_value() * coeff) { + return; } - if (local_response.status() == CpSolverStatus::FEASIBLE || - local_response.status() == CpSolverStatus::OPTIMAL) { - // If the objective are the same, we override the solution, - // otherwise we just ignore this local solution and increment - // num_no_progress. - double coeff = model_proto.objective().scaling_factor(); - if (coeff == 0.0) coeff = 1.0; - if (local_response.objective_value() * coeff >= - response.objective_value() * coeff) { - if (local_response.objective_value() * coeff > - response.objective_value() * coeff) { - return; - } - ++num_no_progress; - } else { - num_no_progress = 0; - } + ++num_no_progress; + } else { + num_no_progress = 0; + } - // Update the global response. - *(response.mutable_solution()) = local_response.solution(); - response.set_objective_value(local_response.objective_value()); - response.set_wall_time(limit->GetElapsedTime()); - response.set_user_time(response.user_time() + - local_response.user_time()); - response.set_deterministic_time( - response.deterministic_time() + - local_response.deterministic_time()); - if (DEBUG_MODE || FLAGS_cp_model_check_intermediate_solutions) { - CHECK(SolutionIsFeasible( - model_proto, - std::vector(local_response.solution().begin(), - local_response.solution().end()))); - } - if (num_no_progress == 0) { // Improving solution. - response.set_solution_info(solution_info); - observer(response); - } - } - }; + // Update the global response. + *(response.mutable_solution()) = local_response.solution(); + response.set_objective_value(local_response.objective_value()); + response.set_wall_time(limit->GetElapsedTime()); + response.set_user_time(response.user_time() + + local_response.user_time()); + response.set_deterministic_time( + response.deterministic_time() + + local_response.deterministic_time()); + if (DEBUG_MODE || FLAGS_cp_model_check_intermediate_solutions) { + CHECK(SolutionIsFeasible( + model_proto, + std::vector(local_response.solution().begin(), + local_response.solution().end()))); + } + if (num_no_progress == 0) { // Improving solution. + response.set_solution_info(solution_info); + observer(response); + } + } + }; }); if (response.status() == CpSolverStatus::FEASIBLE) { diff --git a/ortools/sat/cumulative.cc b/ortools/sat/cumulative.cc index 160b41a98e..889c24f909 100644 --- a/ortools/sat/cumulative.cc +++ b/ortools/sat/cumulative.cc @@ -34,8 +34,8 @@ namespace sat { std::function Cumulative( const std::vector& vars, const std::vector& demand_vars, - const IntegerVariable& capacity) { - return [=](Model* model) { + const IntegerVariable& capacity, SchedulingConstraintHelper* helper) { + return [=](Model* model) mutable { if (vars.empty()) return; IntervalsRepository* intervals = model->GetOrCreate(); @@ -123,9 +123,10 @@ std::function Cumulative( Trail* trail = model->GetOrCreate(); IntegerTrail* integer_trail = model->GetOrCreate(); - SchedulingConstraintHelper* helper = - new SchedulingConstraintHelper(vars, model); - model->TakeOwnership(helper); + if (helper == nullptr) { + helper = new SchedulingConstraintHelper(vars, model); + model->TakeOwnership(helper); + } // Propagator responsible for applying Timetabling filtering rule. It // increases the minimum of the start variables, decrease the maximum of the diff --git a/ortools/sat/cumulative.h b/ortools/sat/cumulative.h index fe320cd437..c5cdf1ab67 100644 --- a/ortools/sat/cumulative.h +++ b/ortools/sat/cumulative.h @@ -39,10 +39,14 @@ namespace sat { // // This constraint assumes that an interval can be optional or have a duration // of zero. The demands and the capacity can be any non-negative number. +// +// Optimization: If one already have an helper constructed from the interval +// variable, it can be passed as last argument. std::function Cumulative( const std::vector& vars, const std::vector& demand_vars, - const IntegerVariable& capacity_var); + const IntegerVariable& capacity_var, + SchedulingConstraintHelper* helper = nullptr); // Adds a simple cumulative constraint on the given intervals, the associated // demands and the capacity variables. See the comment of Cumulative() above for diff --git a/ortools/sat/diffn.cc b/ortools/sat/diffn.cc index d64226939e..5fdea40cf7 100644 --- a/ortools/sat/diffn.cc +++ b/ortools/sat/diffn.cc @@ -30,37 +30,29 @@ namespace operations_research { namespace sat { -void AddCumulativeRelaxation(const std::vector& x, - const std::vector& y, - Model* model) { - IntervalsRepository* const repository = - model->GetOrCreate(); - std::vector starts; +void AddCumulativeRelaxation(SchedulingConstraintHelper* x, + SchedulingConstraintHelper* y, Model* model) { std::vector sizes; - std::vector ends; + int64 min_starts = kint64max; int64 max_ends = kint64min; - - for (const IntervalVariable& interval : y) { - starts.push_back(repository->StartVar(interval)); - IntegerVariable s_var = repository->SizeVar(interval); + for (int box = 0; box < y->NumTasks(); ++box) { + IntegerVariable s_var = y->DurationVars()[box]; if (s_var == kNoIntegerVariable) { - s_var = model->Add( - ConstantIntegerVariable(repository->MinSize(interval).value())); + s_var = model->Add(ConstantIntegerVariable(y->DurationMin(box).value())); } sizes.push_back(s_var); - ends.push_back(repository->EndVar(interval)); - min_starts = std::min(min_starts, model->Get(LowerBound(starts.back()))); - max_ends = std::max(max_ends, model->Get(UpperBound(ends.back()))); + min_starts = std::min(min_starts, y->StartMin(box).value()); + max_ends = std::max(max_ends, y->EndMax(box).value()); } const IntegerVariable min_start_var = model->Add(NewIntegerVariable(min_starts, max_ends)); - model->Add(IsEqualToMinOf(min_start_var, starts)); + model->Add(IsEqualToMinOf(min_start_var, y->StartVars())); const IntegerVariable max_end_var = model->Add(NewIntegerVariable(min_starts, max_ends)); - model->Add(IsEqualToMaxOf(max_end_var, ends)); + model->Add(IsEqualToMaxOf(max_end_var, y->EndVars())); const IntegerVariable capacity = model->Add(NewIntegerVariable(0, CapSub(max_ends, min_starts))); @@ -68,7 +60,7 @@ void AddCumulativeRelaxation(const std::vector& x, model->Add(WeightedSumGreaterOrEqual({capacity, min_start_var, max_end_var}, coeffs, 0)); - model->Add(Cumulative(x, sizes, capacity)); + model->Add(Cumulative(x->Intervals(), sizes, capacity, x)); } namespace { @@ -135,22 +127,28 @@ std::vector> SplitDisjointBoxes( #define RETURN_IF_FALSE(f) \ if (!(f)) return false; -NonOverlappingRectanglesEnergyPropagator:: - NonOverlappingRectanglesEnergyPropagator( - const std::vector& x, - const std::vector& y, Model* model) - : x_(x, model), y_(y, model) {} - NonOverlappingRectanglesEnergyPropagator:: ~NonOverlappingRectanglesEnergyPropagator() {} bool NonOverlappingRectanglesEnergyPropagator::Propagate() { - cached_areas_.resize(x_.NumTasks()); + const int num_boxes = x_.NumTasks(); + x_.SetTimeDirection(true); + y_.SetTimeDirection(true); active_boxes_.clear(); - for (int box = 0; box < x_.NumTasks(); ++box) { + cached_areas_.resize(num_boxes); + cached_dimensions_.resize(num_boxes); + for (int box = 0; box < num_boxes; ++box) { cached_areas_[box] = x_.DurationMin(box) * y_.DurationMin(box); if (cached_areas_[box] == 0) continue; + + // TODO(user): Also consider shifted end max. + Dimension& dimension = cached_dimensions_[box]; + dimension.x_min = x_.ShiftedStartMin(box); + dimension.x_max = x_.EndMax(box); + dimension.y_min = y_.ShiftedStartMin(box); + dimension.y_max = y_.EndMax(box); + active_boxes_.push_back(box); } if (active_boxes_.size() <= 1) return true; @@ -182,24 +180,16 @@ int NonOverlappingRectanglesEnergyPropagator::RegisterWith( void NonOverlappingRectanglesEnergyPropagator::SortBoxesIntoNeighbors( int box, absl::Span local_boxes) { - const IntegerValue box_x_min = x_.StartMin(box); - const IntegerValue box_x_max = x_.EndMax(box); - const IntegerValue box_y_min = y_.StartMin(box); - const IntegerValue box_y_max = y_.EndMax(box); + const Dimension& box_dim = cached_dimensions_[box]; neighbors_.clear(); for (const int other_box : local_boxes) { if (other_box == box) continue; - - const IntegerValue other_x_min = x_.StartMin(other_box); - const IntegerValue other_x_max = x_.EndMax(other_box); - const IntegerValue other_y_min = y_.StartMin(other_box); - const IntegerValue other_y_max = y_.EndMax(other_box); - - const IntegerValue span_x = - std::max(box_x_max, other_x_max) - std::min(box_x_min, other_x_min) + 1; - const IntegerValue span_y = - std::max(box_y_max, other_y_max) - std::min(box_y_min, other_y_min) + 1; + const Dimension& other_dim = cached_dimensions_[other_box]; + const IntegerValue span_x = std::max(box_dim.x_max, other_dim.x_max) - + std::min(box_dim.x_min, other_dim.x_min) + 1; + const IntegerValue span_y = std::max(box_dim.y_max, other_dim.y_max) - + std::min(box_dim.y_min, other_dim.y_min) + 1; neighbors_.push_back({other_box, span_x * span_y}); } std::sort(neighbors_.begin(), neighbors_.end()); @@ -210,11 +200,7 @@ bool NonOverlappingRectanglesEnergyPropagator::FailWhenEnergyIsTooLarge( // Note that we only consider the smallest dimension of each boxes here. SortBoxesIntoNeighbors(box, local_boxes); - IntegerValue area_min_x = x_.StartMin(box); - IntegerValue area_max_x = x_.EndMax(box); - IntegerValue area_min_y = y_.StartMin(box); - IntegerValue area_max_y = y_.EndMax(box); - + Dimension area = cached_dimensions_[box]; IntegerValue sum_of_areas = cached_areas_[box]; IntegerValue total_sum_of_areas = sum_of_areas; @@ -223,12 +209,10 @@ bool NonOverlappingRectanglesEnergyPropagator::FailWhenEnergyIsTooLarge( } const auto add_box_energy_in_rectangle_reason = [&](int b) { - x_.AddStartMinReason(b, area_min_x); - x_.AddDurationMinReason(b, x_.DurationMin(b)); - x_.AddEndMaxReason(b, area_max_x); - y_.AddStartMinReason(b, area_min_y); - y_.AddDurationMinReason(b, y_.DurationMin(b)); - y_.AddEndMaxReason(b, area_max_y); + x_.AddEnergyAfterReason(b, x_.DurationMin(b), area.x_min); + x_.AddEndMaxReason(b, area.x_max); + y_.AddEnergyAfterReason(b, y_.DurationMin(b), area.y_min); + y_.AddEndMaxReason(b, area.y_max); }; for (int i = 0; i < neighbors_.size(); ++i) { @@ -236,15 +220,12 @@ bool NonOverlappingRectanglesEnergyPropagator::FailWhenEnergyIsTooLarge( CHECK_GT(cached_areas_[other_box], 0); // Update Bounding box. - area_min_x = std::min(area_min_x, x_.StartMin(other_box)); - area_max_x = std::max(area_max_x, x_.EndMax(other_box)); - area_min_y = std::min(area_min_y, y_.StartMin(other_box)); - area_max_y = std::max(area_max_y, y_.EndMax(other_box)); + area.TakeUnionWith(cached_dimensions_[other_box]); // Update sum of areas. sum_of_areas += cached_areas_[other_box]; const IntegerValue bounding_area = - (area_max_x - area_min_x) * (area_max_y - area_min_y); + (area.x_max - area.x_min) * (area.y_max - area.y_min); if (bounding_area >= total_sum_of_areas) { // Nothing will be deduced. Exiting. return true; @@ -267,13 +248,14 @@ bool NonOverlappingRectanglesEnergyPropagator::FailWhenEnergyIsTooLarge( // Note that x_ and y_ must be initialized with enough intervals when passed // to the disjunctive propagators. NonOverlappingRectanglesDisjunctivePropagator:: - NonOverlappingRectanglesDisjunctivePropagator( - const std::vector& x, - const std::vector& y, bool strict, Model* model) - : global_x_(x, model), - global_y_(y, model), - x_(x, model), - y_(y, model), + NonOverlappingRectanglesDisjunctivePropagator(bool strict, + SchedulingConstraintHelper* x, + SchedulingConstraintHelper* y, + Model* model) + : global_x_(*x), + global_y_(*y), + x_(x->Intervals(), model), + y_(y->Intervals(), model), strict_(strict), watcher_(model->GetOrCreate()), overload_checker_(true, &x_), @@ -287,17 +269,17 @@ NonOverlappingRectanglesDisjunctivePropagator:: NonOverlappingRectanglesDisjunctivePropagator:: ~NonOverlappingRectanglesDisjunctivePropagator() {} -void NonOverlappingRectanglesDisjunctivePropagator::RegisterWith( - GenericLiteralWatcher* watcher, int fast_priority, int slow_priority) { - fast_id_ = watcher->Register(this); - watcher->SetPropagatorPriority(fast_id_, fast_priority); - global_x_.WatchAllTasks(fast_id_, watcher); - global_y_.WatchAllTasks(fast_id_, watcher); +void NonOverlappingRectanglesDisjunctivePropagator::Register( + int fast_priority, int slow_priority) { + fast_id_ = watcher_->Register(this); + watcher_->SetPropagatorPriority(fast_id_, fast_priority); + global_x_.WatchAllTasks(fast_id_, watcher_); + global_y_.WatchAllTasks(fast_id_, watcher_); - const int slow_id = watcher->Register(this); - watcher->SetPropagatorPriority(slow_id, slow_priority); - global_x_.WatchAllTasks(slow_id, watcher); - global_y_.WatchAllTasks(slow_id, watcher); + const int slow_id = watcher_->Register(this); + watcher_->SetPropagatorPriority(slow_id, slow_priority); + global_x_.WatchAllTasks(slow_id, watcher_); + global_y_.WatchAllTasks(slow_id, watcher_); } bool NonOverlappingRectanglesDisjunctivePropagator:: @@ -425,6 +407,9 @@ bool NonOverlappingRectanglesDisjunctivePropagator:: } bool NonOverlappingRectanglesDisjunctivePropagator::Propagate() { + global_x_.SetTimeDirection(true); + global_y_.SetTimeDirection(true); + std::function inner_propagate; if (watcher_->GetCurrentId() == fast_id_) { inner_propagate = [this]() { diff --git a/ortools/sat/diffn.h b/ortools/sat/diffn.h index 4d4b2fefe6..3d703d4dba 100644 --- a/ortools/sat/diffn.h +++ b/ortools/sat/diffn.h @@ -35,9 +35,9 @@ class NonOverlappingRectanglesEnergyPropagator : public PropagatorInterface { // The strict parameters indicates how to place zero width or zero height // boxes. If strict is true, these boxes must not 'cross' another box, and are // pushed by the other boxes. - NonOverlappingRectanglesEnergyPropagator( - const std::vector& x, - const std::vector& y, Model* model); + NonOverlappingRectanglesEnergyPropagator(SchedulingConstraintHelper* x, + SchedulingConstraintHelper* y) + : x_(*x), y_(*y) {} ~NonOverlappingRectanglesEnergyPropagator() override; bool Propagate() final; @@ -47,12 +47,27 @@ class NonOverlappingRectanglesEnergyPropagator : public PropagatorInterface { void SortBoxesIntoNeighbors(int box, absl::Span local_boxes); bool FailWhenEnergyIsTooLarge(int box, absl::Span local_boxes); - SchedulingConstraintHelper x_; - SchedulingConstraintHelper y_; + SchedulingConstraintHelper& x_; + SchedulingConstraintHelper& y_; std::vector active_boxes_; std::vector cached_areas_; + struct Dimension { + IntegerValue x_min; + IntegerValue x_max; + IntegerValue y_min; + IntegerValue y_max; + + void TakeUnionWith(const Dimension& other) { + x_min = std::min(x_min, other.x_min); + y_min = std::min(y_min, other.y_min); + x_max = std::max(x_max, other.x_max); + y_max = std::max(y_max, other.y_max); + } + }; + std::vector cached_dimensions_; + struct Neighbor { int box; IntegerValue distance_to_bounding_box; @@ -76,14 +91,14 @@ class NonOverlappingRectanglesDisjunctivePropagator // boxes. If strict is true, these boxes must not 'cross' another box, and are // pushed by the other boxes. // The slow_propagators select which disjunctive algorithms to propagate. - NonOverlappingRectanglesDisjunctivePropagator( - const std::vector& x, - const std::vector& y, bool strict, Model* model); + NonOverlappingRectanglesDisjunctivePropagator(bool strict, + SchedulingConstraintHelper* x, + SchedulingConstraintHelper* y, + Model* model); ~NonOverlappingRectanglesDisjunctivePropagator() override; bool Propagate() final; - void RegisterWith(GenericLiteralWatcher* watcher, int fast_priority, - int slow_priority); + void Register(int fast_priority, int slow_priority); private: bool PropagateTwoBoxes(); @@ -91,8 +106,8 @@ class NonOverlappingRectanglesDisjunctivePropagator const SchedulingConstraintHelper& x, const SchedulingConstraintHelper& y, std::function inner_propagate); - SchedulingConstraintHelper global_x_; - SchedulingConstraintHelper global_y_; + SchedulingConstraintHelper& global_x_; + SchedulingConstraintHelper& global_y_; SchedulingConstraintHelper x_; SchedulingConstraintHelper y_; const bool strict_; @@ -127,9 +142,8 @@ class NonOverlappingRectanglesDisjunctivePropagator // Add a cumulative relaxation. That is, on one direction, it does not enforce // the rectangle aspect, allowing vertical slices to move freely. -void AddCumulativeRelaxation(const std::vector& x, - const std::vector& y, - Model* model); +void AddCumulativeRelaxation(SchedulingConstraintHelper* x, + SchedulingConstraintHelper* y, Model* model); // Enforces that the boxes with corners in (x, y), (x + dx, y), (x, y + dy) // and (x + dx, y + dy) do not overlap. @@ -139,22 +153,28 @@ inline std::function NonOverlappingRectangles( const std::vector& x, const std::vector& y, bool is_strict) { return [=](Model* model) { - GenericLiteralWatcher* const watcher = - model->GetOrCreate(); + SchedulingConstraintHelper* x_helper = + new SchedulingConstraintHelper(x, model); + SchedulingConstraintHelper* y_helper = + new SchedulingConstraintHelper(y, model); + model->TakeOwnership(x_helper); + model->TakeOwnership(y_helper); NonOverlappingRectanglesEnergyPropagator* energy_constraint = - new NonOverlappingRectanglesEnergyPropagator(x, y, model); + new NonOverlappingRectanglesEnergyPropagator(x_helper, y_helper); + GenericLiteralWatcher* const watcher = + model->GetOrCreate(); watcher->SetPropagatorPriority(energy_constraint->RegisterWith(watcher), 3); model->TakeOwnership(energy_constraint); NonOverlappingRectanglesDisjunctivePropagator* constraint = - new NonOverlappingRectanglesDisjunctivePropagator(x, y, is_strict, - model); - constraint->RegisterWith(watcher, /*fast_priority=*/3, /*slow_priority=*/4); + new NonOverlappingRectanglesDisjunctivePropagator(is_strict, x_helper, + y_helper, model); + constraint->Register(/*fast_priority=*/3, /*slow_priority=*/4); model->TakeOwnership(constraint); - AddCumulativeRelaxation(x, y, model); - AddCumulativeRelaxation(y, x, model); + AddCumulativeRelaxation(x_helper, y_helper, model); + AddCumulativeRelaxation(y_helper, x_helper, model); }; } diff --git a/ortools/sat/intervals.h b/ortools/sat/intervals.h index 9e06120d2b..9218710f43 100644 --- a/ortools/sat/intervals.h +++ b/ortools/sat/intervals.h @@ -236,6 +236,9 @@ class SchedulingConstraintHelper { // Returns the underlying integer variables. const std::vector& StartVars() const { return start_vars_; } const std::vector& EndVars() const { return end_vars_; } + const std::vector& DurationVars() const { + return duration_vars_; + } const std::vector& Intervals() const { return intervals_; } // Registers the given propagator id to be called if any of the tasks