diff --git a/src/constraint_solver/expr_array.cc b/src/constraint_solver/expr_array.cc index 996656785c..0b8e5ba5e1 100644 --- a/src/constraint_solver/expr_array.cc +++ b/src/constraint_solver/expr_array.cc @@ -2253,7 +2253,7 @@ class ExprLinearizer : public ModelParser { // ----- Factory functions ----- -Constraint* MakeScalProdEqualityFct(Solver* const solver, +Constraint* MakeScalProdEqualityAux(Solver* const solver, const std::vector& vars, const std::vector& coefficients, int64 cst) { @@ -2273,11 +2273,13 @@ Constraint* MakeScalProdEqualityFct(Solver* const solver, if (AreAllOnes(coefficients)) { return solver->MakeSumEquality(vars, cst); } - if (AreAllBooleans(vars) && AreAllPositive(coefficients) && size > 2) { + if (AreAllBooleans(vars) && + AreAllPositive(coefficients) && size > 2) { // TODO(user) : bench BooleanScalProdEqVar with IntConst. return solver->RevAlloc( new PositiveBooleanScalProdEqCst(solver, vars, coefficients, cst)); } + // Simplications. int constants = 0; int positives = 0; @@ -2381,6 +2383,50 @@ Constraint* MakeScalProdEqualityFct(Solver* const solver, return solver->MakeSumEquality(terms, solver->MakeIntConst(cst)); } +Constraint* MakeScalProdEqualityFct(Solver* const solver, + const std::vector& pre_vars, + const std::vector& pre_coefs, + int64 rhs) { + int64 constant = 0; + std::vector vars; + std::vector coefs; + vars.reserve(pre_vars.size()); + coefs.reserve(pre_coefs.size()); + // Try linear scan of the variables to check if there is nothing to do. + bool ok = true; + for (int i = 0; i < pre_vars.size(); ++i) { + IntVar* const v = pre_vars[i]; + const int64 c = pre_coefs[i]; + if (v->Bound()) { + constant += c * v->Min(); + } else if (solver->CastExpression(v) == nullptr) { + vars.push_back(v); + coefs.push_back(c); + } else { + ok = false; + vars.clear(); + coefs.clear(); + break; + } + } + if (!ok) { + // Instrospect the variables to simplify the sum. + hash_map map; + ExprLinearizer lin(&map); + for (int i = 0; i < pre_vars.size(); ++i) { + lin.Visit(pre_vars[i], pre_coefs[i]); + } + constant = lin.Constant(); + for (const auto& iter : map) { + if (iter.second != 0) { + vars.push_back(iter.first); + coefs.push_back(iter.second); + } + } + } + return MakeScalProdEqualityAux(solver, vars, coefs, rhs - constant); +} + Constraint* MakeScalProdEqualityVarFct(Solver* const solver, const std::vector& vars, const std::vector& coefficients,