faster implementation of array_bool_or and array_bool_and

This commit is contained in:
lperron@google.com
2012-06-30 15:28:09 +00:00
parent 6c55707545
commit 8aa32b5159

View File

@@ -969,6 +969,144 @@ class ArrayBoolAndEq : public CastConstraint {
RevSwitch decided_;
};
class ArrayBoolOrEq : public CastConstraint {
public:
ArrayBoolOrEq(Solver* const s,
const std::vector<IntVar*>& vars,
IntVar* const target)
: CastConstraint(s, target),
vars_(vars),
demons_(vars.size()),
unbounded_(0) {}
virtual ~ArrayBoolOrEq() {}
virtual void Post() {
for (int i = 0; i < vars_.size(); ++i) {
if (!vars_[i]->Bound()) {
demons_[i] = MakeConstraintDemon1(solver(),
this,
&ArrayBoolOrEq::PropagateVar,
"PropagateVar",
i);
vars_[i]->WhenBound(demons_[i]);
}
}
Demon* const target_demon =
MakeConstraintDemon0(solver(),
this,
&ArrayBoolOrEq::PropagateTarget,
"PropagateTarget");
target_var_->WhenBound(target_demon);
}
virtual void InitialPropagate() {
target_var_->SetRange(0, 1);
if (target_var_->Min() == 0) {
for (int i = 0; i < vars_.size(); ++i) {
vars_[i]->SetMax(0);
}
} else {
int zeros = 0;
int ones = 0;
int unbounded = 0;
for (int i = 0; i < vars_.size(); ++i) {
unbounded += !vars_[i]->Bound();
zeros += vars_[i]->Max() == 0;
ones += vars_[i]->Min() == 1;
}
if (ones > 0) {
InhibitAll();
target_var_->SetMin(1);
} else if (unbounded == 0) {
target_var_->SetMax(0);
} else if (target_var_->Min() == 1 && unbounded == 1) {
const int index = FindPossibleOne();
CHECK(index != -1);
vars_[index]->SetMin(1);
} else {
unbounded_.SetValue(solver(), unbounded);
}
}
}
void PropagateVar(int index) {
if (vars_[index]->Min() == 0) {
unbounded_.Decr(solver());
if (target_var_->Min() == 1 &&
unbounded_.Value() == 1 &&
!decided_.Switched()) {
const int to_set = FindPossibleOne();
if (to_set != -1) {
vars_[to_set]->SetMin(1);
decided_.Switch(solver());
} else {
solver()->Fail();
}
}
} else {
InhibitAll();
target_var_->SetMin(1);
}
}
void PropagateTarget() {
if (target_var_->Min() == 0) {
for (int i = 0; i < vars_.size(); ++i) {
vars_[i]->SetMax(0);
}
} else {
if (unbounded_.Value() == 1 && !decided_.Switched()) {
const int to_set = FindPossibleOne();
if (to_set != -1) {
vars_[to_set]->SetMin(1);
decided_.Switch(solver());
} else {
solver()->Fail();
}
}
}
}
void InhibitAll() {
for (int i = 0; i < demons_.size(); ++i) {
if (demons_[i] != NULL) {
demons_[i]->inhibit(solver());
}
}
}
int FindPossibleOne() {
for (int i = 0; i < vars_.size(); ++i) {
if (vars_[i]->Max() == 1) {
return i;
}
}
return -1;
}
string DebugString() const {
return StringPrintf("Or(%s) == %s",
DebugStringVector(vars_, ", ").c_str(),
target_var_->DebugString().c_str());
}
void Accept(ModelVisitor* const visitor) const {
visitor->BeginVisitConstraint(ModelVisitor::kMaxEqual, this);
visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
vars_.data(), vars_.size());
visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
target_var_);
visitor->EndVisitConstraint(ModelVisitor::kMaxEqual, this);
}
private:
std::vector<IntVar*> vars_;
std::vector<Demon*> demons_;
NumericalRev<int> unbounded_;
RevSwitch decided_;
};
// ---------- Specialized cases ----------
@@ -2240,17 +2378,25 @@ IntExpr* Solver::MakeMax(const std::vector<IntVar*>& vars) {
if (cache != NULL) {
return cache->Var();
} else {
int64 new_min = kint64min;
int64 new_max = kint64min;
for (int i = 0; i < size; ++i) {
new_min = std::max(new_min, vars[i]->Min());
new_max = std::max(new_max, vars[i]->Max());
if (AreAllBooleans(vars.data(), vars.size())) {
IntVar* const new_var = MakeBoolVar();
AddConstraint(RevAlloc(new ArrayBoolOrEq(this, vars, new_var)));
model_cache_->InsertVarArrayExpression(
new_var, vars, ModelCache::VAR_ARRAY_MIN);
return new_var;
} else {
int64 new_min = kint64min;
int64 new_max = kint64min;
for (int i = 0; i < size; ++i) {
new_min = std::max(new_min, vars[i]->Min());
new_max = std::max(new_max, vars[i]->Max());
}
IntVar* const new_var = MakeIntVar(new_min, new_max);
AddConstraint(RevAlloc(new MaxConstraint(this, vars, new_var)));
model_cache_->InsertVarArrayExpression(
new_var, vars, ModelCache::VAR_ARRAY_MAX);
return new_var;
}
IntVar* const new_var = MakeIntVar(new_min, new_max);
AddConstraint(RevAlloc(new MaxConstraint(this, vars, new_var)));
model_cache_->InsertVarArrayExpression(
new_var, vars, ModelCache::VAR_ARRAY_MAX);
return new_var;
}
}
}
@@ -2266,12 +2412,8 @@ Constraint* Solver::MakeMinEquality(const std::vector<IntVar*>& vars,
Constraint* Solver::MakeMaxEquality(const std::vector<IntVar*>& vars,
IntVar* const max_var) {
if (max_var->Bound() &&
max_var->Min() == 1 &&
AreAllBooleans(vars.data(), vars.size()) &&
vars.size() > 2) {
return RevAlloc(
new SumBooleanGreaterOrEqualToOne(this, vars.data(), vars.size()));
if (AreAllBooleans(vars.data(), vars.size())) {
return RevAlloc(new ArrayBoolOrEq(this, vars, max_var));
} else {
return RevAlloc(new MaxConstraint(this, vars, max_var));
}