Files
ortools-clone/ortools/sat/python/wrappers.cc
Corentin Le Molgat 1b4d75ceb3 sat: backport from main
2025-11-05 13:55:12 +01:00

462 lines
18 KiB
C++

// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/sat/python/wrappers.h"
#include <string>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/base/nullability.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/die_if_null.h"
#include "absl/log/log.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "absl/types/span.h"
#include "google/protobuf/descriptor.h"
namespace operations_research::sat::python {
// A class that generates pybind11 code for a proto message.
class Generator {
public:
struct Context {
static Context TopLevel(const google::protobuf::Descriptor& msg) {
const std::string cpp_name = GetQualifiedCppName(msg);
const std::string shared_name =
absl::StrCat("std::shared_ptr<", cpp_name, ">");
return {.cpp_name = cpp_name, .self_mutable_name = shared_name};
}
static Context Nested(const google::protobuf::Descriptor& msg) {
const std::string cpp_name = GetQualifiedCppName(msg);
return {.cpp_name = cpp_name,
.self_mutable_name = absl::StrCat(cpp_name, "*")};
}
std::string cpp_name;
std::string self_mutable_name;
};
explicit Generator(
absl::Span<const google::protobuf::Descriptor* absl_nonnull const> roots)
: message_stack_(roots.begin(), roots.end()) {
// DFS on root.
while (!message_stack_.empty()) {
const google::protobuf::Descriptor* const msg = message_stack_.back();
message_stack_.pop_back();
if (!visited_messages_.insert(msg).second) continue;
const bool is_top_level = absl::c_linear_search(roots, msg);
current_context_ =
is_top_level ? Context::TopLevel(*msg) : Context::Nested(*msg);
if (is_top_level) {
GenerateTopLevelMessageDecl(*msg);
} else {
GenerateMessageDecl(*msg);
}
GenerateMessageFields(*msg);
absl::StrAppend(&out_, ";\n");
}
// Now generate wrappers for enums, repeated and repeated ptr fields that
// were encountered along the way.
for (const google::protobuf::EnumDescriptor* pb_enum : enum_types_) {
GenerateEnumDecl(*pb_enum);
}
for (const google::protobuf::Descriptor* msg : repeated_ptr_types_) {
GenerateRepeatedPtrDecl(*msg);
}
for (const absl::string_view scalar_type : repeated_scalar_types_) {
GenerateRepeatedScalarDecl(scalar_type);
}
}
std::string Result() && { return std::move(out_); }
private:
template <typename DescriptorT>
static std::string GetQualifiedCppName(const DescriptorT& descriptor) {
return absl::StrReplaceAll(descriptor.full_name(), {{".", "::"}});
}
template <typename DescriptorT>
static std::string GetEscapedName(const DescriptorT& descriptor) {
return absl::StrReplaceAll(descriptor.full_name(), {{".", "_"}});
}
static std::string GetCppType(
const google::protobuf::FieldDescriptor::CppType cpp_type,
const google::protobuf::FieldDescriptor& field) {
switch (cpp_type) {
case google::protobuf::FieldDescriptor::CPPTYPE_INT32:
return "int32_t";
case google::protobuf::FieldDescriptor::CPPTYPE_INT64:
return "int64_t";
case google::protobuf::FieldDescriptor::CPPTYPE_UINT32:
return "uint32_t";
case google::protobuf::FieldDescriptor::CPPTYPE_UINT64:
return "uint64_t";
case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE:
return "double";
case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT:
return "float";
case google::protobuf::FieldDescriptor::CPPTYPE_BOOL:
return "bool";
case google::protobuf::FieldDescriptor::CPPTYPE_ENUM:
return GetQualifiedCppName(*field.enum_type());
case google::protobuf::FieldDescriptor::CPPTYPE_STRING:
return "std::string";
default:
LOG(FATAL) << "Unsupported type: " << cpp_type;
}
}
// Generates a pybind11 wrapper class declaration for a top level message.
void GenerateTopLevelMessageDecl(const google::protobuf::Descriptor& msg) {
CHECK(wrapper_id_.emplace(&msg, wrapper_id_.size()).second)
<< "duplicate message: " << msg.full_name();
absl::SubstituteAndAppend(&out_, R"(
const auto $0 = py::class_<$1, std::shared_ptr<$1>>($2, "$3"))",
GetWrapperName(&msg), current_context_.cpp_name,
GetWrapperName(msg.containing_type()),
msg.name());
// Add constructor and utilities.
absl::SubstituteAndAppend(&out_, R"(
.def(py::init<>())
.def("copy_from",
[](std::shared_ptr<$0> self, std::shared_ptr<$0> other) {
self->CopyFrom(*other);
})
.def("merge_from",
[](std::shared_ptr<$0> self, std::shared_ptr<$0> other) {
self->MergeFrom(*other);
})
.def("merge_text_format",
[](std::shared_ptr<$0> self, const std::string& text) {
return google::protobuf::TextFormat::MergeFromString(text, self.get());
})
.def("parse_text_format",
[](std::shared_ptr<$0> self, const std::string& text) {
return google::protobuf::TextFormat::ParseFromString(text, self.get());
})
.def("__copy__",
[](std::shared_ptr<$0> self) {
return self;
})
.def("__deepcopy__",
[](std::shared_ptr<$0> self, py::dict) {
std::shared_ptr<$0> result = std::make_shared<$0>();
result->CopyFrom(*self);
return result;
})
.def("__str__",
[](std::shared_ptr<$0> self) {
return operations_research::ProtobufDebugString(*self);
}))",
current_context_.cpp_name);
}
// Generates a pybind11 wrapper class declaration for a message.
void GenerateMessageDecl(const google::protobuf::Descriptor& msg) {
CHECK(wrapper_id_.emplace(&msg, wrapper_id_.size()).second)
<< "duplicate message: " << msg.full_name();
absl::SubstituteAndAppend(&out_, R"(
const auto $0 = py::class_<$1>($2, "$3"))",
GetWrapperName(&msg), current_context_.cpp_name,
GetWrapperName(msg.containing_type()),
msg.name());
// Add constructor and utilities.
absl::SubstituteAndAppend(&out_, R"(
.def(py::init<>())
.def("copy_from",
[]($0* self, const $0& other) { self->CopyFrom(other); })
.def("merge_from",
[]($0* self, const $0& other) { self->MergeFrom(other); })
.def("merge_text_format",
[]($0* self, const std::string& text) {
return google::protobuf::TextFormat::MergeFromString(text, self);
})
.def("parse_text_format",
[]($0* self, const std::string& text) {
return google::protobuf::TextFormat::ParseFromString(text, self);
})
.def("__copy__",
[]($0 self) {
return $0(self);
})
.def("__deepcopy__",
[]($0 self, py::dict) {
return $0(self);
})
.def("__str__",
[]($0 self) {
return operations_research::ProtobufDebugString(self);
}))",
current_context_.cpp_name);
}
// Generates a pybind11 wrapper class declaration for an enum.
void GenerateEnumDecl(const google::protobuf::EnumDescriptor& pb_enum) {
absl::SubstituteAndAppend(&out_, R"(
py::enum_<$0>($1, "$2"))",
GetQualifiedCppName(pb_enum),
GetWrapperName(pb_enum.containing_type()),
pb_enum.name());
for (int i = 0; i < pb_enum.value_count(); ++i) {
const google::protobuf::EnumValueDescriptor& value = *pb_enum.value(i);
absl::SubstituteAndAppend(&out_, R"(
.value("$0", $1))",
value.name(), GetQualifiedCppName(value));
}
absl::SubstituteAndAppend(&out_, R"(
.export_values();)");
}
// Generates a pybind11 wrapper class declaration & definitions for a repeated
// ptr.
void GenerateRepeatedPtrDecl(const google::protobuf::Descriptor& msg) {
absl::SubstituteAndAppend(&out_, R"(
py::class_<google::protobuf::RepeatedPtrField<$0>>(m, "repeated_$1")
.def("add",
[](google::protobuf::RepeatedPtrField<$0>* self) {
return self->Add();
},
py::return_value_policy::reference, py::keep_alive<0, 1>())
.def("append", [](google::protobuf::RepeatedPtrField<$0>* self, const $0& value) {
*self->Add() = value;
})
.def("extend",
[](google::protobuf::RepeatedPtrField<$0>* self, const std::vector<$0>& values) {
for (const $0& value : values) {
*self->Add() = value;
}
})
.def("clear", [](google::protobuf::RepeatedPtrField<$0>* self) {
self->Clear();
})
.def("__len__", &google::protobuf::RepeatedPtrField<$0>::size)
.def("__getitem__",
[](google::protobuf::RepeatedPtrField<$0>* self, int index) {
if (index >= self->size()) {
PyErr_SetString(PyExc_IndexError, "Index out of range");
throw py::error_already_set();
}
return self->Mutable(index);
},
py::return_value_policy::reference, py::keep_alive<0, 1>());)",
GetQualifiedCppName(msg), msg.name());
}
// Generates a pybind11 wrapper class declaration & definitions for a repeated
// scalar.
void GenerateRepeatedScalarDecl(absl::string_view scalar_type) {
if (scalar_type == "std::string") {
absl::StrAppend(&out_, R"(
py::class_<google::protobuf::RepeatedPtrField<std::string>>(m, "repeated_scalar_std_string")
.def("append",
[](google::protobuf::RepeatedPtrField<std::string>* self, std::string str) {
self->Add(std::move(str));
})
.def("extend",
[](google::protobuf::RepeatedPtrField<std::string>* self,
const std::vector<std::string>& values) {
self->Add(values.begin(), values.end());
})
.def("clear", [](google::protobuf::RepeatedPtrField<std::string>* self) {
self->Clear();
})
.def("__len__", [](const google::protobuf::RepeatedPtrField<std::string>& self) {
return self.size();
})
.def("__getitem__",
[](const google::protobuf::RepeatedPtrField<std::string>& self, int index) {
if (index >= self.size()) {
PyErr_SetString(PyExc_IndexError, "Index out of range");
throw py::error_already_set();
}
return self.Get(index);
},
py::return_value_policy::copy)
.def("__setitem__",
[](google::protobuf::RepeatedPtrField<std::string>* self,
int index, const std::string& value) {
self->at(index) = value;
})
.def("__str__", [](const google::protobuf::RepeatedPtrField<std::string>& self) {
return absl::StrCat("[", absl::StrJoin(self, ", "), "]");
});)");
} else {
absl::SubstituteAndAppend(
&out_, R"(
py::class_<google::protobuf::RepeatedField<$0>>(m, "repeated_scalar_$1")
.def("append", [](google::protobuf::RepeatedField<$0>* self, $0 value) {
self->Add(value);
})
.def("extend", [](google::protobuf::RepeatedField<$0>* self,
const std::vector<$0>& values) {
self->Add(values.begin(), values.end());
})
.def("clear", [](google::protobuf::RepeatedField<$0>* self) {
self->Clear();
})
.def("__len__", [](const google::protobuf::RepeatedField<$0>& self) {
return self.size();
})
.def("__getitem__", [](const google::protobuf::RepeatedField<$0>& self, int index) {
if (index >= self.size()) {
PyErr_SetString(PyExc_IndexError, "Index out of range");
throw py::error_already_set();
}
return self.Get(index);
})
.def("__setitem__", &google::protobuf::RepeatedField<$0>::Set)
.def("__str__", [](const google::protobuf::RepeatedField<$0>& self) {
return absl::StrCat("[", absl::StrJoin(self, ", "), "]");
});)",
scalar_type, absl::StrReplaceAll(scalar_type, {{"::", "_"}}));
}
}
void GenerateRepeatedField(const google::protobuf::FieldDescriptor& field) {
const google::protobuf::Descriptor* msg_type = field.message_type();
if (msg_type != nullptr) {
// Repeated message.
absl::SubstituteAndAppend(
&out_, R"(
.def_property_readonly(
"$0",
[]($1 self) { return self->mutable_$2(); },
py::return_value_policy::reference, py::keep_alive<0, 1>()))",
field.name(), current_context_.self_mutable_name, field.name());
// We'll need to generate the wrapping for
// `google::protobuf::RepeatedPtrField<$3>`.
repeated_ptr_types_.insert(msg_type);
// We'll need to generate the wrapping for this message type.
message_stack_.push_back(ABSL_DIE_IF_NULL(field.message_type()));
} else {
// Repeated scalar field.
absl::SubstituteAndAppend(&out_, R"(
.def_property_readonly(
"$0",
[]($1 self) { return self->mutable_$0(); },
py::return_value_policy::reference, py::keep_alive<0, 1>()))",
field.name(),
current_context_.self_mutable_name);
// We'll need to generate the wrapping for
// `google::protobuf::RepeatedField<$2>`.
repeated_scalar_types_.insert(GetCppType(field.cpp_type(), field));
}
}
void GenerateSingularField(const google::protobuf::FieldDescriptor& field) {
if (const google::protobuf::Descriptor* msg_type = field.message_type()) {
// Singular message.
absl::SubstituteAndAppend(&out_, R"(
.def_property(
"$0",
[]($1 self) { return self->mutable_$0(); },
[]($1 self, $2 arg) { *self->mutable_$0() = arg; },
py::return_value_policy::reference_internal)
.def("clear_$0", []($1 self) { self->clear_$0(); })
.def("has_$0", []($1 self) { return self->has_$0(); }))",
field.name(),
current_context_.self_mutable_name,
GetQualifiedCppName(*msg_type));
// We'll need to generate the wrapping for this message type.
message_stack_.push_back(ABSL_DIE_IF_NULL(field.message_type()));
} else {
if (const google::protobuf::EnumDescriptor* enum_type =
field.enum_type()) {
enum_types_.insert(enum_type);
}
// Singular scalar (int, string, ...).
absl::SubstituteAndAppend(&out_, R"(
.def_property(
"$0",
[]($1 msg) { return msg->$0(); },
[]($1 msg, $2 arg) { return msg->set_$0(arg); })
.def("clear_$0", []($1 self) { self->clear_$0(); }))",
field.name(),
current_context_.self_mutable_name,
GetCppType(field.cpp_type(), field));
}
}
// Generates definitions for accessing fields of a message.
void GenerateMessageFields(const google::protobuf::Descriptor& msg) {
const std::string msg_name = GetQualifiedCppName(msg);
for (int i = 0; i < msg.field_count(); ++i) {
const google::protobuf::FieldDescriptor& field =
*ABSL_DIE_IF_NULL(msg.field(i));
if (field.is_repeated()) {
GenerateRepeatedField(field);
} else {
GenerateSingularField(field);
}
}
}
// Returns the wrapper name for a message (or "m" if `msg` is null).
// Dies if the scope is not found.
std::string GetWrapperName(const google::protobuf::Descriptor* msg) {
const auto it = wrapper_id_.find(msg);
CHECK(it != wrapper_id_.end())
<< "wrapper id not found: " << msg->full_name();
if (msg == nullptr) return "m";
return absl::StrCat("gen_", it->second);
}
// This identifies the pybind11 wrapper variable for a `_class` declaration in
// the generated code. These are used to generate enums in the correct
// scope.
static constexpr int kNoScope = 0;
absl::flat_hash_map<const google::protobuf::Descriptor*, int> wrapper_id_ = {
{nullptr, kNoScope}};
// Output buffer.
std::string out_;
// Our DFS stack.
std::vector<const google::protobuf::Descriptor* absl_nonnull> message_stack_;
absl::flat_hash_set<const google::protobuf::Descriptor* absl_nonnull>
visited_messages_;
// A list of enum wrappers to generate.
absl::flat_hash_set<const google::protobuf::EnumDescriptor* absl_nonnull>
enum_types_;
// A list of repeated ptr wrappers to generate.
absl::flat_hash_set<const google::protobuf::Descriptor* absl_nonnull>
repeated_ptr_types_;
// A list of repeated scalar wrappers to generate.
absl::flat_hash_set<std::string> repeated_scalar_types_;
// Context for the current message being generated.
Context current_context_;
};
std::string GeneratePybindCode(
absl::Span<const google::protobuf::Descriptor* absl_nonnull const> roots) {
return Generator(roots).Result();
}
} // namespace operations_research::sat::python