diff --git a/src/constraint_solver/expr_array.cc b/src/constraint_solver/expr_array.cc index c6c72e6807..d90675dd91 100644 --- a/src/constraint_solver/expr_array.cc +++ b/src/constraint_solver/expr_array.cc @@ -2398,6 +2398,275 @@ Constraint* MakeScalProdGreaterOrEqualFct(Solver* solver, } return solver->MakeSumGreaterOrEqual(terms, cst); } + +#define IS_TYPE(type, tag) type.compare(ModelVisitor::tag) == 0 + +class ExprLinearizer : public ModelParser { + public: + ExprLinearizer(hash_map* const map) + : map_(map), constant_(0) {} + + virtual ~ExprLinearizer() {} + + // Begin/End visit element. + virtual void BeginVisitModel(const string& solver_name) { + LOG(FATAL) << "Should not be here"; + } + + virtual void EndVisitModel(const string& solver_name) { + LOG(FATAL) << "Should not be here"; + } + + virtual void BeginVisitConstraint(const string& type_name, + const Constraint* const constraint) { + LOG(FATAL) << "Should not be here"; + } + + virtual void EndVisitConstraint(const string& type_name, + const Constraint* const constraint) { + LOG(FATAL) << "Should not be here"; + } + + virtual void BeginVisitExtension(const string& type) { + LOG(FATAL) << "Should not be here"; + } + + virtual void EndVisitExtension(const string& type) { + LOG(FATAL) << "Should not be here"; + } + virtual void BeginVisitIntegerExpression(const string& type_name, + const IntExpr* const expr) { + BeginVisit(true); + } + + virtual void EndVisitIntegerExpression(const string& type_name, + const IntExpr* const expr) { + if (IS_TYPE(type_name, kSum)) { + VisitSum(expr); + } else if (IS_TYPE(type_name, kScalProd)) { + VisitScalProd(expr); + } else if (IS_TYPE(type_name, kDifference)) { + VisitDifference(expr); + } else if (IS_TYPE(type_name, kOpposite)) { + VisitOpposite(expr); + } else if (IS_TYPE(type_name, kProduct)) { + VisitProduct(expr); + } else { + VisitIntegerExpression(expr); + } + EndVisit(); + } + + virtual void VisitIntegerVariable(const IntVar* const variable, + const string& operation, + int64 value, + const IntVar* const delegate) { + if (operation == ModelVisitor::kSumOperation) { + AddConstant(value); + VisitSubExpression(delegate); + } else if (operation == ModelVisitor::kDifferenceOperation) { + AddConstant(value); + PushMultiplier(-1); + VisitSubExpression(delegate); + PopMultiplier(); + } else if (operation == ModelVisitor::kProductOperation) { + PushMultiplier(value); + VisitSubExpression(delegate); + PopMultiplier(); + } + } + + virtual void VisitIntegerVariable(const IntVar* const variable, + const IntExpr* const delegate) { + if (delegate == NULL) { + if (variable->Bound()) { + AddConstant(variable->Min()); + } else { + RegisterExpression(variable, 1); + } + } else { + VisitSubExpression(delegate); + } + } + + // Visit integer arguments. + virtual void VisitIntegerArgument(const string& arg_name, int64 value) { + Top()->SetIntegerArgument(arg_name, value); + } + + virtual void VisitIntegerArrayArgument(const string& arg_name, + const int64* const values, + int size) { + Top()->SetIntegerArrayArgument(arg_name, values, size); + } + + virtual void VisitIntegerMatrixArgument(const string& arg_name, + const IntTupleSet& values) { + Top()->SetIntegerMatrixArgument(arg_name, values); + } + + // Visit integer expression argument. + virtual void VisitIntegerExpressionArgument( + const string& arg_name, + const IntExpr* const argument) { + Top()->SetIntegerExpressionArgument(arg_name, argument); + } + + virtual void VisitIntegerVariableArrayArgument( + const string& arg_name, + const IntVar* const * arguments, + int size) { + Top()->SetIntegerVariableArrayArgument(arg_name, arguments, size); + } + + // Visit interval argument. + virtual void VisitIntervalArgument(const string& arg_name, + const IntervalVar* const argument) {} + + virtual void VisitIntervalArrayArgument(const string& arg_name, + const IntervalVar* const * argument, + int size) {} + + void Visit(const IntExpr* const expr, int64 multiplier) { + PushMultiplier(multiplier); + expr->Accept(this); + PopMultiplier(); + } + + int64 Constant() const { return constant_; } + + private: + void BeginVisit(bool active) { + PushArgumentHolder(); + } + + void EndVisit() { + PopArgumentHolder(); + } + + void VisitSubExpression(const IntExpr* const cp_expr) { + cp_expr->Accept(this); + } + + void VisitSum(const IntExpr* const cp_expr) { + if (Top()->HasIntegerVariableArrayArgument(ModelVisitor::kVarsArgument)) { + const std::vector& cp_vars = + Top()->FindIntegerVariableArrayArgumentOrDie( + ModelVisitor::kVarsArgument); + for (int i = 0; i < cp_vars.size(); ++i) { + VisitSubExpression(cp_vars[i]); + } + } else if (Top()->HasIntegerExpressionArgument( + ModelVisitor::kLeftArgument)) { + const IntExpr* const left = + Top()->FindIntegerExpressionArgumentOrDie( + ModelVisitor::kLeftArgument); + const IntExpr* const right = + Top()->FindIntegerExpressionArgumentOrDie( + ModelVisitor::kRightArgument); + VisitSubExpression(left); + VisitSubExpression(right); + } else { + const IntExpr* const expr = + Top()->FindIntegerExpressionArgumentOrDie( + ModelVisitor::kExpressionArgument); + const int64 value = + Top()->FindIntegerArgumentOrDie(ModelVisitor::kValueArgument); + VisitSubExpression(expr); + AddConstant(value); + } + } + + void VisitScalProd(const IntExpr* const cp_expr) { + const std::vector& cp_vars = + Top()->FindIntegerVariableArrayArgumentOrDie( + ModelVisitor::kVarsArgument); + const std::vector& cp_coefficients = + Top()->FindIntegerArrayArgumentOrDie( + ModelVisitor::kCoefficientsArgument); + CHECK_EQ(cp_vars.size(), cp_coefficients.size()); + for (int i = 0; i < cp_vars.size(); ++i) { + const int64 coefficient = cp_coefficients[i]; + PushMultiplier(coefficient); + VisitSubExpression(cp_vars[i]); + PopMultiplier(); + } + } + + void VisitDifference(const IntExpr* const cp_expr) { + if (Top()->HasIntegerExpressionArgument(ModelVisitor::kLeftArgument)) { + const IntExpr* const left = + Top()->FindIntegerExpressionArgumentOrDie( + ModelVisitor::kLeftArgument); + const IntExpr* const right = + Top()->FindIntegerExpressionArgumentOrDie( + ModelVisitor::kRightArgument); + VisitSubExpression(left); + PushMultiplier(-1); + VisitSubExpression(right); + PopMultiplier(); + } else { + const IntExpr* const expr = + Top()->FindIntegerExpressionArgumentOrDie( + ModelVisitor::kExpressionArgument); + const int64 value = + Top()->FindIntegerArgumentOrDie(ModelVisitor::kValueArgument); + AddConstant(value); + PushMultiplier(-1); + VisitSubExpression(expr); + PopMultiplier(); + } + } + + void VisitOpposite(const IntExpr* const cp_expr) { + const IntExpr* const expr = + Top()->FindIntegerExpressionArgumentOrDie( + ModelVisitor::kExpressionArgument); + PushMultiplier(-1); + VisitSubExpression(expr); + PopMultiplier(); + } + + void VisitProduct(const IntExpr* const cp_expr) { + if (Top()->HasIntegerExpressionArgument( + ModelVisitor::kExpressionArgument)) { + const IntExpr* const expr = + Top()->FindIntegerExpressionArgumentOrDie( + ModelVisitor::kExpressionArgument); + const int64 value = + Top()->FindIntegerArgumentOrDie(ModelVisitor::kValueArgument); + PushMultiplier(value); + VisitSubExpression(expr); + PopMultiplier(); + } else { + RegisterExpression(cp_expr, 1); + } + } + + void VisitIntegerExpression(const IntExpr* const cp_expr) { + RegisterExpression(cp_expr, 1); + } + + void RegisterExpression(const IntExpr* const expr, int64 coef) { + (*map_)[expr] += coef * multipliers_.back(); + } + + void AddConstant(int64 constant) { + constant_ += constant * multipliers_.back(); + } + + void PushMultiplier(int64 multiplier) { + multipliers_.push_back(multiplier); + } + + void PopMultiplier() { + multipliers_.pop_back(); + } + + hash_map* const map_; + std::vector multipliers_; + int64 constant_; +}; } // namespace Constraint* Solver::MakeScalProdGreaterOrEqual(const std::vector& vars, @@ -2578,6 +2847,25 @@ namespace { template IntExpr* MakeScalProdFct(Solver* solver, const std::vector& vars, const std::vector& coefs) { + // hash_map map; + // ExprLinearizer lin(&map); + // for (int i = 0; i < pre_vars.size(); ++i) { + // lin.Visit(pre_vars[i], pre_coefs[i]); + // } + // const int64 constant = lin.Constant(); + // std::vector vars; + // std::vector coefs; + // for (ConstIter > iter(map); + // !iter.at_end(); + // ++iter) { + // vars.push_back(const_cast(iter->first)->Var()); + // coefs.push_back(iter->second); + // } + // if (constant != 0) { + // vars.push_back(solver->MakeIntConst(1)); + // coefs.push_back(constant); + // } + const int size = vars.size(); if (vars.empty() || AreAllNull(coefs.data(), size)) { return solver->MakeIntConst(0LL);