special case on disjunctive constraint when all times are fixed (and not the performed statuses). Fix performance slowdown (15x) on mspsp from fz

This commit is contained in:
lperron@google.com
2014-07-07 13:29:56 +00:00
parent 1fea37737b
commit 1851661fcf
2 changed files with 181 additions and 73 deletions

View File

@@ -739,7 +739,8 @@ class PerformedVar : public BooleanVar {
class FixedDurationIntervalVar : public BaseIntervalVar {
public:
FixedDurationIntervalVar(Solver* const s, int64 start_min, int64 start_max,
int64 duration, bool optional, const std::string& name);
int64 duration, bool optional,
const std::string& name);
// Unperformed interval.
FixedDurationIntervalVar(Solver* const s, const std::string& name);
virtual ~FixedDurationIntervalVar() {}
@@ -1225,10 +1226,8 @@ class StartVarPerformedIntervalVar : public IntervalVar {
};
// TODO(user): Take care of overflows.
StartVarPerformedIntervalVar::StartVarPerformedIntervalVar(Solver* const s,
IntVar* const var,
int64 duration,
const std::string& name)
StartVarPerformedIntervalVar::StartVarPerformedIntervalVar(
Solver* const s, IntVar* const var, int64 duration, const std::string& name)
: IntervalVar(s, name), start_var_(var), duration_(duration) {}
int64 StartVarPerformedIntervalVar::StartMin() const {
@@ -1399,6 +1398,9 @@ class StartVarIntervalVar : public BaseIntervalVar {
virtual void Push() { LOG(FATAL) << "Should not be here"; }
int64 StoredMin() const { return start_min_.Value(); }
int64 StoredMax() const { return start_max_.Value(); }
private:
IntVar* const start_;
int64 duration_;
@@ -1429,7 +1431,7 @@ int64 StartVarIntervalVar::StartMax() const {
}
void StartVarIntervalVar::SetStartMin(int64 m) {
if (performed_->Min() == 1) {
if (performed_->Min() == 1) {
start_->SetMin(m);
} else {
start_min_.SetValue(solver(), std::max(m, start_min_.Value()));
@@ -1549,6 +1551,43 @@ std::string StartVarIntervalVar::DebugString() const {
}
}
class LinkStartVarIntervalVar : public Constraint {
public:
LinkStartVarIntervalVar(Solver* const solver,
StartVarIntervalVar* const interval,
IntVar* const start, IntVar* const performed)
: Constraint(solver),
interval_(interval),
start_(start),
performed_(performed) {}
~LinkStartVarIntervalVar() {}
virtual void Post() {
Demon* const demon = MakeConstraintDemon0(
solver(), this, &LinkStartVarIntervalVar::PerformedBound,
"PerformedBound");
performed_->WhenBound(demon);
}
virtual void InitialPropagate() {
if (performed_->Bound()) {
PerformedBound();
}
}
void PerformedBound() {
if (performed_->Min() == 1) {
start_->SetRange(interval_->StoredMin(), interval_->StoredMax());
}
}
private:
StartVarIntervalVar* const interval_;
IntVar* const start_;
IntVar* const performed_;
};
// ----- FixedInterval -----
class FixedInterval : public IntervalVar {
@@ -1690,7 +1729,7 @@ std::string FixedInterval::DebugString() const {
out = "IntervalVar(start = ";
}
StringAppendF(&out, "%" GG_LL_FORMAT "d, duration = %" GG_LL_FORMAT
"d, performed = true)",
"d, performed = true)",
start_, duration_);
return out;
}
@@ -2202,11 +2241,9 @@ IntervalVar* Solver::MakeFixedDurationIntervalVar(int64 start_min,
this, start_min, start_max, duration, optional, name)));
}
void Solver::MakeFixedDurationIntervalVarArray(int count, int64 start_min,
int64 start_max, int64 duration,
bool optional,
const std::string& name,
std::vector<IntervalVar*>* array) {
void Solver::MakeFixedDurationIntervalVarArray(
int count, int64 start_min, int64 start_max, int64 duration, bool optional,
const std::string& name, std::vector<IntervalVar*>* array) {
CHECK_GT(count, 0);
CHECK(array != nullptr);
array->clear();
@@ -2226,28 +2263,32 @@ IntervalVar* Solver::MakeFixedDurationIntervalVar(IntVar* const start_variable,
new StartVarPerformedIntervalVar(this, start_variable, duration, name)));
}
// Creates an interval var with a fixed duration, and performed var.
// The duration must be greater than 0.
IntervalVar* Solver::MakeFixedDurationIntervalVar(IntVar* const start_variable,
int64 duration,
IntVar* const performed_variable,
const std::string& name) {
// Creates an interval var with a fixed duration, and performed var.
// The duration must be greater than 0.
IntervalVar* Solver::MakeFixedDurationIntervalVar(
IntVar* const start_variable, int64 duration,
IntVar* const performed_variable, const std::string& name) {
CHECK(start_variable != nullptr);
CHECK(performed_variable != nullptr);
CHECK_GE(duration, 0);
if (!performed_variable->Bound()) {
return RegisterIntervalVar(RevAlloc(
new StartVarIntervalVar(this, start_variable, duration, performed_variable, name)));
StartVarIntervalVar* const interval =
reinterpret_cast<StartVarIntervalVar*>(
RegisterIntervalVar(RevAlloc(new StartVarIntervalVar(
this, start_variable, duration, performed_variable, name))));
AddConstraint(RevAlloc(new LinkStartVarIntervalVar(
this, interval, start_variable, performed_variable)));
return interval;
} else if (performed_variable->Min() == 1) {
return RegisterIntervalVar(RevAlloc(
new StartVarPerformedIntervalVar(this, start_variable, duration, name)));
return RegisterIntervalVar(RevAlloc(new StartVarPerformedIntervalVar(
this, start_variable, duration, name)));
}
return nullptr;
}
void Solver::MakeFixedDurationIntervalVarArray(
const std::vector<IntVar*>& start_variables, int64 duration, const std::string& name,
std::vector<IntervalVar*>* array) {
const std::vector<IntVar*>& start_variables, int64 duration,
const std::string& name, std::vector<IntervalVar*>* array) {
CHECK(array != nullptr);
array->clear();
for (int i = 0; i < start_variables.size(); ++i) {
@@ -2258,8 +2299,9 @@ void Solver::MakeFixedDurationIntervalVarArray(
}
void Solver::MakeFixedDurationIntervalVarArray(
const std::vector<IntVar*>& start_variables, const std::vector<int64>& durations,
const std::string& name, std::vector<IntervalVar*>* array) {
const std::vector<IntVar*>& start_variables,
const std::vector<int64>& durations, const std::string& name,
std::vector<IntervalVar*>* array) {
CHECK(array != nullptr);
CHECK_EQ(start_variables.size(), durations.size());
array->clear();
@@ -2271,8 +2313,9 @@ void Solver::MakeFixedDurationIntervalVarArray(
}
void Solver::MakeFixedDurationIntervalVarArray(
const std::vector<IntVar*>& start_variables, const std::vector<int>& durations,
const std::string& name, std::vector<IntervalVar*>* array) {
const std::vector<IntVar*>& start_variables,
const std::vector<int>& durations, const std::string& name,
std::vector<IntervalVar*>* array) {
CHECK(array != nullptr);
CHECK_EQ(start_variables.size(), durations.size());
array->clear();
@@ -2292,9 +2335,8 @@ void Solver::MakeFixedDurationIntervalVarArray(
array->clear();
for (int i = 0; i < start_variables.size(); ++i) {
const std::string var_name = StringPrintf("%s%i", name.c_str(), i);
array->push_back(
MakeFixedDurationIntervalVar(start_variables[i], durations[i],
performed_variables[i], var_name));
array->push_back(MakeFixedDurationIntervalVar(
start_variables[i], durations[i], performed_variables[i], var_name));
}
}
@@ -2307,9 +2349,8 @@ void Solver::MakeFixedDurationIntervalVarArray(
array->clear();
for (int i = 0; i < start_variables.size(); ++i) {
const std::string var_name = StringPrintf("%s%i", name.c_str(), i);
array->push_back(
MakeFixedDurationIntervalVar(start_variables[i], durations[i],
performed_variables[i], var_name));
array->push_back(MakeFixedDurationIntervalVar(
start_variables[i], durations[i], performed_variables[i], var_name));
}
}
@@ -2334,9 +2375,9 @@ void Solver::MakeIntervalVarArray(int count, int64 start_min, int64 start_max,
array->clear();
for (int i = 0; i < count; ++i) {
const std::string var_name = StringPrintf("%s%i", name.c_str(), i);
array->push_back(MakeIntervalVar(start_min, start_max, duration_min,
duration_max, end_min, end_max, optional,
var_name));
array->push_back(
MakeIntervalVar(start_min, start_max, duration_min, duration_max,
end_min, end_max, optional, var_name));
}
}

View File

@@ -63,26 +63,26 @@ namespace {
// TODO(user): Tie breaking.
// Comparison methods, used by the STL sort.
template <class Task>
bool StartMinLessThan(Task* const w1, Task* const w2) {
template <class Task> bool StartMinLessThan(Task* const w1, Task* const w2) {
return (w1->interval->StartMin() < w2->interval->StartMin());
}
template <class Task>
bool StartMaxLessThan(Task* const w1, Task* const w2) {
template <class Task> bool StartMaxLessThan(Task* const w1, Task* const w2) {
return (w1->interval->StartMax() < w2->interval->StartMax());
}
template <class Task>
bool EndMinLessThan(Task* const w1, Task* const w2) {
template <class Task> bool EndMinLessThan(Task* const w1, Task* const w2) {
return (w1->interval->EndMin() < w2->interval->EndMin());
}
template <class Task>
bool EndMaxLessThan(Task* const w1, Task* const w2) {
template <class Task> bool EndMaxLessThan(Task* const w1, Task* const w2) {
return (w1->interval->EndMax() < w2->interval->EndMax());
}
bool IntervalStartMinLessThan(IntervalVar* i1, IntervalVar* i2) {
return i1->StartMin() < i2->StartMin();
}
// ----- Wrappers around intervals -----
// A DisjunctiveTask is a non-preemptive task sharing a disjunctive resource.
@@ -112,9 +112,7 @@ struct CumulativeTask {
int64 DemandMin() const { return demand; }
void WhenAnything(Demon* const demon) {
interval->WhenAnything(demon);
}
void WhenAnything(Demon* const demon) { interval->WhenAnything(demon); }
std::string DebugString() const {
return StringPrintf("Task{ %s, demand: %" GG_LL_FORMAT "d }",
@@ -182,9 +180,9 @@ struct ThetaNode {
}
std::string DebugString() const {
return StringPrintf("ThetaNode{ p = %" GG_LL_FORMAT "d, e = %" GG_LL_FORMAT
"d }",
total_processing, total_ect < 0LL ? -1LL : total_ect);
return StringPrintf(
"ThetaNode{ p = %" GG_LL_FORMAT "d, e = %" GG_LL_FORMAT "d }",
total_processing, total_ect < 0LL ? -1LL : total_ect);
}
int64 total_processing;
@@ -807,9 +805,9 @@ class RankedPropagator : public Constraint {
first_sentinel > 0 ? RankedInterval(first_sentinel - 1) : nullptr;
IntVar* const first_slack =
first_sentinel > 0 ? RankedSlack(first_sentinel - 1) : nullptr;
IntervalVar* const last_interval = last_sentinel < last_position
? RankedInterval(last_sentinel + 1)
: nullptr;
IntervalVar* const last_interval =
last_sentinel < last_position ? RankedInterval(last_sentinel + 1)
: nullptr;
// Nothing to do afterwards, exiting.
if (first_interval == nullptr && last_interval == nullptr) {
@@ -942,16 +940,84 @@ class FullDisjunctiveConstraint : public DisjunctiveConstraint {
}
virtual void InitialPropagate() {
do {
bool all_optional_or_unperformed = true;
for (const IntervalVar* const interval : intervals_) {
if (interval->MustBePerformed()) {
all_optional_or_unperformed = false;
break;
}
}
if (all_optional_or_unperformed) { // Nothing to deduce
return;
}
bool all_times_fixed = true;
for (const IntervalVar* const interval : intervals_) {
if (interval->MayBePerformed() &&
(interval->StartMin() != interval->StartMax() ||
interval->DurationMin() != interval->DurationMax() ||
interval->EndMin() != interval->EndMax())) {
all_times_fixed = false;
break;
}
}
if (all_times_fixed) {
PropagatePerformed();
} else {
do {
do {
// OverloadChecking is symmetrical. It has the same effect on the
// straight and the mirrored version.
straight_.OverloadChecking();
} while (straight_.DetectablePrecedences() ||
mirror_.DetectablePrecedences());
} while (straight_not_last_.Propagate() || mirror_not_last_.Propagate());
} while (straight_.EdgeFinder() || mirror_.EdgeFinder());
do {
// OverloadChecking is symmetrical. It has the same effect on the
// straight and the mirrored version.
straight_.OverloadChecking();
} while (straight_.DetectablePrecedences() ||
mirror_.DetectablePrecedences());
} while (straight_not_last_.Propagate() || mirror_not_last_.Propagate());
} while (straight_.EdgeFinder() || mirror_.EdgeFinder());
}
}
bool Intersect(IntervalVar* const i1, IntervalVar* const i2) const {
return i1->StartMin() < i2->EndMax() && i2->StartMin() < i1->EndMax();
}
void PropagatePerformed() {
performed_.clear();
optional_.clear();
for (IntervalVar* const interval : intervals_) {
if (interval->MustBePerformed()) {
performed_.push_back(interval);
} else if (interval->MayBePerformed()) {
optional_.push_back(interval);
}
}
// Checks feasibility of performed;
if (performed_.empty()) return;
std::sort(performed_.begin(), performed_.end(), IntervalStartMinLessThan);
for (int i = 0; i < performed_.size() - 1; ++i) {
if (performed_[i]->EndMax() > performed_[i + 1]->StartMin()) {
solver()->Fail();
}
}
// Checks if optional intervals can be inserted.
if (optional_.empty()) return;
int index = 0;
const int num_performed = performed_.size();
std::sort(optional_.begin(), optional_.end(), IntervalStartMinLessThan);
for (IntervalVar* const candidate : optional_) {
const int64 start = candidate->StartMin();
while (index < num_performed && start >= performed_[index]->EndMax()) {
index++;
}
if (index == num_performed) return;
if (Intersect(candidate, performed_[index]) ||
(index < num_performed - 1 &&
Intersect(candidate, performed_[index + 1]))) {
candidate->SetPerformed(false);
}
}
}
void Accept(ModelVisitor* const visitor) const {
@@ -1082,6 +1148,8 @@ class FullDisjunctiveConstraint : public DisjunctiveConstraint {
std::vector<IntVar*> actives_;
std::vector<IntVar*> time_cumuls_;
std::vector<IntVar*> time_slacks_;
std::vector<IntervalVar*> performed_;
std::vector<IntervalVar*> optional_;
DISALLOW_COPY_AND_ASSIGN(FullDisjunctiveConstraint);
};
@@ -1106,16 +1174,16 @@ struct DualCapacityThetaNode {
const CumulativeTask& task)
: energy(task.EnergyMin()),
energetic_end_min(capacity * task.interval->StartMin() + energy),
residual_energetic_end_min(
residual_capacity * task.interval->StartMin() + energy) {}
residual_energetic_end_min(residual_capacity *
task.interval->StartMin() + energy) {}
// Constructor for a single cumulative task in the Theta set
DualCapacityThetaNode(int64 capacity, int64 residual_capacity,
const VariableCumulativeTask& task)
: energy(task.EnergyMin()),
energetic_end_min(capacity * task.interval->StartMin() + energy),
residual_energetic_end_min(
residual_capacity * task.interval->StartMin() + energy) {}
residual_energetic_end_min(residual_capacity *
task.interval->StartMin() + energy) {}
// Sets this DualCapacityThetaNode to the result of the natural binary
// operation over the two given operands, corresponding to the following set
@@ -1329,8 +1397,7 @@ int64 SafeProduct(int64 a, int64 b) {
}
// One-sided cumulative edge finder.
template <class Task>
class EdgeFinder : public Constraint {
template <class Task> class EdgeFinder : public Constraint {
public:
EdgeFinder(Solver* const solver, const std::vector<Task*>& tasks,
int64 capacity)
@@ -1623,8 +1690,7 @@ bool TimeLessThan(const ProfileDelta& delta1, const ProfileDelta& delta2) {
//
// The implementation is quite naive, and could certainly be improved, for
// example by maintaining the profile incrementally.
template <class Task>
class CumulativeTimeTable : public Constraint {
template <class Task> class CumulativeTimeTable : public Constraint {
public:
CumulativeTimeTable(Solver* const solver, const std::vector<Task*>& tasks,
int64 capacity)
@@ -1854,7 +1920,9 @@ class CumulativeConstraint : public Constraint {
virtual std::string DebugString() const {
return StringPrintf("CumulativeConstraint([%s], %" GG_LL_FORMAT "d)",
JoinDebugString(tasks_, ", ").c_str(), capacity_);
};
}
;
private:
// Post temporal disjunctions for tasks that cannot overlap.
void PostAllDisjunctions() {
@@ -2029,8 +2097,7 @@ class VariableDemandCumulativeConstraint : public Constraint {
virtual std::string DebugString() const {
return StringPrintf(
"VariableDemandCumulativeConstraint([%s], %" GG_LL_FORMAT "d)",
JoinDebugString(tasks_, ", ").c_str(),
capacity_);
JoinDebugString(tasks_, ", ").c_str(), capacity_);
}
private: