diff --git a/constraint_solver/collect_variables.cc b/constraint_solver/collect_variables.cc index 44d0d4d1ed..3da837cad2 100644 --- a/constraint_solver/collect_variables.cc +++ b/constraint_solver/collect_variables.cc @@ -27,6 +27,97 @@ namespace operations_research { namespace { + +class ArgumentHolder { + public: + struct Matrix { + Matrix() : values(NULL), rows(0), columns(0) {} + Matrix(const int64 * const * v, int r, int c) + : values(v), rows(r), columns(c) {} + const int64* const * values; + int rows; + int columns; + }; + + const string& type_name() const { + return type_name_; + } + + void set_type_name(const string& type_name) { + type_name_ = type_name; + } + + void set_integer_matrix_argument(const string& arg_name, + const int64* const * const values, + int rows, + int columns) { + matrix_argument_[arg_name] = Matrix(values, rows, columns); + } + + void set_integer_expression_argument(const string& arg_name, + const IntExpr* const expr) { + integer_expression_argument_[arg_name] = expr; + } + + void set_integer_variable_array_argument(const string& arg_name, + const IntVar* const * const vars, + int size) { + for (int i = 0; i < size; ++i) { + integer_variable_array_argument_[arg_name].push_back(vars[i]); + } + } + + void set_interval_argument(const string& arg_name, + const IntervalVar* const var) { + interval_argument_[arg_name] = var; + } + + void set_interval_array_argument(const string& arg_name, + const IntervalVar* const * const vars, + int size) { + for (int i = 0; i < size; ++i) { + interval_array_argument_[arg_name].push_back(vars[i]); + } + } + + void set_sequence_argument(const string& arg_name, + const SequenceVar* const var) { + sequence_argument_[arg_name] = var; + } + + void set_sequence_array_argument(const string& arg_name, + const SequenceVar* const * const vars, + int size) { + for (int i = 0; i < size; ++i) { + sequence_array_argument_[arg_name].push_back(vars[i]); + } + } + + const IntExpr* FindIntegerExpressionArgumentOrDie(const string& arg_name) { + return FindOrDie(integer_expression_argument_, arg_name); + } + + const std::vector& FindIntegerVariableArrayArgumentOrDie( + const string& arg_name) { + return FindOrDie(integer_variable_array_argument_, arg_name); + } + + const Matrix& FindIntegerMatrixArgumentOrDie(const string& arg_name) { + return FindOrDie(matrix_argument_, arg_name); + } + + private: + string type_name_; + hash_map integer_expression_argument_; + hash_map interval_argument_; + hash_map sequence_argument_; + hash_map > integer_variable_array_argument_; + hash_map > interval_array_argument_; + hash_map > sequence_array_argument_; + hash_map matrix_argument_; +}; + class CollectVariablesVisitor : public ModelVisitor { public: CollectVariablesVisitor(std::vector* const primary_integer_variables, @@ -41,32 +132,75 @@ class CollectVariablesVisitor : public ModelVisitor { // Header/footers. virtual void BeginVisitModel(const string& solver_name) { - LOG(INFO) << "Starts collecting variables on model " << solver_name; + PushArgumentHolder(); } virtual void EndVisitModel(const string& solver_name) { - LOG(INFO) << "Finishes collecting variables."; + PopArgumentHolder(); + primaries_->assign(primary_set_.begin(), primary_set_.end()); + secondaries_->assign(secondary_set_.begin(), secondary_set_.end()); + intervals_->assign(interval_set_.begin(), interval_set_.end()); + sequences_->assign(sequence_set_.begin(), sequence_set_.end()); } virtual void BeginVisitConstraint(const string& type_name, const Constraint* const constraint) { - if (constraint->IsCastConstraint()) { - const CastConstraint* const cast_constraint = - reinterpret_cast(constraint); - ignored_set_.insert(cast_constraint->target_var()); - } + PushArgumentHolder(); } virtual void EndVisitConstraint(const string& type_name, const Constraint* const constraint) { + if (type_name.compare(ModelVisitor::kLinkExprVar) == 0 || + type_name.compare(ModelVisitor::kSumEqual) == 0 || + type_name.compare(ModelVisitor::kCountEqual) == 0 || + type_name.compare(ModelVisitor::kElementEqual) == 0 || + type_name.compare(ModelVisitor::kScalProdEqual) == 0 || + type_name.compare(ModelVisitor::kIsEqual) == 0 || + type_name.compare(ModelVisitor::kIsDifferent) == 0 || + type_name.compare(ModelVisitor::kIsGreaterOrEqual) == 0 || + type_name.compare(ModelVisitor::kIsLessOrEqual) == 0) { + IntExpr* const target_expr = + const_cast( + top()->FindIntegerExpressionArgumentOrDie( + ModelVisitor::kTargetArgument)); + IntVar* const target_var = target_expr->Var(); + IgnoreIntegerVariable(target_var); + } else if (type_name.compare(ModelVisitor::kAllowedAssignments) == 0) { + const ArgumentHolder::Matrix& matrix = + top()->FindIntegerMatrixArgumentOrDie(ModelVisitor::kTuplesArgument); + vector > counters(matrix.columns); + for (int i = 0; i < matrix.rows; ++i) { + for (int j = 0; j < matrix.columns; ++j) { + counters[j].insert(matrix.values[i][j]); + } + } + for (int j = 0; j < matrix.columns; ++j) { + if (counters[j].size() == matrix.rows) { + vector vars = + top()->FindIntegerVariableArrayArgumentOrDie( + ModelVisitor::kVarsArgument); + LOG(INFO) << "Found index variable in allowed assignment constraint: " + << vars[j]->DebugString(); + for (int k = 0; k < matrix.columns; ++k) { + if (j != k) { + IgnoreIntegerVariable(const_cast(vars[k])); + } + } + break; + } + } + } + PopArgumentHolder(); } virtual void BeginVisitIntegerExpression(const string& type_name, const IntExpr* const expr) { + PushArgumentHolder(); } virtual void EndVisitIntegerExpression(const string& type_name, const IntExpr* const expr) { + PopArgumentHolder(); } virtual void VisitIntegerVariable(const IntVar* const variable, @@ -80,7 +214,6 @@ class CollectVariablesVisitor : public ModelVisitor { !ContainsKey(ignored_set_, var) && !var->Bound()) { primary_set_.insert(var); - primaries_->push_back(const_cast(var)); } } } @@ -101,7 +234,6 @@ class CollectVariablesVisitor : public ModelVisitor { IntervalVar* const var = const_cast(variable); if (!ContainsKey(interval_set_, var)) { interval_set_.insert(var); - intervals_->push_back(var); } } } @@ -119,17 +251,25 @@ class CollectVariablesVisitor : public ModelVisitor { SequenceVar* const var = const_cast(variable); if (!ContainsKey(sequence_set_, var)) { sequence_set_.insert(var); - sequences_->push_back(var); } for (int i = 0; i < var->size(); ++i) { var->Interval(i)->Accept(this); } } + // Integer arguments + virtual void VisitIntegerMatrixArgument(const string& arg_name, + const int64* const * const values, + int rows, + int columns) { + top()->set_integer_matrix_argument(arg_name, values, rows, columns); + } + // Variables. virtual void VisitIntegerExpressionArgument( const string& arg_name, const IntExpr* const argument) { + top()->set_integer_expression_argument(arg_name, argument); argument->Accept(this); } @@ -137,6 +277,7 @@ class CollectVariablesVisitor : public ModelVisitor { const string& arg_name, const IntVar* const * arguments, int size) { + top()->set_integer_variable_array_argument(arg_name, arguments, size); for (int i = 0; i < size; ++i) { arguments[i]->Accept(this); } @@ -145,12 +286,14 @@ class CollectVariablesVisitor : public ModelVisitor { // Visit interval argument. virtual void VisitIntervalArgument(const string& arg_name, const IntervalVar* const argument) { + top()->set_interval_argument(arg_name, argument); argument->Accept(this); } virtual void VisitIntervalArgumentArray(const string& arg_name, const IntervalVar* const * arguments, int size) { + top()->set_interval_array_argument(arg_name, arguments, size); for (int i = 0; i < size; ++i) { arguments[i]->Accept(this); } @@ -159,18 +302,41 @@ class CollectVariablesVisitor : public ModelVisitor { // Visit sequence argument. virtual void VisitSequenceArgument(const string& arg_name, const SequenceVar* const argument) { + top()->set_sequence_argument(arg_name, argument); argument->Accept(this); } virtual void VisitSequenceArgumentArray(const string& arg_name, const SequenceVar* const * arguments, int size) { + top()->set_sequence_array_argument(arg_name, arguments, size); for (int i = 0; i < size; ++i) { arguments[i]->Accept(this); } } private: + void PushArgumentHolder() { + holders_.push_back(new ArgumentHolder); + } + + void PopArgumentHolder() { + CHECK(!holders_.empty()); + delete holders_.back(); + holders_.pop_back(); + } + + ArgumentHolder* top() const { + CHECK(!holders_.empty()); + return holders_.back(); + } + + void IgnoreIntegerVariable(IntVar* const var) { + primary_set_.erase(var); + secondary_set_.erase(var); + ignored_set_.insert(var); + } + std::vector* const primaries_; std::vector* const secondaries_; std::vector* const sequences_; @@ -180,6 +346,7 @@ class CollectVariablesVisitor : public ModelVisitor { hash_set ignored_set_; hash_set sequence_set_; hash_set interval_set_; + std::vector holders_; }; } // namespace diff --git a/constraint_solver/expr_array.cc b/constraint_solver/expr_array.cc index 467548ded2..3d2f0a33ba 100644 --- a/constraint_solver/expr_array.cc +++ b/constraint_solver/expr_array.cc @@ -1402,9 +1402,7 @@ IntExpr* Solver::MakeSum(IntVar* const* vars, int size) { sum_max += vars[i]->Max(); } IntVar* const sum_var = MakeIntVar(sum_min, sum_max); - AddCastConstraint(RevAlloc(new SumConstraint(this, vars, size, sum_var)), - sum_var, - NULL); + AddConstraint(RevAlloc(new SumConstraint(this, vars, size, sum_var)));, return sum_var; } } diff --git a/constraint_solver/expr_cst.cc b/constraint_solver/expr_cst.cc index 5a620ca4be..7cc8e9d104 100644 --- a/constraint_solver/expr_cst.cc +++ b/constraint_solver/expr_cst.cc @@ -287,23 +287,23 @@ Constraint* Solver::MakeNonEquality(IntVar* const e, int v) { // ----- is_equal_cst Constraint ----- namespace { -class IsEqualCstCt : public Constraint { +class IsEqualCstCt : public CastConstraint { public: IsEqualCstCt(Solver* const s, IntVar* const v, int64 c, IntVar* const b) - : Constraint(s), var_(v), cst_(c), boolvar_(b), demon_(NULL) {} + : CastConstraint(s, b), var_(v), cst_(c), demon_(NULL) {} virtual void Post() { demon_ = solver()->MakeConstraintInitialPropagateCallback(this); var_->WhenDomain(demon_); - boolvar_->WhenBound(demon_); + target_var_->WhenBound(demon_); } virtual void InitialPropagate() { bool inhibit = var_->Bound(); int64 u = var_->Contains(cst_); int64 l = inhibit ? u : 0; - boolvar_->SetRange(l, u); - if (boolvar_->Bound()) { + target_var_->SetRange(l, u); + if (target_var_->Bound()) { inhibit = true; - if (boolvar_->Min() == 0) { + if (target_var_->Min() == 0) { var_->RemoveValue(cst_); } else { var_->SetValue(cst_); @@ -317,7 +317,7 @@ class IsEqualCstCt : public Constraint { return StringPrintf("IsEqualCstCt(%s, %" GG_LL_FORMAT "d, %s)", var_->DebugString().c_str(), cst_, - boolvar_->DebugString().c_str()); + target_var_->DebugString().c_str()); } void Accept(ModelVisitor* const visitor) const { @@ -326,14 +326,13 @@ class IsEqualCstCt : public Constraint { var_); visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_); visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument, - boolvar_); + target_var_); visitor->EndVisitConstraint(ModelVisitor::kIsEqual, this); } private: IntVar* const var_; int64 cst_; - IntVar* const boolvar_; Demon* demon_; }; } // namespace @@ -358,10 +357,14 @@ IntVar* Solver::MakeIsEqualCstVar(IntVar* const var, int64 value) { if (cache != NULL) { return cache->Var(); } else { + string name = var->name(); + if (name.empty()) { + name = var->DebugString(); + } IntVar* const boolvar = MakeBoolVar( StringPrintf("StatusVar<%s == %" GG_LL_FORMAT "d>", - var->name().c_str(), value)); - Constraint* const maintain = + name.c_str(), value)); + CastConstraint* const maintain = RevAlloc(new IsEqualCstCt(this, var, value, boolvar)); AddConstraint(maintain); model_cache_->InsertVarConstantExpression( @@ -397,25 +400,25 @@ Constraint* Solver::MakeIsEqualCstCt(IntVar* const var, // ----- is_diff_cst Constraint ----- namespace { -class IsDiffCstCt : public Constraint { +class IsDiffCstCt : public CastConstraint { public: IsDiffCstCt(Solver* const s, IntVar* const v, int64 c, IntVar* const b) - : Constraint(s), var_(v), cst_(c), boolvar_(b), demon_(NULL) {} + : CastConstraint(s, b), var_(v), cst_(c), demon_(NULL) {} virtual void Post() { demon_ = solver()->MakeConstraintInitialPropagateCallback(this); var_->WhenDomain(demon_); - boolvar_->WhenBound(demon_); + target_var_->WhenBound(demon_); } virtual void InitialPropagate() { bool inhibit = var_->Bound(); int64 l = 1 - var_->Contains(cst_); int64 u = inhibit ? l : 1; - boolvar_->SetRange(l, u); - if (boolvar_->Bound()) { + target_var_->SetRange(l, u); + if (target_var_->Bound()) { inhibit = true; - if (boolvar_->Min() == 1) { + if (target_var_->Min() == 1) { var_->RemoveValue(cst_); } else { var_->SetValue(cst_); @@ -430,7 +433,7 @@ class IsDiffCstCt : public Constraint { return StringPrintf("IsDiffCstCt(%s, %" GG_LL_FORMAT "d, %s)", var_->DebugString().c_str(), cst_, - boolvar_->DebugString().c_str()); + target_var_->DebugString().c_str()); } void Accept(ModelVisitor* const visitor) const { @@ -439,14 +442,13 @@ class IsDiffCstCt : public Constraint { var_); visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_); visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument, - boolvar_); + target_var_); visitor->EndVisitConstraint(ModelVisitor::kIsDifferent, this); } private: IntVar* const var_; int64 cst_; - IntVar* const boolvar_; Demon* demon_; }; } // namespace @@ -471,10 +473,14 @@ IntVar* Solver::MakeIsDifferentCstVar(IntVar* const var, int64 value) { if (cache != NULL) { return cache->Var(); } else { + string name = var->name(); + if (name.empty()) { + name = var->DebugString(); + } IntVar* const boolvar = MakeBoolVar( StringPrintf("StatusVar<%s != %" GG_LL_FORMAT "d>", - var->name().c_str(), value)); - Constraint* const maintain = + name.c_str(), value)); + CastConstraint* const maintain = RevAlloc(new IsDiffCstCt(this, var, value, boolvar)); AddConstraint(maintain); model_cache_->InsertVarConstantExpression( @@ -508,24 +514,24 @@ Constraint* Solver::MakeIsDifferentCstCt(IntVar* const var, // ----- is_greater_equal_cst Constraint ----- namespace { -class IsGreaterEqualCstCt : public Constraint { +class IsGreaterEqualCstCt : public CastConstraint { public: IsGreaterEqualCstCt(Solver* const s, IntVar* const v, int64 c, IntVar* const b) - : Constraint(s), var_(v), cst_(c), boolvar_(b), demon_(NULL) {} + : CastConstraint(s, b), var_(v), cst_(c), demon_(NULL) {} virtual void Post() { demon_ = solver()->MakeConstraintInitialPropagateCallback(this); var_->WhenRange(demon_); - boolvar_->WhenBound(demon_); + target_var_->WhenBound(demon_); } virtual void InitialPropagate() { bool inhibit = false; int64 u = var_->Max() >= cst_; int64 l = var_->Min() >= cst_; - boolvar_->SetRange(l, u); - if (boolvar_->Bound()) { + target_var_->SetRange(l, u); + if (target_var_->Bound()) { inhibit = true; - if (boolvar_->Min() == 0) { + if (target_var_->Min() == 0) { var_->SetMax(cst_ - 1); } else { var_->SetMin(cst_); @@ -539,7 +545,7 @@ class IsGreaterEqualCstCt : public Constraint { return StringPrintf("IsGreaterEqualCstCt(%s, %" GG_LL_FORMAT "d, %s)", var_->DebugString().c_str(), cst_, - boolvar_->DebugString().c_str()); + target_var_->DebugString().c_str()); } virtual void Accept(ModelVisitor* const visitor) const { @@ -548,14 +554,13 @@ class IsGreaterEqualCstCt : public Constraint { var_); visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_); visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument, - boolvar_); + target_var_); visitor->EndVisitConstraint(ModelVisitor::kIsGreaterOrEqual, this); } private: IntVar* const var_; int64 cst_; - IntVar* const boolvar_; Demon* demon_; }; } // namespace @@ -574,10 +579,14 @@ IntVar* Solver::MakeIsGreaterOrEqualCstVar(IntVar* const var, int64 value) { if (cache != NULL) { return cache->Var(); } else { + string name = var->name(); + if (name.empty()) { + name = var->DebugString(); + } IntVar* const boolvar = MakeBoolVar( StringPrintf("StatusVar<%s >= %" GG_LL_FORMAT "d>", - var->name().c_str(), value)); - Constraint* const maintain = + name.c_str(), value)); + CastConstraint* const maintain = RevAlloc(new IsGreaterEqualCstCt(this, var, value, boolvar)); AddConstraint(maintain); model_cache_->InsertVarConstantExpression( @@ -614,25 +623,25 @@ Constraint* Solver::MakeIsGreaterCstCt(IntVar* const v, int64 c, // ----- is_lesser_equal_cst Constraint ----- namespace { -class IsLessEqualCstCt : public Constraint { +class IsLessEqualCstCt : public CastConstraint { public: IsLessEqualCstCt(Solver* const s, IntVar* const v, int64 c, IntVar* const b) - : Constraint(s), var_(v), cst_(c), boolvar_(b), demon_(NULL) {} + : CastConstraint(s, b), var_(v), cst_(c), demon_(NULL) {} virtual void Post() { demon_ = solver()->MakeConstraintInitialPropagateCallback(this); var_->WhenRange(demon_); - boolvar_->WhenBound(demon_); + target_var_->WhenBound(demon_); } virtual void InitialPropagate() { bool inhibit = false; int64 u = var_->Min() <= cst_; int64 l = var_->Max() <= cst_; - boolvar_->SetRange(l, u); - if (boolvar_->Bound()) { + target_var_->SetRange(l, u); + if (target_var_->Bound()) { inhibit = true; - if (boolvar_->Min() == 0) { + if (target_var_->Min() == 0) { var_->SetMin(cst_ + 1); } else { var_->SetMax(cst_); @@ -647,7 +656,7 @@ class IsLessEqualCstCt : public Constraint { return StringPrintf("IsLessEqualCstCt(%s, %" GG_LL_FORMAT "d, %s)", var_->DebugString().c_str(), cst_, - boolvar_->DebugString().c_str()); + target_var_->DebugString().c_str()); } virtual void Accept(ModelVisitor* const visitor) const { @@ -656,14 +665,13 @@ class IsLessEqualCstCt : public Constraint { var_); visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_); visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument, - boolvar_); + target_var_); visitor->EndVisitConstraint(ModelVisitor::kIsLessOrEqual, this); } private: IntVar* const var_; int64 cst_; - IntVar* const boolvar_; Demon* demon_; }; } // namespace @@ -683,10 +691,14 @@ IntVar* Solver::MakeIsLessOrEqualCstVar(IntVar* const var, int64 value) { if (cache != NULL) { return cache->Var(); } else { + string name = var->name(); + if (name.empty()) { + name = var->DebugString(); + } IntVar* const boolvar = MakeBoolVar( StringPrintf("StatusVar<%s <= %" GG_LL_FORMAT "d>", - var->name().c_str(), value)); - Constraint* const maintain = + name.c_str(), value)); + CastConstraint* const maintain = RevAlloc(new IsLessEqualCstCt(this, var, value, boolvar)); AddConstraint(maintain); model_cache_->InsertVarConstantExpression( diff --git a/constraint_solver/io.cc b/constraint_solver/io.cc index d1e0716b4f..463dd0250b 100644 --- a/constraint_solver/io.cc +++ b/constraint_solver/io.cc @@ -582,7 +582,7 @@ class SecondPassVisitor : public ModelVisitor { virtual void EndVisitConstraint(const string& type_name, const Constraint* const constraint) { - // We ignore delegate constraints, they will be regenerated automatically. + // We ignore cast constraints, they will be regenerated automatically. if (constraint->IsCastConstraint()) { return; }