diff --git a/ortools/linear_solver/python/model_builder_helper.cc b/ortools/linear_solver/python/model_builder_helper.cc index 9a84fbc736..8036828159 100644 --- a/ortools/linear_solver/python/model_builder_helper.cc +++ b/ortools/linear_solver/python/model_builder_helper.cc @@ -439,11 +439,7 @@ PYBIND11_MODULE(model_builder_helper, m) { const int num_uses = Py_REFCNT(self.ptr()); std::shared_ptr expr = self.cast>(); - if (num_uses == 4) { - expr->AddInPlace(other); - return expr; - } - return expr->Add(other); + return (num_uses == 4) ? expr->AddInPlace(other) : expr->Add(other); }, py::arg("other").none(false), "Returns the sum of `self` and `other`.") @@ -453,46 +449,43 @@ PYBIND11_MODULE(model_builder_helper, m) { const int num_uses = Py_REFCNT(self.ptr()); std::shared_ptr expr = self.cast>(); - if (num_uses == 4) { - expr->AddFloatInPlace(cst); - return expr; - } - return expr->AddFloat(cst); + return (num_uses == 4) ? expr->AddFloatInPlace(cst) + : expr->AddFloat(cst); }, py::arg("cst"), "Returns `self` + `cst`.") .def( - "__iadd__", + "__radd__", [](py::object self, std::shared_ptr other) -> std::shared_ptr { + const int num_uses = Py_REFCNT(self.ptr()); std::shared_ptr expr = self.cast>(); - expr->AddInPlace(other); - return expr; - }, - py::arg("other").none(false), - "Returns the sum of `self` and `other`.") - .def( - "__iadd__", - [](py::object self, double cst) -> std::shared_ptr { - std::shared_ptr expr = - self.cast>(); - expr->AddFloatInPlace(cst); - return expr; + return (num_uses == 4) ? expr->AddInPlace(other) : expr->Add(other); }, py::arg("cst"), "Returns `self` + `cst`.") - .def("__radd__", &LinearExpr::Add, py::arg("other").none(false), - "Returns `self` + `other`.") .def( "__radd__", [](py::object self, double cst) -> std::shared_ptr { const int num_uses = Py_REFCNT(self.ptr()); std::shared_ptr expr = self.cast>(); - if (num_uses == 4) { - expr->AddFloatInPlace(cst); - return expr; - } - return expr->AddFloat(cst); + return (num_uses == 4) ? expr->AddFloatInPlace(cst) + : expr->AddFloat(cst); + }, + py::arg("cst"), "Returns `self` + `cst`.") + .def( + "__iadd__", + [](std::shared_ptr expr, + std::shared_ptr other) -> std::shared_ptr { + return expr->AddInPlace(other); + }, + py::arg("other").none(false), + "Returns the sum of `self` and `other`.") + .def( + "__iadd__", + [](std::shared_ptr expr, + double cst) -> std::shared_ptr { + return expr->AddFloatInPlace(cst); }, py::arg("cst"), "Returns `self` + `cst`.") .def( @@ -502,11 +495,8 @@ PYBIND11_MODULE(model_builder_helper, m) { const int num_uses = Py_REFCNT(self.ptr()); std::shared_ptr expr = self.cast>(); - if (num_uses == 4) { - expr->AddInPlace(other->Neg()); - return expr; - } - return expr->Sub(other); + return (num_uses == 4) ? expr->AddInPlace(other->Neg()) + : expr->Sub(other); }, py::arg("other").none(false), "Returns `self` - `other`.") .def( @@ -515,30 +505,23 @@ PYBIND11_MODULE(model_builder_helper, m) { const int num_uses = Py_REFCNT(self.ptr()); std::shared_ptr expr = self.cast>(); - if (num_uses == 4) { - expr->AddFloatInPlace(-cst); - return expr; - } - return expr->SubFloat(cst); + return (num_uses == 4) ? expr->AddFloatInPlace(-cst) + : expr->SubFloat(cst); }, py::arg("cst"), "Returns `self` - `cst`.") .def( "__isub__", - [](py::object self, + [](std::shared_ptr expr, std::shared_ptr other) -> std::shared_ptr { - std::shared_ptr expr = - self.cast>(); expr->AddInPlace(other->Neg()); - return expr; + return expr->AddInPlace(other->Neg()); }, py::arg("other").none(false), "Returns `self` - `other`.") .def( "__isub__", - [](py::object self, double cst) -> std::shared_ptr { - std::shared_ptr expr = - self.cast>(); - expr->AddFloatInPlace(-cst); - return expr; + [](std::shared_ptr expr, + double cst) -> std::shared_ptr { + return expr->AddFloatInPlace(-cst); }, py::arg("cst"), "Returns `self` - `cst`.") .def_property_readonly( diff --git a/ortools/linear_solver/wrappers/model_builder_helper.cc b/ortools/linear_solver/wrappers/model_builder_helper.cc index 611b07058c..bc5bcdf5eb 100644 --- a/ortools/linear_solver/wrappers/model_builder_helper.cc +++ b/ortools/linear_solver/wrappers/model_builder_helper.cc @@ -988,6 +988,72 @@ std::string FlatExpr::DebugString() const { return s; } +SumArray::SumArray(std::vector> exprs, + double offset) + : exprs_(std::move(exprs)), offset_(offset) {} + +void SumArray::Visit(ExprVisitor& lin, double c) { + for (int i = 0; i < exprs_.size(); ++i) { + lin.AddToProcess(exprs_[i], c); + } + if (offset_ != 0.0) { + lin.AddConstant(offset_ * c); + } +} + +std::string SumArray::ToString() const { + if (exprs_.empty()) { + if (offset_ != 0.0) { + return absl::StrCat(offset_); + } + } + std::string s = "("; + for (int i = 0; i < exprs_.size(); ++i) { + if (i > 0) { + absl::StrAppend(&s, " + "); + } + absl::StrAppend(&s, exprs_[i]->ToString()); + } + if (offset_ != 0.0) { + if (offset_ > 0.0) { + absl::StrAppend(&s, " + ", offset_); + } else { + absl::StrAppend(&s, " - ", -offset_); + } + } + absl::StrAppend(&s, ")"); + return s; +} + +std::string SumArray::DebugString() const { + std::string s = absl::StrCat( + "SumArray(", + absl::StrJoin(exprs_, ", ", + [](std::string* out, std::shared_ptr expr) { + absl::StrAppend(out, expr->DebugString()); + })); + if (offset_ != 0.0) { + absl::StrAppend(&s, ", offset=", offset_); + } + absl::StrAppend(&s, ")"); + return s; +} + +std::shared_ptr SumArray::AddInPlace( + std::shared_ptr expr) { + exprs_.push_back(std::move(expr)); + return shared_from_this(); +} + +std::shared_ptr SumArray::AddFloatInPlace(double cst) { + offset_ += cst; + return shared_from_this(); +} + +int SumArray::num_exprs() const { return exprs_.size(); } + +double SumArray::offset() const { return offset_; } + void FixedValue::Visit(ExprVisitor& lin, double c) { lin.AddConstant(value_ * c); } diff --git a/ortools/linear_solver/wrappers/model_builder_helper.h b/ortools/linear_solver/wrappers/model_builder_helper.h index cfd3de5e0e..4f2371da2c 100644 --- a/ortools/linear_solver/wrappers/model_builder_helper.h +++ b/ortools/linear_solver/wrappers/model_builder_helper.h @@ -26,8 +26,6 @@ #include "absl/container/btree_map.h" #include "absl/container/fixed_array.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" #include "ortools/linear_solver/linear_solver.pb.h" #include "ortools/linear_solver/model_exporter.h" #include "ortools/util/solve_interrupter.h" @@ -150,61 +148,17 @@ class FlatExpr : public LinearExpr { class SumArray : public LinearExpr { public: explicit SumArray(std::vector> exprs, - double offset) - : exprs_(std::move(exprs)), offset_(offset) {} + double offset); ~SumArray() override = default; - void Visit(ExprVisitor& lin, double c) override { - for (int i = 0; i < exprs_.size(); ++i) { - lin.AddToProcess(exprs_[i], c); - } - if (offset_ != 0.0) { - lin.AddConstant(offset_ * c); - } - } + void Visit(ExprVisitor& lin, double c) override; - std::string ToString() const override { - if (exprs_.empty()) { - if (offset_ != 0.0) { - return absl::StrCat(offset_); - } - } - std::string s = "("; - for (int i = 0; i < exprs_.size(); ++i) { - if (i > 0) { - absl::StrAppend(&s, " + "); - } - absl::StrAppend(&s, exprs_[i]->ToString()); - } - if (offset_ != 0.0) { - if (offset_ > 0.0) { - absl::StrAppend(&s, " + ", offset_); - } else { - absl::StrAppend(&s, " - ", -offset_); - } - } - absl::StrAppend(&s, ")"); - return s; - } - - std::string DebugString() const override { - std::string s = absl::StrCat( - "SumArray(", - absl::StrJoin(exprs_, ", ", - [](std::string* out, std::shared_ptr expr) { - absl::StrAppend(out, expr->DebugString()); - })); - if (offset_ != 0.0) { - absl::StrAppend(&s, ", offset=", offset_); - } - absl::StrAppend(&s, ")"); - return s; - } - - void AddInPlace(std::shared_ptr expr) { exprs_.push_back(expr); } - void AddFloatInPlace(double cst) { offset_ += cst; } - int num_exprs() const { return exprs_.size(); } - double offset() const { return offset_; } + std::string ToString() const override; + std::string DebugString() const override; + std::shared_ptr AddInPlace(std::shared_ptr expr); + std::shared_ptr AddFloatInPlace(double cst); + int num_exprs() const; + double offset() const; private: std::vector> exprs_; diff --git a/ortools/sat/python/cp_model_helper.cc b/ortools/sat/python/cp_model_helper.cc index 87ef120b8f..10e57c7657 100644 --- a/ortools/sat/python/cp_model_helper.cc +++ b/ortools/sat/python/cp_model_helper.cc @@ -897,11 +897,7 @@ PYBIND11_MODULE(cp_model_helper, m) { const int num_uses = Py_REFCNT(self.ptr()); std::shared_ptr expr = self.cast>(); - if (num_uses == 4) { - expr->AddInPlace(other); - return expr; - } - return expr->Add(other); + return (num_uses == 4) ? expr->AddInPlace(other) : expr->Add(other); }, py::arg("other").none(false), DOC(operations_research, sat, python, LinearExpr, Add)) @@ -911,11 +907,8 @@ PYBIND11_MODULE(cp_model_helper, m) { const int num_uses = Py_REFCNT(self.ptr()); std::shared_ptr expr = self.cast>(); - if (num_uses == 4) { - expr->AddIntInPlace(cst); - return expr; - } - return expr->AddInt(cst); + return (num_uses == 4) ? expr->AddIntInPlace(cst) + : expr->AddInt(cst); }, DOC(operations_research, sat, python, LinearExpr, AddInt)) .def( @@ -924,11 +917,8 @@ PYBIND11_MODULE(cp_model_helper, m) { const int num_uses = Py_REFCNT(self.ptr()); std::shared_ptr expr = self.cast>(); - if (num_uses == 4) { - expr->AddFloatInPlace(cst); - return expr; - } - return expr->AddFloat(cst); + return (num_uses == 4) ? expr->AddFloatInPlace(cst) + : expr->AddFloat(cst); }, py::arg("other").none(false), DOC(operations_research, sat, python, LinearExpr, AddFloat)) @@ -938,11 +928,8 @@ PYBIND11_MODULE(cp_model_helper, m) { const int num_uses = Py_REFCNT(self.ptr()); std::shared_ptr expr = self.cast>(); - if (num_uses == 4) { - expr->AddIntInPlace(cst); - return expr; - } - return expr->AddInt(cst); + return (num_uses == 4) ? expr->AddIntInPlace(cst) + : expr->AddInt(cst); }, py::arg("cst"), DOC(operations_research, sat, python, LinearExpr, AddInt)) @@ -952,41 +939,31 @@ PYBIND11_MODULE(cp_model_helper, m) { const int num_uses = Py_REFCNT(self.ptr()); std::shared_ptr expr = self.cast>(); - if (num_uses == 4) { - expr->AddFloatInPlace(cst); - return expr; - } - return expr->AddFloat(cst); + return (num_uses == 4) ? expr->AddFloatInPlace(cst) + : expr->AddFloat(cst); }, py::arg("cst"), DOC(operations_research, sat, python, LinearExpr, AddFloat)) .def( "__iadd__", - [](py::object self, + [](std::shared_ptr expr, std::shared_ptr other) -> std::shared_ptr { - std::shared_ptr expr = - self.cast>(); - expr->AddInPlace(other); - return expr; + return expr->AddInPlace(other); }, py::arg("other").none(false), DOC(operations_research, sat, python, LinearExpr, Add)) .def( "__iadd__", - [](py::object self, int64_t cst) -> std::shared_ptr { - std::shared_ptr expr = - self.cast>(); - expr->AddIntInPlace(cst); - return expr; + [](std::shared_ptr expr, + int64_t cst) -> std::shared_ptr { + return expr->AddIntInPlace(cst); }, DOC(operations_research, sat, python, LinearExpr, AddInt)) .def( "__iadd__", - [](py::object self, double cst) -> std::shared_ptr { - std::shared_ptr expr = - self.cast>(); - expr->AddFloatInPlace(cst); - return expr; + [](std::shared_ptr expr, + double cst) -> std::shared_ptr { + return expr->AddFloatInPlace(cst); }, py::arg("other").none(false), DOC(operations_research, sat, python, LinearExpr, AddFloat)) @@ -997,11 +974,8 @@ PYBIND11_MODULE(cp_model_helper, m) { const int num_uses = Py_REFCNT(self.ptr()); std::shared_ptr expr = self.cast>(); - if (num_uses == 4) { - expr->AddInPlace(other->Neg()); - return expr; - } - return expr->Sub(other); + return (num_uses == 4) ? expr->AddInPlace(other->Neg()) + : expr->Sub(other); }, py::arg("other").none(false), DOC(operations_research, sat, python, LinearExpr, Sub)) @@ -1011,11 +985,8 @@ PYBIND11_MODULE(cp_model_helper, m) { const int num_uses = Py_REFCNT(self.ptr()); std::shared_ptr expr = self.cast>(); - if (num_uses == 4) { - expr->AddIntInPlace(-cst); - return expr; - } - return expr->SubInt(cst); + return (num_uses == 4) ? expr->AddIntInPlace(-cst) + : expr->SubInt(cst); }, py::arg("cst"), DOC(operations_research, sat, python, LinearExpr, SubInt)) @@ -1025,41 +996,31 @@ PYBIND11_MODULE(cp_model_helper, m) { const int num_uses = Py_REFCNT(self.ptr()); std::shared_ptr expr = self.cast>(); - if (num_uses == 4) { - expr->AddFloatInPlace(-cst); - return expr; - } - return expr->SubFloat(cst); + return (num_uses == 4) ? expr->AddFloatInPlace(-cst) + : expr->SubFloat(cst); }, py::arg("cst"), DOC(operations_research, sat, python, LinearExpr, SubFloat)) .def( "__isub__", - [](py::object self, + [](std::shared_ptr expr, std::shared_ptr other) -> std::shared_ptr { - std::shared_ptr expr = - self.cast>(); - expr->AddInPlace(other->MulInt(-1)); - return expr; + return expr->AddInPlace(other->Neg()); }, py::arg("other").none(false), DOC(operations_research, sat, python, LinearExpr, Sub)) .def( "__isub__", - [](py::object self, int64_t cst) -> std::shared_ptr { - std::shared_ptr expr = - self.cast>(); - expr->AddIntInPlace(-cst); - return expr; + [](std::shared_ptr expr, + int64_t cst) -> std::shared_ptr { + return expr->AddIntInPlace(-cst); }, DOC(operations_research, sat, python, LinearExpr, SubInt)) .def( "__isub__", - [](py::object self, double cst) -> std::shared_ptr { - std::shared_ptr expr = - self.cast>(); - expr->AddFloatInPlace(-cst); - return expr; + [](std::shared_ptr expr, + double cst) -> std::shared_ptr { + return expr->AddFloatInPlace(-cst); }, py::arg("other").none(false), DOC(operations_research, sat, python, LinearExpr, SubFloat)) @@ -1074,8 +1035,6 @@ PYBIND11_MODULE(cp_model_helper, m) { .def_property_readonly("coefficient", &FloatAffine::coefficient) .def_property_readonly("offset", &FloatAffine::offset); - // We adding an operator like __add__(int), we need to add all overloads, - // otherwise they are not found. py::class_, LinearExpr>( m, "IntAffine", DOC(operations_research, sat, python, IntAffine)) .def(py::init, int64_t, int64_t>()) diff --git a/ortools/sat/python/cp_model_test.py b/ortools/sat/python/cp_model_test.py index aa06c59b2e..9bbaee5513 100644 --- a/ortools/sat/python/cp_model_test.py +++ b/ortools/sat/python/cp_model_test.py @@ -2461,6 +2461,14 @@ TRFM""" s -= model.new_bool_var("") model.add(s == 10) + def test_radd(self): + model = cp_model.CpModel() + x = [model.new_int_var(0, 10, f"x{i}") for i in range(10)] + expr = 1 + sum(x) + self.assertEqual( + str(expr), "(x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + 1)" + ) + def test_simplification1(self): model = cp_model.CpModel() x = model.new_int_var(-10, 10, "x") diff --git a/ortools/sat/python/linear_expr.cc b/ortools/sat/python/linear_expr.cc index 42077ef46f..f8c2954f62 100644 --- a/ortools/sat/python/linear_expr.cc +++ b/ortools/sat/python/linear_expr.cc @@ -340,8 +340,20 @@ SumArray::SumArray(std::vector> exprs, DCHECK_GE(exprs_.size(), 2); } -void SumArray::AddInPlace(std::shared_ptr expr) { +std::shared_ptr SumArray::AddInPlace( + std::shared_ptr expr) { exprs_.push_back(std::move(expr)); + return shared_from_this(); +} + +std::shared_ptr SumArray::AddIntInPlace(int64_t cst) { + int_offset_ += cst; + return shared_from_this(); +} + +std::shared_ptr SumArray::AddFloatInPlace(double cst) { + double_offset_ += cst; + return shared_from_this(); } bool SumArray::VisitAsInt(IntExprVisitor& lin, int64_t c) { diff --git a/ortools/sat/python/linear_expr.h b/ortools/sat/python/linear_expr.h index ae92d1c676..06d973f9ea 100644 --- a/ortools/sat/python/linear_expr.h +++ b/ortools/sat/python/linear_expr.h @@ -286,9 +286,9 @@ class SumArray : public LinearExpr { std::string ToString() const override; std::string DebugString() const override; - void AddInPlace(std::shared_ptr expr); - void AddIntInPlace(int64_t cst) { int_offset_ += cst; } - void AddFloatInPlace(double cst) { double_offset_ += cst; } + std::shared_ptr AddInPlace(std::shared_ptr expr); + std::shared_ptr AddIntInPlace(int64_t cst); + std::shared_ptr AddFloatInPlace(double cst); int num_exprs() const { return exprs_.size(); } int64_t int_offset() const { return int_offset_; } double double_offset() const { return double_offset_; }