more work on variable collector, add more constraints as sub-classes of cast constraint

This commit is contained in:
lperron@google.com
2011-12-17 21:20:58 +00:00
parent e5cb0df1ca
commit dd008982c7
4 changed files with 235 additions and 58 deletions

View File

@@ -27,6 +27,97 @@
namespace operations_research {
namespace {
class ArgumentHolder {
public:
struct Matrix {
Matrix() : values(NULL), rows(0), columns(0) {}
Matrix(const int64 * const * v, int r, int c)
: values(v), rows(r), columns(c) {}
const int64* const * values;
int rows;
int columns;
};
const string& type_name() const {
return type_name_;
}
void set_type_name(const string& type_name) {
type_name_ = type_name;
}
void set_integer_matrix_argument(const string& arg_name,
const int64* const * const values,
int rows,
int columns) {
matrix_argument_[arg_name] = Matrix(values, rows, columns);
}
void set_integer_expression_argument(const string& arg_name,
const IntExpr* const expr) {
integer_expression_argument_[arg_name] = expr;
}
void set_integer_variable_array_argument(const string& arg_name,
const IntVar* const * const vars,
int size) {
for (int i = 0; i < size; ++i) {
integer_variable_array_argument_[arg_name].push_back(vars[i]);
}
}
void set_interval_argument(const string& arg_name,
const IntervalVar* const var) {
interval_argument_[arg_name] = var;
}
void set_interval_array_argument(const string& arg_name,
const IntervalVar* const * const vars,
int size) {
for (int i = 0; i < size; ++i) {
interval_array_argument_[arg_name].push_back(vars[i]);
}
}
void set_sequence_argument(const string& arg_name,
const SequenceVar* const var) {
sequence_argument_[arg_name] = var;
}
void set_sequence_array_argument(const string& arg_name,
const SequenceVar* const * const vars,
int size) {
for (int i = 0; i < size; ++i) {
sequence_array_argument_[arg_name].push_back(vars[i]);
}
}
const IntExpr* FindIntegerExpressionArgumentOrDie(const string& arg_name) {
return FindOrDie(integer_expression_argument_, arg_name);
}
const std::vector<const IntVar*>& FindIntegerVariableArrayArgumentOrDie(
const string& arg_name) {
return FindOrDie(integer_variable_array_argument_, arg_name);
}
const Matrix& FindIntegerMatrixArgumentOrDie(const string& arg_name) {
return FindOrDie(matrix_argument_, arg_name);
}
private:
string type_name_;
hash_map<string, const IntExpr*> integer_expression_argument_;
hash_map<string, const IntervalVar*> interval_argument_;
hash_map<string, const SequenceVar*> sequence_argument_;
hash_map<string,
std::vector<const IntVar*> > integer_variable_array_argument_;
hash_map<string, std::vector<const IntervalVar*> > interval_array_argument_;
hash_map<string, std::vector<const SequenceVar*> > sequence_array_argument_;
hash_map<string, Matrix> matrix_argument_;
};
class CollectVariablesVisitor : public ModelVisitor {
public:
CollectVariablesVisitor(std::vector<IntVar*>* const primary_integer_variables,
@@ -41,32 +132,75 @@ class CollectVariablesVisitor : public ModelVisitor {
// Header/footers.
virtual void BeginVisitModel(const string& solver_name) {
LOG(INFO) << "Starts collecting variables on model " << solver_name;
PushArgumentHolder();
}
virtual void EndVisitModel(const string& solver_name) {
LOG(INFO) << "Finishes collecting variables.";
PopArgumentHolder();
primaries_->assign(primary_set_.begin(), primary_set_.end());
secondaries_->assign(secondary_set_.begin(), secondary_set_.end());
intervals_->assign(interval_set_.begin(), interval_set_.end());
sequences_->assign(sequence_set_.begin(), sequence_set_.end());
}
virtual void BeginVisitConstraint(const string& type_name,
const Constraint* const constraint) {
if (constraint->IsCastConstraint()) {
const CastConstraint* const cast_constraint =
reinterpret_cast<const CastConstraint* const>(constraint);
ignored_set_.insert(cast_constraint->target_var());
}
PushArgumentHolder();
}
virtual void EndVisitConstraint(const string& type_name,
const Constraint* const constraint) {
if (type_name.compare(ModelVisitor::kLinkExprVar) == 0 ||
type_name.compare(ModelVisitor::kSumEqual) == 0 ||
type_name.compare(ModelVisitor::kCountEqual) == 0 ||
type_name.compare(ModelVisitor::kElementEqual) == 0 ||
type_name.compare(ModelVisitor::kScalProdEqual) == 0 ||
type_name.compare(ModelVisitor::kIsEqual) == 0 ||
type_name.compare(ModelVisitor::kIsDifferent) == 0 ||
type_name.compare(ModelVisitor::kIsGreaterOrEqual) == 0 ||
type_name.compare(ModelVisitor::kIsLessOrEqual) == 0) {
IntExpr* const target_expr =
const_cast<IntExpr*>(
top()->FindIntegerExpressionArgumentOrDie(
ModelVisitor::kTargetArgument));
IntVar* const target_var = target_expr->Var();
IgnoreIntegerVariable(target_var);
} else if (type_name.compare(ModelVisitor::kAllowedAssignments) == 0) {
const ArgumentHolder::Matrix& matrix =
top()->FindIntegerMatrixArgumentOrDie(ModelVisitor::kTuplesArgument);
vector<hash_set<int> > counters(matrix.columns);
for (int i = 0; i < matrix.rows; ++i) {
for (int j = 0; j < matrix.columns; ++j) {
counters[j].insert(matrix.values[i][j]);
}
}
for (int j = 0; j < matrix.columns; ++j) {
if (counters[j].size() == matrix.rows) {
vector<const IntVar*> vars =
top()->FindIntegerVariableArrayArgumentOrDie(
ModelVisitor::kVarsArgument);
LOG(INFO) << "Found index variable in allowed assignment constraint: "
<< vars[j]->DebugString();
for (int k = 0; k < matrix.columns; ++k) {
if (j != k) {
IgnoreIntegerVariable(const_cast<IntVar*>(vars[k]));
}
}
break;
}
}
}
PopArgumentHolder();
}
virtual void BeginVisitIntegerExpression(const string& type_name,
const IntExpr* const expr) {
PushArgumentHolder();
}
virtual void EndVisitIntegerExpression(const string& type_name,
const IntExpr* const expr) {
PopArgumentHolder();
}
virtual void VisitIntegerVariable(const IntVar* const variable,
@@ -80,7 +214,6 @@ class CollectVariablesVisitor : public ModelVisitor {
!ContainsKey(ignored_set_, var) &&
!var->Bound()) {
primary_set_.insert(var);
primaries_->push_back(const_cast<IntVar*>(var));
}
}
}
@@ -101,7 +234,6 @@ class CollectVariablesVisitor : public ModelVisitor {
IntervalVar* const var = const_cast<IntervalVar*>(variable);
if (!ContainsKey(interval_set_, var)) {
interval_set_.insert(var);
intervals_->push_back(var);
}
}
}
@@ -119,17 +251,25 @@ class CollectVariablesVisitor : public ModelVisitor {
SequenceVar* const var = const_cast<SequenceVar*>(variable);
if (!ContainsKey(sequence_set_, var)) {
sequence_set_.insert(var);
sequences_->push_back(var);
}
for (int i = 0; i < var->size(); ++i) {
var->Interval(i)->Accept(this);
}
}
// Integer arguments
virtual void VisitIntegerMatrixArgument(const string& arg_name,
const int64* const * const values,
int rows,
int columns) {
top()->set_integer_matrix_argument(arg_name, values, rows, columns);
}
// Variables.
virtual void VisitIntegerExpressionArgument(
const string& arg_name,
const IntExpr* const argument) {
top()->set_integer_expression_argument(arg_name, argument);
argument->Accept(this);
}
@@ -137,6 +277,7 @@ class CollectVariablesVisitor : public ModelVisitor {
const string& arg_name,
const IntVar* const * arguments,
int size) {
top()->set_integer_variable_array_argument(arg_name, arguments, size);
for (int i = 0; i < size; ++i) {
arguments[i]->Accept(this);
}
@@ -145,12 +286,14 @@ class CollectVariablesVisitor : public ModelVisitor {
// Visit interval argument.
virtual void VisitIntervalArgument(const string& arg_name,
const IntervalVar* const argument) {
top()->set_interval_argument(arg_name, argument);
argument->Accept(this);
}
virtual void VisitIntervalArgumentArray(const string& arg_name,
const IntervalVar* const * arguments,
int size) {
top()->set_interval_array_argument(arg_name, arguments, size);
for (int i = 0; i < size; ++i) {
arguments[i]->Accept(this);
}
@@ -159,18 +302,41 @@ class CollectVariablesVisitor : public ModelVisitor {
// Visit sequence argument.
virtual void VisitSequenceArgument(const string& arg_name,
const SequenceVar* const argument) {
top()->set_sequence_argument(arg_name, argument);
argument->Accept(this);
}
virtual void VisitSequenceArgumentArray(const string& arg_name,
const SequenceVar* const * arguments,
int size) {
top()->set_sequence_array_argument(arg_name, arguments, size);
for (int i = 0; i < size; ++i) {
arguments[i]->Accept(this);
}
}
private:
void PushArgumentHolder() {
holders_.push_back(new ArgumentHolder);
}
void PopArgumentHolder() {
CHECK(!holders_.empty());
delete holders_.back();
holders_.pop_back();
}
ArgumentHolder* top() const {
CHECK(!holders_.empty());
return holders_.back();
}
void IgnoreIntegerVariable(IntVar* const var) {
primary_set_.erase(var);
secondary_set_.erase(var);
ignored_set_.insert(var);
}
std::vector<IntVar*>* const primaries_;
std::vector<IntVar*>* const secondaries_;
std::vector<SequenceVar*>* const sequences_;
@@ -180,6 +346,7 @@ class CollectVariablesVisitor : public ModelVisitor {
hash_set<IntVar*> ignored_set_;
hash_set<SequenceVar*> sequence_set_;
hash_set<IntervalVar*> interval_set_;
std::vector<ArgumentHolder*> holders_;
};
} // namespace

View File

@@ -1402,9 +1402,7 @@ IntExpr* Solver::MakeSum(IntVar* const* vars, int size) {
sum_max += vars[i]->Max();
}
IntVar* const sum_var = MakeIntVar(sum_min, sum_max);
AddCastConstraint(RevAlloc(new SumConstraint(this, vars, size, sum_var)),
sum_var,
NULL);
AddConstraint(RevAlloc(new SumConstraint(this, vars, size, sum_var)));,
return sum_var;
}
}

View File

@@ -287,23 +287,23 @@ Constraint* Solver::MakeNonEquality(IntVar* const e, int v) {
// ----- is_equal_cst Constraint -----
namespace {
class IsEqualCstCt : public Constraint {
class IsEqualCstCt : public CastConstraint {
public:
IsEqualCstCt(Solver* const s, IntVar* const v, int64 c, IntVar* const b)
: Constraint(s), var_(v), cst_(c), boolvar_(b), demon_(NULL) {}
: CastConstraint(s, b), var_(v), cst_(c), demon_(NULL) {}
virtual void Post() {
demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
var_->WhenDomain(demon_);
boolvar_->WhenBound(demon_);
target_var_->WhenBound(demon_);
}
virtual void InitialPropagate() {
bool inhibit = var_->Bound();
int64 u = var_->Contains(cst_);
int64 l = inhibit ? u : 0;
boolvar_->SetRange(l, u);
if (boolvar_->Bound()) {
target_var_->SetRange(l, u);
if (target_var_->Bound()) {
inhibit = true;
if (boolvar_->Min() == 0) {
if (target_var_->Min() == 0) {
var_->RemoveValue(cst_);
} else {
var_->SetValue(cst_);
@@ -317,7 +317,7 @@ class IsEqualCstCt : public Constraint {
return StringPrintf("IsEqualCstCt(%s, %" GG_LL_FORMAT "d, %s)",
var_->DebugString().c_str(),
cst_,
boolvar_->DebugString().c_str());
target_var_->DebugString().c_str());
}
void Accept(ModelVisitor* const visitor) const {
@@ -326,14 +326,13 @@ class IsEqualCstCt : public Constraint {
var_);
visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
boolvar_);
target_var_);
visitor->EndVisitConstraint(ModelVisitor::kIsEqual, this);
}
private:
IntVar* const var_;
int64 cst_;
IntVar* const boolvar_;
Demon* demon_;
};
} // namespace
@@ -358,10 +357,14 @@ IntVar* Solver::MakeIsEqualCstVar(IntVar* const var, int64 value) {
if (cache != NULL) {
return cache->Var();
} else {
string name = var->name();
if (name.empty()) {
name = var->DebugString();
}
IntVar* const boolvar = MakeBoolVar(
StringPrintf("StatusVar<%s == %" GG_LL_FORMAT "d>",
var->name().c_str(), value));
Constraint* const maintain =
name.c_str(), value));
CastConstraint* const maintain =
RevAlloc(new IsEqualCstCt(this, var, value, boolvar));
AddConstraint(maintain);
model_cache_->InsertVarConstantExpression(
@@ -397,25 +400,25 @@ Constraint* Solver::MakeIsEqualCstCt(IntVar* const var,
// ----- is_diff_cst Constraint -----
namespace {
class IsDiffCstCt : public Constraint {
class IsDiffCstCt : public CastConstraint {
public:
IsDiffCstCt(Solver* const s, IntVar* const v, int64 c, IntVar* const b)
: Constraint(s), var_(v), cst_(c), boolvar_(b), demon_(NULL) {}
: CastConstraint(s, b), var_(v), cst_(c), demon_(NULL) {}
virtual void Post() {
demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
var_->WhenDomain(demon_);
boolvar_->WhenBound(demon_);
target_var_->WhenBound(demon_);
}
virtual void InitialPropagate() {
bool inhibit = var_->Bound();
int64 l = 1 - var_->Contains(cst_);
int64 u = inhibit ? l : 1;
boolvar_->SetRange(l, u);
if (boolvar_->Bound()) {
target_var_->SetRange(l, u);
if (target_var_->Bound()) {
inhibit = true;
if (boolvar_->Min() == 1) {
if (target_var_->Min() == 1) {
var_->RemoveValue(cst_);
} else {
var_->SetValue(cst_);
@@ -430,7 +433,7 @@ class IsDiffCstCt : public Constraint {
return StringPrintf("IsDiffCstCt(%s, %" GG_LL_FORMAT "d, %s)",
var_->DebugString().c_str(),
cst_,
boolvar_->DebugString().c_str());
target_var_->DebugString().c_str());
}
void Accept(ModelVisitor* const visitor) const {
@@ -439,14 +442,13 @@ class IsDiffCstCt : public Constraint {
var_);
visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
boolvar_);
target_var_);
visitor->EndVisitConstraint(ModelVisitor::kIsDifferent, this);
}
private:
IntVar* const var_;
int64 cst_;
IntVar* const boolvar_;
Demon* demon_;
};
} // namespace
@@ -471,10 +473,14 @@ IntVar* Solver::MakeIsDifferentCstVar(IntVar* const var, int64 value) {
if (cache != NULL) {
return cache->Var();
} else {
string name = var->name();
if (name.empty()) {
name = var->DebugString();
}
IntVar* const boolvar = MakeBoolVar(
StringPrintf("StatusVar<%s != %" GG_LL_FORMAT "d>",
var->name().c_str(), value));
Constraint* const maintain =
name.c_str(), value));
CastConstraint* const maintain =
RevAlloc(new IsDiffCstCt(this, var, value, boolvar));
AddConstraint(maintain);
model_cache_->InsertVarConstantExpression(
@@ -508,24 +514,24 @@ Constraint* Solver::MakeIsDifferentCstCt(IntVar* const var,
// ----- is_greater_equal_cst Constraint -----
namespace {
class IsGreaterEqualCstCt : public Constraint {
class IsGreaterEqualCstCt : public CastConstraint {
public:
IsGreaterEqualCstCt(Solver* const s, IntVar* const v, int64 c,
IntVar* const b)
: Constraint(s), var_(v), cst_(c), boolvar_(b), demon_(NULL) {}
: CastConstraint(s, b), var_(v), cst_(c), demon_(NULL) {}
virtual void Post() {
demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
var_->WhenRange(demon_);
boolvar_->WhenBound(demon_);
target_var_->WhenBound(demon_);
}
virtual void InitialPropagate() {
bool inhibit = false;
int64 u = var_->Max() >= cst_;
int64 l = var_->Min() >= cst_;
boolvar_->SetRange(l, u);
if (boolvar_->Bound()) {
target_var_->SetRange(l, u);
if (target_var_->Bound()) {
inhibit = true;
if (boolvar_->Min() == 0) {
if (target_var_->Min() == 0) {
var_->SetMax(cst_ - 1);
} else {
var_->SetMin(cst_);
@@ -539,7 +545,7 @@ class IsGreaterEqualCstCt : public Constraint {
return StringPrintf("IsGreaterEqualCstCt(%s, %" GG_LL_FORMAT "d, %s)",
var_->DebugString().c_str(),
cst_,
boolvar_->DebugString().c_str());
target_var_->DebugString().c_str());
}
virtual void Accept(ModelVisitor* const visitor) const {
@@ -548,14 +554,13 @@ class IsGreaterEqualCstCt : public Constraint {
var_);
visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
boolvar_);
target_var_);
visitor->EndVisitConstraint(ModelVisitor::kIsGreaterOrEqual, this);
}
private:
IntVar* const var_;
int64 cst_;
IntVar* const boolvar_;
Demon* demon_;
};
} // namespace
@@ -574,10 +579,14 @@ IntVar* Solver::MakeIsGreaterOrEqualCstVar(IntVar* const var, int64 value) {
if (cache != NULL) {
return cache->Var();
} else {
string name = var->name();
if (name.empty()) {
name = var->DebugString();
}
IntVar* const boolvar = MakeBoolVar(
StringPrintf("StatusVar<%s >= %" GG_LL_FORMAT "d>",
var->name().c_str(), value));
Constraint* const maintain =
name.c_str(), value));
CastConstraint* const maintain =
RevAlloc(new IsGreaterEqualCstCt(this, var, value, boolvar));
AddConstraint(maintain);
model_cache_->InsertVarConstantExpression(
@@ -614,25 +623,25 @@ Constraint* Solver::MakeIsGreaterCstCt(IntVar* const v, int64 c,
// ----- is_lesser_equal_cst Constraint -----
namespace {
class IsLessEqualCstCt : public Constraint {
class IsLessEqualCstCt : public CastConstraint {
public:
IsLessEqualCstCt(Solver* const s, IntVar* const v, int64 c, IntVar* const b)
: Constraint(s), var_(v), cst_(c), boolvar_(b), demon_(NULL) {}
: CastConstraint(s, b), var_(v), cst_(c), demon_(NULL) {}
virtual void Post() {
demon_ = solver()->MakeConstraintInitialPropagateCallback(this);
var_->WhenRange(demon_);
boolvar_->WhenBound(demon_);
target_var_->WhenBound(demon_);
}
virtual void InitialPropagate() {
bool inhibit = false;
int64 u = var_->Min() <= cst_;
int64 l = var_->Max() <= cst_;
boolvar_->SetRange(l, u);
if (boolvar_->Bound()) {
target_var_->SetRange(l, u);
if (target_var_->Bound()) {
inhibit = true;
if (boolvar_->Min() == 0) {
if (target_var_->Min() == 0) {
var_->SetMin(cst_ + 1);
} else {
var_->SetMax(cst_);
@@ -647,7 +656,7 @@ class IsLessEqualCstCt : public Constraint {
return StringPrintf("IsLessEqualCstCt(%s, %" GG_LL_FORMAT "d, %s)",
var_->DebugString().c_str(),
cst_,
boolvar_->DebugString().c_str());
target_var_->DebugString().c_str());
}
virtual void Accept(ModelVisitor* const visitor) const {
@@ -656,14 +665,13 @@ class IsLessEqualCstCt : public Constraint {
var_);
visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, cst_);
visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
boolvar_);
target_var_);
visitor->EndVisitConstraint(ModelVisitor::kIsLessOrEqual, this);
}
private:
IntVar* const var_;
int64 cst_;
IntVar* const boolvar_;
Demon* demon_;
};
} // namespace
@@ -683,10 +691,14 @@ IntVar* Solver::MakeIsLessOrEqualCstVar(IntVar* const var, int64 value) {
if (cache != NULL) {
return cache->Var();
} else {
string name = var->name();
if (name.empty()) {
name = var->DebugString();
}
IntVar* const boolvar = MakeBoolVar(
StringPrintf("StatusVar<%s <= %" GG_LL_FORMAT "d>",
var->name().c_str(), value));
Constraint* const maintain =
name.c_str(), value));
CastConstraint* const maintain =
RevAlloc(new IsLessEqualCstCt(this, var, value, boolvar));
AddConstraint(maintain);
model_cache_->InsertVarConstantExpression(

View File

@@ -582,7 +582,7 @@ class SecondPassVisitor : public ModelVisitor {
virtual void EndVisitConstraint(const string& type_name,
const Constraint* const constraint) {
// We ignore delegate constraints, they will be regenerated automatically.
// We ignore cast constraints, they will be regenerated automatically.
if (constraint->IsCastConstraint()) {
return;
}