experimental code to linearize the scal prod

This commit is contained in:
lperron@google.com
2012-06-29 17:09:19 +00:00
parent 8e2352e545
commit 640ba535ff

View File

@@ -2398,6 +2398,275 @@ Constraint* MakeScalProdGreaterOrEqualFct(Solver* solver,
}
return solver->MakeSumGreaterOrEqual(terms, cst);
}
#define IS_TYPE(type, tag) type.compare(ModelVisitor::tag) == 0
class ExprLinearizer : public ModelParser {
public:
ExprLinearizer(hash_map<const IntExpr*, int64>* const map)
: map_(map), constant_(0) {}
virtual ~ExprLinearizer() {}
// Begin/End visit element.
virtual void BeginVisitModel(const string& solver_name) {
LOG(FATAL) << "Should not be here";
}
virtual void EndVisitModel(const string& solver_name) {
LOG(FATAL) << "Should not be here";
}
virtual void BeginVisitConstraint(const string& type_name,
const Constraint* const constraint) {
LOG(FATAL) << "Should not be here";
}
virtual void EndVisitConstraint(const string& type_name,
const Constraint* const constraint) {
LOG(FATAL) << "Should not be here";
}
virtual void BeginVisitExtension(const string& type) {
LOG(FATAL) << "Should not be here";
}
virtual void EndVisitExtension(const string& type) {
LOG(FATAL) << "Should not be here";
}
virtual void BeginVisitIntegerExpression(const string& type_name,
const IntExpr* const expr) {
BeginVisit(true);
}
virtual void EndVisitIntegerExpression(const string& type_name,
const IntExpr* const expr) {
if (IS_TYPE(type_name, kSum)) {
VisitSum(expr);
} else if (IS_TYPE(type_name, kScalProd)) {
VisitScalProd(expr);
} else if (IS_TYPE(type_name, kDifference)) {
VisitDifference(expr);
} else if (IS_TYPE(type_name, kOpposite)) {
VisitOpposite(expr);
} else if (IS_TYPE(type_name, kProduct)) {
VisitProduct(expr);
} else {
VisitIntegerExpression(expr);
}
EndVisit();
}
virtual void VisitIntegerVariable(const IntVar* const variable,
const string& operation,
int64 value,
const IntVar* const delegate) {
if (operation == ModelVisitor::kSumOperation) {
AddConstant(value);
VisitSubExpression(delegate);
} else if (operation == ModelVisitor::kDifferenceOperation) {
AddConstant(value);
PushMultiplier(-1);
VisitSubExpression(delegate);
PopMultiplier();
} else if (operation == ModelVisitor::kProductOperation) {
PushMultiplier(value);
VisitSubExpression(delegate);
PopMultiplier();
}
}
virtual void VisitIntegerVariable(const IntVar* const variable,
const IntExpr* const delegate) {
if (delegate == NULL) {
if (variable->Bound()) {
AddConstant(variable->Min());
} else {
RegisterExpression(variable, 1);
}
} else {
VisitSubExpression(delegate);
}
}
// Visit integer arguments.
virtual void VisitIntegerArgument(const string& arg_name, int64 value) {
Top()->SetIntegerArgument(arg_name, value);
}
virtual void VisitIntegerArrayArgument(const string& arg_name,
const int64* const values,
int size) {
Top()->SetIntegerArrayArgument(arg_name, values, size);
}
virtual void VisitIntegerMatrixArgument(const string& arg_name,
const IntTupleSet& values) {
Top()->SetIntegerMatrixArgument(arg_name, values);
}
// Visit integer expression argument.
virtual void VisitIntegerExpressionArgument(
const string& arg_name,
const IntExpr* const argument) {
Top()->SetIntegerExpressionArgument(arg_name, argument);
}
virtual void VisitIntegerVariableArrayArgument(
const string& arg_name,
const IntVar* const * arguments,
int size) {
Top()->SetIntegerVariableArrayArgument(arg_name, arguments, size);
}
// Visit interval argument.
virtual void VisitIntervalArgument(const string& arg_name,
const IntervalVar* const argument) {}
virtual void VisitIntervalArrayArgument(const string& arg_name,
const IntervalVar* const * argument,
int size) {}
void Visit(const IntExpr* const expr, int64 multiplier) {
PushMultiplier(multiplier);
expr->Accept(this);
PopMultiplier();
}
int64 Constant() const { return constant_; }
private:
void BeginVisit(bool active) {
PushArgumentHolder();
}
void EndVisit() {
PopArgumentHolder();
}
void VisitSubExpression(const IntExpr* const cp_expr) {
cp_expr->Accept(this);
}
void VisitSum(const IntExpr* const cp_expr) {
if (Top()->HasIntegerVariableArrayArgument(ModelVisitor::kVarsArgument)) {
const std::vector<const IntVar*>& cp_vars =
Top()->FindIntegerVariableArrayArgumentOrDie(
ModelVisitor::kVarsArgument);
for (int i = 0; i < cp_vars.size(); ++i) {
VisitSubExpression(cp_vars[i]);
}
} else if (Top()->HasIntegerExpressionArgument(
ModelVisitor::kLeftArgument)) {
const IntExpr* const left =
Top()->FindIntegerExpressionArgumentOrDie(
ModelVisitor::kLeftArgument);
const IntExpr* const right =
Top()->FindIntegerExpressionArgumentOrDie(
ModelVisitor::kRightArgument);
VisitSubExpression(left);
VisitSubExpression(right);
} else {
const IntExpr* const expr =
Top()->FindIntegerExpressionArgumentOrDie(
ModelVisitor::kExpressionArgument);
const int64 value =
Top()->FindIntegerArgumentOrDie(ModelVisitor::kValueArgument);
VisitSubExpression(expr);
AddConstant(value);
}
}
void VisitScalProd(const IntExpr* const cp_expr) {
const std::vector<const IntVar*>& cp_vars =
Top()->FindIntegerVariableArrayArgumentOrDie(
ModelVisitor::kVarsArgument);
const std::vector<int64>& cp_coefficients =
Top()->FindIntegerArrayArgumentOrDie(
ModelVisitor::kCoefficientsArgument);
CHECK_EQ(cp_vars.size(), cp_coefficients.size());
for (int i = 0; i < cp_vars.size(); ++i) {
const int64 coefficient = cp_coefficients[i];
PushMultiplier(coefficient);
VisitSubExpression(cp_vars[i]);
PopMultiplier();
}
}
void VisitDifference(const IntExpr* const cp_expr) {
if (Top()->HasIntegerExpressionArgument(ModelVisitor::kLeftArgument)) {
const IntExpr* const left =
Top()->FindIntegerExpressionArgumentOrDie(
ModelVisitor::kLeftArgument);
const IntExpr* const right =
Top()->FindIntegerExpressionArgumentOrDie(
ModelVisitor::kRightArgument);
VisitSubExpression(left);
PushMultiplier(-1);
VisitSubExpression(right);
PopMultiplier();
} else {
const IntExpr* const expr =
Top()->FindIntegerExpressionArgumentOrDie(
ModelVisitor::kExpressionArgument);
const int64 value =
Top()->FindIntegerArgumentOrDie(ModelVisitor::kValueArgument);
AddConstant(value);
PushMultiplier(-1);
VisitSubExpression(expr);
PopMultiplier();
}
}
void VisitOpposite(const IntExpr* const cp_expr) {
const IntExpr* const expr =
Top()->FindIntegerExpressionArgumentOrDie(
ModelVisitor::kExpressionArgument);
PushMultiplier(-1);
VisitSubExpression(expr);
PopMultiplier();
}
void VisitProduct(const IntExpr* const cp_expr) {
if (Top()->HasIntegerExpressionArgument(
ModelVisitor::kExpressionArgument)) {
const IntExpr* const expr =
Top()->FindIntegerExpressionArgumentOrDie(
ModelVisitor::kExpressionArgument);
const int64 value =
Top()->FindIntegerArgumentOrDie(ModelVisitor::kValueArgument);
PushMultiplier(value);
VisitSubExpression(expr);
PopMultiplier();
} else {
RegisterExpression(cp_expr, 1);
}
}
void VisitIntegerExpression(const IntExpr* const cp_expr) {
RegisterExpression(cp_expr, 1);
}
void RegisterExpression(const IntExpr* const expr, int64 coef) {
(*map_)[expr] += coef * multipliers_.back();
}
void AddConstant(int64 constant) {
constant_ += constant * multipliers_.back();
}
void PushMultiplier(int64 multiplier) {
multipliers_.push_back(multiplier);
}
void PopMultiplier() {
multipliers_.pop_back();
}
hash_map<const IntExpr*, int64>* const map_;
std::vector<int64> multipliers_;
int64 constant_;
};
} // namespace
Constraint* Solver::MakeScalProdGreaterOrEqual(const std::vector<IntVar*>& vars,
@@ -2578,6 +2847,25 @@ namespace {
template<class T> IntExpr* MakeScalProdFct(Solver* solver,
const std::vector<IntVar*>& vars,
const std::vector<T>& coefs) {
// hash_map<const IntExpr*, int64> map;
// ExprLinearizer lin(&map);
// for (int i = 0; i < pre_vars.size(); ++i) {
// lin.Visit(pre_vars[i], pre_coefs[i]);
// }
// const int64 constant = lin.Constant();
// std::vector<IntVar*> vars;
// std::vector<T> coefs;
// for (ConstIter<hash_map<const IntExpr*, int64> > iter(map);
// !iter.at_end();
// ++iter) {
// vars.push_back(const_cast<IntExpr*>(iter->first)->Var());
// coefs.push_back(iter->second);
// }
// if (constant != 0) {
// vars.push_back(solver->MakeIntConst(1));
// coefs.push_back(constant);
// }
const int size = vars.size();
if (vars.empty() || AreAllNull<T>(coefs.data(), size)) {
return solver->MakeIntConst(0LL);