// Copyright 2010-2025 Google LLC // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "ortools/math_opt/elemental/elemental.h" #include #include #include #include #include #include #include #include #include "absl/base/attributes.h" #include "absl/base/no_destructor.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "ortools/base/status_builder.h" #include "ortools/base/status_macros.h" #include "ortools/math_opt/elemental/arrays.h" #include "ortools/math_opt/elemental/attr_key.h" #include "ortools/math_opt/elemental/attributes.h" #include "ortools/math_opt/elemental/derived_data.h" #include "ortools/math_opt/elemental/elements.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" #include "pybind11/stl.h" #include "pybind11_protobuf/native_proto_caster.h" namespace py = ::pybind11; namespace { using ::operations_research::math_opt::AttrKey; using ::operations_research::math_opt::is_attr_key_v; // The conversion performed by `PyToCppEnum` is rather expensive as it requires // doing lookups (`PyObject_{Has,Get}AttrString()`), type // checking, and conversions. `PyToCppEnum` is invoked several times by // `pybind11` to select which overload to call. We've found this to eventually // be a bottleneck in benchmarks, costing nearly as much time as actually // solving the problem. However, `PyToCppEnum` is always called with enum // constants, which always have the same underlying python object, so we can // cache the result for efficiency: the conversions are only done once per enum // value, and cache lookups are essentially free (see analysis in cl/692865844). // The cache is shared between all instantiations of `PyToCppEnum`. struct EnumCacheValue { // The name of the C++ enum. std::string enum_name; // The numeric value of the enum. int64_t int_value = -1; }; absl::flat_hash_map& GetEnumCache() { static absl::NoDestructor< absl::flat_hash_map> cache; return *cache; } // Converts a python enum to a C++ enum. template bool PyToCppEnum(PyObject* const py_enum, EnumT& cpp_enum, const absl::string_view enum_name, const int num_enum_values) { // Note: this is not thread-safe, but pybind11 grabs the global lock before // calling into c++, so we never race: // https://pybind11.readthedocs.io/en/stable/advanced/misc.html#global-interpreter-lock-gil const auto [it, inserted] = GetEnumCache().try_emplace(py_enum); EnumCacheValue& enum_value = it->second; if (inserted) { // Cache miss, so we need to do the conversion work. enum_value.enum_name = py_enum->ob_type->tp_name; if (!PyObject_HasAttrString(py_enum, "value")) { // This is not even an enum, so we can't convert it. Does not happen // unless the user is doing something shady. GetEnumCache().erase(it); return false; } // Get the numeric value of the `.value` member. const auto py_value = py::reinterpret_steal( PyObject_GetAttrString(py_enum, "value")); if (!PyLong_Check(py_value.ptr())) { GetEnumCache().erase(it); return false; } enum_value.int_value = py::int_(py_value); } // Check enum name. if (enum_value.enum_name != enum_name) { return false; } // Convert to C++ enum after checking the range to avoid UB. if (enum_value.int_value < 0 || enum_value.int_value >= num_enum_values) { return false; } cpp_enum = static_cast(enum_value.int_value); return true; } // An adapter to convert a string view constant to a char array (pybind11 uses // the latter). template struct AsArray { constexpr AsArray() { #if __cplusplus >= 202002L std::copy(name.begin(), name.end(), array); #else // `std::copy` is not constexpr before C++20. char* p = array; for (const char c : name) { *p++ = c; } *p = '\0'; #endif } char array[name.size() + 1] = {}; }; // An RAII object that allows creating a UTF8 string from a numpy unicode // string. We could use `UnicodeText` for this, but unicode conversion is // already in python, so we don't need to take a dependency on it. class Utf8String { public: Utf8String(const char* unicode, const Py_ssize_t unicode_size_bytes) : unicode_object_( py::reinterpret_borrow(PyUnicode_DecodeUTF32( unicode, unicode_size_bytes, nullptr, nullptr))) {} absl::string_view view() const { // NOTE: `PyUnicode_AsUTF8AndSize` caches the result of the conversion. Py_ssize_t utf8_size; const char* utf8_data = PyUnicode_AsUTF8AndSize(unicode_object_.ptr(), &utf8_size); return absl::string_view(utf8_data, utf8_size); } private: const py::object unicode_object_; }; // A view of a 2d numpy array as ranges of `AttrKey`s. template class AttrKeyArrayView { public: static_assert(is_attr_key_v); using value_type = AttrKeyT; explicit AttrKeyArrayView( const py::array_t& array ABSL_ATTRIBUTE_LIFETIME_BOUND) : array_(array) { if (array.ndim() != 2) { throw std::invalid_argument(absl::StrCat( "array has incorrect number of dimensions: ", array.ndim(), "; expected 2")); } if (array.shape(1) != AttrKeyT::size()) { throw std::invalid_argument( absl::StrCat("expected array shape (..., ", AttrKeyT::size(), "), got (", array.shape(0), ", ", array.shape(1), ")")); } } Py_ssize_t size() const { return array_.shape(0); } AttrKeyT operator[](const int64_t i) const { std::array key_ids; const auto ids = array_.unchecked<2>(); for (int j = 0; j < AttrKeyT::size(); ++j) { key_ids[j] = ids(i, j); } return {key_ids}; } private: const py::array_t& array_; }; // Throws an exception if the status is not OK. // Invalid arguments are converted to `std::invalid_argument`, all other errors // are converted to `std::runtime_error`. void ThrowIfError(const absl::Status& status) { if (!status.ok()) { if (status.code() == absl::StatusCode::kInvalidArgument) { throw std::invalid_argument(std::string(status.message())); } else { throw std::runtime_error(std::string(status.message())); } } } // Throws a `ValueError` if the status is not OK, otherwise forwards the value. template T&& ThrowIfError(absl::StatusOr&& status_or ABSL_ATTRIBUTE_LIFETIME_BOUND) { ThrowIfError(status_or.status()); return std::move(status_or).value(); } // Checks that `strings` is a 1d array of dtype U. absl::Status CheckStringArray(const py::array& strings) { const char dtype = strings.dtype().char_(); if (strings.ndim() == 1 && dtype == 'U') { return absl::OkStatus(); } return util::InvalidArgumentErrorBuilder() << "expected a 1d array of dtype U:, got " << strings.ndim() << "d array of dtype " << dtype; } // Checks that 1d range `values` has no duplicates. template absl::Status CheckForDuplicates(const InRange& values) { // `py::array_t` has no `value_type`. using T = std::remove_cv_t< std::remove_reference_t()[0])>>; absl::flat_hash_set seen; seen.reserve(values.size()); for (int i = 0; i < values.size(); ++i) { if (!seen.insert(values[i]).second) { return util::InvalidArgumentErrorBuilder() << "array has duplicates: " << values[i]; } } return absl::OkStatus(); } // Maps `fn` over random-accessible range `in` and returns an array of the // results. template py::array_t MapToArray(const InRange& in, const FnT& fn) { const int64_t n = static_cast(in.size()); py::array_t out(n); auto out_ref = out.template mutable_unchecked<1>(); for (int i = 0; i < n; ++i) { out_ref[i] = fn(in[i]); } return out; } } // namespace using ::operations_research::math_opt::AttrKey; using ::operations_research::math_opt::AttrTypeDescriptorT; using ::operations_research::math_opt::ElementId; using ::operations_research::math_opt::ElementType; using ::operations_research::math_opt::GetIndexIfAttr; using ::py::arg; namespace PYBIND11_NAMESPACE { namespace detail { // Type caster for `ElementType`. template <> struct type_caster { PYBIND11_TYPE_CASTER(ElementType, const_name("ElementType")); // Python -> C++ conversion. bool load(handle src, bool) { return PyToCppEnum(src.ptr(), value, "ElementType", operations_research::math_opt::kNumElements); } // C++ -> Python conversion. static handle cast(ElementType src, return_value_policy, handle) { return PyLong_FromLong(static_cast(src)); } }; // Type caster for `ElementId`. template struct type_caster> { PYBIND11_TYPE_CASTER(ElementId, const_name("ElementId")); // Python -> C++ conversion. bool load(handle src, bool) { value = ElementId( static_cast(reinterpret_borrow(src.ptr()))); return true; } // C++ -> Python conversion. static handle cast(ElementId src, return_value_policy, handle) { return PyLong_FromLong(src.value()); } }; // Type caster for casting enum values from python enums to C++ `Elemental` // enums. template struct type_caster() >= 0)>> { PYBIND11_TYPE_CASTER( AttrType, const_name(AsArray::kName>().array)); // Python -> C++ conversion. bool load(handle src, bool) { return PyToCppEnum(src.ptr(), value, AttrTypeDescriptorT::kName, AttrTypeDescriptorT::NumAttrs()); } // C++ -> Python conversion. static handle cast(AttrType src, return_value_policy, handle) { return PyLong_FromLong(static_cast(src)); } }; // Type caster for `AttrKey`. This also canonicalizes the keys. template struct type_caster> : array_caster, int64_t, false, n> {}; } // namespace detail } // namespace PYBIND11_NAMESPACE namespace operations_research::math_opt { namespace { template absl::Status ValidateSliceKeyIndex(const AttrType attr, const int key_index) { const int key_size = GetAttrKeySize(); if (key_index < 0 || key_index >= key_size) { return util::InvalidArgumentErrorBuilder() << "key_index must be in [0, " << key_size << ") for attribute: " << attr << " but key_index was: " << key_index; } return absl::OkStatus(); } // A wrapper that turns a compile-time (template) int argument of a template // function into a normal argument so that: // ApplyOnIndex(Fn, idx) == Fn() for 0 <= idx < n. // Other values of idx cause a CHECK-failure. // // The template argument n contains the number of supported values of idx. // ApplyOnIndex instantiates Fn for values of the template argument in [0, n) // and then dispatches between them dynamically. // // Fn must have the same return type for all values of idx, and the return // type must be movable. // // This function take time linear in n. template ), Fn>> R ApplyOnIndex(Fn fn, const int index) { CHECK_GE(index, 0); CHECK_LT(index, n); std::optional result; // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) ForEachIndex([index, &fn, &result]() { if (k == index) { result = fn.template operator()(); } }); CHECK(result.has_value()); return *std::move(result); } // Like Elemental::Slice, but where the slicing index is given at runtime // instead of compile time. Bad values of either key_index or element_id give // Status errors. template absl::StatusOr>> DynamicSlice( const Elemental& e, const AttrType attr, const int key_index, const int element_id) { RETURN_IF_ERROR(ValidateSliceKeyIndex(attr, key_index)); return ApplyOnIndex()>( // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) [&e, attr, element_id]() { return e.Slice(attr, element_id); }, key_index); } // Like Elemental::GetSliceSize, but where the slicing index is given at runtime // instead of compile time. Bad values of either key_index or element_id give // Status errors. template absl::StatusOr DynamicGetSliceSize(const Elemental& e, const AttrType attr, const int key_index, const int element_id) { RETURN_IF_ERROR(ValidateSliceKeyIndex(attr, key_index)); return ApplyOnIndex()>( // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) [&e, attr, element_id]() { return e.GetSliceSize(attr, element_id); }, key_index); } // Converts a vector attribute keys (with each with size key size), with // num_keys elements to a numpy array with shape (num_keys, key_size). template py::array_t ConvertAttrKeysToNpArray( const absl::Span> keys) { const int64_t num_keys = static_cast(keys.size()); py::array_t result( py::array_t::ShapeContainer({num_keys, key_size}), nullptr); auto fast_result = result.mutable_unchecked<2>(); for (int64_t i = 0; i < num_keys; ++i) { for (int j = 0; j < key_size; ++j) { fast_result(i, j) = keys[i][j]; } } return result; } // Checks that the elements in `keys` exist in `e`. template absl::Status CheckForElementExistence( const Elemental& e, const AttrType attr, const AttrKeyArrayView>& keys) { const int64_t num_keys = keys.size(); for (int i = 0; i < num_keys; ++i) { const auto key = keys[i]; for (int j = 0; j < GetAttrKeySize(); ++j) { const auto element_type = GetElementTypes(attr)[j]; if (!e.ElementExistsUntyped(element_type, key[j])) { return util::InvalidArgumentErrorBuilder() << element_type << " id " << key[j] << " does not exist"; } } } return absl::OkStatus(); } absl::StatusOr GetDiffHandle(const Elemental& elemental, const int64_t diff_id) { std::optional handle = elemental.GetDiffHandle(diff_id); if (handle == std::nullopt) { return util::InvalidArgumentErrorBuilder() << "no diff with id: " << diff_id; } return *handle; } template py::array_t slice_attr(const Elemental& e, const AttrType attr, const int key_index, const int element_id) { const std::vector> slice = ThrowIfError(DynamicSlice(e, attr, key_index, element_id)); return ConvertAttrKeysToNpArray(absl::MakeConstSpan(slice)); } template int64_t get_attr_slice_size(const Elemental& e, const AttrType attr, const int key_index, const int element_id) { return ThrowIfError(DynamicGetSliceSize(e, attr, key_index, element_id)); } } // namespace PYBIND11_MODULE(cpp_elemental, py_module) { pybind11_protobuf::ImportNativeProtoCasters(); // Wrap `Elemental`. py::class_> elemental(py_module, "CppElemental"); elemental.def(py::init(), py::kw_only(), arg("model_name") = "", arg("primary_objective_name") = ""); elemental.def_property_readonly("model_name", &Elemental::model_name); elemental.def_property_readonly("primary_objective_name", &Elemental::primary_objective_name); elemental.def("__repr__", [](const Elemental& e) { return e.DebugString(); }); elemental.def( "clone", [](Elemental& e, std::optional new_model_name) { return std::make_unique(e.Clone(new_model_name)); }, py::kw_only(), arg("new_model_name") = std::nullopt); elemental.def( "export_model", [](const Elemental& e, bool remove_names) { return ThrowIfError(e.ExportModel(remove_names)); }, py::kw_only(), arg("remove_names") = false); elemental.def_static( "from_model_proto", [](const ModelProto& proto) { return std::make_unique( ThrowIfError(Elemental::FromModelProto(proto))); }, arg("proto")); elemental.def("add_diff", [](Elemental& e) { return e.AddDiff().id(); }); elemental.def("delete_diff", [](Elemental& e, const int64_t diff_handle) { const Elemental::DiffHandle h = ThrowIfError(GetDiffHandle(e, diff_handle)); CHECK(e.DeleteDiff(h)); }); elemental.def("advance_diff", [](Elemental& e, const int64_t diff_handle) { const Elemental::DiffHandle h = ThrowIfError(GetDiffHandle(e, diff_handle)); CHECK(e.Advance(h)); }); elemental.def( "export_model_update", [](Elemental& e, const int64_t diff_handle, const bool remove_names) { const Elemental::DiffHandle h = ThrowIfError(GetDiffHandle(e, diff_handle)); return ThrowIfError(e.ExportModelUpdate(h, remove_names)); }, arg("diff_handle"), py::kw_only(), arg("remove_names") = false); // Element counting operations. elemental.def("get_num_elements", &Elemental::NumElements, arg("element_type")); elemental.def("get_next_element_id", &Elemental::NextElementId, arg("element_type")); elemental.def("ensure_next_element_id_at_least", &Elemental::EnsureNextElementIdAtLeastUntyped, arg("element_type"), arg("element_id")); // Non-batch element operations. elemental.def("add_element", &Elemental::AddElementUntyped, arg("element_type"), arg("name")); elemental.def("delete_element", &Elemental::DeleteElementUntyped, arg("element_type"), arg("element_id")); elemental.def("element_exists", &Elemental::ElementExistsUntyped, arg("element_type"), arg("element_id")); elemental.def( "get_element_name", [](Elemental& e, const ElementType elem_type, const int64_t id) { return ThrowIfError(e.GetElementNameUntyped(elem_type, id)); }, arg("element_type"), arg("element_id")); // Batch Element operations. elemental.def( "add_elements", [](Elemental& e, const ElementType elem_type, const int num_elements) { py::array_t result(num_elements); auto ids = result.mutable_unchecked<1>(); for (int i = 0; i < num_elements; ++i) { ids(i) = e.AddElementUntyped(elem_type, ""); } return result; }, arg("element_type"), arg("num_elements")); elemental.def( "add_named_elements", // Note: `array_t` only supports POD types, so `np.array(str)` is passed // as type-erased `py::array` (see // https://github.com/pybind/pybind11/issues/1775). Each element is a // `char[array.itemsize]` containing unicode data. [](Elemental& e, const ElementType elem_type, const py::array& names) { ThrowIfError(CheckStringArray(names)); const int64_t num_elements = static_cast(names.shape(0)); const char* unicode_data = static_cast(names.request().ptr); const Py_ssize_t itemsize_bytes = names.request().itemsize; py::array_t result(names.size()); auto ids = result.mutable_unchecked<1>(); for (int i = 0; i < num_elements; ++i) { ids(i) = e.AddElementUntyped( elem_type, Utf8String(unicode_data + i * itemsize_bytes, itemsize_bytes) .view()); } return result; }, arg("element_type"), arg("names")); elemental.def( "delete_elements", [](Elemental& e, const ElementType elem_type, const py::array_t& elements) { ThrowIfError(CheckForDuplicates(elements.unchecked<1>())); return MapToArray(elements.unchecked<1>(), [&](int64_t id) { return e.DeleteElementUntyped(elem_type, id); }); }, arg("element_type"), arg("num_elements")); elemental.def( "elements_exist", [](Elemental& e, const ElementType elem_type, const py::array_t& elements) { return MapToArray(elements.unchecked<1>(), [&](int64_t id) { return e.ElementExistsUntyped(elem_type, id); }); }, arg("element_type"), arg("elements")); elemental.def( "get_element_names", [](Elemental& e, const ElementType elem_type, py::array_t& elements) { const int64_t num_elements = static_cast(elements.shape(0)); auto ids = elements.unchecked<1>(); std::vector names; names.reserve(num_elements); for (int i = 0; i < num_elements; ++i) { names.emplace_back() = ThrowIfError(e.GetElementNameUntyped(elem_type, ids(i))); } return py::array(py::cast(names)); }, arg("element_type"), arg("elements")); elemental.def( "get_elements", [](Elemental& e, ElementType elem_type) { const auto all_elements = e.AllElementsUntyped(elem_type); return py::array_t(all_elements.size(), all_elements.data()); }, arg("element_type")); // Export attribute operations. ForEach( // NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat) [&elemental](const Descriptor&) { using AttrType = typename Descriptor::AttrType; using ValueType = typename Descriptor::ValueType; using Key = AttrKeyFor; elemental.def("clear_attr", &Elemental::AttrClear, arg("attr")); // Get: elemental.def( "get_attr", [](const Elemental& e, AttrType attr, const Key& key) { return ThrowIfError( e.GetAttr(attr, key)); }, arg("attr"), arg("key")); elemental.def( "get_attrs", [](const Elemental& e, const AttrType a, py::array_t keys) { return MapToArray( AttrKeyArrayView(keys), [&](Key key) { return ThrowIfError( e.GetAttr(a, key)); }); }, arg("attr"), arg("keys")); // Set: elemental.def( "set_attr", [](Elemental& e, AttrType attr, const Key& key, const ValueTypeFor value) { return ThrowIfError( e.SetAttr(attr, key, value)); }, arg("attr"), arg("key"), arg("value")); elemental.def( "set_attrs", [](Elemental& e, AttrType attr, const py::array_t& keys, const py::array_t& values) { const AttrKeyArrayView keys_view(keys); // We need to check for duplicates and existence first, as we // don't want to end up with a partially mutated state on error. ThrowIfError(CheckForDuplicates(keys_view)); ThrowIfError(CheckForElementExistence(e, attr, keys_view)); const auto values_view = values.template unchecked<1>(); const int64_t num_elements = keys_view.size(); for (int i = 0; i < num_elements; ++i) { e.SetAttr(attr, keys_view[i], values_view[i]); } }, arg("attr"), arg("key"), arg("value")); // IsNonDefault: elemental.def( "is_attr_non_default", [](const Elemental& e, AttrType attr, const Key& key) { return ThrowIfError( e.AttrIsNonDefault(attr, key)); }, arg("attr"), arg("key")); elemental.def( "bulk_is_attr_non_default", [](const Elemental& e, AttrType attr, const py::array_t& keys) { return MapToArray( AttrKeyArrayView(keys), [&](Key key) { return ThrowIfError( e.AttrIsNonDefault(attr, key)); }); }, arg("attr"), arg("key")); if constexpr (GetAttrKeySize() >= 1) { elemental.def("slice_attr", &slice_attr); elemental.def("get_attr_slice_size", &get_attr_slice_size); } // NumNonDefaults: elemental.def("get_attr_num_non_defaults", &Elemental::AttrNumNonDefaults, arg("attr")); // GetNonDefaults: elemental.def( "get_attr_non_defaults", [](const Elemental& e, AttrType attr) { const std::vector> non_defaults = e.AttrNonDefaults(attr); return ConvertAttrKeysToNpArray( absl::MakeConstSpan(non_defaults)); }, arg("attr")); }, AllAttrTypeDescriptors{}); } } // namespace operations_research::math_opt