cleanup on global arithmetic constraint

This commit is contained in:
lperron@google.com
2010-12-16 22:45:09 +00:00
parent 1d16db0eaa
commit c8e3d9e5cc

View File

@@ -293,13 +293,7 @@ class VarEqualVarPlusOffset : public ArithmeticConstraint {
class ScalProdOpConstant : public ArithmeticConstraint {
public:
enum Operation {
LESS_OR_EQUAL,
EQUAL,
GREATER_OR_EQUAL
};
ScalProdOpConstant(int64 constant, Operation op)
: constant_(constant), operation_(op) {}
ScalProdOpConstant(int64 lb, int64 ub) : lb_(lb), ub_(ub) {}
virtual ~ScalProdOpConstant() {}
void AddTerm(int var_index, int64 coefficient) {
@@ -316,7 +310,12 @@ class ScalProdOpConstant : public ArithmeticConstraint {
if (find_other != coefficients_.end()) {
hash_map<int, int64>::iterator find_var = coefficients_.find(var);
const int64 other_coefficient = find_other->second;
constant_ += other_coefficient * offset;
if (lb_ != kint64min) {
lb_ += other_coefficient * offset;
}
if (ub_ != kint64max) {
ub_ += other_coefficient * offset;
}
coefficients_.erase(find_other);
if (find_var == coefficients_.end()) {
coefficients_[var] = other_coefficient;
@@ -340,33 +339,39 @@ class ScalProdOpConstant : public ArithmeticConstraint {
if (it->second != 0) {
if (first) {
first = false;
if (it->second == 1) {
output += StringPrintf("var<%d>", it->first);
} else if (it->second == -1) {
output += StringPrintf("-var<%d>", it->first);
} else {
output += StringPrintf("%lld*var<%d>", it->second, it->first);
}
} else if (it->second == 1) {
output += StringPrintf(" + var<%d>", it->first);
} else if (it->second == -1) {
output += StringPrintf(" - var<%d>", it->first);
} else if (it->second > 0) {
output += StringPrintf(" + %lld*var<%d>", it->second, it->first);
} else {
output += " + ";
output += StringPrintf(" - %lld*var<%d>", -it->second, it->first);
}
output += StringPrintf("%lld * var<%d>", it->second, it->first);
}
}
output += ") ";
switch (operation_) {
case LESS_OR_EQUAL:
output += "<=";
break;
case EQUAL:
output += "==";
break;
case GREATER_OR_EQUAL:
output += ">=";
break;
default:
LOG(FATAL) << "Should not be here";
if (lb_ == ub_) {
output += StringPrintf(" == %lld)", ub_);
} else if (lb_ == kint64min) {
output += StringPrintf(" <= %lld)", ub_);
} else if (ub_ == kint64max) {
output += StringPrintf(" >= %lld)", lb_);
} else {
output += StringPrintf(" in [%lld .. %lld])", lb_, ub_);
}
output += StringPrintf(" %lld", constant_);
return output;
}
private:
hash_map<int, int64> coefficients_;
int64 constant_;
const Operation operation_;
int64 lb_;
int64 ub_;
};
class OrConstraint : public ArithmeticConstraint {
@@ -453,8 +458,7 @@ class GlobalArithmeticConstraint : public Constraint {
const vector<int64> coefficients,
int64 constant) {
ScalProdOpConstant* const constraint =
new ScalProdOpConstant(constant,
ScalProdOpConstant::GREATER_OR_EQUAL);
new ScalProdOpConstant(constant, kint64max);
for (int index = 0; index < vars.size(); ++index) {
constraint->AddTerm(VarIndex(vars[index]), coefficients[index]);
}
@@ -465,7 +469,7 @@ class GlobalArithmeticConstraint : public Constraint {
const vector<int64> coefficients,
int64 constant) {
ScalProdOpConstant* const constraint =
new ScalProdOpConstant(constant, ScalProdOpConstant::LESS_OR_EQUAL);
new ScalProdOpConstant(kint64min, constant);
for (int index = 0; index < vars.size(); ++index) {
constraint->AddTerm(VarIndex(vars[index]), coefficients[index]);
}
@@ -476,7 +480,7 @@ class GlobalArithmeticConstraint : public Constraint {
const vector<int64> coefficients,
int64 constant) {
ScalProdOpConstant* const constraint =
new ScalProdOpConstant(constant, ScalProdOpConstant::EQUAL);
new ScalProdOpConstant(constant, constant);
for (int index = 0; index < vars.size(); ++index) {
constraint->AddTerm(VarIndex(vars[index]), coefficients[index]);
}
@@ -486,8 +490,7 @@ class GlobalArithmeticConstraint : public Constraint {
int MakeSumGreaterOrEqualConstant(const vector<IntVar*> vars,
int64 constant) {
ScalProdOpConstant* const constraint =
new ScalProdOpConstant(constant,
ScalProdOpConstant::GREATER_OR_EQUAL);
new ScalProdOpConstant(constant, kint64max);
for (int index = 0; index < vars.size(); ++index) {
constraint->AddTerm(VarIndex(vars[index]), 1);
}
@@ -496,7 +499,7 @@ class GlobalArithmeticConstraint : public Constraint {
int MakeSumLessOrEqualConstant(const vector<IntVar*> vars, int64 constant) {
ScalProdOpConstant* const constraint =
new ScalProdOpConstant(constant, ScalProdOpConstant::LESS_OR_EQUAL);
new ScalProdOpConstant(kint64min, constant);
for (int index = 0; index < vars.size(); ++index) {
constraint->AddTerm(VarIndex(vars[index]), 1);
}
@@ -505,7 +508,7 @@ class GlobalArithmeticConstraint : public Constraint {
int MakeSumEqualConstant(const vector<IntVar*> vars, int64 constant) {
ScalProdOpConstant* const constraint =
new ScalProdOpConstant(constant, ScalProdOpConstant::EQUAL);
new ScalProdOpConstant(constant, constant);
for (int index = 0; index < vars.size(); ++index) {
constraint->AddTerm(VarIndex(vars[index]), 1);
}