experimental code to linearize the scal prod
This commit is contained in:
@@ -2398,6 +2398,275 @@ Constraint* MakeScalProdGreaterOrEqualFct(Solver* solver,
|
||||
}
|
||||
return solver->MakeSumGreaterOrEqual(terms, cst);
|
||||
}
|
||||
|
||||
#define IS_TYPE(type, tag) type.compare(ModelVisitor::tag) == 0
|
||||
|
||||
class ExprLinearizer : public ModelParser {
|
||||
public:
|
||||
ExprLinearizer(hash_map<const IntExpr*, int64>* const map)
|
||||
: map_(map), constant_(0) {}
|
||||
|
||||
virtual ~ExprLinearizer() {}
|
||||
|
||||
// Begin/End visit element.
|
||||
virtual void BeginVisitModel(const string& solver_name) {
|
||||
LOG(FATAL) << "Should not be here";
|
||||
}
|
||||
|
||||
virtual void EndVisitModel(const string& solver_name) {
|
||||
LOG(FATAL) << "Should not be here";
|
||||
}
|
||||
|
||||
virtual void BeginVisitConstraint(const string& type_name,
|
||||
const Constraint* const constraint) {
|
||||
LOG(FATAL) << "Should not be here";
|
||||
}
|
||||
|
||||
virtual void EndVisitConstraint(const string& type_name,
|
||||
const Constraint* const constraint) {
|
||||
LOG(FATAL) << "Should not be here";
|
||||
}
|
||||
|
||||
virtual void BeginVisitExtension(const string& type) {
|
||||
LOG(FATAL) << "Should not be here";
|
||||
}
|
||||
|
||||
virtual void EndVisitExtension(const string& type) {
|
||||
LOG(FATAL) << "Should not be here";
|
||||
}
|
||||
virtual void BeginVisitIntegerExpression(const string& type_name,
|
||||
const IntExpr* const expr) {
|
||||
BeginVisit(true);
|
||||
}
|
||||
|
||||
virtual void EndVisitIntegerExpression(const string& type_name,
|
||||
const IntExpr* const expr) {
|
||||
if (IS_TYPE(type_name, kSum)) {
|
||||
VisitSum(expr);
|
||||
} else if (IS_TYPE(type_name, kScalProd)) {
|
||||
VisitScalProd(expr);
|
||||
} else if (IS_TYPE(type_name, kDifference)) {
|
||||
VisitDifference(expr);
|
||||
} else if (IS_TYPE(type_name, kOpposite)) {
|
||||
VisitOpposite(expr);
|
||||
} else if (IS_TYPE(type_name, kProduct)) {
|
||||
VisitProduct(expr);
|
||||
} else {
|
||||
VisitIntegerExpression(expr);
|
||||
}
|
||||
EndVisit();
|
||||
}
|
||||
|
||||
virtual void VisitIntegerVariable(const IntVar* const variable,
|
||||
const string& operation,
|
||||
int64 value,
|
||||
const IntVar* const delegate) {
|
||||
if (operation == ModelVisitor::kSumOperation) {
|
||||
AddConstant(value);
|
||||
VisitSubExpression(delegate);
|
||||
} else if (operation == ModelVisitor::kDifferenceOperation) {
|
||||
AddConstant(value);
|
||||
PushMultiplier(-1);
|
||||
VisitSubExpression(delegate);
|
||||
PopMultiplier();
|
||||
} else if (operation == ModelVisitor::kProductOperation) {
|
||||
PushMultiplier(value);
|
||||
VisitSubExpression(delegate);
|
||||
PopMultiplier();
|
||||
}
|
||||
}
|
||||
|
||||
virtual void VisitIntegerVariable(const IntVar* const variable,
|
||||
const IntExpr* const delegate) {
|
||||
if (delegate == NULL) {
|
||||
if (variable->Bound()) {
|
||||
AddConstant(variable->Min());
|
||||
} else {
|
||||
RegisterExpression(variable, 1);
|
||||
}
|
||||
} else {
|
||||
VisitSubExpression(delegate);
|
||||
}
|
||||
}
|
||||
|
||||
// Visit integer arguments.
|
||||
virtual void VisitIntegerArgument(const string& arg_name, int64 value) {
|
||||
Top()->SetIntegerArgument(arg_name, value);
|
||||
}
|
||||
|
||||
virtual void VisitIntegerArrayArgument(const string& arg_name,
|
||||
const int64* const values,
|
||||
int size) {
|
||||
Top()->SetIntegerArrayArgument(arg_name, values, size);
|
||||
}
|
||||
|
||||
virtual void VisitIntegerMatrixArgument(const string& arg_name,
|
||||
const IntTupleSet& values) {
|
||||
Top()->SetIntegerMatrixArgument(arg_name, values);
|
||||
}
|
||||
|
||||
// Visit integer expression argument.
|
||||
virtual void VisitIntegerExpressionArgument(
|
||||
const string& arg_name,
|
||||
const IntExpr* const argument) {
|
||||
Top()->SetIntegerExpressionArgument(arg_name, argument);
|
||||
}
|
||||
|
||||
virtual void VisitIntegerVariableArrayArgument(
|
||||
const string& arg_name,
|
||||
const IntVar* const * arguments,
|
||||
int size) {
|
||||
Top()->SetIntegerVariableArrayArgument(arg_name, arguments, size);
|
||||
}
|
||||
|
||||
// Visit interval argument.
|
||||
virtual void VisitIntervalArgument(const string& arg_name,
|
||||
const IntervalVar* const argument) {}
|
||||
|
||||
virtual void VisitIntervalArrayArgument(const string& arg_name,
|
||||
const IntervalVar* const * argument,
|
||||
int size) {}
|
||||
|
||||
void Visit(const IntExpr* const expr, int64 multiplier) {
|
||||
PushMultiplier(multiplier);
|
||||
expr->Accept(this);
|
||||
PopMultiplier();
|
||||
}
|
||||
|
||||
int64 Constant() const { return constant_; }
|
||||
|
||||
private:
|
||||
void BeginVisit(bool active) {
|
||||
PushArgumentHolder();
|
||||
}
|
||||
|
||||
void EndVisit() {
|
||||
PopArgumentHolder();
|
||||
}
|
||||
|
||||
void VisitSubExpression(const IntExpr* const cp_expr) {
|
||||
cp_expr->Accept(this);
|
||||
}
|
||||
|
||||
void VisitSum(const IntExpr* const cp_expr) {
|
||||
if (Top()->HasIntegerVariableArrayArgument(ModelVisitor::kVarsArgument)) {
|
||||
const std::vector<const IntVar*>& cp_vars =
|
||||
Top()->FindIntegerVariableArrayArgumentOrDie(
|
||||
ModelVisitor::kVarsArgument);
|
||||
for (int i = 0; i < cp_vars.size(); ++i) {
|
||||
VisitSubExpression(cp_vars[i]);
|
||||
}
|
||||
} else if (Top()->HasIntegerExpressionArgument(
|
||||
ModelVisitor::kLeftArgument)) {
|
||||
const IntExpr* const left =
|
||||
Top()->FindIntegerExpressionArgumentOrDie(
|
||||
ModelVisitor::kLeftArgument);
|
||||
const IntExpr* const right =
|
||||
Top()->FindIntegerExpressionArgumentOrDie(
|
||||
ModelVisitor::kRightArgument);
|
||||
VisitSubExpression(left);
|
||||
VisitSubExpression(right);
|
||||
} else {
|
||||
const IntExpr* const expr =
|
||||
Top()->FindIntegerExpressionArgumentOrDie(
|
||||
ModelVisitor::kExpressionArgument);
|
||||
const int64 value =
|
||||
Top()->FindIntegerArgumentOrDie(ModelVisitor::kValueArgument);
|
||||
VisitSubExpression(expr);
|
||||
AddConstant(value);
|
||||
}
|
||||
}
|
||||
|
||||
void VisitScalProd(const IntExpr* const cp_expr) {
|
||||
const std::vector<const IntVar*>& cp_vars =
|
||||
Top()->FindIntegerVariableArrayArgumentOrDie(
|
||||
ModelVisitor::kVarsArgument);
|
||||
const std::vector<int64>& cp_coefficients =
|
||||
Top()->FindIntegerArrayArgumentOrDie(
|
||||
ModelVisitor::kCoefficientsArgument);
|
||||
CHECK_EQ(cp_vars.size(), cp_coefficients.size());
|
||||
for (int i = 0; i < cp_vars.size(); ++i) {
|
||||
const int64 coefficient = cp_coefficients[i];
|
||||
PushMultiplier(coefficient);
|
||||
VisitSubExpression(cp_vars[i]);
|
||||
PopMultiplier();
|
||||
}
|
||||
}
|
||||
|
||||
void VisitDifference(const IntExpr* const cp_expr) {
|
||||
if (Top()->HasIntegerExpressionArgument(ModelVisitor::kLeftArgument)) {
|
||||
const IntExpr* const left =
|
||||
Top()->FindIntegerExpressionArgumentOrDie(
|
||||
ModelVisitor::kLeftArgument);
|
||||
const IntExpr* const right =
|
||||
Top()->FindIntegerExpressionArgumentOrDie(
|
||||
ModelVisitor::kRightArgument);
|
||||
VisitSubExpression(left);
|
||||
PushMultiplier(-1);
|
||||
VisitSubExpression(right);
|
||||
PopMultiplier();
|
||||
} else {
|
||||
const IntExpr* const expr =
|
||||
Top()->FindIntegerExpressionArgumentOrDie(
|
||||
ModelVisitor::kExpressionArgument);
|
||||
const int64 value =
|
||||
Top()->FindIntegerArgumentOrDie(ModelVisitor::kValueArgument);
|
||||
AddConstant(value);
|
||||
PushMultiplier(-1);
|
||||
VisitSubExpression(expr);
|
||||
PopMultiplier();
|
||||
}
|
||||
}
|
||||
|
||||
void VisitOpposite(const IntExpr* const cp_expr) {
|
||||
const IntExpr* const expr =
|
||||
Top()->FindIntegerExpressionArgumentOrDie(
|
||||
ModelVisitor::kExpressionArgument);
|
||||
PushMultiplier(-1);
|
||||
VisitSubExpression(expr);
|
||||
PopMultiplier();
|
||||
}
|
||||
|
||||
void VisitProduct(const IntExpr* const cp_expr) {
|
||||
if (Top()->HasIntegerExpressionArgument(
|
||||
ModelVisitor::kExpressionArgument)) {
|
||||
const IntExpr* const expr =
|
||||
Top()->FindIntegerExpressionArgumentOrDie(
|
||||
ModelVisitor::kExpressionArgument);
|
||||
const int64 value =
|
||||
Top()->FindIntegerArgumentOrDie(ModelVisitor::kValueArgument);
|
||||
PushMultiplier(value);
|
||||
VisitSubExpression(expr);
|
||||
PopMultiplier();
|
||||
} else {
|
||||
RegisterExpression(cp_expr, 1);
|
||||
}
|
||||
}
|
||||
|
||||
void VisitIntegerExpression(const IntExpr* const cp_expr) {
|
||||
RegisterExpression(cp_expr, 1);
|
||||
}
|
||||
|
||||
void RegisterExpression(const IntExpr* const expr, int64 coef) {
|
||||
(*map_)[expr] += coef * multipliers_.back();
|
||||
}
|
||||
|
||||
void AddConstant(int64 constant) {
|
||||
constant_ += constant * multipliers_.back();
|
||||
}
|
||||
|
||||
void PushMultiplier(int64 multiplier) {
|
||||
multipliers_.push_back(multiplier);
|
||||
}
|
||||
|
||||
void PopMultiplier() {
|
||||
multipliers_.pop_back();
|
||||
}
|
||||
|
||||
hash_map<const IntExpr*, int64>* const map_;
|
||||
std::vector<int64> multipliers_;
|
||||
int64 constant_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
Constraint* Solver::MakeScalProdGreaterOrEqual(const std::vector<IntVar*>& vars,
|
||||
@@ -2578,6 +2847,25 @@ namespace {
|
||||
template<class T> IntExpr* MakeScalProdFct(Solver* solver,
|
||||
const std::vector<IntVar*>& vars,
|
||||
const std::vector<T>& coefs) {
|
||||
// hash_map<const IntExpr*, int64> map;
|
||||
// ExprLinearizer lin(&map);
|
||||
// for (int i = 0; i < pre_vars.size(); ++i) {
|
||||
// lin.Visit(pre_vars[i], pre_coefs[i]);
|
||||
// }
|
||||
// const int64 constant = lin.Constant();
|
||||
// std::vector<IntVar*> vars;
|
||||
// std::vector<T> coefs;
|
||||
// for (ConstIter<hash_map<const IntExpr*, int64> > iter(map);
|
||||
// !iter.at_end();
|
||||
// ++iter) {
|
||||
// vars.push_back(const_cast<IntExpr*>(iter->first)->Var());
|
||||
// coefs.push_back(iter->second);
|
||||
// }
|
||||
// if (constant != 0) {
|
||||
// vars.push_back(solver->MakeIntConst(1));
|
||||
// coefs.push_back(constant);
|
||||
// }
|
||||
|
||||
const int size = vars.size();
|
||||
if (vars.empty() || AreAllNull<T>(coefs.data(), size)) {
|
||||
return solver->MakeIntConst(0LL);
|
||||
|
||||
Reference in New Issue
Block a user