Files
2025-12-15 13:59:54 +01:00

165 lines
5.2 KiB
C++

// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/codegen/gen_python.h"
#include <memory>
#include <set>
#include <string>
#include "absl/base/optimization.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "ortools/math_opt/elemental/codegen/gen.h"
namespace operations_research::math_opt::codegen {
namespace {
const AttrOpFunctionInfos* GetPythonFunctionInfos() {
// We're not generating functions for python, only enums.
static const auto* const kResult = new AttrOpFunctionInfos();
return kResult;
}
// Emits a set of numbered python enumerators for the given range.
void EmitEnumerators(const absl::Span<const absl::string_view> names,
std::string* out) {
for (int i = 0; i < names.size(); ++i) {
absl::StrAppendFormat(out, " %s = %i\n", absl::AsciiStrToUpper(names[i]),
i);
}
}
// Returns the python type for the given value type.
absl::string_view GetAttrPyValueType(
const CodegenAttrTypeDescriptor::ValueType& value_type) {
switch (value_type) {
case CodegenAttrTypeDescriptor::ValueType::kBool:
return "bool";
case CodegenAttrTypeDescriptor::ValueType::kInt64:
return "int";
case CodegenAttrTypeDescriptor::ValueType::kDouble:
return "float";
}
ABSL_UNREACHABLE();
}
// Returns the python type for the given value type.
absl::string_view GetAttrNumpyValueType(
const CodegenAttrTypeDescriptor::ValueType& value_type) {
switch (value_type) {
case CodegenAttrTypeDescriptor::ValueType::kBool:
return "np.bool_";
case CodegenAttrTypeDescriptor::ValueType::kInt64:
return "np.int64";
case CodegenAttrTypeDescriptor::ValueType::kDouble:
return "np.float64";
}
ABSL_UNREACHABLE();
}
class PythonEnumsGenerator : public CodeGenerator {
public:
PythonEnumsGenerator() : CodeGenerator(GetPythonFunctionInfos()) {}
void EmitHeader(std::string* out) const override {
absl::StrAppend(out, R"(
'''DO NOT EDIT: This file is autogenerated.'''
import enum
from typing import Generic, TypeVar, Union
import numpy as np
)");
}
void EmitElements(absl::Span<const absl::string_view> elements,
std::string* out) const override {
// Generate an enum for the elements.
absl::StrAppend(out, "class ElementType(enum.Enum):\n");
EmitEnumerators(elements, out);
absl::StrAppend(out, "\n");
}
void EmitAttributes(absl::Span<const CodegenAttrTypeDescriptor> descriptors,
std::string* out) const override {
absl::StrAppend(out, "\n");
{
// Collect the list of unique types:
std::set<absl::string_view> value_types;
for (const auto& descriptor : descriptors) {
value_types.insert(GetAttrNumpyValueType(descriptor.value_type));
}
// Emit `AttrValueType`, a type variable for all attribute value types.
absl::StrAppend(out, "AttrValueType = TypeVar('AttrValueType', ",
absl::StrJoin(value_types, ", "), ")\n");
}
absl::StrAppend(out, "\n");
{
std::set<absl::string_view> py_value_types;
for (const auto& descriptor : descriptors) {
py_value_types.insert(GetAttrPyValueType(descriptor.value_type));
}
absl::StrAppend(out, "AttrPyValueType = TypeVar('AttrPyValueType', ",
absl::StrJoin(py_value_types, ", "), ")\n");
}
// `Attr` is an attribute with any value type.
absl::StrAppend(out, R"(
class Attr(Generic[AttrValueType]):
pass
)");
// `PyAttr` is an attribute with any value type.
absl::StrAppend(out, R"(
class PyAttr(Generic[AttrPyValueType]):
pass
)");
// Generate an enum for the attribute type.
for (const auto& descriptor : descriptors) {
absl::StrAppendFormat(
out, "\nclass %s(Attr[%s], PyAttr[%s], int, enum.Enum):\n",
descriptor.name, GetAttrNumpyValueType(descriptor.value_type),
GetAttrPyValueType(descriptor.value_type));
EmitEnumerators(descriptor.attribute_names, out);
absl::StrAppend(out, "\n");
}
// Add a type alias for the union of all attribute types.
absl::StrAppend(
out, "AnyAttr = Union[",
absl::StrJoin(
descriptors, ", ",
[](std::string* out, const CodegenAttrTypeDescriptor& descriptor) {
absl::StrAppend(out, descriptor.name);
}),
"]\n");
}
};
} // namespace
std::unique_ptr<CodeGenerator> PythonEnums() {
return std::make_unique<PythonEnumsGenerator>();
}
} // namespace operations_research::math_opt::codegen