diff --git a/src/constraint_solver/interval.cc b/src/constraint_solver/interval.cc index 328b24bde0..e625f89d24 100644 --- a/src/constraint_solver/interval.cc +++ b/src/constraint_solver/interval.cc @@ -739,7 +739,8 @@ class PerformedVar : public BooleanVar { class FixedDurationIntervalVar : public BaseIntervalVar { public: FixedDurationIntervalVar(Solver* const s, int64 start_min, int64 start_max, - int64 duration, bool optional, const std::string& name); + int64 duration, bool optional, + const std::string& name); // Unperformed interval. FixedDurationIntervalVar(Solver* const s, const std::string& name); virtual ~FixedDurationIntervalVar() {} @@ -1225,10 +1226,8 @@ class StartVarPerformedIntervalVar : public IntervalVar { }; // TODO(user): Take care of overflows. -StartVarPerformedIntervalVar::StartVarPerformedIntervalVar(Solver* const s, - IntVar* const var, - int64 duration, - const std::string& name) +StartVarPerformedIntervalVar::StartVarPerformedIntervalVar( + Solver* const s, IntVar* const var, int64 duration, const std::string& name) : IntervalVar(s, name), start_var_(var), duration_(duration) {} int64 StartVarPerformedIntervalVar::StartMin() const { @@ -1399,6 +1398,9 @@ class StartVarIntervalVar : public BaseIntervalVar { virtual void Push() { LOG(FATAL) << "Should not be here"; } + int64 StoredMin() const { return start_min_.Value(); } + int64 StoredMax() const { return start_max_.Value(); } + private: IntVar* const start_; int64 duration_; @@ -1429,7 +1431,7 @@ int64 StartVarIntervalVar::StartMax() const { } void StartVarIntervalVar::SetStartMin(int64 m) { -if (performed_->Min() == 1) { + if (performed_->Min() == 1) { start_->SetMin(m); } else { start_min_.SetValue(solver(), std::max(m, start_min_.Value())); @@ -1549,6 +1551,43 @@ std::string StartVarIntervalVar::DebugString() const { } } +class LinkStartVarIntervalVar : public Constraint { + public: + LinkStartVarIntervalVar(Solver* const solver, + StartVarIntervalVar* const interval, + IntVar* const start, IntVar* const performed) + : Constraint(solver), + interval_(interval), + start_(start), + performed_(performed) {} + + ~LinkStartVarIntervalVar() {} + + virtual void Post() { + Demon* const demon = MakeConstraintDemon0( + solver(), this, &LinkStartVarIntervalVar::PerformedBound, + "PerformedBound"); + performed_->WhenBound(demon); + } + + virtual void InitialPropagate() { + if (performed_->Bound()) { + PerformedBound(); + } + } + + void PerformedBound() { + if (performed_->Min() == 1) { + start_->SetRange(interval_->StoredMin(), interval_->StoredMax()); + } + } + + private: + StartVarIntervalVar* const interval_; + IntVar* const start_; + IntVar* const performed_; +}; + // ----- FixedInterval ----- class FixedInterval : public IntervalVar { @@ -1690,7 +1729,7 @@ std::string FixedInterval::DebugString() const { out = "IntervalVar(start = "; } StringAppendF(&out, "%" GG_LL_FORMAT "d, duration = %" GG_LL_FORMAT - "d, performed = true)", + "d, performed = true)", start_, duration_); return out; } @@ -2202,11 +2241,9 @@ IntervalVar* Solver::MakeFixedDurationIntervalVar(int64 start_min, this, start_min, start_max, duration, optional, name))); } -void Solver::MakeFixedDurationIntervalVarArray(int count, int64 start_min, - int64 start_max, int64 duration, - bool optional, - const std::string& name, - std::vector* array) { +void Solver::MakeFixedDurationIntervalVarArray( + int count, int64 start_min, int64 start_max, int64 duration, bool optional, + const std::string& name, std::vector* array) { CHECK_GT(count, 0); CHECK(array != nullptr); array->clear(); @@ -2226,28 +2263,32 @@ IntervalVar* Solver::MakeFixedDurationIntervalVar(IntVar* const start_variable, new StartVarPerformedIntervalVar(this, start_variable, duration, name))); } - // Creates an interval var with a fixed duration, and performed var. - // The duration must be greater than 0. -IntervalVar* Solver::MakeFixedDurationIntervalVar(IntVar* const start_variable, - int64 duration, - IntVar* const performed_variable, - const std::string& name) { +// Creates an interval var with a fixed duration, and performed var. +// The duration must be greater than 0. +IntervalVar* Solver::MakeFixedDurationIntervalVar( + IntVar* const start_variable, int64 duration, + IntVar* const performed_variable, const std::string& name) { CHECK(start_variable != nullptr); CHECK(performed_variable != nullptr); CHECK_GE(duration, 0); if (!performed_variable->Bound()) { - return RegisterIntervalVar(RevAlloc( - new StartVarIntervalVar(this, start_variable, duration, performed_variable, name))); + StartVarIntervalVar* const interval = + reinterpret_cast( + RegisterIntervalVar(RevAlloc(new StartVarIntervalVar( + this, start_variable, duration, performed_variable, name)))); + AddConstraint(RevAlloc(new LinkStartVarIntervalVar( + this, interval, start_variable, performed_variable))); + return interval; } else if (performed_variable->Min() == 1) { - return RegisterIntervalVar(RevAlloc( - new StartVarPerformedIntervalVar(this, start_variable, duration, name))); + return RegisterIntervalVar(RevAlloc(new StartVarPerformedIntervalVar( + this, start_variable, duration, name))); } return nullptr; } void Solver::MakeFixedDurationIntervalVarArray( - const std::vector& start_variables, int64 duration, const std::string& name, - std::vector* array) { + const std::vector& start_variables, int64 duration, + const std::string& name, std::vector* array) { CHECK(array != nullptr); array->clear(); for (int i = 0; i < start_variables.size(); ++i) { @@ -2258,8 +2299,9 @@ void Solver::MakeFixedDurationIntervalVarArray( } void Solver::MakeFixedDurationIntervalVarArray( - const std::vector& start_variables, const std::vector& durations, - const std::string& name, std::vector* array) { + const std::vector& start_variables, + const std::vector& durations, const std::string& name, + std::vector* array) { CHECK(array != nullptr); CHECK_EQ(start_variables.size(), durations.size()); array->clear(); @@ -2271,8 +2313,9 @@ void Solver::MakeFixedDurationIntervalVarArray( } void Solver::MakeFixedDurationIntervalVarArray( - const std::vector& start_variables, const std::vector& durations, - const std::string& name, std::vector* array) { + const std::vector& start_variables, + const std::vector& durations, const std::string& name, + std::vector* array) { CHECK(array != nullptr); CHECK_EQ(start_variables.size(), durations.size()); array->clear(); @@ -2292,9 +2335,8 @@ void Solver::MakeFixedDurationIntervalVarArray( array->clear(); for (int i = 0; i < start_variables.size(); ++i) { const std::string var_name = StringPrintf("%s%i", name.c_str(), i); - array->push_back( - MakeFixedDurationIntervalVar(start_variables[i], durations[i], - performed_variables[i], var_name)); + array->push_back(MakeFixedDurationIntervalVar( + start_variables[i], durations[i], performed_variables[i], var_name)); } } @@ -2307,9 +2349,8 @@ void Solver::MakeFixedDurationIntervalVarArray( array->clear(); for (int i = 0; i < start_variables.size(); ++i) { const std::string var_name = StringPrintf("%s%i", name.c_str(), i); - array->push_back( - MakeFixedDurationIntervalVar(start_variables[i], durations[i], - performed_variables[i], var_name)); + array->push_back(MakeFixedDurationIntervalVar( + start_variables[i], durations[i], performed_variables[i], var_name)); } } @@ -2334,9 +2375,9 @@ void Solver::MakeIntervalVarArray(int count, int64 start_min, int64 start_max, array->clear(); for (int i = 0; i < count; ++i) { const std::string var_name = StringPrintf("%s%i", name.c_str(), i); - array->push_back(MakeIntervalVar(start_min, start_max, duration_min, - duration_max, end_min, end_max, optional, - var_name)); + array->push_back( + MakeIntervalVar(start_min, start_max, duration_min, duration_max, + end_min, end_max, optional, var_name)); } } diff --git a/src/constraint_solver/resource.cc b/src/constraint_solver/resource.cc index db78d749b8..212e1a1113 100644 --- a/src/constraint_solver/resource.cc +++ b/src/constraint_solver/resource.cc @@ -63,26 +63,26 @@ namespace { // TODO(user): Tie breaking. // Comparison methods, used by the STL sort. -template -bool StartMinLessThan(Task* const w1, Task* const w2) { +template bool StartMinLessThan(Task* const w1, Task* const w2) { return (w1->interval->StartMin() < w2->interval->StartMin()); } -template -bool StartMaxLessThan(Task* const w1, Task* const w2) { +template bool StartMaxLessThan(Task* const w1, Task* const w2) { return (w1->interval->StartMax() < w2->interval->StartMax()); } -template -bool EndMinLessThan(Task* const w1, Task* const w2) { +template bool EndMinLessThan(Task* const w1, Task* const w2) { return (w1->interval->EndMin() < w2->interval->EndMin()); } -template -bool EndMaxLessThan(Task* const w1, Task* const w2) { +template bool EndMaxLessThan(Task* const w1, Task* const w2) { return (w1->interval->EndMax() < w2->interval->EndMax()); } +bool IntervalStartMinLessThan(IntervalVar* i1, IntervalVar* i2) { + return i1->StartMin() < i2->StartMin(); +} + // ----- Wrappers around intervals ----- // A DisjunctiveTask is a non-preemptive task sharing a disjunctive resource. @@ -112,9 +112,7 @@ struct CumulativeTask { int64 DemandMin() const { return demand; } - void WhenAnything(Demon* const demon) { - interval->WhenAnything(demon); - } + void WhenAnything(Demon* const demon) { interval->WhenAnything(demon); } std::string DebugString() const { return StringPrintf("Task{ %s, demand: %" GG_LL_FORMAT "d }", @@ -182,9 +180,9 @@ struct ThetaNode { } std::string DebugString() const { - return StringPrintf("ThetaNode{ p = %" GG_LL_FORMAT "d, e = %" GG_LL_FORMAT - "d }", - total_processing, total_ect < 0LL ? -1LL : total_ect); + return StringPrintf( + "ThetaNode{ p = %" GG_LL_FORMAT "d, e = %" GG_LL_FORMAT "d }", + total_processing, total_ect < 0LL ? -1LL : total_ect); } int64 total_processing; @@ -807,9 +805,9 @@ class RankedPropagator : public Constraint { first_sentinel > 0 ? RankedInterval(first_sentinel - 1) : nullptr; IntVar* const first_slack = first_sentinel > 0 ? RankedSlack(first_sentinel - 1) : nullptr; - IntervalVar* const last_interval = last_sentinel < last_position - ? RankedInterval(last_sentinel + 1) - : nullptr; + IntervalVar* const last_interval = + last_sentinel < last_position ? RankedInterval(last_sentinel + 1) + : nullptr; // Nothing to do afterwards, exiting. if (first_interval == nullptr && last_interval == nullptr) { @@ -942,16 +940,84 @@ class FullDisjunctiveConstraint : public DisjunctiveConstraint { } virtual void InitialPropagate() { - do { + bool all_optional_or_unperformed = true; + for (const IntervalVar* const interval : intervals_) { + if (interval->MustBePerformed()) { + all_optional_or_unperformed = false; + break; + } + } + if (all_optional_or_unperformed) { // Nothing to deduce + return; + } + + bool all_times_fixed = true; + for (const IntervalVar* const interval : intervals_) { + if (interval->MayBePerformed() && + (interval->StartMin() != interval->StartMax() || + interval->DurationMin() != interval->DurationMax() || + interval->EndMin() != interval->EndMax())) { + all_times_fixed = false; + break; + } + } + + if (all_times_fixed) { + PropagatePerformed(); + } else { do { do { - // OverloadChecking is symmetrical. It has the same effect on the - // straight and the mirrored version. - straight_.OverloadChecking(); - } while (straight_.DetectablePrecedences() || - mirror_.DetectablePrecedences()); - } while (straight_not_last_.Propagate() || mirror_not_last_.Propagate()); - } while (straight_.EdgeFinder() || mirror_.EdgeFinder()); + do { + // OverloadChecking is symmetrical. It has the same effect on the + // straight and the mirrored version. + straight_.OverloadChecking(); + } while (straight_.DetectablePrecedences() || + mirror_.DetectablePrecedences()); + } while (straight_not_last_.Propagate() || mirror_not_last_.Propagate()); + } while (straight_.EdgeFinder() || mirror_.EdgeFinder()); + } + } + + bool Intersect(IntervalVar* const i1, IntervalVar* const i2) const { + return i1->StartMin() < i2->EndMax() && i2->StartMin() < i1->EndMax(); + } + + void PropagatePerformed() { + performed_.clear(); + optional_.clear(); + for (IntervalVar* const interval : intervals_) { + if (interval->MustBePerformed()) { + performed_.push_back(interval); + } else if (interval->MayBePerformed()) { + optional_.push_back(interval); + } + } + // Checks feasibility of performed; + if (performed_.empty()) return; + std::sort(performed_.begin(), performed_.end(), IntervalStartMinLessThan); + for (int i = 0; i < performed_.size() - 1; ++i) { + if (performed_[i]->EndMax() > performed_[i + 1]->StartMin()) { + solver()->Fail(); + } + } + + // Checks if optional intervals can be inserted. + if (optional_.empty()) return; + int index = 0; + const int num_performed = performed_.size(); + std::sort(optional_.begin(), optional_.end(), IntervalStartMinLessThan); + for (IntervalVar* const candidate : optional_) { + const int64 start = candidate->StartMin(); + while (index < num_performed && start >= performed_[index]->EndMax()) { + index++; + } + if (index == num_performed) return; + if (Intersect(candidate, performed_[index]) || + (index < num_performed - 1 && + Intersect(candidate, performed_[index + 1]))) { + candidate->SetPerformed(false); + } + } } void Accept(ModelVisitor* const visitor) const { @@ -1082,6 +1148,8 @@ class FullDisjunctiveConstraint : public DisjunctiveConstraint { std::vector actives_; std::vector time_cumuls_; std::vector time_slacks_; + std::vector performed_; + std::vector optional_; DISALLOW_COPY_AND_ASSIGN(FullDisjunctiveConstraint); }; @@ -1106,16 +1174,16 @@ struct DualCapacityThetaNode { const CumulativeTask& task) : energy(task.EnergyMin()), energetic_end_min(capacity * task.interval->StartMin() + energy), - residual_energetic_end_min( - residual_capacity * task.interval->StartMin() + energy) {} + residual_energetic_end_min(residual_capacity * + task.interval->StartMin() + energy) {} // Constructor for a single cumulative task in the Theta set DualCapacityThetaNode(int64 capacity, int64 residual_capacity, const VariableCumulativeTask& task) : energy(task.EnergyMin()), energetic_end_min(capacity * task.interval->StartMin() + energy), - residual_energetic_end_min( - residual_capacity * task.interval->StartMin() + energy) {} + residual_energetic_end_min(residual_capacity * + task.interval->StartMin() + energy) {} // Sets this DualCapacityThetaNode to the result of the natural binary // operation over the two given operands, corresponding to the following set @@ -1329,8 +1397,7 @@ int64 SafeProduct(int64 a, int64 b) { } // One-sided cumulative edge finder. -template -class EdgeFinder : public Constraint { +template class EdgeFinder : public Constraint { public: EdgeFinder(Solver* const solver, const std::vector& tasks, int64 capacity) @@ -1623,8 +1690,7 @@ bool TimeLessThan(const ProfileDelta& delta1, const ProfileDelta& delta2) { // // The implementation is quite naive, and could certainly be improved, for // example by maintaining the profile incrementally. -template -class CumulativeTimeTable : public Constraint { +template class CumulativeTimeTable : public Constraint { public: CumulativeTimeTable(Solver* const solver, const std::vector& tasks, int64 capacity) @@ -1854,7 +1920,9 @@ class CumulativeConstraint : public Constraint { virtual std::string DebugString() const { return StringPrintf("CumulativeConstraint([%s], %" GG_LL_FORMAT "d)", JoinDebugString(tasks_, ", ").c_str(), capacity_); - }; + } + ; + private: // Post temporal disjunctions for tasks that cannot overlap. void PostAllDisjunctions() { @@ -2029,8 +2097,7 @@ class VariableDemandCumulativeConstraint : public Constraint { virtual std::string DebugString() const { return StringPrintf( "VariableDemandCumulativeConstraint([%s], %" GG_LL_FORMAT "d)", - JoinDebugString(tasks_, ", ").c_str(), - capacity_); + JoinDebugString(tasks_, ", ").c_str(), capacity_); } private: