diff --git a/src/constraint_solver/expr_array.cc b/src/constraint_solver/expr_array.cc index cfe6bd4cf2..dc7cb94c16 100644 --- a/src/constraint_solver/expr_array.cc +++ b/src/constraint_solver/expr_array.cc @@ -969,6 +969,144 @@ class ArrayBoolAndEq : public CastConstraint { RevSwitch decided_; }; +class ArrayBoolOrEq : public CastConstraint { + public: + ArrayBoolOrEq(Solver* const s, + const std::vector& vars, + IntVar* const target) + : CastConstraint(s, target), + vars_(vars), + demons_(vars.size()), + unbounded_(0) {} + + virtual ~ArrayBoolOrEq() {} + + virtual void Post() { + for (int i = 0; i < vars_.size(); ++i) { + if (!vars_[i]->Bound()) { + demons_[i] = MakeConstraintDemon1(solver(), + this, + &ArrayBoolOrEq::PropagateVar, + "PropagateVar", + i); + vars_[i]->WhenBound(demons_[i]); + } + + } + Demon* const target_demon = + MakeConstraintDemon0(solver(), + this, + &ArrayBoolOrEq::PropagateTarget, + "PropagateTarget"); + target_var_->WhenBound(target_demon); + } + + virtual void InitialPropagate() { + target_var_->SetRange(0, 1); + if (target_var_->Min() == 0) { + for (int i = 0; i < vars_.size(); ++i) { + vars_[i]->SetMax(0); + } + } else { + int zeros = 0; + int ones = 0; + int unbounded = 0; + for (int i = 0; i < vars_.size(); ++i) { + unbounded += !vars_[i]->Bound(); + zeros += vars_[i]->Max() == 0; + ones += vars_[i]->Min() == 1; + } + if (ones > 0) { + InhibitAll(); + target_var_->SetMin(1); + } else if (unbounded == 0) { + target_var_->SetMax(0); + } else if (target_var_->Min() == 1 && unbounded == 1) { + const int index = FindPossibleOne(); + CHECK(index != -1); + vars_[index]->SetMin(1); + } else { + unbounded_.SetValue(solver(), unbounded); + } + } + } + + void PropagateVar(int index) { + if (vars_[index]->Min() == 0) { + unbounded_.Decr(solver()); + if (target_var_->Min() == 1 && + unbounded_.Value() == 1 && + !decided_.Switched()) { + const int to_set = FindPossibleOne(); + if (to_set != -1) { + vars_[to_set]->SetMin(1); + decided_.Switch(solver()); + } else { + solver()->Fail(); + } + } + } else { + InhibitAll(); + target_var_->SetMin(1); + } + } + + void PropagateTarget() { + if (target_var_->Min() == 0) { + for (int i = 0; i < vars_.size(); ++i) { + vars_[i]->SetMax(0); + } + } else { + if (unbounded_.Value() == 1 && !decided_.Switched()) { + const int to_set = FindPossibleOne(); + if (to_set != -1) { + vars_[to_set]->SetMin(1); + decided_.Switch(solver()); + } else { + solver()->Fail(); + } + } + } + } + + void InhibitAll() { + for (int i = 0; i < demons_.size(); ++i) { + if (demons_[i] != NULL) { + demons_[i]->inhibit(solver()); + } + } + } + + int FindPossibleOne() { + for (int i = 0; i < vars_.size(); ++i) { + if (vars_[i]->Max() == 1) { + return i; + } + } + return -1; + } + + string DebugString() const { + return StringPrintf("Or(%s) == %s", + DebugStringVector(vars_, ", ").c_str(), + target_var_->DebugString().c_str()); + } + + void Accept(ModelVisitor* const visitor) const { + visitor->BeginVisitConstraint(ModelVisitor::kMaxEqual, this); + visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument, + vars_.data(), vars_.size()); + visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument, + target_var_); + visitor->EndVisitConstraint(ModelVisitor::kMaxEqual, this); + } + + private: + std::vector vars_; + std::vector demons_; + NumericalRev unbounded_; + RevSwitch decided_; +}; // ---------- Specialized cases ---------- @@ -2240,17 +2378,25 @@ IntExpr* Solver::MakeMax(const std::vector& vars) { if (cache != NULL) { return cache->Var(); } else { - int64 new_min = kint64min; - int64 new_max = kint64min; - for (int i = 0; i < size; ++i) { - new_min = std::max(new_min, vars[i]->Min()); - new_max = std::max(new_max, vars[i]->Max()); + if (AreAllBooleans(vars.data(), vars.size())) { + IntVar* const new_var = MakeBoolVar(); + AddConstraint(RevAlloc(new ArrayBoolOrEq(this, vars, new_var))); + model_cache_->InsertVarArrayExpression( + new_var, vars, ModelCache::VAR_ARRAY_MIN); + return new_var; + } else { + int64 new_min = kint64min; + int64 new_max = kint64min; + for (int i = 0; i < size; ++i) { + new_min = std::max(new_min, vars[i]->Min()); + new_max = std::max(new_max, vars[i]->Max()); + } + IntVar* const new_var = MakeIntVar(new_min, new_max); + AddConstraint(RevAlloc(new MaxConstraint(this, vars, new_var))); + model_cache_->InsertVarArrayExpression( + new_var, vars, ModelCache::VAR_ARRAY_MAX); + return new_var; } - IntVar* const new_var = MakeIntVar(new_min, new_max); - AddConstraint(RevAlloc(new MaxConstraint(this, vars, new_var))); - model_cache_->InsertVarArrayExpression( - new_var, vars, ModelCache::VAR_ARRAY_MAX); - return new_var; } } } @@ -2266,12 +2412,8 @@ Constraint* Solver::MakeMinEquality(const std::vector& vars, Constraint* Solver::MakeMaxEquality(const std::vector& vars, IntVar* const max_var) { - if (max_var->Bound() && - max_var->Min() == 1 && - AreAllBooleans(vars.data(), vars.size()) && - vars.size() > 2) { - return RevAlloc( - new SumBooleanGreaterOrEqualToOne(this, vars.data(), vars.size())); + if (AreAllBooleans(vars.data(), vars.size())) { + return RevAlloc(new ArrayBoolOrEq(this, vars, max_var)); } else { return RevAlloc(new MaxConstraint(this, vars, max_var)); }