From ea2e1b63afbccb671f41d6b36a67f17463363123 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Sun, 29 Dec 2024 23:33:55 +0100 Subject: [PATCH] [CP-SAT] polish python linear expr code --- ortools/sat/python/linear_expr.cc | 89 ++++++++++---------- ortools/sat/python/linear_expr.h | 31 +++---- ortools/sat/python/swig_helper.cc | 134 +++++++++++------------------- 3 files changed, 111 insertions(+), 143 deletions(-) diff --git a/ortools/sat/python/linear_expr.cc b/ortools/sat/python/linear_expr.cc index df649f8a9b..11f4c417a4 100644 --- a/ortools/sat/python/linear_expr.cc +++ b/ortools/sat/python/linear_expr.cc @@ -44,7 +44,7 @@ LinearExpr* LinearExpr::Sum(const std::vector& exprs) { } } -LinearExpr* LinearExpr::Sum(const std::vector& exprs) { +LinearExpr* LinearExpr::MixedSum(const std::vector& exprs) { std::vector lin_exprs; int64_t int_offset = 0; double double_offset = 0.0; @@ -84,17 +84,8 @@ LinearExpr* LinearExpr::Sum(const std::vector& exprs) { } } -LinearExpr* LinearExpr::WeightedSum(const std::vector& exprs, - const std::vector& coeffs) { - if (exprs.empty()) return new FloatConstant(0.0); - if (exprs.size() == 1) { - return new FloatAffine(exprs[0], coeffs[0], 0.0); - } - return new FloatWeightedSum(exprs, coeffs, 0.0); -} - -LinearExpr* LinearExpr::WeightedSum(const std::vector& exprs, - const std::vector& coeffs) { +LinearExpr* LinearExpr::WeightedSumInt(const std::vector& exprs, + const std::vector& coeffs) { if (exprs.empty()) return new IntConstant(0); if (exprs.size() == 1) { return new IntAffine(exprs[0], coeffs[0], 0); @@ -102,30 +93,17 @@ LinearExpr* LinearExpr::WeightedSum(const std::vector& exprs, return new IntWeightedSum(exprs, coeffs, 0); } -LinearExpr* LinearExpr::WeightedSum(const std::vector& exprs, - const std::vector& coeffs) { - std::vector lin_exprs; - std::vector lin_coeffs; - double cst = 0.0; - for (int i = 0; i < exprs.size(); ++i) { - if (exprs[i].expr != nullptr) { - lin_exprs.push_back(exprs[i].expr); - lin_coeffs.push_back(coeffs[i]); - } else { - cst += coeffs[i] * - (exprs[i].double_value + static_cast(exprs[i].int_value)); - } +LinearExpr* LinearExpr::WeightedSumDouble(const std::vector& exprs, + const std::vector& coeffs) { + if (exprs.empty()) return new FloatConstant(0.0); + if (exprs.size() == 1) { + return new FloatAffine(exprs[0], coeffs[0], 0.0); } - - if (lin_exprs.empty()) return new FloatConstant(cst); - if (lin_exprs.size() == 1) { - return new FloatAffine(lin_exprs[0], lin_coeffs[0], cst); - } - return new FloatWeightedSum(lin_exprs, lin_coeffs, cst); + return new FloatWeightedSum(exprs, coeffs, 0.0); } -LinearExpr* LinearExpr::WeightedSum(const std::vector& exprs, - const std::vector& coeffs) { +LinearExpr* LinearExpr::MixedWeightedSumInt( + const std::vector& exprs, const std::vector& coeffs) { std::vector lin_exprs; std::vector lin_coeffs; int64_t int_cst = 0; @@ -162,33 +140,56 @@ LinearExpr* LinearExpr::WeightedSum(const std::vector& exprs, return new IntWeightedSum(lin_exprs, lin_coeffs, int_cst); } -LinearExpr* LinearExpr::Term(LinearExpr* expr, double coeff) { - return new FloatAffine(expr, coeff, 0.0); +LinearExpr* LinearExpr::MixedWeightedSumDouble( + const std::vector& exprs, const std::vector& coeffs) { + std::vector lin_exprs; + std::vector lin_coeffs; + double cst = 0.0; + for (int i = 0; i < exprs.size(); ++i) { + if (exprs[i].expr != nullptr) { + lin_exprs.push_back(exprs[i].expr); + lin_coeffs.push_back(coeffs[i]); + } else { + cst += coeffs[i] * + (exprs[i].double_value + static_cast(exprs[i].int_value)); + } + } + + if (lin_exprs.empty()) return new FloatConstant(cst); + if (lin_exprs.size() == 1) { + return new FloatAffine(lin_exprs[0], lin_coeffs[0], cst); + } + return new FloatWeightedSum(lin_exprs, lin_coeffs, cst); } -LinearExpr* LinearExpr::Term(LinearExpr* expr, int64_t coeff) { +LinearExpr* LinearExpr::TermInt(LinearExpr* expr, int64_t coeff) { return new IntAffine(expr, coeff, 0); } -LinearExpr* LinearExpr::Affine(LinearExpr* expr, double coeff, double offset) { - if (coeff == 1.0 && offset == 0.0) return expr; - return new FloatAffine(expr, coeff, offset); +LinearExpr* LinearExpr::TermDouble(LinearExpr* expr, double coeff) { + return new FloatAffine(expr, coeff, 0.0); } -LinearExpr* LinearExpr::Affine(LinearExpr* expr, int64_t coeff, - int64_t offset) { +LinearExpr* LinearExpr::AffineInt(LinearExpr* expr, int64_t coeff, + int64_t offset) { if (coeff == 1 && offset == 0) return expr; return new IntAffine(expr, coeff, offset); } -LinearExpr* LinearExpr::Constant(double value) { - return new FloatConstant(value); +LinearExpr* LinearExpr::AffineDouble(LinearExpr* expr, double coeff, + double offset) { + if (coeff == 1.0 && offset == 0.0) return expr; + return new FloatAffine(expr, coeff, offset); } -LinearExpr* LinearExpr::Constant(int64_t value) { +LinearExpr* LinearExpr::ConstantInt(int64_t value) { return new IntConstant(value); } +LinearExpr* LinearExpr::ConstantDouble(double value) { + return new FloatConstant(value); +} + LinearExpr* LinearExpr::Add(LinearExpr* expr) { return new BinaryAdd(this, expr); } diff --git a/ortools/sat/python/linear_expr.h b/ortools/sat/python/linear_expr.h index 9ed2a35c42..60af496a9f 100644 --- a/ortools/sat/python/linear_expr.h +++ b/ortools/sat/python/linear_expr.h @@ -62,21 +62,22 @@ class LinearExpr { virtual std::string DebugString() const { return "LinearExpr()"; } static LinearExpr* Sum(const std::vector& exprs); - static LinearExpr* Sum(const std::vector& exprs); - static LinearExpr* WeightedSum(const std::vector& exprs, - const std::vector& coeffs); - static LinearExpr* WeightedSum(const std::vector& exprs, - const std::vector& coeffs); - static LinearExpr* WeightedSum(const std::vector& exprs, - const std::vector& coeffs); - static LinearExpr* WeightedSum(const std::vector& exprs, - const std::vector& coeffs); - static LinearExpr* Term(LinearExpr* expr, int64_t coeff); - static LinearExpr* Term(LinearExpr* expr, double coeff); - static LinearExpr* Affine(LinearExpr* expr, int64_t coeff, int64_t offset); - static LinearExpr* Affine(LinearExpr* expr, double coeff, double offset); - static LinearExpr* Constant(int64_t value); - static LinearExpr* Constant(double value); + static LinearExpr* MixedSum(const std::vector& exprs); + static LinearExpr* WeightedSumInt(const std::vector& exprs, + const std::vector& coeffs); + static LinearExpr* WeightedSumDouble(const std::vector& exprs, + const std::vector& coeffs); + static LinearExpr* MixedWeightedSumInt(const std::vector& exprs, + const std::vector& coeffs); + static LinearExpr* MixedWeightedSumDouble( + const std::vector& exprs, const std::vector& coeffs); + static LinearExpr* TermInt(LinearExpr* expr, int64_t coeff); + static LinearExpr* TermDouble(LinearExpr* expr, double coeff); + static LinearExpr* AffineInt(LinearExpr* expr, int64_t coeff, int64_t offset); + static LinearExpr* AffineDouble(LinearExpr* expr, double coeff, + double offset); + static LinearExpr* ConstantInt(int64_t value); + static LinearExpr* ConstantDouble(double value); LinearExpr* Add(LinearExpr* expr); LinearExpr* AddInt(int64_t cst); diff --git a/ortools/sat/python/swig_helper.cc b/ortools/sat/python/swig_helper.cc index 09000f8172..5c15800861 100644 --- a/ortools/sat/python/swig_helper.cc +++ b/ortools/sat/python/swig_helper.cc @@ -313,95 +313,65 @@ PYBIND11_MODULE(swig_helper, m) { .def_readonly("expr", &ExprOrValue::expr) .def_readonly("int_value", &ExprOrValue::int_value); - py::implicitly_convertible(); - py::implicitly_convertible(); py::implicitly_convertible(); + py::implicitly_convertible(); + py::implicitly_convertible(); py::class_(m, "LinearExpr", kLinearExprClassDoc) // We make sure to keep the order of the overloads: LinearExpr* before // ExprOrValue as this is faster to parse and type check. - .def_static( - "sum", - py::overload_cast&>(&LinearExpr::Sum), - py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static( - "sum", - py::overload_cast&>(&LinearExpr::Sum), - py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static("weighted_sum", - py::overload_cast&, - const std::vector&>( - &LinearExpr::WeightedSum), + .def_static("sum", (&LinearExpr::Sum), arg("exprs"), py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static("weighted_sum", - py::overload_cast&, - const std::vector&>( - &LinearExpr::WeightedSum), + .def_static("sum", &LinearExpr::MixedSum, arg("exprs"), py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static("weighted_sum", - py::overload_cast&, - const std::vector&>( - &LinearExpr::WeightedSum), + .def_static("weighted_sum", &LinearExpr::WeightedSumInt, arg("exprs"), + arg("coeffs"), py::return_value_policy::automatic, + py::keep_alive<0, 1>()) + .def_static("weighted_sum", &LinearExpr::WeightedSumDouble, arg("exprs"), + arg("coeffs"), py::return_value_policy::automatic, + py::keep_alive<0, 1>()) + .def_static("weighted_sum", &LinearExpr::MixedWeightedSumInt, + arg("exprs"), arg("coeffs"), py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static("weighted_sum", - py::overload_cast&, - const std::vector&>( - &LinearExpr::WeightedSum), + .def_static("weighted_sum", &LinearExpr::MixedWeightedSumDouble, + arg("exprs"), arg("coeffs"), py::return_value_policy::automatic, py::keep_alive<0, 1>()) // Make sure to keep the order of the overloads: int before float as an // an integer value will be silently converted to a float. - .def_static("term", - py::overload_cast(&LinearExpr::Term), - arg("expr"), arg("coeff"), "Returns expr * coeff.", + .def_static("term", &LinearExpr::TermInt, arg("expr").none(false), + arg("coeff"), "Returns expr * coeff.", py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static("term", - py::overload_cast(&LinearExpr::Term), - arg("expr"), arg("coeff"), "Returns expr * coeff.", + .def_static("term", &LinearExpr::TermDouble, arg("expr").none(false), + arg("coeff"), "Returns expr * coeff.", py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static( - "affine", - py::overload_cast(&LinearExpr::Affine), - arg("expr"), arg("coeff"), arg("offset"), - "Returns expr * coeff + offset.", py::return_value_policy::automatic, - py::keep_alive<0, 1>()) - .def_static( - "affine", - py::overload_cast(&LinearExpr::Affine), - arg("expr"), arg("coeff"), arg("offset"), - "Returns expr * coeff + offset.", py::return_value_policy::automatic, - py::keep_alive<0, 1>()) - .def_static("constant", py::overload_cast(&LinearExpr::Constant), - arg("value"), "Returns a constant linear expression.", + .def_static("affine", &LinearExpr::AffineInt, arg("expr").none(false), + arg("coeff"), arg("offset"), "Returns expr * coeff + offset.", + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("affine", &LinearExpr::AffineDouble, arg("expr").none(false), + arg("coeff"), arg("offset"), "Returns expr * coeff + offset.", + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("constant", &LinearExpr::ConstantInt, arg("value"), + "Returns a constant linear expression.", py::return_value_policy::automatic) - .def_static("constant", py::overload_cast(&LinearExpr::Constant), - arg("value"), "Returns a constant linear expression.", + .def_static("constant", &LinearExpr::ConstantDouble, arg("value"), + "Returns a constant linear expression.", py::return_value_policy::automatic) - .def_static( - "Sum", - py::overload_cast&>(&LinearExpr::Sum), - py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static( - "Sum", - py::overload_cast&>(&LinearExpr::Sum), - py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static("WeightedSum", - py::overload_cast&, - const std::vector&>( - &LinearExpr::WeightedSum), + .def_static("Sum", &LinearExpr::Sum, arg("exprs"), py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static("WeightedSum", - py::overload_cast&, - const std::vector&>( - &LinearExpr::WeightedSum), + .def_static("Sum", &LinearExpr::MixedSum, arg("exprs"), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("WeightedSum", &LinearExpr::MixedWeightedSumInt, arg("exprs"), + arg("coeffs"), py::return_value_policy::automatic, + py::keep_alive<0, 1>()) + .def_static("WeightedSum", &LinearExpr::MixedWeightedSumDouble, + arg("exprs"), arg("coeffs"), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("Term", &LinearExpr::TermInt, arg("expr").none(false), + arg("coeff"), "Returns expr * coeff.", + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def_static("Term", &LinearExpr::TermDouble, arg("expr").none(false), + arg("coeff"), "Returns expr * coeff.", py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static( - "Term", py::overload_cast(&LinearExpr::Term), - arg("expr").none(false), arg("coeff"), "Returns expr * coeff.", - py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def_static( - "Term", py::overload_cast(&LinearExpr::Term), - arg("expr").none(false), arg("coeff"), "Returns expr * coeff.", - py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def("__str__", &LinearExpr::ToString) .def("__repr__", &LinearExpr::DebugString) .def("is_integer", &LinearExpr::IsInteger) @@ -427,18 +397,14 @@ PYBIND11_MODULE(swig_helper, m) { py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def("__rsub__", &LinearExpr::RSubDouble, arg("cst"), py::return_value_policy::automatic, py::keep_alive<0, 1>()) - .def("__mul__", py::overload_cast(&LinearExpr::MulInt), - arg("cst"), py::return_value_policy::automatic, - py::keep_alive<0, 1>()) - .def("__mul__", py::overload_cast(&LinearExpr::MulDouble), - arg("cst"), py::return_value_policy::automatic, - py::keep_alive<0, 1>()) - .def("__rmul__", py::overload_cast(&LinearExpr::MulInt), - arg("cst"), py::return_value_policy::automatic, - py::keep_alive<0, 1>()) - .def("__rmul__", py::overload_cast(&LinearExpr::MulDouble), - arg("cst"), py::return_value_policy::automatic, - py::keep_alive<0, 1>()) + .def("__mul__", &LinearExpr::MulInt, arg("cst"), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__mul__", &LinearExpr::MulDouble, arg("cst"), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__rmul__", &LinearExpr::MulInt, arg("cst"), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) + .def("__rmul__", &LinearExpr::MulDouble, arg("cst"), + py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def("__neg__", &LinearExpr::Neg, py::return_value_policy::automatic, py::keep_alive<0, 1>()) .def(