missing IO for varvaluewatcher, varboundwatcher, elementequal and indexor constraints

This commit is contained in:
lperron@google.com
2012-08-17 23:20:02 +00:00
parent 79822f8ad7
commit 27eeb7fa3a
6 changed files with 258 additions and 48 deletions

View File

@@ -41,6 +41,7 @@ DEFINE_string(insert_license, "",
"Insert content of the given file into the license file.");
DEFINE_bool(collect_variables, false,
"Shows effect of the variable collector.");
DECLARE_bool(log_prefix);
namespace operations_research {
@@ -416,6 +417,7 @@ int Run() {
} // namespace operations_research
int main(int argc, char **argv) {
FLAGS_log_prefix=false;
google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_input.empty()) {
LOG(FATAL) << "Filename not specified";

View File

@@ -2871,6 +2871,11 @@ void ModelVisitor::VisitConstIntArrayArgument(const string& arg_name,
VisitIntegerArrayArgument(arg_name, values.RawData(), values.size());
}
void ModelVisitor::VisitIntegerArrayArgument(const string& arg_name,
const std::vector<int64>& values) {
VisitIntegerArrayArgument(arg_name, values.data(), values.size());
}
void ModelVisitor::VisitInt64ToBoolExtension(
ResultCallback1<bool, int64>* const callback,
int64 index_min,

View File

@@ -1474,14 +1474,17 @@ class Solver {
Constraint* MakeElementEquality(const std::vector<IntVar*>& vars,
IntVar* const index,
IntVar* const target);
Constraint* MakeElementEquality(const std::vector<IntVar*>& vars,
IntVar* const index,
int64 target);
// This constraints is a special case of the element constraint with
// an array of integer variables where all the variables are all
// differents. In that case, and with a constant value, the index
// of the element constraint is bound to be the unique index of the
// variable in 'vars' equal to the value target.
Constraint* MakeArrayPositionConstraint(const std::vector<IntVar*>& vars,
IntVar* const position,
int64 target);
Constraint* MakeIndexOfConstraint(const std::vector<IntVar*>& vars,
IntVar* const position,
int64 target);
// This method is a specialized case of the MakeConstraintDemon
// method to call the InitiatePropagate of the constraint 'ct'.
@@ -3161,6 +3164,7 @@ class ModelVisitor : public BaseObject {
static const char kAbsEqual[];
static const char kAllDifferent[];
static const char kAllowedAssignments[];
static const char kIndexOf[];
static const char kBetween[];
static const char kConvexPiecewise[];
static const char kCountEqual[];
@@ -3374,6 +3378,8 @@ class ModelVisitor : public BaseObject {
#if !defined(SWIG)
// Using SWIG on calbacks is troublesome, let's hide these methods during
// the wrapping.
virtual void VisitIntegerArrayArgument(const string& arg_name,
const std::vector<int64>& values);
void VisitConstIntArrayArgument(const string& arg_name,
const ConstIntArray& argument);
void VisitInt64ToBoolExtension(ResultCallback1<bool, int64>* const callback,

View File

@@ -1308,12 +1308,12 @@ class IntExprArrayElementCstCt : public Constraint {
// This constraint implements index == position(constant in vars).
class IntExprArrayPositionCt : public Constraint {
class IntExprIndexOfCt : public Constraint {
public:
IntExprArrayPositionCt(Solver* const s,
const std::vector<IntVar*>& vars,
IntVar* const index,
int64 target)
IntExprIndexOfCt(Solver* const s,
const std::vector<IntVar*>& vars,
IntVar* const index,
int64 target)
: Constraint(s),
vars_(vars),
size_(vars.size()),
@@ -1322,13 +1322,13 @@ class IntExprArrayPositionCt : public Constraint {
demons_(size_),
index_iterator_(index->MakeHoleIterator(true)) {}
virtual ~IntExprArrayPositionCt() {}
virtual ~IntExprIndexOfCt() {}
virtual void Post() {
for (int i = 0; i < size_; ++i) {
demons_[i] = MakeConstraintDemon1(solver(),
this,
&IntExprArrayPositionCt::Propagate,
&IntExprIndexOfCt::Propagate,
"Propagate",
i);
vars_[i]->WhenDomain(demons_[i]);
@@ -1336,7 +1336,7 @@ class IntExprArrayPositionCt : public Constraint {
Demon* const index_demon =
MakeConstraintDemon0(solver(),
this,
&IntExprArrayPositionCt::PropagateIndex,
&IntExprIndexOfCt::PropagateIndex,
"PropagateIndex");
index_->WhenDomain(index_demon);
}
@@ -1389,14 +1389,14 @@ class IntExprArrayPositionCt : public Constraint {
}
virtual string DebugString() const {
return StringPrintf("IntExprArrayPosition([%s], %s) == %" GG_LL_FORMAT "d",
return StringPrintf("IntExprIndexOf([%s], %s) == %" GG_LL_FORMAT "d",
DebugStringVector(vars_, ", ").c_str(),
index_->DebugString().c_str(),
target_);
}
virtual void Accept(ModelVisitor* const visitor) const {
visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);
visitor->BeginVisitConstraint(ModelVisitor::kIndexOf, this);
visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
vars_.data(),
size_);
@@ -1404,7 +1404,7 @@ class IntExprArrayPositionCt : public Constraint {
index_);
visitor->VisitIntegerArgument(ModelVisitor::kTargetArgument,
target_);
visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);
visitor->EndVisitConstraint(ModelVisitor::kIndexOf, this);
}
private:
@@ -1497,11 +1497,24 @@ Constraint* Solver::MakeElementEquality(const std::vector<IntVar*>& vars,
}
}
Constraint* Solver::MakeArrayPositionConstraint(
Constraint* Solver::MakeElementEquality(const std::vector<IntVar*>& vars,
IntVar* const index,
int64 target) {
// if (AreAllBound(vars)) {
// std::vector<int64> values(vars.size());
// for (int i = 0; i < vars.size(); ++i) {
// values[i] = vars[i]->Value();
// }
// return MakeElementEquality(values, index, target);
// }
return RevAlloc(new IntExprArrayElementCstCt(this, vars, index, target));
}
Constraint* Solver::MakeIndexOfConstraint(
const std::vector<IntVar*>& vars,
IntVar* const index,
int64 target) {
return RevAlloc(new IntExprArrayPositionCt(this, vars, index, target));
return RevAlloc(new IntExprIndexOfCt(this, vars, index, target));
}
IntExpr* Solver::MakeIndexExpression(const std::vector<IntVar*>& vars,
@@ -1516,7 +1529,7 @@ IntExpr* Solver::MakeIndexExpression(const std::vector<IntVar*>& vars,
NameVector(vars, ", ").c_str(),
value);
IntVar* const index = MakeIntVar(0, vars.size() - 1, name);
AddConstraint(MakeArrayPositionConstraint(vars, index, value));
AddConstraint(MakeIndexOfConstraint(vars, index, value));
model_cache_->InsertVarArrayConstantExpression(
index, vars, value, ModelCache::VAR_ARRAY_CONSTANT_INDEX);
return index;

View File

@@ -302,6 +302,24 @@ class DomainIntVar : public IntVar {
var_demon_(NULL),
active_watchers_(0) {}
ValueWatcher(Solver* const solver,
DomainIntVar* const variable,
const std::vector<int64>& values,
const std::vector<IntVar*>& vars)
: Constraint(solver),
variable_(variable),
iterator_(variable_->MakeHoleIterator(true)),
watchers_(16),
min_range_(kint64max),
max_range_(kint64min),
var_demon_(NULL),
active_watchers_(0) {
CHECK_EQ(vars.size(), values.size());
for (int i = 0; i < values.size(); ++i) {
SetValueWatcher(vars[i], values[i]);
}
}
~ValueWatcher() {}
IntVar* GetOrMakeValueWatcher(int64 value) {
@@ -329,13 +347,24 @@ class DomainIntVar : public IntVar {
}
min_range_.SetValue(solver(), std::min(min_range_.Value(), value));
max_range_.SetValue(solver(), std::max(max_range_.Value(), value));
watchers_.RevInsert(variable_->solver(), value, boolvar);
watchers_.RevInsert(solver(), value, boolvar);
if (posted_.Switched() && !boolvar->Bound()) {
boolvar->WhenBound(solver()->RevAlloc(new WatchDemon(this, value)));
}
return boolvar;
}
void SetValueWatcher(IntVar* const boolvar, int64 value) {
CHECK(watchers_.At(value) == NULL);
active_watchers_.Incr(solver());
min_range_.SetValue(solver(), std::min(min_range_.Value(), value));
max_range_.SetValue(solver(), std::max(max_range_.Value(), value));
watchers_.RevInsert(solver(), value, boolvar);
if (posted_.Switched() && !boolvar->Bound()) {
boolvar->WhenBound(solver()->RevAlloc(new WatchDemon(this, value)));
}
}
virtual void Post() {
var_demon_ = solver()->RevAlloc(new VarDemon(this));
variable_->WhenDomain(var_demon_);
@@ -428,17 +457,21 @@ class DomainIntVar : public IntVar {
visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
visitor->VisitIntegerExpressionArgument(
ModelVisitor::kVariableArgument, variable_);
std::vector<int64> all_coefficients;
std::vector<IntVar*> all_bool_vars;
const int64 max_r = max_range_.Value();
const int64 min_r = min_range_.Value();
for (int64 i = min_r; i <= max_r; ++i) {
IntVar* const boolvar = watchers_.At(i);
if (boolvar != NULL) {
all_coefficients.push_back(i);
all_bool_vars.push_back(boolvar);
}
}
visitor->VisitIntegerVariableArrayArgument(
ModelVisitor::kVarsArgument, all_bool_vars);
visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
all_bool_vars);
visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
all_coefficients);
visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
}
@@ -517,6 +550,24 @@ class DomainIntVar : public IntVar {
var_demon_(NULL),
active_watchers_(0) {}
BoundWatcher(Solver* const solver,
DomainIntVar* const variable,
const std::vector<int64>& values,
const std::vector<IntVar*>& vars)
: Constraint(solver),
variable_(variable),
iterator_(variable_->MakeHoleIterator(true)),
watchers_(16),
min_range_(kint64max),
max_range_(kint64min),
var_demon_(NULL),
active_watchers_(0) {
CHECK_EQ(vars.size(), values.size());
for (int i = 0; i < values.size(); ++i) {
SetBoundWatcher(vars[i], values[i]);
}
}
~BoundWatcher() {}
IntVar* GetOrMakeBoundWatcher(int64 value) {
@@ -551,6 +602,17 @@ class DomainIntVar : public IntVar {
return boolvar;
}
void SetBoundWatcher(IntVar* const boolvar, int64 value) {
CHECK(watchers_.At(value) == NULL);
active_watchers_.Incr(solver());
min_range_.SetValue(solver(), std::min(min_range_.Value(), value));
max_range_.SetValue(solver(), std::max(max_range_.Value(), value));
watchers_.RevInsert(solver(), value, boolvar);
if (posted_.Switched() && !boolvar->Bound()) {
boolvar->WhenBound(solver()->RevAlloc(new WatchDemon(this, value)));
}
}
virtual void Post() {
var_demon_ = solver()->RevAlloc(new VarDemon(this));
variable_->WhenRange(var_demon_);
@@ -762,6 +824,20 @@ class DomainIntVar : public IntVar {
}
}
Constraint* SetIsEqual(const std::vector<int64>& values,
const std::vector<IntVar*>& vars) {
if (value_watcher_ == NULL) {
solver()->SaveAndSetValue(
reinterpret_cast<void**>(&value_watcher_),
reinterpret_cast<void*>(
solver()->RevAlloc(new ValueWatcher(solver(),
this,
values,
vars))));
return value_watcher_;
}
}
virtual IntVar* IsDifferent(int64 constant) {
Solver* const s = solver();
if (constant == min_.Value() && value_watcher_ == NULL) {
@@ -822,6 +898,20 @@ class DomainIntVar : public IntVar {
}
}
Constraint* SetIsGreaterOrEqual(const std::vector<int64>& values,
const std::vector<IntVar*>& vars) {
if (bound_watcher_ == NULL) {
solver()->SaveAndSetValue(
reinterpret_cast<void**>(&bound_watcher_),
reinterpret_cast<void*>(
solver()->RevAlloc(new BoundWatcher(solver(),
this,
values,
vars))));
return bound_watcher_;
}
}
virtual IntVar* IsLessOrEqual(int64 constant) {
Solver* const s = solver();
IntExpr* const cache = s->Cache()->FindExprConstantExpression(
@@ -5765,6 +5855,22 @@ Action* NewDomainIntVarCleaner() {
return new VariableQueueCleaner;
}
Constraint* SetIsEqual(IntVar* const var,
const std::vector<int64>& values,
const std::vector<IntVar*>& vars) {
DomainIntVar* const dvar = dynamic_cast<DomainIntVar*>(var);
CHECK_NOTNULL(dvar);
return dvar->SetIsEqual(values, vars);
}
Constraint* SetIsGreaterOrEqual(IntVar* const var,
const std::vector<int64>& values,
const std::vector<IntVar*>& vars) {
DomainIntVar* const dvar = dynamic_cast<DomainIntVar*>(var);
CHECK_NOTNULL(dvar);
return dvar->SetIsGreaterOrEqual(values, vars);
}
void Solver::set_queue_cleaner_on_fail(IntVar* const var) {
DCHECK_EQ(DOMAIN_INT_VAR, var->VarType());
DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);

View File

@@ -1303,37 +1303,49 @@ IntExpr* BuildElement(CPModelLoader* const builder,
}
// ----- kElementEqual -----
// TODO(user): Add API on solver and uncomment this method.
/*
Constraint* BuildElementEqual(CPModelLoader* const builder,
const CPConstraintProto& proto) {
IntExpr* target = NULL;
VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument,
proto,
&target));
Constraint* BuildElementEqual(CPModelLoader* const builder,
const CPConstraintProto& proto) {
IntExpr* index = NULL;
VERIFY(builder->ScanArguments(ModelVisitor::kIndexArgument,
proto,
&index));
std::vector<int64> values;
if (builder->ScanArguments(ModelVisitor::kValuesArgument,
proto,
&values)) {
IntExpr* index = NULL;
VERIFY(builder->ScanArguments(ModelVisitor::kIndexArgument,
proto,
&index));
return builder->solver()->MakeElement(values, index->Var());
}
std::vector<IntVar*> vars;
if (builder->ScanArguments(ModelVisitor::kVarsArgument,
proto,
&vars)) {
IntExpr* index = NULL;
VERIFY(builder->ScanArguments(ModelVisitor::kIndexArgument,
proto,
&index));
return builder->solver()->MakeElement(vars, index->Var());
proto,
&values)) {
IntExpr* target = NULL;
VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument,
proto,
&target));
return builder->solver()->MakeElementEquality(values,
index->Var(),
target->Var());
} else {
std::vector<IntVar*> vars;
if (builder->ScanArguments(ModelVisitor::kVarsArgument,
proto,
&vars)) {
IntExpr* target = NULL;
if (builder->ScanArguments(ModelVisitor::kTargetArgument,
proto,
&target)) {
return builder->solver()->MakeElementEquality(vars,
index->Var(),
target->Var());
} else {
int64 target_value = 0;
VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument,
proto,
&target_value));
return builder->solver()->MakeElementEquality(vars,
index->Var(),
target_value);
}
}
}
return NULL;
}
*/
}
// ----- kEndExpr -----
@@ -1400,6 +1412,27 @@ Constraint* BuildGreaterOrEqual(CPModelLoader* const builder,
return NULL;
}
// ----- kIndexOf -----
Constraint* BuildIndexOf(CPModelLoader* const builder,
const CPConstraintProto& proto) {
IntExpr* index = NULL;
VERIFY(builder->ScanArguments(ModelVisitor::kIndexArgument,
proto,
&index));
std::vector<IntVar*> vars;
VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument,
proto,
&vars));
int64 target_value = 0;
VERIFY(builder->ScanArguments(ModelVisitor::kTargetArgument,
proto,
&target_value));
return builder->solver()->MakeIndexOfConstraint(vars,
index->Var(),
target_value);
}
// ----- kIntegerVariable -----
IntExpr* BuildIntegerVariable(CPModelLoader* const builder,
@@ -2214,6 +2247,44 @@ Constraint* BuildTrueConstraint(CPModelLoader* const builder,
return builder->solver()->MakeTrueConstraint();
}
// ----- kVarValueWatcher -----
Constraint* SetIsEqual(IntVar* const var,
const std::vector<int64>& values,
const std::vector<IntVar*>& vars);
Constraint* BuildVarValueWatcher(CPModelLoader* const builder,
const CPConstraintProto& proto) {
IntExpr* expr = NULL;
VERIFY(builder->ScanArguments(ModelVisitor::kVariableArgument, proto, &expr));
std::vector<IntVar*> vars;
VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars));
std::vector<int64> values;
VERIFY(builder->ScanArguments(ModelVisitor::kValuesArgument,
proto,
&values));
return SetIsEqual(expr->Var(), values, vars);
}
// ----- kVarBoundWatcher -----
Constraint* SetIsGreaterOrEqual(IntVar* const var,
const std::vector<int64>& values,
const std::vector<IntVar*>& vars);
Constraint* BuildVarBoundWatcher(CPModelLoader* const builder,
const CPConstraintProto& proto) {
IntExpr* expr = NULL;
VERIFY(builder->ScanArguments(ModelVisitor::kVariableArgument, proto, &expr));
std::vector<IntVar*> vars;
VERIFY(builder->ScanArguments(ModelVisitor::kVarsArgument, proto, &vars));
std::vector<int64> values;
VERIFY(builder->ScanArguments(ModelVisitor::kValuesArgument,
proto,
&values));
return SetIsGreaterOrEqual(expr->Var(), values, vars);
}
#undef VERIFY
#undef VERIFY_EQ
} // namespace
@@ -2226,6 +2297,7 @@ bool CPModelLoader::BuildFromProto(const CPIntegerExpressionProto& proto) {
Solver::IntegerExpressionBuilder* const builder =
solver_->GetIntegerExpressionBuilder(tags_.Element(tag_index));
if (!builder) {
LOG(WARNING) << "Tag " << tags_.Element(tag_index) << " was not found";
return false;
}
IntExpr* const built = builder->Run(this, proto);
@@ -2243,6 +2315,7 @@ Constraint* CPModelLoader::BuildFromProto(const CPConstraintProto& proto) {
Solver::ConstraintBuilder* const builder =
solver_->GetConstraintBuilder(tags_.Element(tag_index));
if (!builder) {
LOG(WARNING) << "Tag " << tags_.Element(tag_index) << " was not found";
return NULL;
}
Constraint* const built = builder->Run(this, proto);
@@ -2255,6 +2328,7 @@ bool CPModelLoader::BuildFromProto(const CPIntervalVariableProto& proto) {
Solver::IntervalVariableBuilder* const builder =
solver_->GetIntervalVariableBuilder(tags_.Element(tag_index));
if (!builder) {
LOG(WARNING) << "Tag " << tags_.Element(tag_index) << " was not found";
return NULL;
}
IntervalVar* const built = builder->Run(this, proto);
@@ -2272,6 +2346,7 @@ bool CPModelLoader::BuildFromProto(const CPSequenceVariableProto& proto) {
Solver::SequenceVariableBuilder* const builder =
solver_->GetSequenceVariableBuilder(tags_.Element(tag_index));
if (!builder) {
LOG(WARNING) << "Tag " << tags_.Element(tag_index) << " was not found";
return NULL;
}
SequenceVar* const built = builder->Run(this, proto);
@@ -2588,12 +2663,13 @@ void Solver::InitBuilders() {
REGISTER(kDivide, BuildDivide);
REGISTER(kDurationExpr, BuildDurationExpr);
REGISTER(kElement, BuildElement);
// REGISTER(kElementEqual, BuildElementEqual);
REGISTER(kElementEqual, BuildElementEqual);
REGISTER(kEndExpr, BuildEndExpr);
REGISTER(kEquality, BuildEquality);
REGISTER(kFalseConstraint, BuildFalseConstraint);
REGISTER(kGreater, BuildGreater);
REGISTER(kGreaterOrEqual, BuildGreaterOrEqual);
REGISTER(kIndexOf, BuildIndexOf);
REGISTER(kIntegerVariable, BuildIntegerVariable);
REGISTER(kIntervalBinaryRelation, BuildIntervalBinaryRelation);
REGISTER(kIntervalDisjunction, BuildIntervalDisjunction);
@@ -2635,6 +2711,8 @@ void Solver::InitBuilders() {
REGISTER(kSumLessOrEqual, BuildSumLessOrEqual);
REGISTER(kTransition, BuildTransition);
REGISTER(kTrueConstraint, BuildTrueConstraint);
REGISTER(kVarBoundWatcher, BuildVarBoundWatcher);
REGISTER(kVarValueWatcher, BuildVarValueWatcher);
}
#undef REGISTER