diff --git a/examples/cpp/model_util.cc b/examples/cpp/model_util.cc index a3311b282a..6fc2be10fd 100644 --- a/examples/cpp/model_util.cc +++ b/examples/cpp/model_util.cc @@ -41,6 +41,7 @@ DEFINE_string(insert_license, "", "Insert content of the given file into the license file."); DEFINE_bool(collect_variables, false, "Shows effect of the variable collector."); +DECLARE_bool(log_prefix); namespace operations_research { @@ -416,6 +417,7 @@ int Run() { } // namespace operations_research int main(int argc, char **argv) { + FLAGS_log_prefix=false; google::ParseCommandLineFlags(&argc, &argv, true); if (FLAGS_input.empty()) { LOG(FATAL) << "Filename not specified"; diff --git a/src/constraint_solver/constraint_solver.cc b/src/constraint_solver/constraint_solver.cc index ad5ef4a15c..c68c09209e 100644 --- a/src/constraint_solver/constraint_solver.cc +++ b/src/constraint_solver/constraint_solver.cc @@ -2871,6 +2871,11 @@ void ModelVisitor::VisitConstIntArrayArgument(const string& arg_name, VisitIntegerArrayArgument(arg_name, values.RawData(), values.size()); } +void ModelVisitor::VisitIntegerArrayArgument(const string& arg_name, + const std::vector& values) { + VisitIntegerArrayArgument(arg_name, values.data(), values.size()); +} + void ModelVisitor::VisitInt64ToBoolExtension( ResultCallback1* const callback, int64 index_min, diff --git a/src/constraint_solver/constraint_solver.h b/src/constraint_solver/constraint_solver.h index 83f30e82d7..00604e359a 100644 --- a/src/constraint_solver/constraint_solver.h +++ b/src/constraint_solver/constraint_solver.h @@ -1474,14 +1474,17 @@ class Solver { Constraint* MakeElementEquality(const std::vector& vars, IntVar* const index, IntVar* const target); + Constraint* MakeElementEquality(const std::vector& vars, + IntVar* const index, + int64 target); // This constraints is a special case of the element constraint with // an array of integer variables where all the variables are all // differents. In that case, and with a constant value, the index // of the element constraint is bound to be the unique index of the // variable in 'vars' equal to the value target. - Constraint* MakeArrayPositionConstraint(const std::vector& vars, - IntVar* const position, - int64 target); + Constraint* MakeIndexOfConstraint(const std::vector& vars, + IntVar* const position, + int64 target); // This method is a specialized case of the MakeConstraintDemon // method to call the InitiatePropagate of the constraint 'ct'. @@ -3161,6 +3164,7 @@ class ModelVisitor : public BaseObject { static const char kAbsEqual[]; static const char kAllDifferent[]; static const char kAllowedAssignments[]; + static const char kIndexOf[]; static const char kBetween[]; static const char kConvexPiecewise[]; static const char kCountEqual[]; @@ -3374,6 +3378,8 @@ class ModelVisitor : public BaseObject { #if !defined(SWIG) // Using SWIG on calbacks is troublesome, let's hide these methods during // the wrapping. + virtual void VisitIntegerArrayArgument(const string& arg_name, + const std::vector& values); void VisitConstIntArrayArgument(const string& arg_name, const ConstIntArray& argument); void VisitInt64ToBoolExtension(ResultCallback1* const callback, diff --git a/src/constraint_solver/element.cc b/src/constraint_solver/element.cc index bba0394c6a..97323310c3 100644 --- a/src/constraint_solver/element.cc +++ b/src/constraint_solver/element.cc @@ -1308,12 +1308,12 @@ class IntExprArrayElementCstCt : public Constraint { // This constraint implements index == position(constant in vars). -class IntExprArrayPositionCt : public Constraint { +class IntExprIndexOfCt : public Constraint { public: - IntExprArrayPositionCt(Solver* const s, - const std::vector& vars, - IntVar* const index, - int64 target) + IntExprIndexOfCt(Solver* const s, + const std::vector& vars, + IntVar* const index, + int64 target) : Constraint(s), vars_(vars), size_(vars.size()), @@ -1322,13 +1322,13 @@ class IntExprArrayPositionCt : public Constraint { demons_(size_), index_iterator_(index->MakeHoleIterator(true)) {} - virtual ~IntExprArrayPositionCt() {} + virtual ~IntExprIndexOfCt() {} virtual void Post() { for (int i = 0; i < size_; ++i) { demons_[i] = MakeConstraintDemon1(solver(), this, - &IntExprArrayPositionCt::Propagate, + &IntExprIndexOfCt::Propagate, "Propagate", i); vars_[i]->WhenDomain(demons_[i]); @@ -1336,7 +1336,7 @@ class IntExprArrayPositionCt : public Constraint { Demon* const index_demon = MakeConstraintDemon0(solver(), this, - &IntExprArrayPositionCt::PropagateIndex, + &IntExprIndexOfCt::PropagateIndex, "PropagateIndex"); index_->WhenDomain(index_demon); } @@ -1389,14 +1389,14 @@ class IntExprArrayPositionCt : public Constraint { } virtual string DebugString() const { - return StringPrintf("IntExprArrayPosition([%s], %s) == %" GG_LL_FORMAT "d", + return StringPrintf("IntExprIndexOf([%s], %s) == %" GG_LL_FORMAT "d", DebugStringVector(vars_, ", ").c_str(), index_->DebugString().c_str(), target_); } virtual void Accept(ModelVisitor* const visitor) const { - visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this); + visitor->BeginVisitConstraint(ModelVisitor::kIndexOf, this); visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument, vars_.data(), size_); @@ -1404,7 +1404,7 @@ class IntExprArrayPositionCt : public Constraint { index_); visitor->VisitIntegerArgument(ModelVisitor::kTargetArgument, target_); - visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this); + visitor->EndVisitConstraint(ModelVisitor::kIndexOf, this); } private: @@ -1497,11 +1497,24 @@ Constraint* Solver::MakeElementEquality(const std::vector& vars, } } -Constraint* Solver::MakeArrayPositionConstraint( +Constraint* Solver::MakeElementEquality(const std::vector& vars, + IntVar* const index, + int64 target) { + // if (AreAllBound(vars)) { + // std::vector values(vars.size()); + // for (int i = 0; i < vars.size(); ++i) { + // values[i] = vars[i]->Value(); + // } + // return MakeElementEquality(values, index, target); + // } + return RevAlloc(new IntExprArrayElementCstCt(this, vars, index, target)); +} + +Constraint* Solver::MakeIndexOfConstraint( const std::vector& vars, IntVar* const index, int64 target) { - return RevAlloc(new IntExprArrayPositionCt(this, vars, index, target)); + return RevAlloc(new IntExprIndexOfCt(this, vars, index, target)); } IntExpr* Solver::MakeIndexExpression(const std::vector& vars, @@ -1516,7 +1529,7 @@ IntExpr* Solver::MakeIndexExpression(const std::vector& vars, NameVector(vars, ", ").c_str(), value); IntVar* const index = MakeIntVar(0, vars.size() - 1, name); - AddConstraint(MakeArrayPositionConstraint(vars, index, value)); + AddConstraint(MakeIndexOfConstraint(vars, index, value)); model_cache_->InsertVarArrayConstantExpression( index, vars, value, ModelCache::VAR_ARRAY_CONSTANT_INDEX); return index; diff --git a/src/constraint_solver/expressions.cc b/src/constraint_solver/expressions.cc index c428c12531..2963cf93c6 100644 --- a/src/constraint_solver/expressions.cc +++ b/src/constraint_solver/expressions.cc @@ -302,6 +302,24 @@ class DomainIntVar : public IntVar { var_demon_(NULL), active_watchers_(0) {} + ValueWatcher(Solver* const solver, + DomainIntVar* const variable, + const std::vector& values, + const std::vector& vars) + : Constraint(solver), + variable_(variable), + iterator_(variable_->MakeHoleIterator(true)), + watchers_(16), + min_range_(kint64max), + max_range_(kint64min), + var_demon_(NULL), + active_watchers_(0) { + CHECK_EQ(vars.size(), values.size()); + for (int i = 0; i < values.size(); ++i) { + SetValueWatcher(vars[i], values[i]); + } + } + ~ValueWatcher() {} IntVar* GetOrMakeValueWatcher(int64 value) { @@ -329,13 +347,24 @@ class DomainIntVar : public IntVar { } min_range_.SetValue(solver(), std::min(min_range_.Value(), value)); max_range_.SetValue(solver(), std::max(max_range_.Value(), value)); - watchers_.RevInsert(variable_->solver(), value, boolvar); + watchers_.RevInsert(solver(), value, boolvar); if (posted_.Switched() && !boolvar->Bound()) { boolvar->WhenBound(solver()->RevAlloc(new WatchDemon(this, value))); } return boolvar; } + void SetValueWatcher(IntVar* const boolvar, int64 value) { + CHECK(watchers_.At(value) == NULL); + active_watchers_.Incr(solver()); + min_range_.SetValue(solver(), std::min(min_range_.Value(), value)); + max_range_.SetValue(solver(), std::max(max_range_.Value(), value)); + watchers_.RevInsert(solver(), value, boolvar); + if (posted_.Switched() && !boolvar->Bound()) { + boolvar->WhenBound(solver()->RevAlloc(new WatchDemon(this, value))); + } + } + virtual void Post() { var_demon_ = solver()->RevAlloc(new VarDemon(this)); variable_->WhenDomain(var_demon_); @@ -428,17 +457,21 @@ class DomainIntVar : public IntVar { visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this); visitor->VisitIntegerExpressionArgument( ModelVisitor::kVariableArgument, variable_); + std::vector all_coefficients; std::vector all_bool_vars; const int64 max_r = max_range_.Value(); const int64 min_r = min_range_.Value(); for (int64 i = min_r; i <= max_r; ++i) { IntVar* const boolvar = watchers_.At(i); if (boolvar != NULL) { + all_coefficients.push_back(i); all_bool_vars.push_back(boolvar); } } - visitor->VisitIntegerVariableArrayArgument( - ModelVisitor::kVarsArgument, all_bool_vars); + visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument, + all_bool_vars); + visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, + all_coefficients); visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this); } @@ -517,6 +550,24 @@ class DomainIntVar : public IntVar { var_demon_(NULL), active_watchers_(0) {} + BoundWatcher(Solver* const solver, + DomainIntVar* const variable, + const std::vector& values, + const std::vector& vars) + : Constraint(solver), + variable_(variable), + iterator_(variable_->MakeHoleIterator(true)), + watchers_(16), + min_range_(kint64max), + max_range_(kint64min), + var_demon_(NULL), + active_watchers_(0) { + CHECK_EQ(vars.size(), values.size()); + for (int i = 0; i < values.size(); ++i) { + SetBoundWatcher(vars[i], values[i]); + } + } + ~BoundWatcher() {} IntVar* GetOrMakeBoundWatcher(int64 value) { @@ -551,6 +602,17 @@ class DomainIntVar : public IntVar { return boolvar; } + void SetBoundWatcher(IntVar* const boolvar, int64 value) { + CHECK(watchers_.At(value) == NULL); + active_watchers_.Incr(solver()); + min_range_.SetValue(solver(), std::min(min_range_.Value(), value)); + max_range_.SetValue(solver(), std::max(max_range_.Value(), value)); + watchers_.RevInsert(solver(), value, boolvar); + if (posted_.Switched() && !boolvar->Bound()) { + boolvar->WhenBound(solver()->RevAlloc(new WatchDemon(this, value))); + } + } + virtual void Post() { var_demon_ = solver()->RevAlloc(new VarDemon(this)); variable_->WhenRange(var_demon_); @@ -762,6 +824,20 @@ class DomainIntVar : public IntVar { } } + Constraint* SetIsEqual(const std::vector& values, + const std::vector& vars) { + if (value_watcher_ == NULL) { + solver()->SaveAndSetValue( + reinterpret_cast(&value_watcher_), + reinterpret_cast( + solver()->RevAlloc(new ValueWatcher(solver(), + this, + values, + vars)))); + return value_watcher_; + } + } + virtual IntVar* IsDifferent(int64 constant) { Solver* const s = solver(); if (constant == min_.Value() && value_watcher_ == NULL) { @@ -822,6 +898,20 @@ class DomainIntVar : public IntVar { } } + Constraint* SetIsGreaterOrEqual(const std::vector& values, + const std::vector& vars) { + if (bound_watcher_ == NULL) { + solver()->SaveAndSetValue( + reinterpret_cast(&bound_watcher_), + reinterpret_cast( + solver()->RevAlloc(new BoundWatcher(solver(), + this, + values, + vars)))); + return bound_watcher_; + } + } + virtual IntVar* IsLessOrEqual(int64 constant) { Solver* const s = solver(); IntExpr* const cache = s->Cache()->FindExprConstantExpression( @@ -5765,6 +5855,22 @@ Action* NewDomainIntVarCleaner() { return new VariableQueueCleaner; } +Constraint* SetIsEqual(IntVar* const var, + const std::vector& values, + const std::vector& vars) { + DomainIntVar* const dvar = dynamic_cast(var); + CHECK_NOTNULL(dvar); + return dvar->SetIsEqual(values, vars); +} + +Constraint* SetIsGreaterOrEqual(IntVar* const var, + const std::vector& values, + const std::vector& vars) { + DomainIntVar* const dvar = dynamic_cast(var); + CHECK_NOTNULL(dvar); + return dvar->SetIsGreaterOrEqual(values, vars); +} + void Solver::set_queue_cleaner_on_fail(IntVar* const var) { DCHECK_EQ(DOMAIN_INT_VAR, var->VarType()); DomainIntVar* const dvar = reinterpret_cast(var); diff --git a/src/constraint_solver/io.cc b/src/constraint_solver/io.cc index a47eb3c2bb..1090ea7025 100644 --- a/src/constraint_solver/io.cc +++ b/src/constraint_solver/io.cc @@ -1303,37 +1303,49 @@ IntExpr* BuildElement(CPModelLoader* const builder, } // ----- kElementEqual ----- -// TODO(user): Add API on solver and uncomment this method. -/* - Constraint* BuildElementEqual(CPModelLoader* const builder, - const CPConstraintProto& proto) { - IntExpr* target = NULL; - VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, - proto, - &target)); + +Constraint* BuildElementEqual(CPModelLoader* const builder, + const CPConstraintProto& proto) { + IntExpr* index = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kIndexArgument, + proto, + &index)); std::vector values; if (builder->ScanArguments(ModelVisitor::kValuesArgument, - proto, - &values)) { - IntExpr* index = NULL; - VERIFY(builder->ScanArguments(ModelVisitor::kIndexArgument, - proto, - &index)); - return builder->solver()->MakeElement(values, index->Var()); - } - std::vector vars; - if (builder->ScanArguments(ModelVisitor::kVarsArgument, - proto, - &vars)) { - IntExpr* index = NULL; - VERIFY(builder->ScanArguments(ModelVisitor::kIndexArgument, - proto, - &index)); - return builder->solver()->MakeElement(vars, index->Var()); + proto, + &values)) { + IntExpr* target = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, + proto, + &target)); + return builder->solver()->MakeElementEquality(values, + index->Var(), + target->Var()); + } else { + std::vector vars; + if (builder->ScanArguments(ModelVisitor::kVarsArgument, + proto, + &vars)) { + IntExpr* target = NULL; + if (builder->ScanArguments(ModelVisitor::kTargetArgument, + proto, + &target)) { + return builder->solver()->MakeElementEquality(vars, + index->Var(), + target->Var()); + } else { + int64 target_value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, + proto, + &target_value)); + return builder->solver()->MakeElementEquality(vars, + index->Var(), + target_value); + } + } } return NULL; - } -*/ +} // ----- kEndExpr ----- @@ -1400,6 +1412,27 @@ Constraint* BuildGreaterOrEqual(CPModelLoader* const builder, return NULL; } +// ----- kIndexOf ----- + +Constraint* BuildIndexOf(CPModelLoader* const builder, + const CPConstraintProto& proto) { + IntExpr* index = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kIndexArgument, + proto, + &index)); + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, + proto, + &vars)); + int64 target_value = 0; + VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument, + proto, + &target_value)); + return builder->solver()->MakeIndexOfConstraint(vars, + index->Var(), + target_value); +} + // ----- kIntegerVariable ----- IntExpr* BuildIntegerVariable(CPModelLoader* const builder, @@ -2214,6 +2247,44 @@ Constraint* BuildTrueConstraint(CPModelLoader* const builder, return builder->solver()->MakeTrueConstraint(); } +// ----- kVarValueWatcher ----- + +Constraint* SetIsEqual(IntVar* const var, + const std::vector& values, + const std::vector& vars); + +Constraint* BuildVarValueWatcher(CPModelLoader* const builder, + const CPConstraintProto& proto) { + IntExpr* expr = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kVariableArgument, proto, &expr)); + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + std::vector values; + VERIFY(builder->ScanArguments(ModelVisitor::kValuesArgument, + proto, + &values)); + return SetIsEqual(expr->Var(), values, vars); +} + +// ----- kVarBoundWatcher ----- + +Constraint* SetIsGreaterOrEqual(IntVar* const var, + const std::vector& values, + const std::vector& vars); + +Constraint* BuildVarBoundWatcher(CPModelLoader* const builder, + const CPConstraintProto& proto) { + IntExpr* expr = NULL; + VERIFY(builder->ScanArguments(ModelVisitor::kVariableArgument, proto, &expr)); + std::vector vars; + VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars)); + std::vector values; + VERIFY(builder->ScanArguments(ModelVisitor::kValuesArgument, + proto, + &values)); + return SetIsGreaterOrEqual(expr->Var(), values, vars); +} + #undef VERIFY #undef VERIFY_EQ } // namespace @@ -2226,6 +2297,7 @@ bool CPModelLoader::BuildFromProto(const CPIntegerExpressionProto& proto) { Solver::IntegerExpressionBuilder* const builder = solver_->GetIntegerExpressionBuilder(tags_.Element(tag_index)); if (!builder) { + LOG(WARNING) << "Tag " << tags_.Element(tag_index) << " was not found"; return false; } IntExpr* const built = builder->Run(this, proto); @@ -2243,6 +2315,7 @@ Constraint* CPModelLoader::BuildFromProto(const CPConstraintProto& proto) { Solver::ConstraintBuilder* const builder = solver_->GetConstraintBuilder(tags_.Element(tag_index)); if (!builder) { + LOG(WARNING) << "Tag " << tags_.Element(tag_index) << " was not found"; return NULL; } Constraint* const built = builder->Run(this, proto); @@ -2255,6 +2328,7 @@ bool CPModelLoader::BuildFromProto(const CPIntervalVariableProto& proto) { Solver::IntervalVariableBuilder* const builder = solver_->GetIntervalVariableBuilder(tags_.Element(tag_index)); if (!builder) { + LOG(WARNING) << "Tag " << tags_.Element(tag_index) << " was not found"; return NULL; } IntervalVar* const built = builder->Run(this, proto); @@ -2272,6 +2346,7 @@ bool CPModelLoader::BuildFromProto(const CPSequenceVariableProto& proto) { Solver::SequenceVariableBuilder* const builder = solver_->GetSequenceVariableBuilder(tags_.Element(tag_index)); if (!builder) { + LOG(WARNING) << "Tag " << tags_.Element(tag_index) << " was not found"; return NULL; } SequenceVar* const built = builder->Run(this, proto); @@ -2588,12 +2663,13 @@ void Solver::InitBuilders() { REGISTER(kDivide, BuildDivide); REGISTER(kDurationExpr, BuildDurationExpr); REGISTER(kElement, BuildElement); - // REGISTER(kElementEqual, BuildElementEqual); + REGISTER(kElementEqual, BuildElementEqual); REGISTER(kEndExpr, BuildEndExpr); REGISTER(kEquality, BuildEquality); REGISTER(kFalseConstraint, BuildFalseConstraint); REGISTER(kGreater, BuildGreater); REGISTER(kGreaterOrEqual, BuildGreaterOrEqual); + REGISTER(kIndexOf, BuildIndexOf); REGISTER(kIntegerVariable, BuildIntegerVariable); REGISTER(kIntervalBinaryRelation, BuildIntervalBinaryRelation); REGISTER(kIntervalDisjunction, BuildIntervalDisjunction); @@ -2635,6 +2711,8 @@ void Solver::InitBuilders() { REGISTER(kSumLessOrEqual, BuildSumLessOrEqual); REGISTER(kTransition, BuildTransition); REGISTER(kTrueConstraint, BuildTrueConstraint); + REGISTER(kVarBoundWatcher, BuildVarBoundWatcher); + REGISTER(kVarValueWatcher, BuildVarValueWatcher); } #undef REGISTER