diff --git a/ortools/linear_solver/python/model_builder_helper.cc b/ortools/linear_solver/python/model_builder_helper.cc index 214f0b26dc..326e52c2e8 100644 --- a/ortools/linear_solver/python/model_builder_helper.cc +++ b/ortools/linear_solver/python/model_builder_helper.cc @@ -15,6 +15,15 @@ #include "ortools/linear_solver/wrappers/model_builder_helper.h" +#include + +#if PY_VERSION_HEX >= 0x030E00A7 && !defined(PYPY_VERSION) +#define Py_BUILD_CORE +#include "internal/pycore_frame.h" +#include "internal/pycore_interpframe.h" +#undef Py_BUILD_CORE +#endif + #include #include #include @@ -300,6 +309,41 @@ std::shared_ptr WeightedSumArguments( } } +#if PY_VERSION_HEX >= 0x030E00A7 && !defined(PYPY_VERSION) +bool check_unique_temporary(PyObject* op) { + PyFrameObject* frame = PyEval_GetFrame(); + if (frame == NULL) { + return false; + } + _PyInterpreterFrame* f = frame->f_frame; + _PyStackRef* base = _PyFrame_Stackbase(f); + _PyStackRef* stackpointer = f->stackpointer; + + while (stackpointer > base) { + stackpointer--; + if (op == PyStackRef_AsPyObjectBorrow(*stackpointer)) { + // We want detect if the object is a temporary and borrowed. If so, it + // should be only referenced once in the stack, but it should not be safe. + return !PyStackRef_IsHeapSafe(*stackpointer); + } + } + return false; +} + +template +bool IsFree(std::shared_ptr expr) { + PyObject* lhs = py::cast(expr).ptr(); + const int num_uses = Py_REFCNT(lhs); + const bool is_referenced_in_caller_frame = check_unique_temporary(lhs); + return num_uses == 3 && !is_referenced_in_caller_frame; +} +#else +template +bool IsFree(std::shared_ptr expr) { + return Py_REFCNT(py::cast(expr).ptr()) == 4; +} +#endif + PYBIND11_MODULE(model_builder_helper, m) { pybind11_protobuf::ImportNativeProtoCasters(); @@ -434,43 +478,33 @@ PYBIND11_MODULE(model_builder_helper, m) { .def(py::init>, double>()) .def( "__add__", - [](py::object self, + [](std::shared_ptr expr, std::shared_ptr other) -> std::shared_ptr { - const int num_uses = Py_REFCNT(self.ptr()); - std::shared_ptr expr = - self.cast>(); - return (num_uses == 4) ? expr->AddInPlace(other) : expr->Add(other); + return IsFree(expr) ? expr->AddInPlace(other) : expr->Add(other); }, py::arg("other").none(false), "Returns the sum of `self` and `other`.") .def( "__add__", - [](py::object self, double cst) -> std::shared_ptr { - const int num_uses = Py_REFCNT(self.ptr()); - std::shared_ptr expr = - self.cast>(); - return (num_uses == 4) ? expr->AddFloatInPlace(cst) - : expr->AddFloat(cst); + [](std::shared_ptr expr, + double cst) -> std::shared_ptr { + return IsFree(expr) ? expr->AddFloatInPlace(cst) + : expr->AddFloat(cst); }, py::arg("cst"), "Returns `self` + `cst`.") .def( "__radd__", - [](py::object self, + [](std::shared_ptr expr, std::shared_ptr other) -> std::shared_ptr { - const int num_uses = Py_REFCNT(self.ptr()); - std::shared_ptr expr = - self.cast>(); - return (num_uses == 4) ? expr->AddInPlace(other) : expr->Add(other); + return IsFree(expr) ? expr->AddInPlace(other) : expr->Add(other); }, py::arg("cst"), "Returns `self` + `cst`.") .def( "__radd__", - [](py::object self, double cst) -> std::shared_ptr { - const int num_uses = Py_REFCNT(self.ptr()); - std::shared_ptr expr = - self.cast>(); - return (num_uses == 4) ? expr->AddFloatInPlace(cst) - : expr->AddFloat(cst); + [](std::shared_ptr expr, + double cst) -> std::shared_ptr { + return IsFree(expr) ? expr->AddFloatInPlace(cst) + : expr->AddFloat(cst); }, py::arg("cst"), "Returns `self` + `cst`.") .def( @@ -490,23 +524,18 @@ PYBIND11_MODULE(model_builder_helper, m) { py::arg("cst"), "Returns `self` + `cst`.") .def( "__sub__", - [](py::object self, + [](std::shared_ptr expr, std::shared_ptr other) -> std::shared_ptr { - const int num_uses = Py_REFCNT(self.ptr()); - std::shared_ptr expr = - self.cast>(); - return (num_uses == 4) ? expr->AddInPlace(other->Neg()) - : expr->Sub(other); + return IsFree(expr) ? expr->AddInPlace(other->Neg()) + : expr->Sub(other); }, py::arg("other").none(false), "Returns `self` - `other`.") .def( "__sub__", - [](py::object self, double cst) -> std::shared_ptr { - const int num_uses = Py_REFCNT(self.ptr()); - std::shared_ptr expr = - self.cast>(); - return (num_uses == 4) ? expr->AddFloatInPlace(-cst) - : expr->SubFloat(cst); + [](std::shared_ptr expr, + double cst) -> std::shared_ptr { + return IsFree(expr) ? expr->AddFloatInPlace(-cst) + : expr->SubFloat(cst); }, py::arg("cst"), "Returns `self` - `cst`.") .def( diff --git a/ortools/sat/python/BUILD.bazel b/ortools/sat/python/BUILD.bazel index e607ddf06b..2230b4742c 100644 --- a/ortools/sat/python/BUILD.bazel +++ b/ortools/sat/python/BUILD.bazel @@ -138,7 +138,7 @@ py_library( py_test( name = "cp_model_test", - size = "small", + size = "medium", srcs = ["cp_model_test.py"], deps = [ ":cp_model", diff --git a/ortools/sat/python/cp_model_helper.cc b/ortools/sat/python/cp_model_helper.cc index a13d1aa22f..e21a33b602 100644 --- a/ortools/sat/python/cp_model_helper.cc +++ b/ortools/sat/python/cp_model_helper.cc @@ -13,6 +13,13 @@ #include +#if PY_VERSION_HEX >= 0x030E00A7 && !defined(PYPY_VERSION) +#define Py_BUILD_CORE +#include "internal/pycore_frame.h" +#include "internal/pycore_interpframe.h" +#undef Py_BUILD_CORE +#endif + #include #include #include @@ -1107,6 +1114,41 @@ std::shared_ptr CpBaseModel::AddRoutesInternal( return std::make_shared(shared_from_this(), ct_index); } +#if PY_VERSION_HEX >= 0x030E00A7 && !defined(PYPY_VERSION) +bool check_unique_temporary(PyObject* op) { + PyFrameObject* frame = PyEval_GetFrame(); + if (frame == NULL) { + return false; + } + _PyInterpreterFrame* f = frame->f_frame; + _PyStackRef* base = _PyFrame_Stackbase(f); + _PyStackRef* stackpointer = f->stackpointer; + + while (stackpointer > base) { + stackpointer--; + if (op == PyStackRef_AsPyObjectBorrow(*stackpointer)) { + // We want detect if the object is a temporary and borrowed. If so, it + // should be only referenced once in the stack, but it should not be safe. + return !PyStackRef_IsHeapSafe(*stackpointer); + } + } + return false; +} + +template +bool IsFree(std::shared_ptr expr) { + PyObject* lhs = py::cast(expr).ptr(); + const int num_uses = Py_REFCNT(lhs); + const bool is_referenced_in_caller_frame = check_unique_temporary(lhs); + return num_uses == 3 && !is_referenced_in_caller_frame; +} +#else +template +bool IsFree(std::shared_ptr expr) { + return Py_REFCNT(py::cast(expr).ptr()) == 4; +} +#endif + PYBIND11_MODULE(cp_model_helper, m) { py::module::import("ortools.util.python.sorted_interval_list"); @@ -1536,8 +1578,7 @@ PYBIND11_MODULE(cp_model_helper, m) { "__add__", [](std::shared_ptr expr, std::shared_ptr other) -> std::shared_ptr { - const int num_uses = Py_REFCNT(py::cast(expr).ptr()); - return (num_uses == 4) ? expr->AddInPlace(other) : expr->Add(other); + return IsFree(expr) ? expr->AddInPlace(other) : expr->Add(other); }, py::arg("other").none(false), DOC(operations_research, sat, python, LinearExpr, Add)) @@ -1545,18 +1586,15 @@ PYBIND11_MODULE(cp_model_helper, m) { "__add__", [](std::shared_ptr expr, int64_t cst) -> std::shared_ptr { - const int num_uses = Py_REFCNT(py::cast(expr).ptr()); - return (num_uses == 4) ? expr->AddIntInPlace(cst) - : expr->AddInt(cst); + return IsFree(expr) ? expr->AddIntInPlace(cst) : expr->AddInt(cst); }, DOC(operations_research, sat, python, LinearExpr, AddInt)) .def( "__add__", [](std::shared_ptr expr, double cst) -> std::shared_ptr { - const int num_uses = Py_REFCNT(py::cast(expr).ptr()); - return (num_uses == 4) ? expr->AddFloatInPlace(cst) - : expr->AddFloat(cst); + return IsFree(expr) ? expr->AddFloatInPlace(cst) + : expr->AddFloat(cst); }, py::arg("other").none(false), DOC(operations_research, sat, python, LinearExpr, AddFloat)) @@ -1564,8 +1602,7 @@ PYBIND11_MODULE(cp_model_helper, m) { "__radd__", [](std::shared_ptr expr, std::shared_ptr other) -> std::shared_ptr { - const int num_uses = Py_REFCNT(py::cast(expr).ptr()); - return (num_uses == 4) ? expr->AddInPlace(other) : expr->Add(other); + return IsFree(expr) ? expr->AddInPlace(other) : expr->Add(other); }, py::arg("other").none(false), DOC(operations_research, sat, python, LinearExpr, Add)) @@ -1573,9 +1610,7 @@ PYBIND11_MODULE(cp_model_helper, m) { "__radd__", [](std::shared_ptr expr, int64_t cst) -> std::shared_ptr { - const int num_uses = Py_REFCNT(py::cast(expr).ptr()); - return (num_uses == 4) ? expr->AddIntInPlace(cst) - : expr->AddInt(cst); + return IsFree(expr) ? expr->AddIntInPlace(cst) : expr->AddInt(cst); }, py::arg("cst"), DOC(operations_research, sat, python, LinearExpr, AddInt)) @@ -1583,9 +1618,8 @@ PYBIND11_MODULE(cp_model_helper, m) { "__radd__", [](std::shared_ptr expr, double cst) -> std::shared_ptr { - const int num_uses = Py_REFCNT(py::cast(expr).ptr()); - return (num_uses == 4) ? expr->AddFloatInPlace(cst) - : expr->AddFloat(cst); + return IsFree(expr) ? expr->AddFloatInPlace(cst) + : expr->AddFloat(cst); }, py::arg("cst"), DOC(operations_research, sat, python, LinearExpr, AddFloat)) @@ -1616,9 +1650,8 @@ PYBIND11_MODULE(cp_model_helper, m) { "__sub__", [](std::shared_ptr expr, std::shared_ptr other) -> std::shared_ptr { - const int num_uses = Py_REFCNT(py::cast(expr).ptr()); - return (num_uses == 4) ? expr->AddInPlace(other->Neg()) - : expr->Sub(other); + return IsFree(expr) ? expr->AddInPlace(other->Neg()) + : expr->Sub(other); }, py::arg("other").none(false), DOC(operations_research, sat, python, LinearExpr, Sub)) @@ -1626,9 +1659,7 @@ PYBIND11_MODULE(cp_model_helper, m) { "__sub__", [](std::shared_ptr expr, int64_t cst) -> std::shared_ptr { - const int num_uses = Py_REFCNT(py::cast(expr).ptr()); - return (num_uses == 4) ? expr->AddIntInPlace(-cst) - : expr->SubInt(cst); + return IsFree(expr) ? expr->AddIntInPlace(-cst) : expr->SubInt(cst); }, py::arg("cst"), DOC(operations_research, sat, python, LinearExpr, SubInt)) @@ -1636,9 +1667,8 @@ PYBIND11_MODULE(cp_model_helper, m) { "__sub__", [](std::shared_ptr expr, double cst) -> std::shared_ptr { - const int num_uses = Py_REFCNT(py::cast(expr).ptr()); - return (num_uses == 4) ? expr->AddFloatInPlace(-cst) - : expr->SubFloat(cst); + return IsFree(expr) ? expr->AddFloatInPlace(-cst) + : expr->SubFloat(cst); }, py::arg("cst"), DOC(operations_research, sat, python, LinearExpr, SubFloat))