From e7d71ffcd3927b1fd00b38599e94a796e09ea0fb Mon Sep 17 00:00:00 2001 From: "lperron@google.com" Date: Tue, 23 Jul 2013 16:09:40 +0000 Subject: [PATCH] correct ismember code --- src/constraint_solver/expr_cst.cc | 71 ++++++++++++++++++++----------- 1 file changed, 45 insertions(+), 26 deletions(-) diff --git a/src/constraint_solver/expr_cst.cc b/src/constraint_solver/expr_cst.cc index d7a04ce700..8b360eff9b 100644 --- a/src/constraint_solver/expr_cst.cc +++ b/src/constraint_solver/expr_cst.cc @@ -976,20 +976,50 @@ class IsMemberCt : public Constraint { } virtual void Post() { - demon_ = solver()->MakeConstraintInitialPropagateCallback(this); + demon_ = MakeConstraintDemon0( + solver(), this, &IsMemberCt::VarDomain, "VarDomain"); if (!var_->Bound()) { var_->WhenDomain(demon_); } if (!boolvar_->Bound()) { - boolvar_->WhenBound(demon_); + Demon* const bdemon = MakeConstraintDemon0( + solver(), this, &IsMemberCt::TargetBound, "TargetBound"); + boolvar_->WhenBound(bdemon); } } virtual void InitialPropagate() { - if (boolvar_->Min() == 1LL) { - demon_->inhibit(solver()); - var_->SetValues(values_.RawData(), values_.size()); - } else if (boolvar_->Max() == 1LL) { + boolvar_->SetRange(0, 1); + if (boolvar_->Bound()) { + TargetBound(); + } else { + VarDomain(); + } + } + + virtual string DebugString() const { + return StringPrintf("IsMemberCt(%s, %s, %s)", + var_->DebugString().c_str(), + values_.DebugString().c_str(), + boolvar_->DebugString().c_str()); + } + + virtual void Accept(ModelVisitor* const visitor) const { + visitor->BeginVisitConstraint(ModelVisitor::kIsMember, this); + visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument, + var_); + visitor->VisitConstIntArrayArgument(ModelVisitor::kValuesArgument, + values_); + visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument, + boolvar_); + visitor->EndVisitConstraint(ModelVisitor::kIsMember, this); + } + + private: + void VarDomain() { + if (boolvar_->Bound()) { + TargetBound(); + } else { for (int offset = 0; offset < values_.size(); ++offset) { const int candidate = (support_ + offset) % values_.size(); if (var_->Contains(values_[candidate])) { @@ -1020,31 +1050,20 @@ class IsMemberCt : public Constraint { // No positive support, setting boolvar to false. demon_->inhibit(solver()); boolvar_->SetValue(0); - } else { // boolvar_ set to 0. + } + } + + void TargetBound() { + DCHECK(boolvar_->Bound()); + if (boolvar_->Min() == 1LL) { + demon_->inhibit(solver()); + var_->SetValues(values_.RawData(), values_.size()); + } else { demon_->inhibit(solver()); var_->RemoveValues(values_.RawData(), values_.size()); } } - virtual string DebugString() const { - return StringPrintf("IsMemberCt(%s, %s, %s)", - var_->DebugString().c_str(), - values_.DebugString().c_str(), - boolvar_->DebugString().c_str()); - } - - virtual void Accept(ModelVisitor* const visitor) const { - visitor->BeginVisitConstraint(ModelVisitor::kIsMember, this); - visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument, - var_); - visitor->VisitConstIntArrayArgument(ModelVisitor::kValuesArgument, - values_); - visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument, - boolvar_); - visitor->EndVisitConstraint(ModelVisitor::kIsMember, this); - } - - private: IntVar* const var_; hash_set values_as_set_; ConstIntArray values_;