math_opt: export from google3

* CMake has not been updated yet
* bazel was compiling at least last week

bazel: disable math opt facility_location.py

missing some dependencies...
This commit is contained in:
Corentin Le Molgat
2022-12-16 17:06:11 +01:00
parent 178084b3f7
commit 5bf70b691f
245 changed files with 28953 additions and 5680 deletions

View File

@@ -298,6 +298,8 @@ endforeach()
if(BUILD_MATH_OPT)
add_subdirectory(ortools/math_opt/core/python)
add_subdirectory(ortools/math_opt/elemental/python)
add_subdirectory(ortools/math_opt/io/python)
add_subdirectory(ortools/math_opt/python)
endif()
@@ -331,7 +333,12 @@ if(BUILD_MATH_OPT)
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/__init__.py CONTENT "")
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/core/__init__.py CONTENT "")
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/core/python/__init__.py CONTENT "")
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/elemental/__init__.py CONTENT "")
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/elemental/python/__init__.py CONTENT "")
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/io/__init__.py CONTENT "")
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/io/python/__init__.py CONTENT "")
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/python/__init__.py CONTENT "")
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/python/elemental/__init__.py CONTENT "")
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/python/ipc/__init__.py CONTENT "")
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/python/testing/__init__.py CONTENT "")
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/solvers/__init__.py CONTENT "")
@@ -366,27 +373,42 @@ file(COPY
ortools/linear_solver/python/model_builder_numbers.py
DESTINATION ${PYTHON_PROJECT_DIR}/linear_solver/python)
if(BUILD_MATH_OPT)
configure_file(
ortools/math_opt/elemental/python/enums.py.in
${PYTHON_PROJECT_DIR}/math_opt/elemental/python/enums.py
COPYONLY)
file(COPY
ortools/math_opt/python/bounded_expressions.py
ortools/math_opt/python/callback.py
ortools/math_opt/python/compute_infeasible_subsystem_result.py
ortools/math_opt/python/errors.py
ortools/math_opt/python/expressions.py
ortools/math_opt/python/from_model.py
ortools/math_opt/python/hash_model_storage.py
ortools/math_opt/python/indicator_constraints.py
ortools/math_opt/python/init_arguments.py
ortools/math_opt/python/linear_constraints.py
ortools/math_opt/python/mathopt.py
ortools/math_opt/python/message_callback.py
ortools/math_opt/python/model.py
ortools/math_opt/python/model_parameters.py
ortools/math_opt/python/model_storage.py
ortools/math_opt/python/normalized_inequality.py
ortools/math_opt/python/normalize.py
ortools/math_opt/python/objectives.py
ortools/math_opt/python/parameters.py
ortools/math_opt/python/quadratic_constraints.py
ortools/math_opt/python/result.py
ortools/math_opt/python/solution.py
ortools/math_opt/python/solve.py
ortools/math_opt/python/solver_resources.py
ortools/math_opt/python/sparse_containers.py
ortools/math_opt/python/statistics.py
ortools/math_opt/python/variables.py
DESTINATION ${PYTHON_PROJECT_DIR}/math_opt/python)
file(COPY
ortools/math_opt/python/elemental/elemental.py
DESTINATION ${PYTHON_PROJECT_DIR}/math_opt/python/elemental)
file(COPY
ortools/math_opt/python/ipc/proto_converter.py
ortools/math_opt/python/ipc/remote_http_solve.py
@@ -663,7 +685,13 @@ add_custom_command(
$<TARGET_FILE:model_builder_helper_pybind11> ${PYTHON_PROJECT}/linear_solver/python
COMMAND ${CMAKE_COMMAND} -E
$<IF:$<BOOL:${BUILD_MATH_OPT}>,copy,true>
$<TARGET_FILE:math_opt_pybind11> ${PYTHON_PROJECT}/math_opt/core/python
$<TARGET_FILE:math_opt_core_pybind11> ${PYTHON_PROJECT}/math_opt/core/python
COMMAND ${CMAKE_COMMAND} -E
$<IF:$<BOOL:${BUILD_MATH_OPT}>,copy,true>
$<TARGET_FILE:math_opt_elemental_pybind11> ${PYTHON_PROJECT}/math_opt/elemental/python
COMMAND ${CMAKE_COMMAND} -E
$<IF:$<BOOL:${BUILD_MATH_OPT}>,copy,true>
$<TARGET_FILE:math_opt_io_pybind11> ${PYTHON_PROJECT}/math_opt/io/python
COMMAND ${CMAKE_COMMAND} -E
$<IF:$<BOOL:${BUILD_MATH_OPT}>,copy,true>
$<TARGET_FILE:status_py_extension_stub> ${PYTHON_PROJECT}/../pybind11_abseil
@@ -697,7 +725,9 @@ add_custom_command(
routing_pybind11
pywraplp
model_builder_helper_pybind11
math_opt_pybind11
$<$<BOOL:${BUILD_MATH_OPT}>:math_opt_core_pybind11>
$<$<BOOL:${BUILD_MATH_OPT}>:math_opt_elemental_pybind11>
$<$<BOOL:${BUILD_MATH_OPT}>:math_opt_io_pybind11>
$<TARGET_NAME_IF_EXISTS:pdlp_pybind11>
cp_model_helper_pybind11
rcpsp_pybind11
@@ -760,6 +790,10 @@ search_python_module(
search_python_module(
NAME wheel
PACKAGE wheel)
search_python_module(
NAME typing_extensions
PACKAGE typing-extensions
NO_VERSION)
add_custom_command(
OUTPUT python/dist_timestamp

View File

@@ -18,6 +18,7 @@ endif()
add_subdirectory(core)
add_subdirectory(constraints)
add_subdirectory(cpp)
add_subdirectory(elemental)
add_subdirectory(io)
add_subdirectory(labs)
add_subdirectory(solver_tests)
@@ -31,6 +32,7 @@ target_sources(${NAME} PUBLIC
$<TARGET_OBJECTS:${NAME}_core>
$<TARGET_OBJECTS:${NAME}_core_c_api>
$<TARGET_OBJECTS:${NAME}_cpp>
$<TARGET_OBJECTS:${NAME}_elemental>
$<TARGET_OBJECTS:${NAME}_io>
$<TARGET_OBJECTS:${NAME}_labs>
$<TARGET_OBJECTS:${NAME}_solvers>

View File

@@ -0,0 +1,18 @@
# math_opt
The code in this directory provides a generic way of accessing mathematical
optimization solvers (sometimes called mathematical programming solvers), such
as GLOP, CP-SAT, SCIP and Gurobi. In particular, a single API is provided to
make these solvers largely interoperable.
New code should prefer MathOpt to `MPSolver`, as defined in
[linear_solver.h](../linear_solver/linear_solver.h)
when possible.
MathOpt has client libraries in C++, Python, and Java that most users should use
to build and solve their models. A proto API is also provided, but this is not
recommended for most users.
See
[parameters.proto](../math_opt/parameters.proto?q=SolverTypeProto)
for the list of supported solvers.

View File

@@ -158,7 +158,17 @@ message CallbackResultProto {
bool is_lazy = 4;
}
// Ends the solve early.
// When true it tells the solver to interrupt the solve as soon as possible.
//
// It can be set from any event. This is equivalent to using a
// SolveInterrupter and triggering it from the callback.
//
// Some solvers don't support interruption, in that case this is simply
// ignored and the solve terminates as usual. On top of that solvers may not
// immediately stop the solve. Thus the user should expect the callback to
// still be called after they set `terminate` to true in a previous
// call. Returning with `terminate` false after having previously returned
// true won't cancel the interruption.
bool terminate = 1;
// TODO(b/172214608): SCIP allows to reject a feasible solution without

View File

@@ -18,10 +18,11 @@ cc_library(
srcs = ["indicator_constraint.cc"],
hdrs = ["indicator_constraint.h"],
deps = [
"//ortools/base:intops",
"//ortools/math_opt/constraints/util:model_util",
"//ortools/math_opt/cpp:variable_and_expressions",
"//ortools/math_opt/elemental:elements",
"//ortools/math_opt/storage:model_storage",
"//ortools/math_opt/storage:model_storage_item",
"@abseil-cpp//absl/strings",
],
)
@@ -51,6 +52,7 @@ cc_library(
"//ortools/math_opt:model_update_cc_proto",
"//ortools/math_opt:sparse_containers_cc_proto",
"//ortools/math_opt/core:sorted",
"//ortools/math_opt/elemental:elements",
"//ortools/math_opt/storage:atomic_constraint_storage",
"//ortools/math_opt/storage:sparse_coefficient_map",
"@abseil-cpp//absl/container:flat_hash_set",

View File

@@ -18,8 +18,6 @@
#include <string>
#include <utility>
#include "absl/strings/string_view.h"
#include "ortools/base/strong_int.h"
#include "ortools/math_opt/constraints/util/model_util.h"
#include "ortools/math_opt/cpp/variable_and_expressions.h"
#include "ortools/math_opt/storage/model_storage.h"
@@ -27,22 +25,22 @@
namespace operations_research::math_opt {
BoundedLinearExpression IndicatorConstraint::ImpliedConstraint() const {
const IndicatorConstraintData& data = storage()->constraint_data(id_);
const IndicatorConstraintData& data = storage()->constraint_data(typed_id());
// NOTE: The following makes a copy of `data.linear_terms`. This can be made
// more efficient if the need arises.
LinearExpression expr = ToLinearExpression(
*storage_, {.coeffs = data.linear_terms, .offset = 0.0});
*storage(), {.coeffs = data.linear_terms, .offset = 0.0});
return data.lower_bound <= std::move(expr) <= data.upper_bound;
}
std::string IndicatorConstraint::ToString() const {
if (!storage()->has_constraint(id_)) {
if (!storage()->has_constraint(typed_id())) {
return std::string(kDeletedConstraintDefaultDescription);
}
const IndicatorConstraintData& data = storage()->constraint_data(id_);
const IndicatorConstraintData& data = storage()->constraint_data(typed_id());
std::stringstream str;
if (data.indicator.has_value()) {
str << Variable(storage_, *data.indicator)
str << Variable(storage(), *data.indicator)
<< (data.activate_on_zero ? " = 0" : " = 1");
} else {
str << "[unset indicator variable]";

View File

@@ -16,38 +16,28 @@
#ifndef OR_TOOLS_MATH_OPT_CONSTRAINTS_INDICATOR_INDICATOR_CONSTRAINT_H_
#define OR_TOOLS_MATH_OPT_CONSTRAINTS_INDICATOR_INDICATOR_CONSTRAINT_H_
#include <cstdint>
#include <optional>
#include <ostream>
#include <string>
#include <vector>
#include "absl/strings/string_view.h"
#include "ortools/base/strong_int.h"
#include "ortools/math_opt/constraints/util/model_util.h"
#include "ortools/math_opt/cpp/variable_and_expressions.h"
#include "ortools/math_opt/elemental/elements.h"
#include "ortools/math_opt/storage/model_storage.h"
#include "ortools/math_opt/storage/model_storage_item.h"
namespace operations_research::math_opt {
// A value type that references an indicator constraint from ModelStorage.
// Usually this type is passed by copy.
//
// This type implements https://abseil.io/docs/cpp/guides/hash.
class IndicatorConstraint {
class IndicatorConstraint final
: public ModelStorageElement<ElementType::kIndicatorConstraint,
IndicatorConstraint> {
public:
// The typed integer used for ids.
using IdType = IndicatorConstraintId;
using ModelStorageElement::ModelStorageElement;
inline IndicatorConstraint(const ModelStorage* storage,
IndicatorConstraintId id);
inline int64_t id() const;
inline IndicatorConstraintId typed_id() const;
inline const ModelStorage* storage() const;
inline absl::string_view name() const;
absl::string_view name() const;
// Returns nullopt if the indicator variable is unset (this is a valid state,
// in which the constraint is functionally ignored).
@@ -65,91 +55,36 @@ class IndicatorConstraint {
// Returns a detailed string description of the contents of the constraint
// (not its name, use `<<` for that instead).
std::string ToString() const;
friend inline bool operator==(const IndicatorConstraint& lhs,
const IndicatorConstraint& rhs);
friend inline bool operator!=(const IndicatorConstraint& lhs,
const IndicatorConstraint& rhs);
template <typename H>
friend H AbslHashValue(H h, const IndicatorConstraint& constraint);
friend std::ostream& operator<<(std::ostream& ostr,
const IndicatorConstraint& constraint);
private:
const ModelStorage* storage_;
IndicatorConstraintId id_;
};
// Streams the name of the constraint, as registered upon constraint creation,
// or a short default if none was provided.
inline std::ostream& operator<<(std::ostream& ostr,
const IndicatorConstraint& constraint);
////////////////////////////////////////////////////////////////////////////////
// Inline function implementations
////////////////////////////////////////////////////////////////////////////////
int64_t IndicatorConstraint::id() const { return id_.value(); }
IndicatorConstraintId IndicatorConstraint::typed_id() const { return id_; }
const ModelStorage* IndicatorConstraint::storage() const { return storage_; }
absl::string_view IndicatorConstraint::name() const {
if (storage_->has_constraint(id_)) {
return storage_->constraint_data(id_).name;
inline absl::string_view IndicatorConstraint::name() const {
if (storage()->has_constraint(typed_id())) {
return storage()->constraint_data(typed_id()).name;
}
return kDeletedConstraintDefaultDescription;
}
std::optional<Variable> IndicatorConstraint::indicator_variable() const {
const std::optional<VariableId> maybe_indicator =
storage_->constraint_data(id_).indicator;
storage()->constraint_data(typed_id()).indicator;
if (!maybe_indicator.has_value()) {
return std::nullopt;
}
return Variable(storage_, *maybe_indicator);
return Variable(storage(), *maybe_indicator);
}
bool IndicatorConstraint::activate_on_zero() const {
return storage_->constraint_data(id_).activate_on_zero;
return storage()->constraint_data(typed_id()).activate_on_zero;
}
std::vector<Variable> IndicatorConstraint::NonzeroVariables() const {
return AtomicConstraintNonzeroVariables(*storage_, id_);
return AtomicConstraintNonzeroVariables(*storage(), typed_id());
}
bool operator==(const IndicatorConstraint& lhs,
const IndicatorConstraint& rhs) {
return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_;
}
bool operator!=(const IndicatorConstraint& lhs,
const IndicatorConstraint& rhs) {
return !(lhs == rhs);
}
template <typename H>
H AbslHashValue(H h, const IndicatorConstraint& constraint) {
return H::combine(std::move(h), constraint.id_.value(), constraint.storage_);
}
std::ostream& operator<<(std::ostream& ostr,
const IndicatorConstraint& constraint) {
// TODO(b/170992529): handle quoting of invalid characters in the name.
const absl::string_view name = constraint.name();
if (name.empty()) {
ostr << "__indic_con#" << constraint.id() << "__";
} else {
ostr << name;
}
return ostr;
}
IndicatorConstraint::IndicatorConstraint(const ModelStorage* const storage,
const IndicatorConstraintId id)
: storage_(storage), id_(id) {}
} // namespace operations_research::math_opt
#endif // OR_TOOLS_MATH_OPT_CONSTRAINTS_INDICATOR_INDICATOR_CONSTRAINT_H_

View File

@@ -19,6 +19,7 @@
#include <string>
#include <vector>
#include "ortools/math_opt/elemental/elements.h"
#include "ortools/math_opt/model.pb.h"
#include "ortools/math_opt/model_update.pb.h"
#include "ortools/math_opt/storage/atomic_constraint_storage.h"
@@ -34,6 +35,8 @@ struct IndicatorConstraintData {
using IdType = IndicatorConstraintId;
using ProtoType = IndicatorConstraintProto;
using UpdatesProtoType = IndicatorConstraintUpdatesProto;
static constexpr ElementType kElementType = ElementType::kIndicatorConstraint;
static constexpr bool kSupportsElemental = true;
// The `in_proto` must be in a valid state; see the inline comments on
// `IndicatorConstraintProto` for details.
@@ -55,6 +58,8 @@ struct IndicatorConstraintData {
template <>
struct AtomicConstraintTraits<IndicatorConstraintId> {
using ConstraintData = IndicatorConstraintData;
static constexpr ElementType kElementType = ElementType::kIndicatorConstraint;
static constexpr bool kSupportsElemental = true;
};
} // namespace operations_research::math_opt

View File

@@ -22,7 +22,9 @@ cc_library(
"//ortools/math_opt/constraints/util:model_util",
"//ortools/math_opt/cpp:key_types",
"//ortools/math_opt/cpp:variable_and_expressions",
"//ortools/math_opt/elemental:elements",
"//ortools/math_opt/storage:model_storage",
"//ortools/math_opt/storage:model_storage_item",
"//ortools/math_opt/storage:sparse_coefficient_map",
"//ortools/math_opt/storage:sparse_matrix",
"@abseil-cpp//absl/log:check",
@@ -54,6 +56,7 @@ cc_library(
"//ortools/math_opt:model_cc_proto",
"//ortools/math_opt:model_update_cc_proto",
"//ortools/math_opt:sparse_containers_cc_proto",
"//ortools/math_opt/elemental:elements",
"//ortools/math_opt/storage:atomic_constraint_storage",
"//ortools/math_opt/storage:model_storage_types",
"//ortools/math_opt/storage:sparse_coefficient_map",

View File

@@ -26,7 +26,7 @@ namespace operations_research::math_opt {
BoundedQuadraticExpression QuadraticConstraint::AsBoundedQuadraticExpression()
const {
QuadraticExpression expression;
const QuadraticConstraintData& data = storage()->constraint_data(id_);
const QuadraticConstraintData& data = storage()->constraint_data(typed_id());
for (const auto [var, coeff] : data.linear_terms.terms()) {
expression += coeff * Variable(storage(), var);
}

View File

@@ -18,19 +18,18 @@
#ifndef OR_TOOLS_MATH_OPT_CONSTRAINTS_QUADRATIC_QUADRATIC_CONSTRAINT_H_
#define OR_TOOLS_MATH_OPT_CONSTRAINTS_QUADRATIC_QUADRATIC_CONSTRAINT_H_
#include <cstdint>
#include <ostream>
#include <sstream>
#include <string>
#include <vector>
#include "absl/log/check.h"
#include "absl/strings/string_view.h"
#include "ortools/base/strong_int.h"
#include "ortools/math_opt/constraints/util/model_util.h"
#include "ortools/math_opt/cpp/key_types.h"
#include "ortools/math_opt/cpp/variable_and_expressions.h"
#include "ortools/math_opt/elemental/elements.h"
#include "ortools/math_opt/storage/model_storage.h"
#include "ortools/math_opt/storage/model_storage_item.h"
#include "ortools/math_opt/storage/sparse_coefficient_map.h"
#include "ortools/math_opt/storage/sparse_matrix.h"
@@ -38,20 +37,11 @@ namespace operations_research::math_opt {
// A value type that references a quadratic constraint from ModelStorage.
// Usually this type is passed by copy.
//
// This type implements https://abseil.io/docs/cpp/guides/hash.
class QuadraticConstraint {
class QuadraticConstraint final
: public ModelStorageElement<ElementType::kQuadraticConstraint,
QuadraticConstraint> {
public:
// The typed integer used for ids.
using IdType = QuadraticConstraintId;
inline QuadraticConstraint(const ModelStorage* storage,
QuadraticConstraintId id);
inline int64_t id() const;
inline QuadraticConstraintId typed_id() const;
inline const ModelStorage* storage() const;
using ModelStorageElement::ModelStorageElement;
inline double lower_bound() const;
inline double upper_bound() const;
@@ -89,47 +79,23 @@ class QuadraticConstraint {
// Returns a detailed string description of the contents of the constraint
// (not its name, use `<<` for that instead).
inline std::string ToString() const;
friend inline bool operator==(const QuadraticConstraint& lhs,
const QuadraticConstraint& rhs);
friend inline bool operator!=(const QuadraticConstraint& lhs,
const QuadraticConstraint& rhs);
template <typename H>
friend H AbslHashValue(H h, const QuadraticConstraint& quadratic_constraint);
friend std::ostream& operator<<(
std::ostream& ostr, const QuadraticConstraint& quadratic_constraint);
private:
const ModelStorage* storage_;
QuadraticConstraintId id_;
};
// Streams the name of the constraint, as registered upon constraint creation,
// or a short default if none was provided.
inline std::ostream& operator<<(std::ostream& ostr,
const QuadraticConstraint& constraint);
////////////////////////////////////////////////////////////////////////////////
// Inline function implementations
////////////////////////////////////////////////////////////////////////////////
int64_t QuadraticConstraint::id() const { return id_.value(); }
QuadraticConstraintId QuadraticConstraint::typed_id() const { return id_; }
const ModelStorage* QuadraticConstraint::storage() const { return storage_; }
double QuadraticConstraint::lower_bound() const {
return storage_->constraint_data(id_).lower_bound;
return storage()->constraint_data(typed_id()).lower_bound;
}
double QuadraticConstraint::upper_bound() const {
return storage_->constraint_data(id_).upper_bound;
return storage()->constraint_data(typed_id()).upper_bound;
}
absl::string_view QuadraticConstraint::name() const {
if (storage_->has_constraint(id_)) {
return storage_->constraint_data(id_).name;
if (storage()->has_constraint(typed_id())) {
return storage()->constraint_data(typed_id()).name;
}
return kDeletedConstraintDefaultDescription;
}
@@ -145,27 +111,31 @@ bool QuadraticConstraint::is_quadratic_coefficient_nonzero(
}
double QuadraticConstraint::linear_coefficient(const Variable variable) const {
CHECK_EQ(variable.storage(), storage_)
CHECK_EQ(variable.storage(), storage())
<< internal::kObjectsFromOtherModelStorage;
return storage_->constraint_data(id_).linear_terms.get(variable.typed_id());
return storage()
->constraint_data(typed_id())
.linear_terms.get(variable.typed_id());
}
double QuadraticConstraint::quadratic_coefficient(
const Variable first_variable, const Variable second_variable) const {
CHECK_EQ(first_variable.storage(), storage_)
CHECK_EQ(first_variable.storage(), storage())
<< internal::kObjectsFromOtherModelStorage;
CHECK_EQ(second_variable.storage(), storage_)
CHECK_EQ(second_variable.storage(), storage())
<< internal::kObjectsFromOtherModelStorage;
return storage_->constraint_data(id_).quadratic_terms.get(
first_variable.typed_id(), second_variable.typed_id());
return storage()
->constraint_data(typed_id())
.quadratic_terms.get(first_variable.typed_id(),
second_variable.typed_id());
}
std::vector<Variable> QuadraticConstraint::NonzeroVariables() const {
return AtomicConstraintNonzeroVariables(*storage_, id_);
return AtomicConstraintNonzeroVariables(*storage(), typed_id());
}
std::string QuadraticConstraint::ToString() const {
if (!storage()->has_constraint(id_)) {
if (!storage()->has_constraint(typed_id())) {
return std::string(kDeletedConstraintDefaultDescription);
}
std::stringstream str;
@@ -173,38 +143,6 @@ std::string QuadraticConstraint::ToString() const {
return str.str();
}
bool operator==(const QuadraticConstraint& lhs,
const QuadraticConstraint& rhs) {
return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_;
}
bool operator!=(const QuadraticConstraint& lhs,
const QuadraticConstraint& rhs) {
return !(lhs == rhs);
}
template <typename H>
H AbslHashValue(H h, const QuadraticConstraint& quadratic_constraint) {
return H::combine(std::move(h), quadratic_constraint.id_.value(),
quadratic_constraint.storage_);
}
std::ostream& operator<<(std::ostream& ostr,
const QuadraticConstraint& constraint) {
// TODO(b/170992529): handle quoting of invalid characters in the name.
const absl::string_view name = constraint.name();
if (name.empty()) {
ostr << "__quad_con#" << constraint.id() << "__";
} else {
ostr << name;
}
return ostr;
}
QuadraticConstraint::QuadraticConstraint(const ModelStorage* const storage,
const QuadraticConstraintId id)
: storage_(storage), id_(id) {}
} // namespace operations_research::math_opt
#endif // OR_TOOLS_MATH_OPT_CONSTRAINTS_QUADRATIC_QUADRATIC_CONSTRAINT_H_

View File

@@ -18,6 +18,7 @@
#include <string>
#include <vector>
#include "ortools/math_opt/elemental/elements.h"
#include "ortools/math_opt/model.pb.h"
#include "ortools/math_opt/model_update.pb.h"
#include "ortools/math_opt/storage/atomic_constraint_storage.h"
@@ -36,6 +37,9 @@ struct QuadraticConstraintData {
using ProtoType = QuadraticConstraintProto;
using UpdatesProtoType = QuadraticConstraintUpdatesProto;
static constexpr ElementType kElementType = ElementType::kQuadraticConstraint;
static constexpr bool kSupportsElemental = true;
// The `in_proto` must be in a valid state; see the inline comments on
// `QuadraticConstraintProto` for details.
static QuadraticConstraintData FromProto(const ProtoType& in_proto);
@@ -53,6 +57,8 @@ struct QuadraticConstraintData {
template <>
struct AtomicConstraintTraits<QuadraticConstraintId> {
using ConstraintData = QuadraticConstraintData;
static constexpr ElementType kElementType = ElementType::kQuadraticConstraint;
static constexpr bool kSupportsElemental = true;
};
} // namespace operations_research::math_opt

View File

@@ -24,6 +24,7 @@ cc_library(
"//ortools/math_opt/cpp:variable_and_expressions",
"//ortools/math_opt/storage:linear_expression_data",
"//ortools/math_opt/storage:model_storage",
"//ortools/math_opt/storage:model_storage_item",
"//ortools/math_opt/storage:model_storage_types",
"@abseil-cpp//absl/strings",
],

View File

@@ -30,7 +30,7 @@
namespace operations_research::math_opt {
LinearExpression SecondOrderConeConstraint::UpperBound() const {
return ToLinearExpression(*storage_,
return ToLinearExpression(*storage(),
storage()->constraint_data(id_).upper_bound);
}
@@ -40,7 +40,7 @@ std::vector<LinearExpression> SecondOrderConeConstraint::ArgumentsToNorm()
std::vector<LinearExpression> args;
args.reserve(data.arguments_to_norm.size());
for (const LinearExpressionData& arg_data : data.arguments_to_norm) {
args.push_back(ToLinearExpression(*storage_, arg_data));
args.push_back(ToLinearExpression(*storage(), arg_data));
}
return args;
}
@@ -58,9 +58,9 @@ std::string SecondOrderConeConstraint::ToString() const {
str << ", ";
}
leading_comma = true;
str << ToLinearExpression(*storage_, arg_data);
str << ToLinearExpression(*storage(), arg_data);
}
str << "}||₂ ≤ " << ToLinearExpression(*storage_, data.upper_bound);
str << "}||₂ ≤ " << ToLinearExpression(*storage(), data.upper_bound);
return str.str();
}

View File

@@ -17,7 +17,6 @@
#define OR_TOOLS_MATH_OPT_CONSTRAINTS_SECOND_ORDER_CONE_SECOND_ORDER_CONE_CONSTRAINT_H_
#include <cstdint>
#include <optional>
#include <ostream>
#include <string>
#include <vector>
@@ -28,6 +27,7 @@
#include "ortools/math_opt/constraints/util/model_util.h"
#include "ortools/math_opt/cpp/variable_and_expressions.h"
#include "ortools/math_opt/storage/model_storage.h"
#include "ortools/math_opt/storage/model_storage_item.h"
#include "ortools/math_opt/storage/model_storage_types.h"
namespace operations_research::math_opt {
@@ -36,18 +36,17 @@ namespace operations_research::math_opt {
// ModelStorage. Usually this type is passed by copy.
//
// This type implements https://abseil.io/docs/cpp/guides/hash.
class SecondOrderConeConstraint {
class SecondOrderConeConstraint final : public ModelStorageItem {
public:
// The typed integer used for ids.
using IdType = SecondOrderConeConstraintId;
inline SecondOrderConeConstraint(const ModelStorage* storage,
inline SecondOrderConeConstraint(ModelStorageCPtr storage,
SecondOrderConeConstraintId id);
inline int64_t id() const;
inline SecondOrderConeConstraintId typed_id() const;
inline const ModelStorage* storage() const;
inline absl::string_view name() const;
@@ -77,7 +76,6 @@ class SecondOrderConeConstraint {
const SecondOrderConeConstraint& constraint);
private:
const ModelStorage* storage_;
SecondOrderConeConstraintId id_;
};
@@ -96,24 +94,20 @@ SecondOrderConeConstraintId SecondOrderConeConstraint::typed_id() const {
return id_;
}
const ModelStorage* SecondOrderConeConstraint::storage() const {
return storage_;
}
absl::string_view SecondOrderConeConstraint::name() const {
if (storage_->has_constraint(id_)) {
return storage_->constraint_data(id_).name;
if (storage()->has_constraint(id_)) {
return storage()->constraint_data(id_).name;
}
return kDeletedConstraintDefaultDescription;
}
std::vector<Variable> SecondOrderConeConstraint::NonzeroVariables() const {
return AtomicConstraintNonzeroVariables(*storage_, id_);
return AtomicConstraintNonzeroVariables(*storage(), id_);
}
bool operator==(const SecondOrderConeConstraint& lhs,
const SecondOrderConeConstraint& rhs) {
return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_;
return lhs.id_ == rhs.id_ && lhs.storage() == rhs.storage();
}
bool operator!=(const SecondOrderConeConstraint& lhs,
@@ -123,7 +117,7 @@ bool operator!=(const SecondOrderConeConstraint& lhs,
template <typename H>
H AbslHashValue(H h, const SecondOrderConeConstraint& constraint) {
return H::combine(std::move(h), constraint.id_.value(), constraint.storage_);
return H::combine(std::move(h), constraint.id_.value(), constraint.storage());
}
std::ostream& operator<<(std::ostream& ostr,
@@ -139,8 +133,8 @@ std::ostream& operator<<(std::ostream& ostr,
}
SecondOrderConeConstraint::SecondOrderConeConstraint(
const ModelStorage* const storage, const SecondOrderConeConstraintId id)
: storage_(storage), id_(id) {}
const ModelStorageCPtr storage, const SecondOrderConeConstraintId id)
: ModelStorageItem(storage), id_(id) {}
} // namespace operations_research::math_opt

View File

@@ -34,6 +34,7 @@ struct SecondOrderConeConstraintData {
using IdType = SecondOrderConeConstraintId;
using ProtoType = SecondOrderConeConstraintProto;
using UpdatesProtoType = SecondOrderConeConstraintUpdatesProto;
static constexpr bool kSupportsElemental = false;
// The `in_proto` must be in a valid state; see the inline comments on
// `SecondOrderConeConstraintProto` for details.
@@ -50,6 +51,7 @@ struct SecondOrderConeConstraintData {
template <>
struct AtomicConstraintTraits<SecondOrderConeConstraintId> {
using ConstraintData = SecondOrderConeConstraintData;
static constexpr bool kSupportsElemental = false;
};
} // namespace operations_research::math_opt

View File

@@ -24,7 +24,9 @@ cc_library(
"//ortools/math_opt/cpp:variable_and_expressions",
"//ortools/math_opt/storage:linear_expression_data",
"//ortools/math_opt/storage:model_storage",
"//ortools/math_opt/storage:model_storage_item",
"//ortools/math_opt/storage:sparse_coefficient_map",
"@abseil-cpp//absl/base:nullability",
"@abseil-cpp//absl/strings",
],
)
@@ -57,6 +59,7 @@ cc_library(
"//ortools/math_opt/cpp:variable_and_expressions",
"//ortools/math_opt/storage:linear_expression_data",
"//ortools/math_opt/storage:model_storage",
"//ortools/math_opt/storage:model_storage_item",
"//ortools/math_opt/storage:sparse_coefficient_map",
"@abseil-cpp//absl/strings",
],

View File

@@ -23,10 +23,10 @@ namespace operations_research::math_opt {
LinearExpression Sos1Constraint::Expression(int index) const {
const LinearExpressionData& storage_expr =
storage_->constraint_data(id_).expression(index);
storage()->constraint_data(id_).expression(index);
LinearExpression out_expr = storage_expr.offset;
for (const auto [var_id, coeff] : storage_expr.coeffs.terms()) {
out_expr += coeff * Variable(storage_, var_id);
out_expr += coeff * Variable(storage(), var_id);
}
return out_expr;
}

View File

@@ -21,12 +21,14 @@
#include <string>
#include <vector>
#include "absl/base/nullability.h"
#include "absl/strings/string_view.h"
#include "ortools/base/strong_int.h"
#include "ortools/math_opt/constraints/sos/util.h"
#include "ortools/math_opt/constraints/util/model_util.h"
#include "ortools/math_opt/cpp/variable_and_expressions.h"
#include "ortools/math_opt/storage/model_storage.h"
#include "ortools/math_opt/storage/model_storage_item.h"
namespace operations_research::math_opt {
@@ -34,17 +36,16 @@ namespace operations_research::math_opt {
// Usually this type is passed by copy.
//
// This type implements https://abseil.io/docs/cpp/guides/hash.
class Sos1Constraint {
class Sos1Constraint final : public ModelStorageItem {
public:
// The typed integer used for ids.
using IdType = Sos1ConstraintId;
inline Sos1Constraint(const ModelStorage* storage, Sos1ConstraintId id);
inline Sos1Constraint(ModelStorageCPtr storage, Sos1ConstraintId id);
inline int64_t id() const;
inline Sos1ConstraintId typed_id() const;
inline const ModelStorage* storage() const;
inline int64_t num_expressions() const;
LinearExpression Expression(int index) const;
@@ -69,7 +70,6 @@ class Sos1Constraint {
const Sos1Constraint& constraint);
private:
const ModelStorage* storage_;
Sos1ConstraintId id_;
};
@@ -86,33 +86,31 @@ int64_t Sos1Constraint::id() const { return id_.value(); }
Sos1ConstraintId Sos1Constraint::typed_id() const { return id_; }
const ModelStorage* Sos1Constraint::storage() const { return storage_; }
int64_t Sos1Constraint::num_expressions() const {
return storage_->constraint_data(id_).num_expressions();
return storage()->constraint_data(id_).num_expressions();
}
bool Sos1Constraint::has_weights() const {
return storage_->constraint_data(id_).has_weights();
return storage()->constraint_data(id_).has_weights();
}
double Sos1Constraint::weight(int index) const {
return storage_->constraint_data(id_).weight(index);
return storage()->constraint_data(id_).weight(index);
}
absl::string_view Sos1Constraint::name() const {
if (storage_->has_constraint(id_)) {
return storage_->constraint_data(id_).name();
if (storage()->has_constraint(id_)) {
return storage()->constraint_data(id_).name();
}
return kDeletedConstraintDefaultDescription;
}
std::vector<Variable> Sos1Constraint::NonzeroVariables() const {
return AtomicConstraintNonzeroVariables(*storage_, id_);
return AtomicConstraintNonzeroVariables(*storage(), id_);
}
bool operator==(const Sos1Constraint& lhs, const Sos1Constraint& rhs) {
return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_;
return lhs.id_ == rhs.id_ && lhs.storage() == rhs.storage();
}
bool operator!=(const Sos1Constraint& lhs, const Sos1Constraint& rhs) {
@@ -121,7 +119,7 @@ bool operator!=(const Sos1Constraint& lhs, const Sos1Constraint& rhs) {
template <typename H>
H AbslHashValue(H h, const Sos1Constraint& constraint) {
return H::combine(std::move(h), constraint.id_.value(), constraint.storage_);
return H::combine(std::move(h), constraint.id_.value(), constraint.storage());
}
std::ostream& operator<<(std::ostream& ostr, const Sos1Constraint& constraint) {
@@ -136,15 +134,15 @@ std::ostream& operator<<(std::ostream& ostr, const Sos1Constraint& constraint) {
}
std::string Sos1Constraint::ToString() const {
if (storage_->has_constraint(id_)) {
if (storage()->has_constraint(id_)) {
return internal::SosConstraintToString(*this, "SOS1");
}
return std::string(kDeletedConstraintDefaultDescription);
}
Sos1Constraint::Sos1Constraint(const ModelStorage* const storage,
Sos1Constraint::Sos1Constraint(const ModelStorageCPtr storage,
const Sos1ConstraintId id)
: storage_(storage), id_(id) {}
: ModelStorageItem(storage), id_(id) {}
} // namespace operations_research::math_opt

View File

@@ -23,10 +23,10 @@ namespace operations_research::math_opt {
LinearExpression Sos2Constraint::Expression(int index) const {
const LinearExpressionData& storage_expr =
storage_->constraint_data(id_).expression(index);
storage()->constraint_data(id_).expression(index);
LinearExpression out_expr = storage_expr.offset;
for (const auto [var_id, coeff] : storage_expr.coeffs.terms()) {
out_expr += coeff * Variable(storage_, var_id);
out_expr += coeff * Variable(storage(), var_id);
}
return out_expr;
}

View File

@@ -27,6 +27,7 @@
#include "ortools/math_opt/constraints/util/model_util.h"
#include "ortools/math_opt/cpp/variable_and_expressions.h"
#include "ortools/math_opt/storage/model_storage.h"
#include "ortools/math_opt/storage/model_storage_item.h"
namespace operations_research::math_opt {
@@ -34,17 +35,16 @@ namespace operations_research::math_opt {
// Usually this type is passed by copy.
//
// This type implements https://abseil.io/docs/cpp/guides/hash.
class Sos2Constraint {
class Sos2Constraint final : public ModelStorageItem {
public:
// The typed integer used for ids.
using IdType = Sos2ConstraintId;
inline Sos2Constraint(const ModelStorage* storage, Sos2ConstraintId id);
inline Sos2Constraint(ModelStorageCPtr storage, Sos2ConstraintId id);
inline int64_t id() const;
inline Sos2ConstraintId typed_id() const;
inline const ModelStorage* storage() const;
inline int64_t num_expressions() const;
LinearExpression Expression(int index) const;
@@ -70,7 +70,6 @@ class Sos2Constraint {
const Sos2Constraint& constraint);
private:
const ModelStorage* storage_;
Sos2ConstraintId id_;
};
@@ -87,33 +86,31 @@ int64_t Sos2Constraint::id() const { return id_.value(); }
Sos2ConstraintId Sos2Constraint::typed_id() const { return id_; }
const ModelStorage* Sos2Constraint::storage() const { return storage_; }
int64_t Sos2Constraint::num_expressions() const {
return storage_->constraint_data(id_).num_expressions();
return storage()->constraint_data(id_).num_expressions();
}
bool Sos2Constraint::has_weights() const {
return storage_->constraint_data(id_).has_weights();
return storage()->constraint_data(id_).has_weights();
}
double Sos2Constraint::weight(int index) const {
return storage_->constraint_data(id_).weight(index);
return storage()->constraint_data(id_).weight(index);
}
absl::string_view Sos2Constraint::name() const {
if (storage_->has_constraint(id_)) {
return storage_->constraint_data(id_).name();
if (storage()->has_constraint(id_)) {
return storage()->constraint_data(id_).name();
}
return kDeletedConstraintDefaultDescription;
}
std::vector<Variable> Sos2Constraint::NonzeroVariables() const {
return AtomicConstraintNonzeroVariables(*storage_, id_);
return AtomicConstraintNonzeroVariables(*storage(), id_);
}
bool operator==(const Sos2Constraint& lhs, const Sos2Constraint& rhs) {
return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_;
return lhs.id_ == rhs.id_ && lhs.storage() == rhs.storage();
}
bool operator!=(const Sos2Constraint& lhs, const Sos2Constraint& rhs) {
@@ -122,7 +119,7 @@ bool operator!=(const Sos2Constraint& lhs, const Sos2Constraint& rhs) {
template <typename H>
H AbslHashValue(H h, const Sos2Constraint& constraint) {
return H::combine(std::move(h), constraint.id_.value(), constraint.storage_);
return H::combine(std::move(h), constraint.id_.value(), constraint.storage());
}
std::ostream& operator<<(std::ostream& ostr, const Sos2Constraint& constraint) {
@@ -137,15 +134,15 @@ std::ostream& operator<<(std::ostream& ostr, const Sos2Constraint& constraint) {
}
std::string Sos2Constraint::ToString() const {
if (storage_->has_constraint(id_)) {
if (storage()->has_constraint(id_)) {
return internal::SosConstraintToString(*this, "SOS2");
}
return std::string(kDeletedConstraintDefaultDescription);
}
Sos2Constraint::Sos2Constraint(const ModelStorage* const storage,
Sos2Constraint::Sos2Constraint(const ModelStorageCPtr storage,
const Sos2ConstraintId id)
: storage_(storage), id_(id) {}
: ModelStorageItem(storage), id_(id) {}
} // namespace operations_research::math_opt

View File

@@ -43,6 +43,7 @@ class SosConstraintData {
using IdType = ConstraintId;
using ProtoType = SosConstraintProto;
using UpdatesProtoType = SosConstraintUpdatesProto;
static constexpr bool kSupportsElemental = false;
static_assert(
std::disjunction_v<std::is_same<ConstraintId, Sos1ConstraintId>,
@@ -101,11 +102,13 @@ using Sos2ConstraintData = internal::SosConstraintData<Sos2ConstraintId>;
template <>
struct AtomicConstraintTraits<Sos1ConstraintId> {
using ConstraintData = Sos1ConstraintData;
static constexpr bool kSupportsElemental = false;
};
template <>
struct AtomicConstraintTraits<Sos2ConstraintId> {
using ConstraintData = Sos2ConstraintData;
static constexpr bool kSupportsElemental = false;
};
////////////////////////////////////////////////////////////////////////////////

View File

@@ -52,7 +52,7 @@ std::vector<Variable> AtomicConstraintNonzeroVariables(
}
// Duck-types on `ConstraintType` having a typedef for the associated `IdType`,
// and having a `(const ModelStorage*, IdType)` constructor.
// and having a `(ModelStorageCPtr, IdType)` constructor.
template <typename ConstraintType>
std::vector<ConstraintType> AtomicConstraints(const ModelStorage& storage) {
using IdType = typename ConstraintType::IdType;

View File

@@ -14,6 +14,7 @@
#ifndef OR_TOOLS_MATH_OPT_CORE_INVALID_INDICATORS_H_
#define OR_TOOLS_MATH_OPT_CORE_INVALID_INDICATORS_H_
#include <cstddef>
#include <cstdint>
#include <vector>

View File

@@ -14,6 +14,7 @@
#ifndef OR_TOOLS_MATH_OPT_CORE_INVERTED_BOUNDS_H_
#define OR_TOOLS_MATH_OPT_CORE_INVERTED_BOUNDS_H_
#include <cstddef>
#include <cstdint>
#include <vector>

View File

@@ -23,6 +23,7 @@
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "ortools/base/logging.h"

View File

@@ -11,29 +11,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
pybind11_add_module(math_opt_pybind11 MODULE solver.cc)
set_target_properties(math_opt_pybind11 PROPERTIES
pybind11_add_module(math_opt_core_pybind11 MODULE solver.cc)
set_target_properties(math_opt_core_pybind11 PROPERTIES
LIBRARY_OUTPUT_NAME "solver")
# note: macOS is APPLE and also UNIX !
if(APPLE)
set_target_properties(math_opt_pybind11 PROPERTIES
set_target_properties(math_opt_core_pybind11 PROPERTIES
SUFFIX ".so"
INSTALL_RPATH
"@loader_path;@loader_path/../../../../${PYTHON_PROJECT}/.libs;@loader_path/../../../../pybind11_abseil")
elseif(UNIX)
set_target_properties(math_opt_pybind11 PROPERTIES
set_target_properties(math_opt_core_pybind11 PROPERTIES
INSTALL_RPATH
"$ORIGIN:$ORIGIN/../../../../${PYTHON_PROJECT}/.libs:$ORIGIN/../../../../pybind11_abseil")
endif()
target_link_libraries(math_opt_pybind11 PRIVATE
target_link_libraries(math_opt_core_pybind11 PRIVATE
${PROJECT_NAMESPACE}::ortools
pybind11_abseil::absl_casters
pybind11_abseil::status_casters
pybind11_native_proto_caster
protobuf::libprotobuf)
add_library(${PROJECT_NAMESPACE}::math_opt_pybind11 ALIAS math_opt_pybind11)
add_library(${PROJECT_NAMESPACE}::math_opt_core_pybind11 ALIAS math_opt_core_pybind11)
if(BUILD_TESTING)
file(GLOB PYTHON_SRCS "*_test.py")

View File

@@ -17,7 +17,7 @@ import threading
from typing import Callable, Optional, Sequence
from absl.testing import absltest
from absl.testing import parameterized
from pybind11_abseil.status import StatusNotOk
from pybind11_abseil import status
from ortools.math_opt import callback_pb2
from ortools.math_opt import model_parameters_pb2
from ortools.math_opt import model_pb2
@@ -116,7 +116,7 @@ class PybindSolverTest(parameterized.TestCase):
with self.assertRaisesRegex(RuntimeError, "id 7 not found"):
_solve_model(model, use_solver_class=use_solver_class)
else:
with self.assertRaisesRegex(StatusNotOk, "id 7 not found"):
with self.assertRaisesRegex(status.StatusNotOk, "id 7 not found"):
_solve_model(model, use_solver_class=use_solver_class)
@parameterized.named_parameters(

View File

@@ -49,7 +49,7 @@ namespace {
// Returns an InternalError with the input status message if the input status is
// not OK.
absl::Status ToInternalError(const absl::Status original) {
absl::Status ToInternalError(absl::Status original) {
if (original.ok()) {
return original;
}
@@ -201,7 +201,7 @@ Solver::ComputeInfeasibleSubsystem(
RETURN_IF_ERROR(ValidateSolveParameters(arguments.parameters))
<< "invalid parameters";
ASSIGN_OR_RETURN(const ComputeInfeasibleSubsystemResultProto result,
ASSIGN_OR_RETURN(ComputeInfeasibleSubsystemResultProto result,
underlying_solver_->ComputeInfeasibleSubsystem(
arguments.parameters, arguments.message_callback,
arguments.interrupter));

View File

@@ -28,7 +28,8 @@ namespace internal {
// This variable is intended to be used by MathOpt unit tests in other languages
// to test the proper garbage collection. It should never be used in any other
// context.
OR_DLL extern std::atomic<int64_t> debug_num_solver;
OR_DLL
extern std::atomic<int64_t> debug_num_solver;
} // namespace internal
} // namespace math_opt

View File

@@ -94,7 +94,9 @@ cc_library(
"//ortools/math_opt/storage:sparse_coefficient_map",
"//ortools/math_opt/storage:sparse_matrix",
"//ortools/util:fp_roundtrip_conv",
"@abseil-cpp//absl/base:nullability",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/log:die_if_null",
"@abseil-cpp//absl/status",
"@abseil-cpp//absl/status:statusor",
"@abseil-cpp//absl/strings",
@@ -113,6 +115,7 @@ cc_library(
"//ortools/base:intops",
"//ortools/base:map_util",
"//ortools/math_opt/storage:model_storage",
"//ortools/math_opt/storage:model_storage_item",
"//ortools/math_opt/storage:model_storage_types",
"//ortools/util:fp_roundtrip_conv",
"@abseil-cpp//absl/base:core_headers",
@@ -130,8 +133,8 @@ cc_library(
deps = [
":key_types",
":variable_and_expressions",
"//ortools/base:intops",
"//ortools/math_opt/storage:model_storage",
"//ortools/math_opt/storage:model_storage_item",
"//ortools/math_opt/storage:model_storage_types",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/strings",
@@ -144,10 +147,11 @@ cc_library(
deps = [
":key_types",
":variable_and_expressions",
"//ortools/base:intops",
"//ortools/math_opt/constraints/util:model_util",
"//ortools/math_opt/storage:model_storage",
"//ortools/math_opt/storage:model_storage_item",
"//ortools/math_opt/storage:model_storage_types",
"@abseil-cpp//absl/container:flat_hash_map",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/strings",
],
@@ -259,6 +263,7 @@ cc_library(
hdrs = ["key_types.h"],
deps = [
"//ortools/math_opt/storage:model_storage",
"//ortools/math_opt/storage:model_storage_item",
"@abseil-cpp//absl/algorithm:container",
"@abseil-cpp//absl/container:flat_hash_map",
"@abseil-cpp//absl/status",

View File

@@ -18,9 +18,7 @@
#define OR_TOOLS_MATH_OPT_CPP_BASIS_STATUS_H_
#include <cstdint>
#include <optional>
#include "absl/types/span.h"
#include "ortools/math_opt/cpp/enums.h" // IWYU pragma: export
#include "ortools/math_opt/solution.pb.h"

View File

@@ -70,7 +70,7 @@ CallbackData::CallbackData(const CallbackEvent event,
const absl::Duration runtime)
: event(event), runtime(runtime) {}
CallbackData::CallbackData(const ModelStorage* storage,
CallbackData::CallbackData(const ModelStorageCPtr storage,
const CallbackDataProto& proto)
// iOS 11 does not support .value() hence we use operator* here and CHECK
// below that we have a value.
@@ -90,7 +90,7 @@ CallbackData::CallbackData(const ModelStorage* storage,
}
absl::Status CallbackRegistration::CheckModelStorage(
const ModelStorage* const expected_storage) const {
const ModelStorageCPtr expected_storage) const {
RETURN_IF_ERROR(mip_node_filter.CheckModelStorage(expected_storage))
<< "invalid mip_node_filter";
RETURN_IF_ERROR(mip_solution_filter.CheckModelStorage(expected_storage))
@@ -113,7 +113,7 @@ CallbackRegistrationProto CallbackRegistration::Proto() const {
}
absl::Status CallbackResult::CheckModelStorage(
const ModelStorage* const expected_storage) const {
const ModelStorageCPtr expected_storage) const {
for (const GeneratedLinearConstraint& constraint : new_constraints) {
RETURN_IF_ERROR(
internal::CheckModelStorage(/*storage=*/constraint.storage(),

View File

@@ -144,7 +144,7 @@ MATH_OPT_DEFINE_ENUM(CallbackEvent, CALLBACK_EVENT_UNSPECIFIED);
struct CallbackRegistration {
// Returns a failure if the referenced variables don't belong to the input
// expected_storage (which must not be nullptr).
absl::Status CheckModelStorage(const ModelStorage* expected_storage) const;
absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const;
// Returns the proto equivalent of this object.
//
@@ -189,7 +189,7 @@ struct CallbackData {
// Users will typically not need this function.
// Will CHECK fail if proto is not valid.
CallbackData(const ModelStorage* storage, const CallbackDataProto& proto);
CallbackData(ModelStorageCPtr storage, const CallbackDataProto& proto);
// The current state of the underlying solver.
CallbackEvent event;
@@ -229,7 +229,7 @@ struct CallbackResult {
BoundedLinearExpression linear_constraint;
bool is_lazy = false;
const ModelStorage* storage() const {
NullableModelStorageCPtr storage() const {
return linear_constraint.expression.storage();
}
};
@@ -249,7 +249,7 @@ struct CallbackResult {
// Returns a failure if the referenced variables don't belong to the input
// expected_storage (which must not be nullptr).
absl::Status CheckModelStorage(const ModelStorage* expected_storage) const;
absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const;
// Returns the proto equivalent of this object.
//
@@ -257,7 +257,17 @@ struct CallbackResult {
// internal consistency of the referenced variables.
CallbackResultProto Proto() const;
// Stop the solve process and return early. Can be called from any event.
// When true it tells the solver to interrupt the solve as soon as possible.
//
// It can be set from any event. This is equivalent to using a
// SolveInterrupter and triggering it from the callback.
//
// Some solvers don't support interruption, in that case this is simply
// ignored and the solve terminates as usual. On top of that solvers may not
// immediately stop the solve. Thus the user should expect the callback to
// still be called after they set `terminate` to true in a previous
// call. Returning with `terminate` false after having previously returned
// true won't cancel the interruption.
bool terminate = false;
// The user cuts and lazy constraints added. Prefer AddUserCut() and

View File

@@ -76,7 +76,7 @@ template <typename K>
absl::Status BoundsMapProtoToCpp(
const google::protobuf::Map<int64_t, ModelSubsetProto::Bounds>& source,
absl::flat_hash_map<K, ModelSubset::Bounds>& target,
const ModelStorage* const model,
const ModelStorageCPtr model,
bool (ModelStorage::* const contains_strong_id)(typename K::IdType id)
const,
const absl::string_view object_name) {
@@ -95,7 +95,7 @@ absl::Status BoundsMapProtoToCpp(
template <typename K>
absl::Status RepeatedIdsProtoToCpp(
const google::protobuf::RepeatedField<int64_t>& source,
absl::flat_hash_set<K>& target, const ModelStorage* const model,
absl::flat_hash_set<K>& target, const ModelStorageCPtr model,
bool (ModelStorage::* const contains_strong_id)(typename K::IdType id)
const,
const absl::string_view object_name) {
@@ -134,7 +134,7 @@ google::protobuf::RepeatedField<int64_t> RepeatedIdsCppToProto(
} // namespace
absl::StatusOr<ModelSubset> ModelSubset::FromProto(
const ModelStorage* const model, const ModelSubsetProto& proto) {
const ModelStorageCPtr model, const ModelSubsetProto& proto) {
ModelSubset model_subset;
RETURN_IF_ERROR(BoundsMapProtoToCpp(proto.variable_bounds(),
model_subset.variable_bounds, model,
@@ -184,7 +184,7 @@ ModelSubsetProto ModelSubset::Proto() const {
}
absl::Status ModelSubset::CheckModelStorage(
const ModelStorage* const expected_storage) const {
const ModelStorageCPtr expected_storage) const {
const auto validate_map_keys =
[expected_storage](const auto& map,
const absl::string_view name) -> absl::Status {
@@ -348,7 +348,7 @@ std::ostream& operator<<(std::ostream& out, const ModelSubset& model_subset) {
absl::StatusOr<ComputeInfeasibleSubsystemResult>
ComputeInfeasibleSubsystemResult::FromProto(
const ModelStorage* const model,
const ModelStorageCPtr model,
const ComputeInfeasibleSubsystemResultProto& result_proto) {
ComputeInfeasibleSubsystemResult result;
const std::optional<FeasibilityStatus> feasibility =
@@ -383,7 +383,7 @@ ComputeInfeasibleSubsystemResultProto ComputeInfeasibleSubsystemResult::Proto()
}
absl::Status ComputeInfeasibleSubsystemResult::CheckModelStorage(
const ModelStorage* const expected_storage) const {
const ModelStorageCPtr expected_storage) const {
return infeasible_subsystem.CheckModelStorage(expected_storage);
}

View File

@@ -64,7 +64,7 @@ struct ModelSubset {
//
// Returns an error when `model` does not contain a variable or constraint
// associated with an index present in `proto`.
static absl::StatusOr<ModelSubset> FromProto(const ModelStorage* model,
static absl::StatusOr<ModelSubset> FromProto(ModelStorageCPtr model,
const ModelSubsetProto& proto);
// Returns the proto equivalent of this object.
@@ -74,8 +74,8 @@ struct ModelSubset {
ModelSubsetProto Proto() const;
// Returns a failure if the `Variable` and Constraints contained in the fields
// do not belong to the input expected_storage (which must not be nullptr).
absl::Status CheckModelStorage(const ModelStorage* expected_storage) const;
// do not belong to the input expected_storage.
absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const;
// True if this object corresponds to the empty subset.
bool empty() const;
@@ -105,7 +105,7 @@ struct ComputeInfeasibleSubsystemResult {
// index present in `proto.infeasible_subsystem`.
// * ValidateComputeInfeasibleSubsystemResultNoModel(result_proto) fails.
static absl::StatusOr<ComputeInfeasibleSubsystemResult> FromProto(
const ModelStorage* model,
ModelStorageCPtr model,
const ComputeInfeasibleSubsystemResultProto& result_proto);
// Returns the proto equivalent of this object.
@@ -116,8 +116,8 @@ struct ComputeInfeasibleSubsystemResult {
ComputeInfeasibleSubsystemResultProto Proto() const;
// Returns a failure if this object contains references to a model other than
// `expected_storage` (which must not be nullptr).
absl::Status CheckModelStorage(const ModelStorage* expected_storage) const;
// `expected_storage`.
absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const;
// The primal feasibility status of the model, as determined by the solver.
FeasibilityStatus feasibility = FeasibilityStatus::kUndetermined;

View File

@@ -25,14 +25,16 @@
//
// A key type K must match the following requirements:
// - K::IdType is a value type used for indices.
// - K has a constructor K(const ModelStorage*, K::IdType).
// - K has a constructor K(ModelStorageCPtr, K::IdType).
// - K is a value-semantic type.
// - K has a function with signature `K::IdType K::typed_id() const`.
// - K has a function with signature `const ModelStorage* K::storage() const`.
// It must return a non-null pointer.
// - K has a function with signature `ModelStorageCPtr K::storage() const`.
// - K::IdType is a valid key for absl::flat_hash_map or absl::flat_hash_set
// (supports hash and ==).
// - the is_key_type_v<> below should include them.
// TODO(b/396580721): Those requirements are those of `ModelStorageElement`.
// Once we've migrated most key types to `ModelStorageElement`, we should be
// able to simplify this code.
#ifndef OR_TOOLS_MATH_OPT_CPP_KEY_TYPES_H_
#define OR_TOOLS_MATH_OPT_CPP_KEY_TYPES_H_
@@ -45,6 +47,7 @@
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "ortools/math_opt/storage/model_storage.h"
#include "ortools/math_opt/storage/model_storage_item.h"
namespace operations_research::math_opt {
@@ -70,11 +73,9 @@ class Objective;
// the values in the hash map are in the math_opt namespace.
template <typename T>
constexpr inline bool is_key_type_v =
(std::is_same_v<T, Variable> || std::is_same_v<T, LinearConstraint> ||
std::is_same_v<T, QuadraticConstraint> ||
(is_model_storage_element<T>::value ||
std::is_same_v<T, SecondOrderConeConstraint> ||
std::is_same_v<T, Sos1Constraint> || std::is_same_v<T, Sos2Constraint> ||
std::is_same_v<T, IndicatorConstraint> ||
std::is_same_v<T, QuadraticTermKey> || std::is_same_v<T, Objective>);
// Returns the keys of the map sorted by their (storage(), type_id()).
@@ -162,12 +163,12 @@ inline constexpr absl::string_view kInputFromInvalidModelStorage =
"the input does not belong to the same model";
// Returns a failure when the input pointer is not nullptr and points to a
// different model storage than expected_storage (which must not be nullptr).
// different model storage than expected_storage.
//
// Failure message is kInputFromInvalidModelStorage.
inline absl::Status CheckModelStorage(
const ModelStorage* const storage,
const ModelStorage* const expected_storage) {
inline absl::Status CheckModelStorage(const NullableModelStorageCPtr storage,
const ModelStorageCPtr expected_storage) {
// This is not allowed by the contract, but let's be safe.
if (expected_storage == nullptr) {
return absl::InternalError("expected_storage is nullptr");
}

View File

@@ -18,19 +18,18 @@
#ifndef OR_TOOLS_MATH_OPT_CPP_LINEAR_CONSTRAINT_H_
#define OR_TOOLS_MATH_OPT_CPP_LINEAR_CONSTRAINT_H_
#include <cstdint>
#include <ostream>
#include <sstream>
#include <string>
#include <utility>
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/strings/string_view.h"
#include "ortools/base/strong_int.h"
#include "ortools/math_opt/constraints/util/model_util.h"
#include "ortools/math_opt/cpp/key_types.h"
#include "ortools/math_opt/cpp/variable_and_expressions.h"
#include "ortools/math_opt/storage/model_storage.h"
#include "ortools/math_opt/storage/model_storage_item.h"
#include "ortools/math_opt/storage/model_storage_types.h"
namespace operations_research {
@@ -38,19 +37,11 @@ namespace math_opt {
// A value type that references a linear constraint from ModelStorage. Usually
// this type is passed by copy.
//
// This type implements https://abseil.io/docs/cpp/guides/hash.
class LinearConstraint {
class LinearConstraint final
: public ModelStorageElement<ElementType::kLinearConstraint,
LinearConstraint> {
public:
// The typed integer used for ids.
using IdType = LinearConstraintId;
inline LinearConstraint(const ModelStorage* storage, LinearConstraintId id);
inline int64_t id() const;
inline LinearConstraintId typed_id() const;
inline const ModelStorage* storage() const;
using ModelStorageElement::ModelStorageElement;
inline double lower_bound() const;
inline double upper_bound() const;
@@ -76,65 +67,42 @@ class LinearConstraint {
// Returns a detailed string description of the contents of the constraint
// (not its name, use `<<` for that instead).
inline std::string ToString() const;
friend inline bool operator==(const LinearConstraint& lhs,
const LinearConstraint& rhs);
friend inline bool operator!=(const LinearConstraint& lhs,
const LinearConstraint& rhs);
template <typename H>
friend H AbslHashValue(H h, const LinearConstraint& linear_constraint);
friend std::ostream& operator<<(std::ostream& ostr,
const LinearConstraint& linear_constraint);
private:
const ModelStorage* storage_;
LinearConstraintId id_;
};
template <typename V>
using LinearConstraintMap = absl::flat_hash_map<LinearConstraint, V>;
// Streams the name of the constraint, as registered upon constraint creation,
// or a short default if none was provided.
inline std::ostream& operator<<(std::ostream& ostr,
const LinearConstraint& linear_constraint);
////////////////////////////////////////////////////////////////////////////////
// Inline function implementations
////////////////////////////////////////////////////////////////////////////////
int64_t LinearConstraint::id() const { return id_.value(); }
LinearConstraintId LinearConstraint::typed_id() const { return id_; }
const ModelStorage* LinearConstraint::storage() const { return storage_; }
double LinearConstraint::lower_bound() const {
return storage_->linear_constraint_lower_bound(id_);
return storage()->linear_constraint_lower_bound(typed_id());
}
double LinearConstraint::upper_bound() const {
return storage_->linear_constraint_upper_bound(id_);
return storage()->linear_constraint_upper_bound(typed_id());
}
absl::string_view LinearConstraint::name() const {
if (storage()->has_linear_constraint(id_)) {
return storage_->linear_constraint_name(id_);
if (storage()->has_linear_constraint(typed_id())) {
return storage()->linear_constraint_name(typed_id());
}
return kDeletedConstraintDefaultDescription;
}
bool LinearConstraint::is_coefficient_nonzero(const Variable variable) const {
CHECK_EQ(variable.storage(), storage_)
CHECK_EQ(variable.storage(), storage())
<< internal::kObjectsFromOtherModelStorage;
return storage_->is_linear_constraint_coefficient_nonzero(
id_, variable.typed_id());
return storage()->is_linear_constraint_coefficient_nonzero(
typed_id(), variable.typed_id());
}
double LinearConstraint::coefficient(const Variable variable) const {
CHECK_EQ(variable.storage(), storage_)
CHECK_EQ(variable.storage(), storage())
<< internal::kObjectsFromOtherModelStorage;
return storage_->linear_constraint_coefficient(id_, variable.typed_id());
return storage()->linear_constraint_coefficient(typed_id(),
variable.typed_id());
}
BoundedLinearExpression LinearConstraint::AsBoundedLinearExpression() const {
@@ -150,7 +118,7 @@ BoundedLinearExpression LinearConstraint::AsBoundedLinearExpression() const {
}
std::string LinearConstraint::ToString() const {
if (!storage()->has_linear_constraint(id_)) {
if (!storage()->has_linear_constraint(typed_id())) {
return std::string(kDeletedConstraintDefaultDescription);
}
std::stringstream str;
@@ -158,36 +126,6 @@ std::string LinearConstraint::ToString() const {
return str.str();
}
bool operator==(const LinearConstraint& lhs, const LinearConstraint& rhs) {
return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_;
}
bool operator!=(const LinearConstraint& lhs, const LinearConstraint& rhs) {
return !(lhs == rhs);
}
template <typename H>
H AbslHashValue(H h, const LinearConstraint& linear_constraint) {
return H::combine(std::move(h), linear_constraint.id_.value(),
linear_constraint.storage_);
}
std::ostream& operator<<(std::ostream& ostr,
const LinearConstraint& linear_constraint) {
// TODO(b/170992529): handle quoting of invalid characters in the name.
const absl::string_view name = linear_constraint.name();
if (name.empty()) {
ostr << "__lin_con#" << linear_constraint.id() << "__";
} else {
ostr << name;
}
return ostr;
}
LinearConstraint::LinearConstraint(const ModelStorage* const storage,
const LinearConstraintId id)
: storage_(storage), id_(id) {}
} // namespace math_opt
} // namespace operations_research

View File

@@ -100,7 +100,7 @@ struct MapFilter {
// Returns a failure if the keys don't belong to the input expected_storage
// (which must not be nullptr).
inline absl::Status CheckModelStorage(
const ModelStorage* expected_storage) const;
ModelStorageCPtr expected_storage) const;
// Returns the proto corresponding to this filter.
//
@@ -192,7 +192,7 @@ MapFilter<KeyType> MakeKeepKeysFilter(std::initializer_list<KeyType> keys) {
template <typename KeyType>
absl::Status MapFilter<KeyType>::CheckModelStorage(
const ModelStorage* expected_storage) const {
const ModelStorageCPtr expected_storage) const {
if (!filtered_keys.has_value()) {
return absl::OkStatus();
}

View File

@@ -25,6 +25,7 @@
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "gtest/gtest.h"
@@ -871,7 +872,7 @@ std::vector<TerminationReason> CompatibleReasons(
}
Matcher<std::vector<Solution>> CheckSolutions(
const std::vector<Solution>& expected_solutions,
absl::Span<const Solution> expected_solutions,
const SolveResultMatcherOptions& options) {
if (options.first_solution_only && !expected_solutions.empty()) {
return FirstElementIs(

View File

@@ -37,7 +37,7 @@ class PrinterMessageCallbackImpl {
const absl::string_view prefix)
: output_stream_(output_stream), prefix_(prefix) {}
void Call(const std::vector<std::string>& messages) {
void Call(absl::Span<const std::string> messages) {
const absl::MutexLock lock(&mutex_);
for (const std::string& message : messages) {
output_stream_ << prefix_ << message << '\n';
@@ -56,7 +56,7 @@ void PushBack(absl::Span<const std::string> messages,
sink->insert(sink->end(), messages.begin(), messages.end());
}
void PushBack(const std::vector<std::string>& messages,
void PushBack(absl::Span<const std::string> messages,
google::protobuf::RepeatedPtrField<std::string>* const sink) {
std::copy(messages.begin(), messages.end(),
google::protobuf::RepeatedFieldBackInserter(sink));
@@ -68,7 +68,7 @@ class VectorLikeMessageCallbackImpl {
explicit VectorLikeMessageCallbackImpl(Sink* const sink)
: sink_(ABSL_DIE_IF_NULL(sink)) {}
void Call(const std::vector<std::string>& messages) {
void Call(absl::Span<const std::string> messages) {
const absl::MutexLock lock(&mutex_);
PushBack(messages, sink_);
}
@@ -102,7 +102,7 @@ MessageCallback InfoLoggerMessageCallback(const absl::string_view prefix,
MessageCallback VLoggerMessageCallback(int level, absl::string_view prefix,
absl::SourceLocation loc) {
return [=](const std::vector<std::string>& messages) {
return [=](absl::Span<const std::string> messages) {
for (const std::string& message : messages) {
VLOG(level).AtLocation(loc.file_name(), loc.line()) << prefix << message;
}
@@ -117,8 +117,7 @@ MessageCallback VectorMessageCallback(std::vector<std::string>* sink) {
const auto impl =
std::make_shared<VectorLikeMessageCallbackImpl<std::vector<std::string>>>(
sink);
return
[=](const std::vector<std::string>& messages) { impl->Call(messages); };
return [=](absl::Span<const std::string> messages) { impl->Call(messages); };
}
MessageCallback RepeatedPtrFieldMessageCallback(

View File

@@ -22,7 +22,9 @@
#include <utility>
#include <vector>
#include "absl/base/nullability.h"
#include "absl/log/check.h"
#include "absl/log/die_if_null.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
@@ -53,7 +55,7 @@ constexpr double kInf = std::numeric_limits<double>::infinity();
absl::StatusOr<std::unique_ptr<Model>> Model::FromModelProto(
const ModelProto& model_proto) {
ASSIGN_OR_RETURN(std::unique_ptr<ModelStorage> storage,
ASSIGN_OR_RETURN(absl::Nonnull<std::unique_ptr<ModelStorage>> storage,
ModelStorage::FromModelProto(model_proto));
return std::make_unique<Model>(std::move(storage));
}
@@ -61,10 +63,10 @@ absl::StatusOr<std::unique_ptr<Model>> Model::FromModelProto(
Model::Model(const absl::string_view name)
: storage_(std::make_shared<ModelStorage>(name)) {}
Model::Model(std::unique_ptr<ModelStorage> storage)
: storage_(std::move(storage)) {}
Model::Model(absl::Nonnull<std::unique_ptr<ModelStorage>> storage)
: storage_(ABSL_DIE_IF_NULL(std::move(storage))) {}
std::unique_ptr<Model> Model::Clone(
absl::Nonnull<std::unique_ptr<Model>> Model::Clone(
const std::optional<absl::string_view> new_name) const {
return std::make_unique<Model>(storage_->Clone(new_name));
}

View File

@@ -23,6 +23,7 @@
#include <ostream>
#include <vector>
#include "absl/base/nullability.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
@@ -136,7 +137,7 @@ class Model {
// This constructor is used when loading a model, for example from a
// ModelProto or an MPS file. Note that in those cases the FromModelProto()
// should be used.
explicit Model(std::unique_ptr<ModelStorage> storage);
explicit Model(absl::Nonnull<std::unique_ptr<ModelStorage>> storage);
Model(const Model&) = delete;
Model& operator=(const Model&) = delete;
@@ -158,7 +159,7 @@ class Model {
// * in an arbitrary order using Variables() and LinearConstraints().
//
// Note that the returned model does not have any update tracker.
std::unique_ptr<Model> Clone(
absl::Nonnull<std::unique_ptr<Model>> Clone(
std::optional<absl::string_view> new_name = std::nullopt) const;
inline absl::string_view name() const;
@@ -893,13 +894,13 @@ class Model {
//
// This API is for internal use only and regular users should have no need for
// it.
const ModelStorage* storage() const { return storage_.get(); }
ModelStorageCPtr storage() const { return storage_.get(); }
// Returns a pointer to the underlying model storage.
//
// This API is for internal use only and regular users should have no need for
// it.
ModelStorage* storage() { return storage_.get(); }
ModelStoragePtr storage() { return storage_.get(); }
// Prints the objective, the constraints and the variables of the model over
// several lines in a human-readable way. Includes a new line at the end of
@@ -911,12 +912,12 @@ class Model {
// points to the same model as storage_.
//
// Use CheckModel() when nullptr is not a valid value.
inline void CheckOptionalModel(const ModelStorage* other_storage) const;
inline void CheckOptionalModel(NullableModelStorageCPtr other_storage) const;
// Asserts (with CHECK) that the input pointer is the same as storage_.
//
// Use CheckOptionalModel() if nullptr is a valid value too.
inline void CheckModel(const ModelStorage* other_storage) const;
inline void CheckModel(ModelStorageCPtr other_storage) const;
// Don't use storage_ directly; prefer to use storage() so that const member
// functions don't have modifying access to the underlying storage.
@@ -924,7 +925,7 @@ class Model {
// We use a shared_ptr here so that the UpdateTracker class can have a
// weak_ptr on the ModelStorage. This let it have a destructor that don't
// crash when called after the destruction of the associated Model.
const std::shared_ptr<ModelStorage> storage_;
const absl::Nonnull<std::shared_ptr<ModelStorage>> storage_;
};
////////////////////////////////////////////////////////////////////////////////
@@ -973,7 +974,7 @@ int64_t Model::next_variable_id() const {
}
bool Model::has_variable(const int64_t id) const {
return has_variable(VariableId(id));
return id < 0 ? false : has_variable(VariableId(id));
}
bool Model::has_variable(const VariableId id) const {
@@ -1074,7 +1075,7 @@ int64_t Model::next_linear_constraint_id() const {
}
bool Model::has_linear_constraint(const int64_t id) const {
return has_linear_constraint(LinearConstraintId(id));
return id < 0 ? false : has_linear_constraint(LinearConstraintId(id));
}
bool Model::has_linear_constraint(const LinearConstraintId id) const {
@@ -1082,6 +1083,7 @@ bool Model::has_linear_constraint(const LinearConstraintId id) const {
}
LinearConstraint Model::linear_constraint(const int64_t id) const {
CHECK_GE(id, 0) << "negative linear constraint id: " << id;
return linear_constraint(LinearConstraintId(id));
}
@@ -1176,7 +1178,7 @@ int64_t Model::next_quadratic_constraint_id() const {
}
bool Model::has_quadratic_constraint(const int64_t id) const {
return has_quadratic_constraint(QuadraticConstraintId(id));
return id < 0 ? false : has_quadratic_constraint(QuadraticConstraintId(id));
}
bool Model::has_quadratic_constraint(const QuadraticConstraintId id) const {
@@ -1184,6 +1186,7 @@ bool Model::has_quadratic_constraint(const QuadraticConstraintId id) const {
}
QuadraticConstraint Model::quadratic_constraint(const int64_t id) const {
CHECK_GE(id, 0) << "negative quadratic constraint id: " << id;
return quadratic_constraint(QuadraticConstraintId(id));
}
@@ -1219,7 +1222,9 @@ int64_t Model::next_second_order_cone_constraint_id() const {
}
bool Model::has_second_order_cone_constraint(const int64_t id) const {
return has_second_order_cone_constraint(SecondOrderConeConstraintId(id));
return id < 0 ? false
: has_second_order_cone_constraint(
SecondOrderConeConstraintId(id));
}
bool Model::has_second_order_cone_constraint(
@@ -1265,7 +1270,7 @@ int64_t Model::next_sos1_constraint_id() const {
}
bool Model::has_sos1_constraint(const int64_t id) const {
return has_sos1_constraint(Sos1ConstraintId(id));
return id < 0 ? false : has_sos1_constraint(Sos1ConstraintId(id));
}
bool Model::has_sos1_constraint(const Sos1ConstraintId id) const {
@@ -1306,7 +1311,7 @@ int64_t Model::next_sos2_constraint_id() const {
}
bool Model::has_sos2_constraint(const int64_t id) const {
return has_sos2_constraint(Sos2ConstraintId(id));
return id < 0 ? false : has_sos2_constraint(Sos2ConstraintId(id));
}
bool Model::has_sos2_constraint(const Sos2ConstraintId id) const {
@@ -1347,7 +1352,7 @@ int64_t Model::next_indicator_constraint_id() const {
}
bool Model::has_indicator_constraint(const int64_t id) const {
return has_indicator_constraint(IndicatorConstraintId(id));
return id < 0 ? false : has_indicator_constraint(IndicatorConstraintId(id));
}
bool Model::has_indicator_constraint(const IndicatorConstraintId id) const {
@@ -1555,7 +1560,7 @@ int64_t Model::next_auxiliary_objective_id() const {
}
bool Model::has_auxiliary_objective(const int64_t id) const {
return has_auxiliary_objective(AuxiliaryObjectiveId(id));
return id < 0 ? false : has_auxiliary_objective(AuxiliaryObjectiveId(id));
}
bool Model::has_auxiliary_objective(const AuxiliaryObjectiveId id) const {
@@ -1563,6 +1568,7 @@ bool Model::has_auxiliary_objective(const AuxiliaryObjectiveId id) const {
}
Objective Model::auxiliary_objective(const int64_t id) const {
CHECK_GE(id, 0) << "negative auxiliary objective id: " << id;
return auxiliary_objective(AuxiliaryObjectiveId(id));
}
@@ -1618,14 +1624,15 @@ void Model::set_is_maximize(const Objective objective, const bool is_maximize) {
storage()->set_is_maximize(objective.typed_id(), is_maximize);
}
void Model::CheckOptionalModel(const ModelStorage* const other_storage) const {
void Model::CheckOptionalModel(
const NullableModelStorageCPtr other_storage) const {
if (other_storage != nullptr) {
CHECK_EQ(other_storage, storage())
<< internal::kObjectsFromOtherModelStorage;
}
}
void Model::CheckModel(const ModelStorage* const other_storage) const {
void Model::CheckModel(const ModelStorageCPtr other_storage) const {
CHECK_EQ(other_storage, storage()) << internal::kObjectsFromOtherModelStorage;
}

View File

@@ -58,7 +58,7 @@ ModelSolveParameters ModelSolveParameters::OnlySomePrimalVariables(
}
absl::Status ModelSolveParameters::CheckModelStorage(
const ModelStorage* const expected_storage) const {
const ModelStorageCPtr expected_storage) const {
for (const SolutionHint& hint : solution_hints) {
RETURN_IF_ERROR(hint.CheckModelStorage(expected_storage))
<< "invalid hint in solution_hints";
@@ -100,7 +100,7 @@ absl::Status ModelSolveParameters::CheckModelStorage(
}
absl::Status ModelSolveParameters::SolutionHint::CheckModelStorage(
const ModelStorage* expected_storage) const {
const ModelStorageCPtr expected_storage) const {
for (const auto& [v, _] : variable_values) {
RETURN_IF_ERROR(internal::CheckModelStorage(
/*storage=*/v.storage(),

View File

@@ -128,7 +128,7 @@ struct ModelSolveParameters {
// Returns a failure if the referenced variables and constraints don't
// belong to the input expected_storage (which must not be nullptr).
absl::Status CheckModelStorage(const ModelStorage* expected_storage) const;
absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const;
// Returns the proto equivalent of this object.
//
@@ -215,7 +215,7 @@ struct ModelSolveParameters {
// Returns a failure if the referenced variables and constraints do not belong
// to the input expected_storage (which must not be nullptr).
absl::Status CheckModelStorage(const ModelStorage* expected_storage) const;
absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const;
// Returns the proto equivalent of this object.
//

View File

@@ -31,18 +31,18 @@ LinearExpression Objective::AsLinearExpression() const {
<< "The objective function contains quadratic terms and cannot be "
"represented as a LinearExpression";
LinearExpression objective = offset();
for (const auto [raw_var_id, coeff] : storage_->linear_objective(id_)) {
objective += coeff * Variable(storage_, raw_var_id);
for (const auto [raw_var_id, coeff] : storage()->linear_objective(id_)) {
objective += coeff * Variable(storage(), raw_var_id);
}
return objective;
}
QuadraticExpression Objective::AsQuadraticExpression() const {
QuadraticExpression result = offset();
for (const auto& [v, coef] : storage_->linear_objective(id_)) {
for (const auto& [v, coef] : storage()->linear_objective(id_)) {
result += coef * Variable(storage(), v);
}
for (const auto& [v1, v2, coef] : storage_->quadratic_objective_terms(id_)) {
for (const auto& [v1, v2, coef] : storage()->quadratic_objective_terms(id_)) {
result +=
QuadraticTerm(Variable(storage(), v1), Variable(storage(), v2), coef);
}

View File

@@ -25,10 +25,10 @@
#include "absl/log/check.h"
#include "absl/strings/string_view.h"
#include "ortools/base/strong_int.h"
#include "ortools/math_opt/cpp/key_types.h"
#include "ortools/math_opt/cpp/variable_and_expressions.h"
#include "ortools/math_opt/storage/model_storage.h"
#include "ortools/math_opt/storage/model_storage_item.h"
#include "ortools/math_opt/storage/model_storage_types.h"
namespace operations_research::math_opt {
@@ -40,15 +40,15 @@ constexpr absl::string_view kDeletedObjectiveDefaultDescription =
// ModelStorage. Usually this type is passed by copy.
//
// This type implements https://abseil.io/docs/cpp/guides/hash.
class Objective {
class Objective final : public ModelStorageItem {
public:
// The type used for ids.
using IdType = AuxiliaryObjectiveId;
// Returns an object that refers to the primary objective of the model.
inline static Objective Primary(const ModelStorage* storage);
inline static Objective Primary(ModelStorageCPtr storage);
// Returns an object that refers to an auxiliary objective of the model.
inline static Objective Auxiliary(const ModelStorage* storage,
inline static Objective Auxiliary(ModelStorageCPtr storage,
AuxiliaryObjectiveId id);
// Returns the raw integer ID associated with the objective: nullopt for the
@@ -57,8 +57,6 @@ class Objective {
// Returns the strong int ID associated with the objective: nullopt for the
// primary objective, an AuxiliaryObjectiveId for an auxiliary objective.
inline ObjectiveId typed_id() const;
// Returns a const-pointer to the underlying storage object for the model.
inline const ModelStorage* storage() const;
// Returns true if the ID corresponds to the primary objective, and false if
// it is an auxiliary objective.
@@ -113,9 +111,8 @@ class Objective {
const Objective& objective);
private:
inline Objective(const ModelStorage* storage, ObjectiveId id);
inline Objective(ModelStorageCPtr storage, ObjectiveId id);
const ModelStorage* storage_;
ObjectiveId id_;
};
@@ -139,68 +136,66 @@ std::optional<int64_t> Objective::id() const {
ObjectiveId Objective::typed_id() const { return id_; }
const ModelStorage* Objective::storage() const { return storage_; }
bool Objective::is_primary() const { return id_ == kPrimaryObjectiveId; }
int64_t Objective::priority() const {
return storage_->objective_priority(id_);
return storage()->objective_priority(id_);
}
bool Objective::maximize() const { return storage_->is_maximize(id_); }
bool Objective::maximize() const { return storage()->is_maximize(id_); }
absl::string_view Objective::name() const {
if (is_primary() || storage_->has_auxiliary_objective(*id_)) {
return storage_->objective_name(id_);
if (is_primary() || storage()->has_auxiliary_objective(*id_)) {
return storage()->objective_name(id_);
}
return kDeletedObjectiveDefaultDescription;
}
double Objective::offset() const { return storage_->objective_offset(id_); }
double Objective::offset() const { return storage()->objective_offset(id_); }
int64_t Objective::num_quadratic_terms() const {
return storage_->num_quadratic_objective_terms(id_);
return storage()->num_quadratic_objective_terms(id_);
}
int64_t Objective::num_linear_terms() const {
return storage_->num_linear_objective_terms(id_);
return storage()->num_linear_objective_terms(id_);
}
double Objective::coefficient(const Variable variable) const {
CHECK_EQ(variable.storage(), storage_)
CHECK_EQ(variable.storage(), storage())
<< internal::kObjectsFromOtherModelStorage;
return storage_->linear_objective_coefficient(id_, variable.typed_id());
return storage()->linear_objective_coefficient(id_, variable.typed_id());
}
double Objective::coefficient(const Variable first_variable,
const Variable second_variable) const {
CHECK_EQ(first_variable.storage(), storage_)
CHECK_EQ(first_variable.storage(), storage())
<< internal::kObjectsFromOtherModelStorage;
CHECK_EQ(second_variable.storage(), storage_)
CHECK_EQ(second_variable.storage(), storage())
<< internal::kObjectsFromOtherModelStorage;
return storage_->quadratic_objective_coefficient(
return storage()->quadratic_objective_coefficient(
id_, first_variable.typed_id(), second_variable.typed_id());
}
bool Objective::is_coefficient_nonzero(const Variable variable) const {
CHECK_EQ(variable.storage(), storage_)
CHECK_EQ(variable.storage(), storage())
<< internal::kObjectsFromOtherModelStorage;
return storage_->is_linear_objective_coefficient_nonzero(id_,
variable.typed_id());
return storage()->is_linear_objective_coefficient_nonzero(
id_, variable.typed_id());
}
bool Objective::is_coefficient_nonzero(const Variable first_variable,
const Variable second_variable) const {
CHECK_EQ(first_variable.storage(), storage_)
CHECK_EQ(first_variable.storage(), storage())
<< internal::kObjectsFromOtherModelStorage;
CHECK_EQ(second_variable.storage(), storage_)
CHECK_EQ(second_variable.storage(), storage())
<< internal::kObjectsFromOtherModelStorage;
return storage_->is_quadratic_objective_coefficient_nonzero(
return storage()->is_quadratic_objective_coefficient_nonzero(
id_, first_variable.typed_id(), second_variable.typed_id());
}
bool operator==(const Objective& lhs, const Objective& rhs) {
return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_;
return lhs.id_ == rhs.id_ && lhs.storage() == rhs.storage();
}
bool operator!=(const Objective& lhs, const Objective& rhs) {
@@ -209,17 +204,17 @@ bool operator!=(const Objective& lhs, const Objective& rhs) {
template <typename H>
H AbslHashValue(H h, const Objective& objective) {
return H::combine(std::move(h), objective.id_, objective.storage_);
return H::combine(std::move(h), objective.id_, objective.storage());
}
Objective::Objective(const ModelStorage* const storage, const ObjectiveId id)
: storage_(storage), id_(id) {}
Objective::Objective(const ModelStorageCPtr storage, const ObjectiveId id)
: ModelStorageItem(storage), id_(id) {}
Objective Objective::Primary(const ModelStorage* const storage) {
Objective Objective::Primary(const ModelStorageCPtr storage) {
return Objective(storage, kPrimaryObjectiveId);
}
Objective Objective::Auxiliary(const ModelStorage* const storage,
Objective Objective::Auxiliary(const ModelStorageCPtr storage,
const AuxiliaryObjectiveId id) {
return Objective(storage, id);
}

View File

@@ -13,7 +13,6 @@
#include "ortools/math_opt/cpp/parameters.h"
#include <cstdint>
#include <optional>
#include <sstream>
#include <string>
@@ -86,7 +85,7 @@ std::optional<absl::string_view> Enum<SolverType>::ToOptString(
case SolverType::kSantorini:
return "santorini";
case SolverType::kXpress:
return "xpress";
return "xpress";
}
return std::nullopt;
}
@@ -96,7 +95,7 @@ absl::Span<const SolverType> Enum<SolverType>::AllValues() {
SolverType::kGscip, SolverType::kGurobi, SolverType::kGlop,
SolverType::kCpSat, SolverType::kPdlp, SolverType::kGlpk,
SolverType::kEcos, SolverType::kScs, SolverType::kHighs,
SolverType::kSantorini,
SolverType::kSantorini, SolverType::kXpress,
};
return absl::MakeConstSpan(kSolverTypeValues);
}

View File

@@ -24,7 +24,6 @@
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "absl/types/span.h"
#include "ortools/base/linked_hash_map.h"
#include "ortools/glop/parameters.pb.h" // IWYU pragma: export
#include "ortools/gscip/gscip.pb.h" // IWYU pragma: export
@@ -114,7 +113,7 @@ enum class SolverType {
//
// Supports LP, MIP, and nonconvex integer quadratic problems.
// A fast option, but has special licensing.
kXpress = SOLVER_TYPE_XPRESS
kXpress = SOLVER_TYPE_XPRESS,
};
MATH_OPT_DEFINE_ENUM(SolverType, SOLVER_TYPE_UNSPECIFIED);

View File

@@ -18,10 +18,10 @@
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "ortools/base/logging.h"
#include "ortools/base/status_builder.h"
#include "ortools/base/status_macros.h"
#include "ortools/math_opt/cpp/sparse_containers.h"
@@ -57,7 +57,7 @@ absl::Span<const SolutionStatus> Enum<SolutionStatus>::AllValues() {
}
absl::StatusOr<PrimalSolution> PrimalSolution::FromProto(
const ModelStorage* model,
const ModelStorageCPtr model,
const PrimalSolutionProto& primal_solution_proto) {
PrimalSolution primal_solution;
OR_ASSIGN_OR_RETURN3(
@@ -104,7 +104,7 @@ double PrimalSolution::get_objective_value(const Objective objective) const {
}
absl::StatusOr<PrimalRay> PrimalRay::FromProto(
const ModelStorage* model, const PrimalRayProto& primal_ray_proto) {
const ModelStorageCPtr model, const PrimalRayProto& primal_ray_proto) {
PrimalRay result;
OR_ASSIGN_OR_RETURN3(
result.variable_values,
@@ -120,7 +120,8 @@ PrimalRayProto PrimalRay::Proto() const {
}
absl::StatusOr<DualSolution> DualSolution::FromProto(
const ModelStorage* model, const DualSolutionProto& dual_solution_proto) {
const ModelStorageCPtr model,
const DualSolutionProto& dual_solution_proto) {
DualSolution dual_solution;
OR_ASSIGN_OR_RETURN3(
dual_solution.dual_values,
@@ -159,7 +160,7 @@ DualSolutionProto DualSolution::Proto() const {
return result;
}
absl::StatusOr<DualRay> DualRay::FromProto(const ModelStorage* model,
absl::StatusOr<DualRay> DualRay::FromProto(const ModelStorageCPtr model,
const DualRayProto& dual_ray_proto) {
DualRay result;
OR_ASSIGN_OR_RETURN3(
@@ -180,7 +181,7 @@ DualRayProto DualRay::Proto() const {
return result;
}
absl::StatusOr<Basis> Basis::FromProto(const ModelStorage* model,
absl::StatusOr<Basis> Basis::FromProto(const ModelStorageCPtr model,
const BasisProto& basis_proto) {
Basis basis;
OR_ASSIGN_OR_RETURN3(
@@ -197,7 +198,7 @@ absl::StatusOr<Basis> Basis::FromProto(const ModelStorage* model,
}
absl::Status Basis::CheckModelStorage(
const ModelStorage* const expected_storage) const {
const ModelStorageCPtr expected_storage) const {
for (const auto& [v, _] : variable_status) {
RETURN_IF_ERROR(
internal::CheckModelStorage(/*storage=*/v.storage(),
@@ -223,7 +224,7 @@ BasisProto Basis::Proto() const {
}
absl::StatusOr<Solution> Solution::FromProto(
const ModelStorage* model, const SolutionProto& solution_proto) {
const ModelStorageCPtr model, const SolutionProto& solution_proto) {
Solution solution;
if (solution_proto.has_primal_solution()) {
OR_ASSIGN_OR_RETURN3(

View File

@@ -67,8 +67,7 @@ struct PrimalSolution {
// * VariableValuesFromProto(primal_solution_proto.variable_values) fails.
// * the feasibility_status is not specified.
static absl::StatusOr<PrimalSolution> FromProto(
const ModelStorage* model,
const PrimalSolutionProto& primal_solution_proto);
ModelStorageCPtr model, const PrimalSolutionProto& primal_solution_proto);
// Returns the proto equivalent of this.
PrimalSolutionProto Proto() const;
@@ -112,7 +111,7 @@ struct PrimalRay {
// Returns an error when
// VariableValuesFromProto(primal_ray_proto.variable_values) fails.
static absl::StatusOr<PrimalRay> FromProto(
const ModelStorage* model, const PrimalRayProto& primal_ray_proto);
ModelStorageCPtr model, const PrimalRayProto& primal_ray_proto);
// Returns the proto equivalent of this.
PrimalRayProto Proto() const;
@@ -139,7 +138,7 @@ struct DualSolution {
// * LinearConstraintValuesFromProto(dual_solution_proto.dual_values) fails.
// * dual_solution_proto.feasibility_status is not specified.
static absl::StatusOr<DualSolution> FromProto(
const ModelStorage* model, const DualSolutionProto& dual_solution_proto);
ModelStorageCPtr model, const DualSolutionProto& dual_solution_proto);
// Returns the proto equivalent of this.
DualSolutionProto Proto() const;
@@ -179,7 +178,7 @@ struct DualRay {
// Returns an error when either of:
// * VariableValuesFromProto(dual_ray_proto.reduced_costs) fails.
// * LinearConstraintValuesFromProto(dual_ray_proto.dual_values) fails.
static absl::StatusOr<DualRay> FromProto(const ModelStorage* model,
static absl::StatusOr<DualRay> FromProto(ModelStorageCPtr model,
const DualRayProto& dual_ray_proto);
// Returns the proto equivalent of this.
@@ -218,12 +217,12 @@ struct Basis {
// Returns an error if:
// * VariableBasisFromProto(basis_proto.variable_status) fails.
// * LinearConstraintBasisFromProto(basis_proto.constraint_status) fails.
static absl::StatusOr<Basis> FromProto(const ModelStorage* model,
static absl::StatusOr<Basis> FromProto(ModelStorageCPtr model,
const BasisProto& basis_proto);
// Returns a failure if the referenced variables don't belong to the input
// expected_storage (which must not be nullptr).
absl::Status CheckModelStorage(const ModelStorage* expected_storage) const;
absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const;
// Returns the proto equivalent of this object.
//
@@ -262,7 +261,7 @@ struct Solution {
// Returns an error if FromProto() fails on any field that is not std::nullopt
// (see the static FromProto() functions for each field type for details).
static absl::StatusOr<Solution> FromProto(
const ModelStorage* model, const SolutionProto& solution_proto);
ModelStorageCPtr model, const SolutionProto& solution_proto);
// Returns the proto equivalent of this.
SolutionProto Proto() const;

View File

@@ -25,7 +25,7 @@
namespace operations_research::math_opt {
absl::Status SolveArguments::CheckModelStorageAndCallback(
const ModelStorage* const expected_storage) const {
const ModelStorageCPtr expected_storage) const {
RETURN_IF_ERROR(model_parameters.CheckModelStorage(expected_storage))
<< "invalid model_parameters";
RETURN_IF_ERROR(callback_registration.CheckModelStorage(expected_storage))

View File

@@ -88,7 +88,7 @@ struct SolveArguments {
// to the input expected_storage (which must not be nullptr). Also returns a
// failure if callback events are registered but no callback is provided.
absl::Status CheckModelStorageAndCallback(
const ModelStorage* expected_storage) const;
ModelStorageCPtr expected_storage) const;
};
} // namespace operations_research::math_opt

View File

@@ -40,9 +40,10 @@
namespace operations_research::math_opt::internal {
namespace {
absl::StatusOr<SolveResult> CallSolve(
BaseSolver& solver, const ModelStorage* const expected_storage,
const SolveArguments& arguments, SolveInterrupter& local_canceller) {
absl::StatusOr<SolveResult> CallSolve(BaseSolver& solver,
const ModelStorageCPtr expected_storage,
const SolveArguments& arguments,
SolveInterrupter& local_canceller) {
RETURN_IF_ERROR(arguments.CheckModelStorageAndCallback(expected_storage));
BaseSolver::Callback cb = nullptr;
@@ -104,7 +105,7 @@ absl::StatusOr<SolveResult> CallSolve(
}
absl::StatusOr<ComputeInfeasibleSubsystemResult> CallComputeInfeasibleSubsystem(
BaseSolver& solver, const ModelStorage* const expected_storage,
BaseSolver& solver, const ModelStorageCPtr expected_storage,
const ComputeInfeasibleSubsystemArguments& arguments,
SolveInterrupter& local_canceller) {
ASSIGN_OR_RETURN(
@@ -180,7 +181,7 @@ IncrementalSolverImpl::IncrementalSolverImpl(
BaseSolverFactory solver_factory, SolverType solver_type,
const bool remove_names, std::shared_ptr<SolveInterrupter> local_canceller,
std::unique_ptr<const ScopedSolveInterrupterCallback> user_canceller_cb,
const ModelStorage* const expected_storage,
const ModelStorageCPtr expected_storage,
std::unique_ptr<UpdateTracker> update_tracker,
std::unique_ptr<BaseSolver> solver)
: solver_factory_(std::move(solver_factory)),

View File

@@ -102,7 +102,7 @@ class IncrementalSolverImpl : public IncrementalSolver {
BaseSolverFactory solver_factory, SolverType solver_type,
bool remove_names, std::shared_ptr<SolveInterrupter> local_canceller,
std::unique_ptr<const ScopedSolveInterrupterCallback> user_canceller_cb,
const ModelStorage* expected_storage,
ModelStorageCPtr expected_storage,
std::unique_ptr<UpdateTracker> update_tracker,
std::unique_ptr<BaseSolver> solver);
@@ -114,7 +114,7 @@ class IncrementalSolverImpl : public IncrementalSolver {
// can be destroyed after local_canceller_ without risk.
std::shared_ptr<SolveInterrupter> local_canceller_;
std::unique_ptr<const ScopedSolveInterrupterCallback> user_canceller_cb_;
const ModelStorage* const expected_storage_;
const ModelStorageCPtr expected_storage_;
const std::unique_ptr<UpdateTracker> update_tracker_;
std::unique_ptr<BaseSolver> solver_;
};

View File

@@ -536,7 +536,7 @@ TerminationProto UpgradedTerminationProtoForStatsMigration(
} // namespace
absl::StatusOr<SolveResult> SolveResult::FromProto(
const ModelStorage* model, const SolveResultProto& solve_result_proto) {
const ModelStorageCPtr model, const SolveResultProto& solve_result_proto) {
OR_ASSIGN_OR_RETURN3(
auto termination,
Termination::FromProto(

View File

@@ -518,7 +518,7 @@ struct SolveResult {
// validation, or not rely on the strong guarantees of ValidateResult()
// and just treat SolveResult as a simple struct.
static absl::StatusOr<SolveResult> FromProto(
const ModelStorage* model, const SolveResultProto& solve_result_proto);
ModelStorageCPtr model, const SolveResultProto& solve_result_proto);
// Returns the proto equivalent of this.
//

View File

@@ -49,8 +49,7 @@ absl::Status CheckSparseVectorProto(const SparseVectorProtoType& vec) {
template <typename Key>
absl::StatusOr<absl::flat_hash_map<Key, BasisStatus>> BasisVectorFromProto(
const ModelStorage* const model,
const SparseBasisStatusVector& basis_proto) {
const ModelStorageCPtr model, const SparseBasisStatusVector& basis_proto) {
using IdType = typename Key::IdType;
absl::flat_hash_map<Key, BasisStatus> map;
map.reserve(basis_proto.ids_size());
@@ -104,7 +103,7 @@ SparseBasisStatusVector BasisMapToProto(
return result;
}
absl::Status VariableIdsExist(const ModelStorage* const model,
absl::Status VariableIdsExist(const ModelStorageCPtr model,
const absl::Span<const int64_t> ids) {
for (const int64_t id : ids) {
if (!model->has_variable(VariableId(id))) {
@@ -115,7 +114,7 @@ absl::Status VariableIdsExist(const ModelStorage* const model,
return absl::OkStatus();
}
absl::Status LinearConstraintIdsExist(const ModelStorage* const model,
absl::Status LinearConstraintIdsExist(const ModelStorageCPtr model,
const absl::Span<const int64_t> ids) {
for (const int64_t id : ids) {
if (!model->has_linear_constraint(LinearConstraintId(id))) {
@@ -126,7 +125,7 @@ absl::Status LinearConstraintIdsExist(const ModelStorage* const model,
return absl::OkStatus();
}
absl::Status QuadraticConstraintIdsExist(const ModelStorage* const model,
absl::Status QuadraticConstraintIdsExist(const ModelStorageCPtr model,
const absl::Span<const int64_t> ids) {
for (const int64_t id : ids) {
if (!model->has_constraint(QuadraticConstraintId(id))) {
@@ -140,15 +139,14 @@ absl::Status QuadraticConstraintIdsExist(const ModelStorage* const model,
} // namespace
absl::StatusOr<VariableMap<double>> VariableValuesFromProto(
const ModelStorage* const model,
const SparseDoubleVectorProto& vars_proto) {
const ModelStorageCPtr model, const SparseDoubleVectorProto& vars_proto) {
RETURN_IF_ERROR(CheckSparseVectorProto(vars_proto));
RETURN_IF_ERROR(VariableIdsExist(model, vars_proto.ids()));
return MakeView(vars_proto).as_map<Variable>(model);
}
absl::StatusOr<VariableMap<int32_t>> VariableValuesFromProto(
const ModelStorage* model, const SparseInt32VectorProto& vars_proto) {
const ModelStorageCPtr model, const SparseInt32VectorProto& vars_proto) {
RETURN_IF_ERROR(CheckSparseVectorProto(vars_proto));
RETURN_IF_ERROR(VariableIdsExist(model, vars_proto.ids()));
return MakeView(vars_proto).as_map<Variable>(model);
@@ -161,7 +159,7 @@ SparseDoubleVectorProto VariableValuesToProto(
absl::StatusOr<absl::flat_hash_map<Objective, double>>
AuxiliaryObjectiveValuesFromProto(
const ModelStorage* const model,
const ModelStorageCPtr model,
const google::protobuf::Map<int64_t, double>& aux_obj_proto) {
absl::flat_hash_map<Objective, double> result;
for (const auto [raw_id, value] : aux_obj_proto) {
@@ -187,7 +185,7 @@ google::protobuf::Map<int64_t, double> AuxiliaryObjectiveValuesToProto(
}
absl::StatusOr<LinearConstraintMap<double>> LinearConstraintValuesFromProto(
const ModelStorage* const model,
const ModelStorageCPtr model,
const SparseDoubleVectorProto& lin_cons_proto) {
RETURN_IF_ERROR(CheckSparseVectorProto(lin_cons_proto));
RETURN_IF_ERROR(LinearConstraintIdsExist(model, lin_cons_proto.ids()));
@@ -201,7 +199,7 @@ SparseDoubleVectorProto LinearConstraintValuesToProto(
absl::StatusOr<absl::flat_hash_map<QuadraticConstraint, double>>
QuadraticConstraintValuesFromProto(
const ModelStorage* const model,
const ModelStorageCPtr model,
const SparseDoubleVectorProto& quad_cons_proto) {
RETURN_IF_ERROR(CheckSparseVectorProto(quad_cons_proto));
RETURN_IF_ERROR(QuadraticConstraintIdsExist(model, quad_cons_proto.ids()));
@@ -215,8 +213,7 @@ SparseDoubleVectorProto QuadraticConstraintValuesToProto(
}
absl::StatusOr<VariableMap<BasisStatus>> VariableBasisFromProto(
const ModelStorage* const model,
const SparseBasisStatusVector& basis_proto) {
const ModelStorageCPtr model, const SparseBasisStatusVector& basis_proto) {
RETURN_IF_ERROR(CheckSparseVectorProto(basis_proto));
RETURN_IF_ERROR(VariableIdsExist(model, basis_proto.ids()));
return BasisVectorFromProto<Variable>(model, basis_proto);
@@ -228,8 +225,7 @@ SparseBasisStatusVector VariableBasisToProto(
}
absl::StatusOr<LinearConstraintMap<BasisStatus>> LinearConstraintBasisFromProto(
const ModelStorage* const model,
const SparseBasisStatusVector& basis_proto) {
const ModelStorageCPtr model, const SparseBasisStatusVector& basis_proto) {
RETURN_IF_ERROR(CheckSparseVectorProto(basis_proto));
RETURN_IF_ERROR(LinearConstraintIdsExist(model, basis_proto.ids()));
return BasisVectorFromProto<LinearConstraint>(model, basis_proto);

View File

@@ -42,7 +42,7 @@ namespace operations_research::math_opt {
//
// Note that the values of vars_proto.values are not checked (it may have NaNs).
absl::StatusOr<VariableMap<double>> VariableValuesFromProto(
const ModelStorage* model, const SparseDoubleVectorProto& vars_proto);
ModelStorageCPtr model, const SparseDoubleVectorProto& vars_proto);
// Returns the VariableMap<int32_t> equivalent to `vars_proto`.
//
@@ -52,7 +52,7 @@ absl::StatusOr<VariableMap<double>> VariableValuesFromProto(
// * vars_proto.ids has elements that are variables in `model` (this implies
// that each id is in [0, max(int64_t))).
absl::StatusOr<VariableMap<int32_t>> VariableValuesFromProto(
const ModelStorage* model, const SparseInt32VectorProto& vars_proto);
ModelStorageCPtr model, const SparseInt32VectorProto& vars_proto);
// Returns the proto equivalent of variable_values.
SparseDoubleVectorProto VariableValuesToProto(
@@ -67,7 +67,7 @@ SparseDoubleVectorProto VariableValuesToProto(
// Note that the values of `aux_obj_proto` are not checked (it may have NaNs).
absl::StatusOr<absl::flat_hash_map<Objective, double>>
AuxiliaryObjectiveValuesFromProto(
const ModelStorage* model,
ModelStorageCPtr model,
const google::protobuf::Map<int64_t, double>& aux_obj_proto);
// Returns the proto equivalent of auxiliary_obj_values.
@@ -88,7 +88,7 @@ google::protobuf::Map<int64_t, double> AuxiliaryObjectiveValuesToProto(
// Note that the values of lin_cons_proto.values are not checked (it may have
// NaNs).
absl::StatusOr<LinearConstraintMap<double>> LinearConstraintValuesFromProto(
const ModelStorage* model, const SparseDoubleVectorProto& lin_cons_proto);
ModelStorageCPtr model, const SparseDoubleVectorProto& lin_cons_proto);
// Returns the proto equivalent of linear_constraint_values.
SparseDoubleVectorProto LinearConstraintValuesToProto(
@@ -107,7 +107,7 @@ SparseDoubleVectorProto LinearConstraintValuesToProto(
// NaNs).
absl::StatusOr<absl::flat_hash_map<QuadraticConstraint, double>>
QuadraticConstraintValuesFromProto(
const ModelStorage* model, const SparseDoubleVectorProto& quad_cons_proto);
ModelStorageCPtr model, const SparseDoubleVectorProto& quad_cons_proto);
// Returns the proto equivalent of quadratic_constraint_values.
SparseDoubleVectorProto QuadraticConstraintValuesToProto(
@@ -123,7 +123,7 @@ SparseDoubleVectorProto QuadraticConstraintValuesToProto(
// that each id is in [0, max(int64_t))).
// * basis_proto.values does not contain UNSPECIFIED and has valid enum values.
absl::StatusOr<VariableMap<BasisStatus>> VariableBasisFromProto(
const ModelStorage* model, const SparseBasisStatusVector& basis_proto);
ModelStorageCPtr model, const SparseBasisStatusVector& basis_proto);
// Returns the proto equivalent of basis_values.
SparseBasisStatusVector VariableBasisToProto(
@@ -138,7 +138,7 @@ SparseBasisStatusVector VariableBasisToProto(
// implies that each id is in [0, max(int64_t))).
// * basis_proto.values does not contain UNSPECIFIED and has valid enum values.
absl::StatusOr<LinearConstraintMap<BasisStatus>> LinearConstraintBasisFromProto(
const ModelStorage* model, const SparseBasisStatusVector& basis_proto);
ModelStorageCPtr model, const SparseBasisStatusVector& basis_proto);
// Returns the proto equivalent of basis_values.
SparseBasisStatusVector LinearConstraintBasisToProto(

View File

@@ -24,6 +24,9 @@
#include "ortools/base/map_util.h"
#include "ortools/base/strong_int.h"
#include "ortools/math_opt/cpp/formatters.h"
#ifdef MATH_OPT_USE_EXPRESSION_COUNTERS
#include "ortools/math_opt/storage/model_storage_item.h"
#endif // MATH_OPT_USE_EXPRESSION_COUNTERS
#include "ortools/util/fp_roundtrip_conv.h"
namespace operations_research {
@@ -35,7 +38,9 @@ constexpr double kInf = std::numeric_limits<double>::infinity();
LinearExpression::LinearExpression() { ++num_calls_default_constructor_; }
LinearExpression::LinearExpression(const LinearExpression& other)
: storage_(other.storage_), terms_(other.terms_), offset_(other.offset_) {
: ModelStorageItemContainer(other.storage()),
terms_(other.terms_),
offset_(other.offset_) {
++num_calls_copy_constructor_;
}
@@ -203,7 +208,7 @@ std::ostream& operator<<(std::ostream& ostr,
QuadraticExpression::QuadraticExpression() { ++num_calls_default_constructor_; }
QuadraticExpression::QuadraticExpression(const QuadraticExpression& other)
: storage_(other.storage_),
: ModelStorageItemContainer(other),
quadratic_terms_(other.quadratic_terms_),
linear_terms_(other.linear_terms_),
offset_(other.offset_) {

View File

@@ -103,11 +103,11 @@
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "ortools/base/logging.h"
#include "ortools/base/strong_int.h"
#include "ortools/math_opt/cpp/key_types.h" // IWYU pragma: export
#include "ortools/math_opt/storage/model_storage.h"
#include "ortools/math_opt/storage/model_storage_item.h"
#include "ortools/math_opt/storage/model_storage_types.h"
namespace operations_research {
@@ -118,40 +118,20 @@ class LinearExpression;
// A value type that references a variable from ModelStorage. Usually this type
// is passed by copy.
//
// This type implements https://abseil.io/docs/cpp/guides/hash (see
// VariablesEquality for details about how operator== works).
class Variable {
class Variable final : public ModelStorageElement<
ElementType::kVariable, Variable,
// This type has a special equality operator
// (see `VariablesEquality` below).
ModelStorageElementEquality::kWithoutEquality> {
public:
// The typed integer used for ids.
using IdType = VariableId;
// Usually users will obtain variables using Model::AddVariable(). There
// should be little for users to build this object from an ModelStorage.
inline Variable(const ModelStorage* storage, VariableId id);
// Each call to AddVariable will produce Variables id() increasing by one,
// starting at zero. Deleted ids are NOT reused. Thus, if no variables are
// deleted, the ids in the model will be consecutive.
inline int64_t id() const;
inline VariableId typed_id() const;
inline const ModelStorage* storage() const;
using ModelStorageElement::ModelStorageElement;
inline double lower_bound() const;
inline double upper_bound() const;
inline bool is_integer() const;
inline absl::string_view name() const;
template <typename H>
friend H AbslHashValue(H h, const Variable& variable);
friend std::ostream& operator<<(std::ostream& ostr, const Variable& variable);
inline LinearExpression operator-() const;
private:
const ModelStorage* storage_;
VariableId id_;
};
namespace internal {
@@ -186,8 +166,6 @@ inline bool operator!=(const Variable& lhs, const Variable& rhs);
template <typename V>
using VariableMap = absl::flat_hash_map<Variable, V>;
inline std::ostream& operator<<(std::ostream& ostr, const Variable& variable);
// A term in an sum of variables multiplied by coefficients.
struct LinearTerm {
// Usually this constructor is never called explicitly by users. Instead it
@@ -226,7 +204,7 @@ class QuadraticExpression;
// TODO(b/169415098): add a function to remove zero terms.
// TODO(b/169415834): study if exact zeros should be automatically removed.
// TODO(b/169415103): add tests that some expressions don't compile.
class LinearExpression {
class LinearExpression final : public ModelStorageItemContainer {
public:
// For unit testing purpose, we define optional counters. We have to
// explicitly define the default constructor, copy constructor and assignment
@@ -238,9 +216,6 @@ class LinearExpression {
LinearExpression();
LinearExpression(const LinearExpression& other);
#endif // MATH_OPT_USE_EXPRESSION_COUNTERS
// We have to define a custom move constructor as we need to reset storage_ to
// nullptr.
inline LinearExpression(LinearExpression&& other) noexcept;
// Usually users should use the overloads of operators to build linear
// expressions. For example, assuming `x` and `y` are Variable, then `x + 2*y
// + 5` will build a LinearExpression automatically.
@@ -250,8 +225,9 @@ class LinearExpression {
inline LinearExpression(Variable variable); // NOLINT
inline LinearExpression(const LinearTerm& term); // NOLINT
LinearExpression& operator=(const LinearExpression& other) = default;
// We have to define a custom move assignment operator as we need to reset
// storage_ to nullptr.
// A moved-from `LinearExpression` is the zero expression: it's not associated
// to a storage, has no terms and its offset is zero.
inline LinearExpression(LinearExpression&& other) noexcept;
inline LinearExpression& operator=(LinearExpression&& other) noexcept;
inline LinearExpression& operator+=(const LinearExpression& other);
@@ -367,8 +343,6 @@ class LinearExpression {
double EvaluateWithDefaultZero(
const VariableMap<double>& variable_values) const;
inline const ModelStorage* storage() const;
#ifdef MATH_OPT_USE_EXPRESSION_COUNTERS
static thread_local int num_calls_default_constructor_;
static thread_local int num_calls_copy_constructor_;
@@ -384,14 +358,9 @@ class LinearExpression {
const LinearExpression& expression);
friend QuadraticExpression;
// Sets the storage_ to the input value if nullptr, else CHECKs that it is
// equal. Also CHECKs that the input value is not nullptr.
inline void SetOrCheckStorage(const ModelStorage* storage);
// Invariants:
// * nullptr, if terms_ is empty
// * equal to Variable::storage() of each key of terms_, else
const ModelStorage* storage_ = nullptr;
// * storage() == v.storage()for each v in terms_;
// * storage() == nullptr, if terms_ is empty.
VariableMap<double> terms_;
double offset_ = 0.0;
};
@@ -647,14 +616,14 @@ using QuadraticProductId = std::pair<VariableId, VariableId>;
// silently correct this if not satisfied by the inputs.
//
// This type can be used as a key in ABSL hash containers.
class QuadraticTermKey {
class QuadraticTermKey final : public ModelStorageItem {
public:
// NOTE: this definition is for use by IdMap; clients should not rely upon it.
using IdType = QuadraticProductId;
// NOTE: This constructor will silently re-order the passed id so that, upon
// exiting the constructor, variable_ids_.first <= variable_ids_.second.
inline QuadraticTermKey(const ModelStorage* storage, QuadraticProductId id);
inline QuadraticTermKey(ModelStorageCPtr storage, QuadraticProductId id);
// NOTE: This constructor will CHECK fail if the variable models do not agree,
// i.e. first_variable.storage() != second_variable.storage(). It will also
// silently re-order the passed id so that, upon exiting the constructor,
@@ -662,19 +631,17 @@ class QuadraticTermKey {
inline QuadraticTermKey(Variable first_variable, Variable second_variable);
inline QuadraticProductId typed_id() const;
inline const ModelStorage* storage() const;
// Returns the Variable with the smallest id.
Variable first() const { return Variable(storage_, variable_ids_.first); }
Variable first() const { return Variable(storage(), variable_ids_.first); }
// Returns the Variable the largest id.
Variable second() const { return Variable(storage_, variable_ids_.second); }
Variable second() const { return Variable(storage(), variable_ids_.second); }
template <typename H>
friend H AbslHashValue(H h, const QuadraticTermKey& key);
private:
const ModelStorage* storage_;
QuadraticProductId variable_ids_;
};
@@ -744,7 +711,7 @@ using QuadraticTermMap = absl::flat_hash_map<QuadraticTermKey, V>;
// is, it is forbidden that both are non-null and not equal. Use
// CheckModelsAgree() and the initializer_list constructor to enforce this
// invariant in any class or friend method.
class QuadraticExpression {
class QuadraticExpression final : public ModelStorageItemContainer {
public:
// For unit testing purpose, we define optional counters. We have to
// explicitly define the default constructor, copy constructor and assignment
@@ -756,9 +723,6 @@ class QuadraticExpression {
QuadraticExpression();
QuadraticExpression(const QuadraticExpression& other);
#endif // MATH_OPT_USE_EXPRESSION_COUNTERS
// We have to define a custom move constructor as we need to reset storage_ to
// nullptr.
inline QuadraticExpression(QuadraticExpression&& other) noexcept;
// Users should prefer the default constructor and operator overloads to build
// expressions.
inline QuadraticExpression(
@@ -770,8 +734,9 @@ class QuadraticExpression {
inline QuadraticExpression(LinearExpression expr); // NOLINT
inline QuadraticExpression(const QuadraticTerm& term); // NOLINT
QuadraticExpression& operator=(const QuadraticExpression& other) = default;
// We have to define a custom move assignment operator as we need to reset
// storage_ to nullptr.
// A moved-from `LinearExpression` is the zero expression: it's not associated
// to a storage, has no terms and its offset is zero.
inline QuadraticExpression(QuadraticExpression&& other) noexcept;
inline QuadraticExpression& operator=(QuadraticExpression&& other) noexcept;
inline double offset() const;
@@ -933,8 +898,6 @@ class QuadraticExpression {
double EvaluateWithDefaultZero(
const VariableMap<double>& variable_values) const;
inline const ModelStorage* storage() const;
#ifdef MATH_OPT_USE_EXPRESSION_COUNTERS
static thread_local int num_calls_default_constructor_;
static thread_local int num_calls_copy_constructor_;
@@ -950,15 +913,10 @@ class QuadraticExpression {
friend std::ostream& operator<<(std::ostream& ostr,
const QuadraticExpression& expr);
// Sets the storage_ to the input value if nullptr, else CHECKs that it is
// equal. Also CHECKs that the input value is not nullptr.
inline void SetOrCheckStorage(const ModelStorage* storage);
// Invariants:
// * nullptr, if both quadratic_terms_ and linear_terms_ are empty
// * equal to Variable::storage() of each key of linear_terms_ and
// QuadraticTermKey::storage() of each key of quadratic_terms_, else
const ModelStorage* storage_ = nullptr;
// * storage() == v.storage() for each v in linear_terms_;
// * storage() == v.storage() for each v in quadratic_terms_;
// * storage() == nullptr, if both terms_ and quadratic_terms_ are empty.
QuadraticTermMap<double> quadratic_terms_;
VariableMap<double> linear_terms_;
double offset_ = 0.0;
@@ -1268,50 +1226,25 @@ inline BoundedQuadraticExpression operator==(double lhs,
// Variable
////////////////////////////////////////////////////////////////////////////////
Variable::Variable(const ModelStorage* const storage, const VariableId id)
: storage_(storage), id_(id) {
DCHECK(storage != nullptr);
}
int64_t Variable::id() const { return id_.value(); }
VariableId Variable::typed_id() const { return id_; }
const ModelStorage* Variable::storage() const { return storage_; }
double Variable::lower_bound() const {
return storage_->variable_lower_bound(id_);
return storage()->variable_lower_bound(typed_id());
}
double Variable::upper_bound() const {
return storage_->variable_upper_bound(id_);
return storage()->variable_upper_bound(typed_id());
}
bool Variable::is_integer() const { return storage_->is_variable_integer(id_); }
bool Variable::is_integer() const {
return storage()->is_variable_integer(typed_id());
}
absl::string_view Variable::name() const {
if (storage()->has_variable(id_)) {
return storage_->variable_name(id_);
if (storage()->has_variable(typed_id())) {
return storage()->variable_name(typed_id());
}
return "[variable deleted from model]";
}
template <typename H>
H AbslHashValue(H h, const Variable& variable) {
return H::combine(std::move(h), variable.id_.value(), variable.storage_);
}
std::ostream& operator<<(std::ostream& ostr, const Variable& variable) {
// TODO(b/170992529): handle quoting of invalid characters in the name.
const absl::string_view name = variable.name();
if (name.empty()) {
ostr << "__var#" << variable.id() << "__";
} else {
ostr << name;
}
return ostr;
}
LinearExpression Variable::operator-() const {
return LinearExpression({LinearTerm(*this, -1.0)}, 0.0);
}
@@ -1368,17 +1301,9 @@ LinearTerm operator/(Variable variable, const double coefficient) {
// LinearExpression
////////////////////////////////////////////////////////////////////////////////
void LinearExpression::SetOrCheckStorage(const ModelStorage* const storage) {
CHECK(storage != nullptr) << internal::kKeyHasNullModelStorage;
if (storage_ == nullptr) {
storage_ = storage;
return;
}
CHECK_EQ(storage, storage_) << internal::kObjectsFromOtherModelStorage;
}
LinearExpression::LinearExpression(LinearExpression&& other) noexcept
: storage_(std::exchange(other.storage_, nullptr)),
: ModelStorageItemContainer(
static_cast<ModelStorageItemContainer&&>(other)),
terms_(std::move(other.terms_)),
offset_(std::exchange(other.offset_, 0.0)) {
other.terms_.clear();
@@ -1389,7 +1314,8 @@ LinearExpression::LinearExpression(LinearExpression&& other) noexcept
LinearExpression& LinearExpression::operator=(
LinearExpression&& other) noexcept {
storage_ = std::exchange(other.storage_, nullptr);
ModelStorageItemContainer::operator=(
static_cast<ModelStorageItemContainer&&>(other));
terms_ = std::move(other.terms_);
other.terms_.clear();
offset_ = std::exchange(other.offset_, 0.0);
@@ -1403,7 +1329,7 @@ LinearExpression::LinearExpression(std::initializer_list<LinearTerm> terms,
++num_calls_initializer_list_constructor_;
#endif // MATH_OPT_USE_EXPRESSION_COUNTERS
for (const auto& term : terms) {
SetOrCheckStorage(term.variable.storage());
SetOrCheckStorage(term.variable);
// The same variable may appear multiple times in the input list; we must
// accumulate the coefficients.
terms_[term.variable] += term.coefficient;
@@ -1579,7 +1505,7 @@ LinearExpression& LinearExpression::operator+=(const LinearExpression& other) {
// thus we don't need to compare in the loop. Of course this only applies if
// the other has terms.
if (!other.terms_.empty()) {
SetOrCheckStorage(other.storage());
SetOrCheckStorage(other);
for (const auto& [v, coeff] : other.terms_) {
terms_[v] += coeff;
}
@@ -1589,13 +1515,13 @@ LinearExpression& LinearExpression::operator+=(const LinearExpression& other) {
}
LinearExpression& LinearExpression::operator+=(const LinearTerm& term) {
SetOrCheckStorage(term.variable.storage());
SetOrCheckStorage(term.variable);
terms_[term.variable] += term.coefficient;
return *this;
}
LinearExpression& LinearExpression::operator+=(const Variable variable) {
SetOrCheckStorage(variable.storage());
SetOrCheckStorage(variable);
return *this += LinearTerm(variable, 1.0);
}
@@ -1607,7 +1533,7 @@ LinearExpression& LinearExpression::operator+=(const double value) {
LinearExpression& LinearExpression::operator-=(const LinearExpression& other) {
// See operator+=.
if (!other.terms_.empty()) {
SetOrCheckStorage(other.storage());
SetOrCheckStorage(other);
for (const auto& [v, coeff] : other.terms_) {
terms_[v] -= coeff;
}
@@ -1617,13 +1543,13 @@ LinearExpression& LinearExpression::operator-=(const LinearExpression& other) {
}
LinearExpression& LinearExpression::operator-=(const LinearTerm& term) {
SetOrCheckStorage(term.variable.storage());
SetOrCheckStorage(term.variable);
terms_[term.variable] -= term.coefficient;
return *this;
}
LinearExpression& LinearExpression::operator-=(const Variable variable) {
SetOrCheckStorage(variable.storage());
SetOrCheckStorage(variable);
return *this -= LinearTerm(variable, 1.0);
}
@@ -1713,8 +1639,6 @@ const VariableMap<double>& LinearExpression::terms() const { return terms_; }
double LinearExpression::offset() const { return offset_; }
const ModelStorage* LinearExpression::storage() const { return storage_; }
////////////////////////////////////////////////////////////////////////////////
// VariablesEquality
////////////////////////////////////////////////////////////////////////////////
@@ -2055,9 +1979,9 @@ BoundedLinearExpression operator==(const double lhs, const Variable rhs) {
// QuadraticTermKey
////////////////////////////////////////////////////////////////////////////////
QuadraticTermKey::QuadraticTermKey(const ModelStorage* storage,
QuadraticTermKey::QuadraticTermKey(const ModelStorageCPtr storage,
const QuadraticProductId id)
: storage_(storage), variable_ids_(id) {
: ModelStorageItem(storage), variable_ids_(id) {
if (variable_ids_.first > variable_ids_.second) {
// See https://en.cppreference.com/w/cpp/named_req/Swappable for details.
using std::swap;
@@ -2075,8 +1999,6 @@ QuadraticTermKey::QuadraticTermKey(const Variable first_variable,
QuadraticProductId QuadraticTermKey::typed_id() const { return variable_ids_; }
const ModelStorage* QuadraticTermKey::storage() const { return storage_; }
template <typename H>
H AbslHashValue(H h, const QuadraticTermKey& key) {
return H::combine(std::move(h), key.typed_id().first.value(),
@@ -2124,17 +2046,9 @@ QuadraticTermKey QuadraticTerm::GetKey() const {
// QuadraticExpression (no arithmetic)
////////////////////////////////////////////////////////////////////////////////
void QuadraticExpression::SetOrCheckStorage(const ModelStorage* const storage) {
CHECK(storage != nullptr) << internal::kKeyHasNullModelStorage;
if (storage_ == nullptr) {
storage_ = storage;
return;
}
CHECK_EQ(storage, storage_) << internal::kObjectsFromOtherModelStorage;
}
QuadraticExpression::QuadraticExpression(QuadraticExpression&& other) noexcept
: storage_(std::exchange(other.storage_, nullptr)),
: ModelStorageItemContainer(
static_cast<ModelStorageItemContainer&&>(other)),
quadratic_terms_(std::move(other.quadratic_terms_)),
linear_terms_(std::move(other.linear_terms_)),
offset_(std::exchange(other.offset_, 0.0)) {
@@ -2147,7 +2061,8 @@ QuadraticExpression::QuadraticExpression(QuadraticExpression&& other) noexcept
QuadraticExpression& QuadraticExpression::operator=(
QuadraticExpression&& other) noexcept {
storage_ = std::exchange(other.storage_, nullptr);
ModelStorageItemContainer::operator=(
static_cast<ModelStorageItemContainer&&>(other));
quadratic_terms_ = std::move(other.quadratic_terms_);
other.quadratic_terms_.clear();
linear_terms_ = std::move(other.linear_terms_);
@@ -2164,12 +2079,12 @@ QuadraticExpression::QuadraticExpression(
++num_calls_initializer_list_constructor_;
#endif // MATH_OPT_USE_EXPRESSION_COUNTERS
for (const LinearTerm& term : linear_terms) {
SetOrCheckStorage(term.variable.storage());
SetOrCheckStorage(term.variable);
linear_terms_[term.variable] += term.coefficient;
}
for (const QuadraticTerm& term : quadratic_terms) {
const QuadraticTermKey key = term.GetKey();
SetOrCheckStorage(key.storage());
SetOrCheckStorage(key);
quadratic_terms_[key] += term.coefficient();
}
}
@@ -2184,9 +2099,9 @@ QuadraticExpression::QuadraticExpression(const LinearTerm& term)
: QuadraticExpression({}, {term}, 0.0) {}
QuadraticExpression::QuadraticExpression(LinearExpression expr)
: storage_(std::exchange(expr.storage_, nullptr)),
: ModelStorageItemContainer(expr.storage()),
linear_terms_(std::move(expr.terms_)),
offset_(std::exchange(expr.offset_, 0.0)) {
offset_(expr.offset_) {
#ifdef MATH_OPT_USE_EXPRESSION_COUNTERS
++num_calls_linear_expression_constructor_;
#endif // MATH_OPT_USE_EXPRESSION_COUNTERS
@@ -2195,8 +2110,6 @@ QuadraticExpression::QuadraticExpression(LinearExpression expr)
QuadraticExpression::QuadraticExpression(const QuadraticTerm& term)
: QuadraticExpression({term}, {}, 0.0) {}
const ModelStorage* QuadraticExpression::storage() const { return storage_; }
double QuadraticExpression::offset() const { return offset_; }
const VariableMap<double>& QuadraticExpression::linear_terms() const {
@@ -2582,13 +2495,13 @@ QuadraticExpression& QuadraticExpression::operator+=(const double value) {
}
QuadraticExpression& QuadraticExpression::operator+=(const Variable variable) {
SetOrCheckStorage(variable.storage());
SetOrCheckStorage(variable);
linear_terms_[variable] += 1;
return *this;
}
QuadraticExpression& QuadraticExpression::operator+=(const LinearTerm& term) {
SetOrCheckStorage(term.variable.storage());
SetOrCheckStorage(term.variable);
linear_terms_[term.variable] += term.coefficient;
return *this;
}
@@ -2598,7 +2511,7 @@ QuadraticExpression& QuadraticExpression::operator+=(
offset_ += expr.offset();
// See comment in LinearExpression::operator+=.
if (!expr.terms().empty()) {
SetOrCheckStorage(expr.storage());
SetOrCheckStorage(expr);
for (const auto& [v, coeff] : expr.terms()) {
linear_terms_[v] += coeff;
}
@@ -2609,7 +2522,7 @@ QuadraticExpression& QuadraticExpression::operator+=(
QuadraticExpression& QuadraticExpression::operator+=(
const QuadraticTerm& term) {
const QuadraticTermKey key = term.GetKey();
SetOrCheckStorage(key.storage());
SetOrCheckStorage(key);
quadratic_terms_[key] += term.coefficient();
return *this;
}
@@ -2619,7 +2532,7 @@ QuadraticExpression& QuadraticExpression::operator+=(
offset_ += expr.offset();
// See comment in LinearExpression::operator+=.
if (!expr.linear_terms().empty() || !expr.quadratic_terms().empty()) {
SetOrCheckStorage(expr.storage());
SetOrCheckStorage(expr);
for (const auto& [v, coeff] : expr.linear_terms()) {
linear_terms_[v] += coeff;
}
@@ -2637,13 +2550,13 @@ QuadraticExpression& QuadraticExpression::operator-=(const double value) {
}
QuadraticExpression& QuadraticExpression::operator-=(const Variable variable) {
SetOrCheckStorage(variable.storage());
SetOrCheckStorage(variable);
linear_terms_[variable] -= 1;
return *this;
}
QuadraticExpression& QuadraticExpression::operator-=(const LinearTerm& term) {
SetOrCheckStorage(term.variable.storage());
SetOrCheckStorage(term.variable);
linear_terms_[term.variable] -= term.coefficient;
return *this;
}
@@ -2653,7 +2566,7 @@ QuadraticExpression& QuadraticExpression::operator-=(
offset_ -= expr.offset();
// See comment in LinearExpression::operator+=.
if (!expr.terms().empty()) {
SetOrCheckStorage(expr.storage());
SetOrCheckStorage(expr);
for (const auto& [v, coeff] : expr.terms()) {
linear_terms_[v] -= coeff;
}
@@ -2664,7 +2577,7 @@ QuadraticExpression& QuadraticExpression::operator-=(
QuadraticExpression& QuadraticExpression::operator-=(
const QuadraticTerm& term) {
const QuadraticTermKey key = term.GetKey();
SetOrCheckStorage(key.storage());
SetOrCheckStorage(key);
quadratic_terms_[key] -= term.coefficient();
return *this;
}
@@ -2674,7 +2587,7 @@ QuadraticExpression& QuadraticExpression::operator-=(
offset_ -= expr.offset();
// See comment in LinearExpression::operator+=.
if (!expr.linear_terms().empty() || !expr.quadratic_terms().empty()) {
SetOrCheckStorage(expr.storage());
SetOrCheckStorage(expr);
for (const auto& [v, coeff] : expr.linear_terms()) {
linear_terms_[v] -= coeff;
}

View File

@@ -0,0 +1,553 @@
# Copyright 2010-2025 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@rules_cc//cc:cc_library.bzl", "cc_library")
load("@rules_cc//cc:cc_test.bzl", "cc_test")
cc_library(
name = "attributes",
hdrs = ["attributes.h"],
visibility = ["//ortools/math_opt:__subpackages__"],
deps = [
":arrays",
":elements",
":symmetry",
"//ortools/base:array",
"@abseil-cpp//absl/strings:string_view",
],
)
cc_test(
name = "attributes_test",
srcs = ["attributes_test.cc"],
deps = [
":arrays",
":attributes",
"//ortools/base:gmock_main",
"//ortools/math_opt/testing:stream",
"@abseil-cpp//absl/strings",
],
)
cc_library(
name = "elemental",
srcs = [
"elemental.cc",
"elemental_export_model.cc",
"elemental_from_proto.cc",
"elemental_to_string.cc",
],
hdrs = ["elemental.h"],
visibility = ["//ortools/math_opt:__subpackages__"],
deps = [
":arrays",
":attr_key",
":attr_storage",
":attributes",
":derived_data",
":diff",
":element_ref_tracker",
":element_storage",
":elements",
":symmetry",
":thread_safe_id_map",
"//ortools/base:status_macros",
"//ortools/math_opt:model_cc_proto",
"//ortools/math_opt:model_update_cc_proto",
"//ortools/math_opt:sparse_containers_cc_proto",
"//ortools/math_opt/core:model_summary",
"//ortools/math_opt/validators:model_validator",
"@abseil-cpp//absl/algorithm:container",
"@abseil-cpp//absl/container:flat_hash_map",
"@abseil-cpp//absl/container:flat_hash_set",
"@abseil-cpp//absl/log",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/log:die_if_null",
"@abseil-cpp//absl/status",
"@abseil-cpp//absl/status:statusor",
"@abseil-cpp//absl/strings",
"@abseil-cpp//absl/strings:string_view",
"@abseil-cpp//absl/types:span",
"@com_google_protobuf//:protobuf",
],
)
cc_test(
name = "elemental_test",
srcs = ["elemental_test.cc"],
deps = [
":attr_key",
":attributes",
":derived_data",
":diff",
":elemental",
":elemental_matcher",
":elements",
":symmetry",
":testing",
"//ortools/base:gmock_main",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/status",
"@com_google_benchmark//:benchmark",
],
)
cc_library(
name = "derived_data",
hdrs = ["derived_data.h"],
visibility = ["//ortools/math_opt:__subpackages__"],
deps = [
":arrays",
":attr_key",
":attributes",
":elements",
"//ortools/util:fp_roundtrip_conv",
"@abseil-cpp//absl/log",
"@abseil-cpp//absl/strings",
],
)
cc_test(
name = "derived_data_test",
srcs = ["derived_data_test.cc"],
deps = [
":arrays",
":attr_key",
":attributes",
":derived_data",
":elements",
":symmetry",
"//ortools/base:gmock_main",
"//ortools/math_opt/testing:stream",
],
)
cc_library(
name = "element_storage",
srcs = ["element_storage.cc"],
hdrs = ["element_storage.h"],
deps = [
"//ortools/base:status_macros",
"@abseil-cpp//absl/algorithm:container",
"@abseil-cpp//absl/container:flat_hash_map",
"@abseil-cpp//absl/status:statusor",
"@abseil-cpp//absl/strings:string_view",
],
)
cc_test(
name = "element_storage_test",
srcs = ["element_storage_test.cc"],
deps = [
":element_storage",
"//ortools/base:gmock_main",
"@abseil-cpp//absl/status",
"@com_google_benchmark//:benchmark",
],
)
cc_library(
name = "element_diff",
hdrs = ["element_diff.h"],
deps = ["@abseil-cpp//absl/container:flat_hash_set"],
)
cc_test(
name = "element_diff_test",
srcs = ["element_diff_test.cc"],
deps = [
":element_diff",
"//ortools/base:gmock_main",
],
)
cc_library(
name = "diff",
srcs = ["diff.cc"],
hdrs = ["diff.h"],
deps = [
"derived_data",
":attr_diff",
":attr_key",
":attributes",
":element_diff",
":elements",
"@abseil-cpp//absl/container:flat_hash_set",
"@abseil-cpp//absl/types:span",
],
)
cc_test(
name = "diff_test",
srcs = ["diff_test.cc"],
deps = [
":attr_key",
":attributes",
":diff",
":elements",
"//ortools/base:gmock_main",
"@abseil-cpp//absl/types:span",
],
)
cc_library(
name = "attr_storage",
hdrs = ["attr_storage.h"],
deps = [
":attr_key",
":symmetry",
"//ortools/base:map_util",
"@abseil-cpp//absl/container:flat_hash_map",
"@abseil-cpp//absl/container:flat_hash_set",
"@abseil-cpp//absl/functional:function_ref",
],
)
cc_test(
name = "attr_storage_test",
srcs = ["attr_storage_test.cc"],
deps = [
":attr_key",
":attr_storage",
":symmetry",
"//ortools/base:gmock_main",
"@com_google_benchmark//:benchmark",
],
)
cc_library(
name = "attr_diff",
hdrs = ["attr_diff.h"],
deps = [
":attr_key",
"@abseil-cpp//absl/container:flat_hash_set",
],
)
cc_test(
name = "attr_diff_test",
srcs = ["attr_diff_test.cc"],
deps = [
":attr_diff",
":attr_key",
":symmetry",
"//ortools/base:gmock_main",
],
)
cc_library(
name = "attr_key",
hdrs = ["attr_key.h"],
visibility = ["//ortools/math_opt:__subpackages__"],
deps = [
":elements",
":symmetry",
"//ortools/base:status_macros",
"@abseil-cpp//absl/container:flat_hash_map",
"@abseil-cpp//absl/container:flat_hash_set",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/strings",
"@abseil-cpp//absl/types:span",
],
)
cc_test(
name = "attr_key_test",
srcs = ["attr_key_test.cc"],
deps = [
":attr_key",
":elements",
":symmetry",
":testing",
"//ortools/base:gmock_main",
"//ortools/math_opt/testing:stream",
"@abseil-cpp//absl/algorithm:container",
"@abseil-cpp//absl/container:flat_hash_map",
"@abseil-cpp//absl/container:flat_hash_set",
"@abseil-cpp//absl/hash:hash_testing",
"@abseil-cpp//absl/meta:type_traits",
"@abseil-cpp//absl/status",
"@abseil-cpp//absl/strings",
"@com_google_benchmark//:benchmark",
],
)
cc_library(
name = "arrays",
hdrs = ["arrays.h"],
visibility = ["//ortools/math_opt/elemental:__subpackages__"],
)
cc_library(
name = "elemental_differencer",
srcs = ["elemental_differencer.cc"],
hdrs = ["elemental_differencer.h"],
deps = [
":attr_key",
":attributes",
":derived_data",
":elemental",
":elements",
"@abseil-cpp//absl/container:flat_hash_set",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/status:statusor",
"@abseil-cpp//absl/strings",
],
)
cc_test(
name = "elemental_differencer_test",
srcs = ["elemental_differencer_test.cc"],
deps = [
":attr_key",
":attributes",
":elemental",
":elemental_differencer",
":elements",
"//ortools/base:gmock_main",
"@abseil-cpp//absl/container:flat_hash_set",
],
)
cc_test(
name = "arrays_test",
srcs = ["arrays_test.cc"],
deps = [
":arrays",
"//ortools/base:array",
"//ortools/base:gmock_main",
"@abseil-cpp//absl/status",
"@abseil-cpp//absl/strings",
"@abseil-cpp//absl/strings:string_view",
],
)
cc_test(
name = "elemental_export_model_test",
srcs = ["elemental_export_model_test.cc"],
deps = [
":attr_key",
":attributes",
":derived_data",
":elemental",
":elements",
"//ortools/base:gmock_main",
"//ortools/math_opt:model_cc_proto",
"//ortools/math_opt:sparse_containers_cc_proto",
],
)
cc_test(
name = "elemental_to_string_test",
srcs = ["elemental_to_string_test.cc"],
deps = [
":attr_key",
":attributes",
":elemental",
":elements",
"//ortools/base:gmock_main",
"//ortools/math_opt/testing:stream",
"@abseil-cpp//absl/strings",
],
)
cc_test(
name = "safe_attr_ops_test",
srcs = ["safe_attr_ops_test.cc"],
deps = [
":attr_key",
":attributes",
":elemental",
":elements",
":safe_attr_ops",
"//ortools/base:gmock_main",
"@abseil-cpp//absl/status",
],
)
cc_library(
name = "safe_attr_ops",
hdrs = ["safe_attr_ops.h"],
visibility = ["//ortools/math_opt/elemental/c_api:__subpackages__"],
deps = [
":derived_data",
":elemental",
"//ortools/base:status_macros",
"@abseil-cpp//absl/status",
"@abseil-cpp//absl/status:statusor",
"@abseil-cpp//absl/strings",
],
)
cc_library(
name = "testing",
testonly = 1,
hdrs = ["testing.h"],
deps = [":attr_key"],
)
cc_library(
name = "symmetry",
hdrs = ["symmetry.h"],
deps = [
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/strings:str_format",
"@abseil-cpp//absl/strings:string_view",
],
)
cc_library(
name = "elemental_matcher",
testonly = 1,
srcs = ["elemental_matcher.cc"],
hdrs = ["elemental_matcher.h"],
deps = [
":elemental",
":elemental_differencer",
"//ortools/base:gmock",
"@abseil-cpp//absl/base:core_headers",
],
)
cc_library(
name = "element_ref_tracker",
hdrs = ["element_ref_tracker.h"],
deps = [
":attr_key",
":elements",
"//ortools/base:map_util",
"@abseil-cpp//absl/container:flat_hash_map",
"@abseil-cpp//absl/container:flat_hash_set",
],
)
cc_library(
name = "elements",
srcs = ["elements.cc"],
hdrs = ["elements.h"],
visibility = ["//ortools/math_opt:__subpackages__"],
deps = [
"//ortools/base:array",
"@abseil-cpp//absl/base:core_headers",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/strings",
"@abseil-cpp//absl/strings:str_format",
"@abseil-cpp//absl/strings:string_view",
],
)
cc_test(
name = "elements_test",
srcs = ["elements_test.cc"],
deps = [
":elements",
"//ortools/base:gmock_main",
"//ortools/math_opt/testing:stream",
"@abseil-cpp//absl/hash:hash_testing",
"@abseil-cpp//absl/strings",
],
)
cc_test(
name = "elemental_matcher_test",
srcs = ["elemental_matcher_test.cc"],
deps = [
":elemental",
":elemental_differencer",
":elemental_matcher",
":elements",
"//ortools/base:gmock_main",
],
)
cc_test(
name = "element_ref_tracker_test",
srcs = ["element_ref_tracker_test.cc"],
deps = [
":attr_key",
":element_ref_tracker",
":elements",
":symmetry",
"//ortools/base:gmock_main",
],
)
cc_library(
name = "thread_safe_id_map",
hdrs = ["thread_safe_id_map.h"],
deps = [
"//ortools/base:stl_util",
"@abseil-cpp//absl/base:core_headers",
"@abseil-cpp//absl/container:flat_hash_set",
"@abseil-cpp//absl/synchronization",
"@abseil-cpp//absl/types:span",
],
)
cc_test(
name = "thread_safe_id_map_test",
srcs = ["thread_safe_id_map_test.cc"],
deps = [
":thread_safe_id_map",
"//ortools/base:gmock_main",
],
)
cc_test(
name = "elemental_from_proto_test",
srcs = ["elemental_from_proto_test.cc"],
deps = [
":attr_key",
":attributes",
":derived_data",
":elemental",
":elemental_matcher",
":elements",
"//ortools/base:gmock_main",
"//ortools/math_opt:model_cc_proto",
"//ortools/math_opt:sparse_containers_cc_proto",
"@abseil-cpp//absl/status",
],
)
cc_test(
name = "elemental_from_proto_fuzz_test",
srcs = ["elemental_from_proto_fuzz_test.cc"],
tags = ["componentid:1147829"],
deps = [
":elemental",
":elemental_matcher",
"//ortools/base:fuzztest",
"//ortools/base:gmock_main",
"//ortools/math_opt:model_update_cc_proto",
"@abseil-cpp//absl/status:statusor",
],
)
cc_test(
name = "elemental_update_from_proto_test",
srcs = ["elemental_update_from_proto_test.cc"],
deps = [
":attr_key",
":attributes",
":derived_data",
":elemental",
":elemental_matcher",
":elements",
"//ortools/base:gmock_main",
"//ortools/math_opt:model_cc_proto",
"//ortools/math_opt:model_update_cc_proto",
"//ortools/math_opt:sparse_containers_cc_proto",
"@abseil-cpp//absl/status",
],
)

View File

@@ -0,0 +1,30 @@
# Copyright 2010-2025 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set(NAME ${PROJECT_NAME}_math_opt_elemental)
add_library(${NAME} OBJECT)
file(GLOB_RECURSE _SRCS "*.h" "*.cc")
list(FILTER _SRCS EXCLUDE REGEX ".*/.*_test.cc")
list(FILTER _SRCS EXCLUDE REGEX "/elemental_matcher.*")
list(FILTER _SRCS EXCLUDE REGEX "/python/.*")
list(FILTER _SRCS EXCLUDE REGEX "/codegen/codegen.cc")
target_sources(${NAME} PRIVATE ${_SRCS})
set_target_properties(${NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_include_directories(${NAME} PUBLIC
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}>
$<BUILD_INTERFACE:${PROJECT_BINARY_DIR}>)
target_link_libraries(${NAME} PRIVATE
${PROJECT_NAMESPACE}::math_opt_proto
absl::strings
)

View File

@@ -0,0 +1,3 @@
# Elemental
See go/math-opt-elemental and g/math-opt-dev/c/0cgOO6qkoWM.

View File

@@ -0,0 +1,74 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Utilities to apply template functors on index ranges.
// See tests for examples.
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ARRAYS_H_
#define OR_TOOLS_MATH_OPT_ELEMENTAL_ARRAYS_H_
#include <tuple>
#include <utility>
namespace operations_research::math_opt {
// Calls `fn<0, ..., n-1>()`, and returns the result. Typically used for
// simple reduce operations that can be expressed as a fold.
//
// Examples:
// - Sum of elements from 0 to 5 (result is 15):
// `ApplyOnIndexRange<6>([]<int... i>() { return (i + ... + 0); });`
//
// - Sum of elements of array `a`:
// ```
// ApplyOnIndexRange<a.size()>([&a]<int... i>() {
// return (a[i] + ... + 0);
// });
// ```
template <int n, typename Fn>
constexpr decltype(auto) ApplyOnIndexRange(Fn&& fn) {
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
return [&fn]<int... is>(std::integer_sequence<int, is...>) mutable {
return fn.template operator()<is...>();
}(std::make_integer_sequence<int, n>());
}
// Calls (fn<0>(), ..., fn<n-1>()).
// Typically used for independent operations on elements, or more complex reduce
// operations that cannot be expressed with a fold.
//
// Example (independent operations): Log each array element for some array `a`:
// `ForEachIndex<a.size()>([&a]<int i>() { LOG(ERROR) << a[i]; });`
//
// NOTE: this returns the result of the last call, which allows returning some
// internal state (and avoids capturing an external variable by reference) for
// complex fold operations. See `CollectTest` for an example.
template <int n, typename Fn>
constexpr decltype(auto) ForEachIndex(Fn&& fn) {
return ApplyOnIndexRange<n>(
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
[&fn]<int... is>() { return (fn.template operator()<is>(), ...); });
}
// Calls `fn` of each element of `tuple`, and returns the result of the
// last invocation.
template <typename Fn, typename Tuple>
constexpr decltype(auto) ForEach(Fn&& fn, Tuple&& tuple) {
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
return std::apply([&fn]<typename... Ts>(
Ts&&... ts) { return (fn(std::forward<Ts>(ts)), ...); },
std::forward<Tuple>(tuple));
}
} // namespace operations_research::math_opt
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ARRAYS_H_

View File

@@ -0,0 +1,179 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/arrays.h"
#include <array>
#include <iterator>
#include <string>
#include <tuple>
#include <vector>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "gtest/gtest.h"
#include "ortools/base/array.h"
#include "ortools/base/gmock.h"
namespace operations_research::math_opt {
namespace {
using ::testing::ElementsAre;
// Sums the elements of an array-like object `a`.
template <auto a>
constexpr int Sum() {
return ApplyOnIndexRange<std::size(a)>(
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
[]<int... i>() { return (a[i] + ... + 0); });
}
// Same as `Sum`, but starts at 1.
template <auto a>
constexpr int SumPlusOne() {
return ApplyOnIndexRange<std::size(a)>(
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
[]<int... i>() { return (a[i] + ... + 1); });
}
#if __cplusplus >= 202002L
// NOLINTBEGIN(clang-diagnostic-pre-c++20-compat)
TEST(ApplyOnIndexRangeTest, Sum) {
EXPECT_EQ(Sum<gtl::to_array({5, 3, 1})>(), 9);
EXPECT_EQ(SumPlusOne<gtl::to_array({5, 3, 1})>(), 10);
}
// NOLINTEND(clang-diagnostic-pre-c++20-compat)
#endif
// Returns the weighted sum of the elements of an array-like object `a`, where
// weights are indices.
template <auto a>
constexpr double ScaledSum() {
return ApplyOnIndexRange<std::size(a)>(
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
[]<int... i>() { return ((i * a[i]) + ... + 0.0); });
}
#if __cplusplus >= 202002L
// NOLINTBEGIN(clang-diagnostic-pre-c++20-compat)
TEST(ApplyOnIndexRangeTest, ScaledSum) {
EXPECT_EQ(ScaledSum<gtl::to_array({5, 3, 1})>(), 5.0);
}
// NOLINTEND(clang-diagnostic-pre-c++20-compat)
#endif
// Returns the number of even elements in an array-like object `a`.
template <auto a>
constexpr int CountEven() {
return ApplyOnIndexRange<std::size(a)>(
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
[]<int... i>() { return ((a[i] % 2 == 0 ? 1 : 0) + ... + 0); });
}
#if __cplusplus >= 202002L
// NOLINTBEGIN(clang-diagnostic-pre-c++20-compat)
TEST(ApplyOnIndexRangeTest, CountEven) {
EXPECT_EQ(CountEven<gtl::to_array({5, 4, 8, 1, 10})>(), 3);
}
// NOLINTEND(clang-diagnostic-pre-c++20-compat)
#endif
// Returns array of doubles of the same size as `a`, where each element has been
// halved.
template <auto a>
constexpr std::array<double, std::size(a)> Half() {
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
return ApplyOnIndexRange<std::size(a)>([]<int... i>() {
return std::array<double, std::size(a)>(
{(static_cast<double>(a[i]) / 2.0)...});
});
}
#if __cplusplus >= 202002L
// NOLINTBEGIN(clang-diagnostic-pre-c++20-compat)
TEST(ApplyOnIndexRangeTest, Half) {
EXPECT_THAT(Half<gtl::to_array({5, 4, 8, 1, 10})>(),
ElementsAre(2.5, 2.0, 4.0, 0.5, 5.0));
}
// NOLINTEND(clang-diagnostic-pre-c++20-compat)
#endif
// Returns true of all elements of `a` are even.
template <auto a>
constexpr int AllEven() {
return ApplyOnIndexRange<std::size(a)>(
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
[]<int... i>() { return (((a[i] % 2) == 0) && ...); });
}
// Returns true of any element of `a` is even.
template <auto a>
constexpr int AnyEven() {
return ApplyOnIndexRange<std::size(a)>(
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
[]<int... i>() { return (((a[i] % 2) == 0) || ...); });
}
#if __cplusplus >= 202002L
// NOLINTBEGIN(clang-diagnostic-pre-c++20-compat)
TEST(ApplyOnIndexRangeTest, Even) {
EXPECT_FALSE(AllEven<gtl::to_array({5, 4, 8, 1, 10})>());
EXPECT_TRUE(AnyEven<gtl::to_array({5, 4, 8, 1, 10})>());
EXPECT_TRUE(AllEven<gtl::to_array({8, 2, 6})>());
EXPECT_TRUE(AnyEven<gtl::to_array({8, 2, 6})>());
EXPECT_FALSE(AllEven<gtl::to_array({3, 7, 1})>());
EXPECT_FALSE(AnyEven<gtl::to_array({3, 7, 1})>());
}
// NOLINTEND(clang-diagnostic-pre-c++20-compat)
#endif
// A example of a more complex reduce operation using `ForEachIndex`. Here, we
// want to collect a list of integers for which an operation (`may_fail`)
// failed.
TEST(ForEachIndexTest, CollectTest) {
constexpr auto may_fail = [](int i) {
if (i == 3 || i == 7 || i == 42) {
return absl::InvalidArgumentError("bad number");
}
return absl::OkStatus();
};
EXPECT_THAT(
ForEachIndex<21>(
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
[&may_fail, failed_indices = std::vector<int>()]<int i>() mutable
-> const std::vector<int>& {
if (!may_fail(i).ok()) {
failed_indices.push_back(i);
}
return failed_indices;
}),
ElementsAre(3, 7));
}
TEST(ForEachTest, StrCatHeterogeneousTypes) {
EXPECT_EQ(
ForEach(
[r = std::string()](const auto& v) mutable -> absl::string_view {
absl::StrAppend(&r, " ", v);
return r;
},
std::make_tuple("a", 1, 0.5)),
" a 1 0.5");
}
} // namespace
} // namespace operations_research::math_opt

View File

@@ -0,0 +1,58 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_DIFF_H_
#define OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_DIFF_H_
#include <cstdint>
#include <utility>
#include "absl/container/flat_hash_set.h"
#include "ortools/math_opt/elemental/attr_key.h"
namespace operations_research::math_opt {
// Tracks modifications to an Attribute with a key size of n (e.g., variable
// lower bound has a key size of 1).
template <int n, typename Symmetry>
class AttrDiff {
public:
using Key = AttrKey<n, Symmetry>;
// On creation, the attribute is not modified for any key.
AttrDiff() = default;
// Clear all tracked modifications.
void Advance() { modified_keys_.clear(); }
// Mark the attribute as modified for `key`.
void SetModified(const Key key) { modified_keys_.insert(key); }
// Returns the attribute keys that have been modified for this attribute (the
// elements where set_modified() was called without a subsequent call to
// Advance()).
const AttrKeyHashSet<Key>& modified_keys() const { return modified_keys_; }
bool has_modified_keys() const { return !modified_keys_.empty(); }
// Stop tracking modifications for this attribute key. (Typically invoked when
// an element in the key was deleted from the model.)
void Erase(const Key key) { modified_keys_.erase(key); }
private:
AttrKeyHashSet<Key> modified_keys_;
};
} // namespace operations_research::math_opt
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_DIFF_H_

View File

@@ -0,0 +1,168 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/attr_diff.h"
#include "gtest/gtest.h"
#include "ortools/base/gmock.h"
#include "ortools/math_opt/elemental/attr_key.h"
#include "ortools/math_opt/elemental/symmetry.h"
namespace operations_research::math_opt {
using ::testing::IsEmpty;
using ::testing::UnorderedElementsAre;
////////////////////////////////////////////////////////////////////////////////
// AttrDiff<0>
////////////////////////////////////////////////////////////////////////////////
TEST(AttrDiff0Test, InitNotModified) {
AttrDiff<0, NoSymmetry> diff;
EXPECT_THAT(diff.modified_keys(), IsEmpty());
}
TEST(AttrDiff0Test, SetModified) {
AttrDiff<0, NoSymmetry> diff;
diff.SetModified(AttrKey());
EXPECT_THAT(diff.modified_keys(), UnorderedElementsAre(AttrKey()));
}
TEST(AttrDiff0Test, Advance) {
AttrDiff<0, NoSymmetry> diff;
diff.SetModified(AttrKey());
diff.Advance();
EXPECT_THAT(diff.modified_keys(), IsEmpty());
}
////////////////////////////////////////////////////////////////////////////////
// Attr1Diff
////////////////////////////////////////////////////////////////////////////////
TEST(AttrDiff1Test, InitNotModified) {
AttrDiff<1, NoSymmetry> diff;
EXPECT_THAT(diff.modified_keys(), IsEmpty());
}
TEST(AttrDiff1Test, SetModified) {
AttrDiff<1, NoSymmetry> diff;
diff.SetModified(AttrKey(2));
diff.SetModified(AttrKey(5));
diff.SetModified(AttrKey(6));
EXPECT_THAT(diff.modified_keys(),
UnorderedElementsAre(AttrKey(2), AttrKey(5), AttrKey(6)));
}
TEST(AttrDiff1Test, Advance) {
AttrDiff<1, NoSymmetry> diff;
diff.SetModified(AttrKey(2));
diff.SetModified(AttrKey(5));
diff.Advance();
EXPECT_THAT(diff.modified_keys(), IsEmpty());
}
TEST(AttrDiff1Test, EraseIsModifiedGetsRemoved) {
AttrDiff<1, NoSymmetry> diff;
diff.SetModified(AttrKey(2));
diff.SetModified(AttrKey(5));
diff.SetModified(AttrKey(6));
diff.Erase(AttrKey(5));
EXPECT_THAT(diff.modified_keys(),
UnorderedElementsAre(AttrKey(2), AttrKey(6)));
}
TEST(AttrDiff1Test, EraseNotModifiedNoEffect) {
AttrDiff<1, NoSymmetry> diff;
diff.SetModified(AttrKey(2));
diff.SetModified(AttrKey(5));
diff.Erase(AttrKey(1));
EXPECT_THAT(diff.modified_keys(),
UnorderedElementsAre(AttrKey(2), AttrKey(5)));
}
////////////////////////////////////////////////////////////////////////////////
// Attr2Diff
////////////////////////////////////////////////////////////////////////////////
TEST(AttrDiffTest2, InitNotModified) {
AttrDiff<2, NoSymmetry> diff;
EXPECT_THAT(diff.modified_keys(), IsEmpty());
}
TEST(AttrDiffTest2, SetModified) {
AttrDiff<2, NoSymmetry> diff;
diff.SetModified(AttrKey(2, 4));
diff.SetModified(AttrKey(5, 2));
diff.SetModified(AttrKey(2, 5));
diff.SetModified(AttrKey(6, 6));
EXPECT_THAT(diff.modified_keys(),
UnorderedElementsAre(AttrKey(2, 4), AttrKey(5, 2), AttrKey(2, 5),
AttrKey(6, 6)));
}
TEST(AttrDiffTest2, Advance) {
AttrDiff<2, NoSymmetry> diff;
diff.SetModified(AttrKey(2, 3));
diff.SetModified(AttrKey(2, 8));
diff.Advance();
EXPECT_THAT(diff.modified_keys(), IsEmpty());
}
TEST(AttrDiffTest2, EraseIsModifiedGetsRemoved) {
AttrDiff<2, NoSymmetry> diff;
diff.SetModified(AttrKey(2, 5));
diff.SetModified(AttrKey(4, 3));
diff.SetModified(AttrKey(3, 4));
diff.SetModified(AttrKey(6, 6));
EXPECT_THAT(diff.modified_keys(),
UnorderedElementsAre(AttrKey(2, 5), AttrKey(3, 4), AttrKey(4, 3),
AttrKey(6, 6)));
diff.Erase(AttrKey(4, 3));
EXPECT_THAT(
diff.modified_keys(),
UnorderedElementsAre(AttrKey(2, 5), AttrKey(3, 4), AttrKey(6, 6)));
}
TEST(AttrDiffTest2, EraseIsModifiedGetsRemovedSymmetric) {
using Diff = AttrDiff<2, ElementSymmetry<0, 1>>;
using Key = Diff::Key;
Diff diff;
diff.SetModified(Key(2, 5));
diff.SetModified(Key(4, 3));
diff.SetModified(Key(3, 4)); // Noop, same as (4,3).
diff.SetModified(Key(6, 6));
EXPECT_THAT(diff.modified_keys(),
UnorderedElementsAre(Key(2, 5), Key(3, 4), Key(6, 6)));
diff.Erase(Key(4, 3));
EXPECT_THAT(diff.modified_keys(), UnorderedElementsAre(Key(2, 5), Key(6, 6)));
}
TEST(AttrDiffTest2, EraseNotModifiedNoEffect) {
AttrDiff<2, NoSymmetry> diff;
diff.SetModified(AttrKey(2, 5));
diff.SetModified(AttrKey(6, 6));
diff.Erase(AttrKey(1, 3));
EXPECT_THAT(diff.modified_keys(),
UnorderedElementsAre(AttrKey(2, 5), AttrKey(6, 6)));
}
} // namespace operations_research::math_opt

View File

@@ -0,0 +1,360 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_KEY_H_
#define OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_KEY_H_
#include <array>
#include <cstddef>
#include <cstdint>
#include <ostream>
#include <type_traits>
#include <utility>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "ortools/base/status_builder.h"
#include "ortools/math_opt/elemental/elements.h"
#include "ortools/math_opt/elemental/symmetry.h"
namespace operations_research::math_opt {
// An attribute key for an attribute keyed on `n` elements.
// `AttrKey` is a value type.
template <int n, typename Symmetry = NoSymmetry>
class AttrKey {
public:
using value_type = int64_t;
using SymmetryT = Symmetry;
// Default constructor: values are uninitialized.
constexpr AttrKey() {} // NOLINT: uninitialized on purpose.
template <typename... Ints,
typename = std::enable_if_t<(sizeof...(Ints) == n &&
(std::is_integral_v<Ints> && ...))>>
explicit constexpr AttrKey(const Ints... ids) {
auto push_back = [this, i = 0](auto e) mutable { element_ids_[i++] = e; };
(push_back(ids), ...);
Symmetry::Enforce(element_ids_);
}
template <ElementType... element_types,
typename = std::enable_if_t<(sizeof...(element_types) == n)>>
explicit constexpr AttrKey(const ElementId<element_types>... ids)
: AttrKey(ids.value()...) {}
constexpr AttrKey(std::array<value_type, n> ids) // NOLINT: pybind11.
: element_ids_(ids) {
Symmetry::Enforce(element_ids_);
}
// Canonicalizes a non-canonical key, i.e., enforces the symmetries
static constexpr AttrKey Canonicalize(AttrKey<n, NoSymmetry> key) {
return AttrKey(key.element_ids_);
}
// Creates a key from a range of `n` elements.
static absl::StatusOr<AttrKey> FromRange(absl::Span<const int64_t> range) {
if (range.size() != n) {
return ::util::InvalidArgumentErrorBuilder()
<< "cannot build AttrKey<" << n << "> from a range of size "
<< range.size();
}
AttrKey result;
std::copy(range.begin(), range.end(), result.element_ids_.begin());
Symmetry::Enforce(result.element_ids_);
return result;
}
constexpr AttrKey(const AttrKey&) = default;
constexpr AttrKey& operator=(const AttrKey&) = default;
static constexpr int size() { return n; }
// Element access.
constexpr value_type operator[](const int dim) const {
DCHECK_LT(dim, n);
DCHECK_GE(dim, 0);
return element_ids_[dim];
}
constexpr value_type& operator[](const int dim) {
DCHECK_LT(dim, n);
DCHECK_GE(dim, 0);
return element_ids_[dim];
}
// Element iteration.
constexpr const value_type* begin() const { return element_ids_.begin(); }
constexpr const value_type* end() const { return element_ids_.end(); }
// `AttrKey` is comparable (ordering is lexicographic) and hashable.
//
// TODO(b/365998156): post C++ 20, replace by spaceship operator (with all
// comparison operators below). Do NOT use the default generated operator (see
// below).
friend constexpr bool operator==(const AttrKey& l, const AttrKey& r) {
// This is much faster than using the default generated `operator==`.
for (int i = 0; i < n; ++i) {
if (l.element_ids_[i] != r.element_ids_[i]) {
return false;
}
}
return true;
}
friend constexpr bool operator<(const AttrKey& l, const AttrKey& r) {
// This is much faster than using the default generated `operator<`.
for (int i = 0; i < n; ++i) {
if (l.element_ids_[i] < r.element_ids_[i]) {
return true;
}
if (l.element_ids_[i] > r.element_ids_[i]) {
return false;
}
}
return false;
}
friend constexpr bool operator<=(const AttrKey& l, const AttrKey& r) {
// This is much faster than using the default generated `operator<`.
for (int i = 0; i < n; ++i) {
if (l.element_ids_[i] < r.element_ids_[i]) {
return true;
}
if (l.element_ids_[i] > r.element_ids_[i]) {
return false;
}
}
return true;
}
friend constexpr bool operator>(const AttrKey& l, const AttrKey& r) {
return r < l;
}
friend constexpr bool operator>=(const AttrKey& l, const AttrKey& r) {
return r <= l;
}
template <typename H>
friend H AbslHashValue(H h, const AttrKey& a) {
return H::combine_contiguous(std::move(h), a.element_ids_.data(), n);
}
// `AttrKey` is printable for logging and tests.
template <typename Sink>
friend void AbslStringify(Sink& sink, const AttrKey& key) {
sink.Append(absl::StrCat(
"AttrKey(", absl::StrJoin(absl::MakeSpan(key.element_ids_), ", "),
")"));
}
// Removes the element at dimension `dim` from the key and returns a key with
// only remaining dimensions.
template <int dim>
AttrKey<n - 1, NoSymmetry> RemoveElement() const {
static_assert(dim >= 0);
static_assert(dim < n);
AttrKey<n - 1, NoSymmetry> result;
for (int i = 0; i < dim; ++i) {
result.element_ids_[i] = element_ids_[i];
}
for (int i = dim + 1; i < n; ++i) {
result.element_ids_[i - 1] = element_ids_[i];
}
return result;
}
// Adds element `elem` at dimension `dim` and returns the result.
// The result must respect `NewSymmetry` (we `DCHECK` this).
template <int dim, typename NewSymmetry>
AttrKey<n + 1, NewSymmetry> AddElement(const value_type elem) const {
static_assert(dim >= 0);
static_assert(dim < n + 1);
AttrKey<n + 1, NewSymmetry> result;
for (int i = 0; i < dim; ++i) {
result.element_ids_[i] = element_ids_[i];
}
result.element_ids_[dim] = elem;
for (int i = dim + 1; i < n + 1; ++i) {
result.element_ids_[i] = element_ids_[i - 1];
}
DCHECK(NewSymmetry::Validate(result.element_ids_))
<< result << " does not have `" << NewSymmetry::GetName()
<< "` symmetry";
return result;
}
private:
template <int other_n, typename OtherSymmetry>
friend class AttrKey;
std::array<value_type, n> element_ids_;
};
// CTAD for `AttrKey(1,2)`.
template <typename... Ints>
AttrKey(Ints... dims) -> AttrKey<sizeof...(Ints), NoSymmetry>;
// Traits to detect whether `T` is an `AttrKey`.
template <typename T>
struct is_attr_key : public std::false_type {};
template <int n, typename Symmetry>
struct is_attr_key<AttrKey<n, Symmetry>> : public std::true_type {};
template <typename T>
static constexpr inline bool is_attr_key_v = is_attr_key<T>::value;
// Required for open-source `StatusBuilder` support.
template <int n, typename Symmetry>
std::ostream& operator<<(std::ostream& ostr, const AttrKey<n, Symmetry>& key) {
ostr << absl::StrCat(key);
return ostr;
}
namespace detail {
// A set of zero or one `AttrKey<0, Symmetry>, V`. This is used to make
// implementations of `AttrDiff` and `AttrStorage` uniform.
// `V` must by default constructible, trivially destructible and copyable
// (we'll fail to compile otherwise).
// After c++26, optional is a sequence container, so this can pretty much become
// `std::optional<AttrKey<0, Symmetry>>` + `find()`.
template <typename Symmetry, typename V>
class AttrKey0RawSet {
public:
using value_type = V;
using Key = AttrKey<0, Symmetry>;
template <typename ValueT>
class IteratorImpl {
public:
IteratorImpl() = default;
// `iterator` converts to `const_iterator`.
IteratorImpl(const IteratorImpl<std::remove_cv_t<ValueT>>& other) // NOLINT
: value_(other.value_) {}
// Dereference.
ValueT& operator*() const {
DCHECK_NE(value_, nullptr);
return *value_;
}
ValueT* operator->() const {
DCHECK_NE(value_, nullptr);
return value_;
}
// Increment.
IteratorImpl& operator++() {
DCHECK_NE(value_, nullptr);
value_ = nullptr;
return *this;
}
// Equality.
friend bool operator==(const IteratorImpl& l, const IteratorImpl& r) {
return l.value_ == r.value_;
}
friend bool operator!=(const IteratorImpl& l, const IteratorImpl& r) {
return !(l == r);
}
private:
friend class AttrKey0RawSet;
explicit IteratorImpl(ValueT& value) : value_(&value) {}
ValueT* value_ = nullptr;
};
using iterator = IteratorImpl<value_type>;
using const_iterator = IteratorImpl<const value_type>;
AttrKey0RawSet() = default;
bool empty() const { return !engaged_; }
size_t size() const { return engaged_ ? 1 : 0; }
const_iterator begin() const {
return engaged_ ? const_iterator(value_) : const_iterator();
}
const_iterator end() const { return const_iterator(); }
iterator begin() { return engaged_ ? iterator(value_) : iterator(); }
iterator end() { return iterator(); }
bool contains(Key) const { return engaged_; }
const_iterator find(Key) const { return begin(); }
iterator find(Key) { return begin(); }
void clear() { engaged_ = false; }
size_t erase(Key) {
if (engaged_) {
engaged_ = false;
return 1;
}
return 0;
}
size_t erase(const_iterator) { return erase(Key()); }
template <typename... Args>
std::pair<iterator, bool> try_emplace(Key, Args&&... args) {
if (engaged_) {
return std::make_pair(iterator(value_), false);
}
value_ = value_type(Key(), std::forward<Args>(args)...);
engaged_ = true;
return std::make_pair(iterator(value_), true);
}
std::pair<iterator, bool> insert(const value_type& v) {
if (engaged_) {
return std::make_pair(iterator(value_), false);
}
value_ = v;
engaged_ = true;
return std::make_pair(iterator(value_), true);
}
private:
// The following greatly simplifies the implementation because we don't have
// to worry about side effects of the dtor (see e.g. `clear()`).
static_assert(std::is_trivially_destructible_v<value_type>);
bool engaged_ = false;
value_type value_;
};
} // namespace detail
// A hash set of `AttrKeyT`, where `AttrKeyT` is an `AttrKey<n, Symmetry>`.
template <typename AttrKeyT,
typename = std::enable_if_t<is_attr_key_v<AttrKeyT>>>
using AttrKeyHashSet = std::conditional_t<
(AttrKeyT::size() > 0), absl::flat_hash_set<AttrKeyT>,
detail::AttrKey0RawSet<typename AttrKeyT::SymmetryT, AttrKeyT>>;
// A hash map of `AttrKeyT` to `V`, where `AttrKeyT` is an
// `AttrKey<n, Symmetry>`.
template <typename AttrKeyT, typename V,
typename = std::enable_if_t<is_attr_key_v<AttrKeyT>>>
using AttrKeyHashMap =
std::conditional_t<(AttrKeyT::size() > 0), absl::flat_hash_map<AttrKeyT, V>,
detail::AttrKey0RawSet<typename AttrKeyT::SymmetryT,
std::pair<AttrKeyT, V>>>;
} // namespace operations_research::math_opt
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_KEY_H_

View File

@@ -0,0 +1,335 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/attr_key.h"
#include <cstdint>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/hash/hash_testing.h"
#include "absl/meta/type_traits.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "benchmark/benchmark.h"
#include "gtest/gtest.h"
#include "ortools/base/gmock.h"
#include "ortools/math_opt/elemental/elements.h"
#include "ortools/math_opt/elemental/symmetry.h"
#include "ortools/math_opt/elemental/testing.h"
#include "ortools/math_opt/testing/stream.h"
namespace operations_research::math_opt {
namespace {
using testing::ElementsAre;
using testing::HasSubstr;
using testing::IsEmpty;
using testing::Pair;
using testing::SizeIs;
using testing::UnorderedElementsAre;
using testing::status::IsOkAndHolds;
using testing::status::StatusIs;
static_assert(sizeof(AttrKey<0>) <= sizeof(uint64_t));
static_assert(sizeof(AttrKey<1>) == sizeof(uint64_t));
static_assert(sizeof(AttrKey<2>) == 2 * sizeof(uint64_t));
static_assert(sizeof(AttrKey<2, ElementSymmetry<0, 1>>) ==
2 * sizeof(uint64_t));
// Make sure that passing AttrKey by value really puts it in registers rather
// than leaving it in the caller's frame (see
// https://itanium-cxx-abi.github.io/cxx-abi/abi.html#non-trivial).
static_assert(absl::is_trivially_relocatable<AttrKey<0>>());
static_assert(absl::is_trivially_relocatable<AttrKey<1>>());
static_assert(absl::is_trivially_relocatable<AttrKey<2>>());
static_assert(
absl::is_trivially_relocatable<AttrKey<2, ElementSymmetry<0, 1>>>());
TEST(AttrKeyTest, CtorAndIteration) {
EXPECT_THAT(AttrKey(), ElementsAre());
EXPECT_THAT(AttrKey(1), ElementsAre(1));
EXPECT_THAT(AttrKey(1, 2), ElementsAre(1, 2));
}
TEST(AttrKeyTest, ElementIdCtor) {
EXPECT_THAT(AttrKey(ElementId<ElementType::kVariable>(1)), ElementsAre(1));
EXPECT_THAT(AttrKey(ElementId<ElementType::kVariable>(1),
ElementId<ElementType::kLinearConstraint>(2)),
ElementsAre(1, 2));
}
TEST(AttrKeyTest, ElementAccess) {
const AttrKey key(1, 2);
EXPECT_EQ(key[0], 1);
EXPECT_EQ(key[1], 2);
AttrKey mutable_key(1, 2);
EXPECT_EQ(mutable_key[0], 1);
EXPECT_EQ(mutable_key[1], 2);
}
TEST(AttrKeyTest, SupportsAbslHash1) {
EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({
AttrKey(1),
AttrKey(2),
AttrKey(0),
}));
}
TEST(AttrKeyTest, SupportsAbslHash2) {
EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({
AttrKey(1, 2),
AttrKey(2, 3),
AttrKey(0, 0),
}));
}
TEST(AttrKeyTest, Stringify) {
EXPECT_EQ(absl::StrCat(AttrKey(1, 2, 3)), "AttrKey(1, 2, 3)");
EXPECT_EQ(StreamToString(AttrKey(1, 2, 3)), "AttrKey(1, 2, 3)");
}
TEST(AttrKeyTest, AddRemove) {
const AttrKey key0;
EXPECT_THAT(key0, ElementsAre());
const AttrKey key1 = key0.AddElement<0, NoSymmetry>(3);
EXPECT_THAT(key1, ElementsAre(3));
const AttrKey key2 = key1.AddElement<0, NoSymmetry>(1);
EXPECT_THAT(key2, ElementsAre(1, 3));
const AttrKey key3 = key2.AddElement<1, NoSymmetry>(2);
EXPECT_THAT(key3, ElementsAre(1, 2, 3));
const AttrKey key4 = key3.AddElement<3, NoSymmetry>(4);
EXPECT_THAT(key4, ElementsAre(1, 2, 3, 4));
}
TEST(AttrKeyTest, AddRemoveNotSymmetric) {
using NoSym = NoSymmetry;
EXPECT_THAT((AttrKey(0, 2).AddElement<1, NoSym>(1)), ElementsAre(0, 1, 2));
EXPECT_THAT((AttrKey(0, 1).AddElement<2, NoSym>(2)), ElementsAre(0, 1, 2));
EXPECT_THAT((AttrKey(0, 1).AddElement<1, NoSym>(2)), ElementsAre(0, 2, 1));
EXPECT_THAT((AttrKey(0, 2).AddElement<2, NoSym>(1)), ElementsAre(0, 2, 1));
}
TEST(AttrKeyDeathTest, AddRemoveSymmetric) {
using Sym01 = ElementSymmetry<1, 2>;
EXPECT_THAT((AttrKey(0, 2).AddElement<1, Sym01>(1)), ElementsAre(0, 1, 2));
EXPECT_THAT((AttrKey(0, 1).AddElement<2, Sym01>(2)), ElementsAre(0, 1, 2));
#ifndef NDEBUG
EXPECT_DEATH(
(AttrKey(0, 1).AddElement<1, Sym01>(2)),
HasSubstr(
"AttrKey(0, 2, 1) does not have `ElementSymmetry<1, 2>` symmetry"));
EXPECT_DEATH(
(AttrKey(0, 2).AddElement<2, Sym01>(1)),
HasSubstr(
"AttrKey(0, 2, 1) does not have `ElementSymmetry<1, 2>` symmetry"));
#endif
}
TEST(AttrKeyTest, ComparisonOperators) {
// a[0] < a[1] < a[2] < a[3] < a[4]
const std::vector<AttrKey<4>> a = {AttrKey(1, 0, 0, 0), AttrKey(2, 5, 1, 12),
AttrKey(2, 5, 3, 10), AttrKey(2, 5, 3, 11),
AttrKey(3, 0, 0, 0)};
// Now test each of the operators
for (int i = 0; i < a.size(); ++i) {
SCOPED_TRACE(absl::StrCat(i));
for (int j = 0; j < a.size(); ++j) {
SCOPED_TRACE(absl::StrCat(j));
if (i == j) {
EXPECT_FALSE(a[i] < a[j]);
EXPECT_TRUE(a[i] <= a[j]);
EXPECT_TRUE(a[i] == a[j]);
EXPECT_TRUE(a[i] >= a[j]);
EXPECT_FALSE(a[i] > a[j]);
} else if (i < j) {
EXPECT_TRUE(a[i] < a[j]);
EXPECT_TRUE(a[i] <= a[j]);
EXPECT_FALSE(a[i] == a[j]);
EXPECT_FALSE(a[i] >= a[j]);
EXPECT_FALSE(a[i] > a[j]);
} else {
EXPECT_FALSE(a[i] < a[j]);
EXPECT_FALSE(a[i] <= a[j]);
EXPECT_FALSE(a[i] == a[j]);
EXPECT_TRUE(a[i] >= a[j]);
EXPECT_TRUE(a[i] > a[j]);
}
}
}
}
TEST(AttrKey0SetTest, Works) {
AttrKeyHashSet<AttrKey<0>> set;
EXPECT_THAT(set, IsEmpty());
EXPECT_THAT(set, SizeIs(0));
EXPECT_THAT(set, UnorderedElementsAre());
EXPECT_FALSE(set.contains(AttrKey()));
EXPECT_TRUE(set.find(AttrKey()) == set.end());
EXPECT_EQ(set.erase(AttrKey()), 0);
set.insert(AttrKey());
EXPECT_THAT(set, Not(IsEmpty()));
EXPECT_THAT(set, SizeIs(1));
EXPECT_THAT(set, UnorderedElementsAre(AttrKey()));
EXPECT_TRUE(set.contains(AttrKey()));
EXPECT_TRUE(set.find(AttrKey()) == set.begin());
EXPECT_EQ(set.erase(AttrKey()), 1);
EXPECT_THAT(set, IsEmpty());
set.insert(AttrKey());
set.clear();
EXPECT_THAT(set, IsEmpty());
set.insert(AttrKey());
set.erase(AttrKey());
EXPECT_THAT(set, IsEmpty());
}
TEST(AttrKey0MapTest, Works) {
AttrKeyHashMap<AttrKey<0>, int> map;
EXPECT_THAT(map, IsEmpty());
EXPECT_THAT(map, SizeIs(0));
EXPECT_THAT(map, UnorderedElementsAre());
EXPECT_FALSE(map.contains(AttrKey()));
EXPECT_TRUE(map.find(AttrKey()) == map.end());
EXPECT_EQ(map.erase(AttrKey()), 0);
map.try_emplace(AttrKey(), 42);
EXPECT_THAT(map, Not(IsEmpty()));
EXPECT_THAT(map, SizeIs(1));
EXPECT_THAT(map, UnorderedElementsAre(Pair(AttrKey(), 42)));
EXPECT_EQ(map.begin()->first, AttrKey());
EXPECT_EQ(map.begin()->second, 42);
EXPECT_TRUE(map.contains(AttrKey()));
EXPECT_TRUE(map.find(AttrKey()) == map.begin());
EXPECT_EQ(map.erase(AttrKey()), 1);
EXPECT_THAT(map, IsEmpty());
map.insert({AttrKey(), 43});
map.clear();
EXPECT_THAT(map, IsEmpty());
map.try_emplace(AttrKey(), 43);
map.erase(AttrKey());
EXPECT_THAT(map, IsEmpty());
map.try_emplace(AttrKey(), 43);
map.erase(map.begin());
EXPECT_THAT(map, IsEmpty());
}
TEST(AttrKeyTest, FromRange) {
EXPECT_THAT((AttrKey<0>::FromRange({})), IsOkAndHolds(AttrKey()));
EXPECT_THAT((AttrKey<1>::FromRange({1})), IsOkAndHolds(AttrKey(1)));
EXPECT_THAT((AttrKey<2>::FromRange({1, 2})), IsOkAndHolds(AttrKey(1, 2)));
EXPECT_THAT((AttrKey<0>::FromRange({1})),
StatusIs(absl::StatusCode::kInvalidArgument));
EXPECT_THAT((AttrKey<1>::FromRange({})),
StatusIs(absl::StatusCode::kInvalidArgument));
EXPECT_THAT((AttrKey<2>::FromRange({1})),
StatusIs(absl::StatusCode::kInvalidArgument));
}
TEST(AttrKeyTest, FromRangeSymmetric) {
using Key = AttrKey<3, ElementSymmetry<1, 2>>;
EXPECT_THAT((Key::FromRange({0, 1, 2})), IsOkAndHolds(Key(0, 1, 2)));
EXPECT_THAT((Key::FromRange({0, 2, 1})), IsOkAndHolds(Key(0, 1, 2)));
EXPECT_THAT((Key::FromRange({3, 1, 2})), IsOkAndHolds(Key(3, 1, 2)));
EXPECT_THAT((Key::FromRange({3, 2, 1})), IsOkAndHolds(Key(3, 1, 2)));
}
TEST(AttrKeyTest, IsAttrKey) {
EXPECT_TRUE(is_attr_key_v<AttrKey<0>>);
EXPECT_TRUE(is_attr_key_v<AttrKey<1>>);
EXPECT_FALSE(is_attr_key_v<int>);
}
constexpr int kBenchmarkSize = 30;
template <typename SetT>
void BM_HashSet0(benchmark::State& state) {
SetT set;
for (const auto s : state) {
auto it = set.find(AttrKey());
benchmark::DoNotOptimize(it);
}
}
BENCHMARK(BM_HashSet0<AttrKeyHashSet<AttrKey<0>>>);
BENCHMARK(BM_HashSet0<absl::flat_hash_set<AttrKey<0>>>);
template <typename T>
void BM_HashMap1(benchmark::State& state) {
absl::flat_hash_map<T, int> map;
for (int i = 0; i < kBenchmarkSize * kBenchmarkSize; ++i) {
if (i % 2 > 0) { // Half of the lookups are hits.
map[T(i)] = i;
}
}
for (const auto s : state) {
for (int i = 0; i < kBenchmarkSize * kBenchmarkSize; ++i) {
auto it = map.find(T(i));
benchmark::DoNotOptimize(it);
}
}
}
BENCHMARK(BM_HashMap1<AttrKey<1>>);
BENCHMARK(BM_HashMap1<int64_t>);
template <typename T>
void BM_HashMap2(benchmark::State& state) {
absl::flat_hash_map<T, int> map;
for (int i = 0; i < kBenchmarkSize; ++i) {
for (int j = 0; j < kBenchmarkSize; ++j) {
if ((i * kBenchmarkSize + j) % 2 > 0) { // Half of the lookups are hits.
map[T(i, j)] = i;
}
}
}
for (const auto s : state) {
for (int i = 0; i < kBenchmarkSize; ++i) {
for (int j = 0; j < kBenchmarkSize; ++j) {
auto it = map.find(T(i, j));
benchmark::DoNotOptimize(it);
}
}
}
}
BENCHMARK(BM_HashMap2<AttrKey<2>>);
BENCHMARK(BM_HashMap2<std::pair<int64_t, int64_t>>);
template <int n>
void BM_SortAttrKeys(benchmark::State& state) {
const std::vector<AttrKey<n>> keys =
MakeRandomAttrKeys<n, NoSymmetry>(state.range(0), state.range(0));
for (const auto s : state) {
auto copy = keys;
absl::c_sort(copy);
benchmark::DoNotOptimize(copy);
}
}
BENCHMARK(BM_SortAttrKeys<1>)->Arg(100)->Arg(10000);
BENCHMARK(BM_SortAttrKeys<2>)->Arg(100)->Arg(10000);
} // namespace
} // namespace operations_research::math_opt

View File

@@ -0,0 +1,429 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_STORAGE_H_
#define OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_STORAGE_H_
#include <array>
#include <cstddef>
#include <cstdint>
#include <optional>
#include <type_traits>
#include <variant>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "ortools/base/map_util.h"
#include "ortools/math_opt/elemental/attr_key.h"
#include "ortools/math_opt/elemental/symmetry.h"
namespace operations_research::math_opt {
namespace detail {
// A non-default key set based on a vector. This is very efficient for
// insertions, reads, and slicing, but does not support deletions.
template <int n>
class DenseKeySet {
public:
// {Dense,Sparse}KeySet stores symmetric keys, symmetry is handled by
// `SlicingStorage`.
using Key = AttrKey<n, NoSymmetry>;
DenseKeySet() = default;
size_t size() const { return key_set_.size(); }
template <typename F>
// requires std::invocable<F, const Key&>
void ForEach(F f) const {
for (const Key& key : key_set_) {
f(key);
}
}
// Note: this does not check for duplicates. This is fine because inserting
// into this map is gated on inserting into the AttrStorage, which does check
// for duplicates.
void Insert(const Key& key) { key_set_.push_back(key); }
auto begin() const { return key_set_.begin(); }
auto end() const { return key_set_.end(); }
private:
std::vector<Key> key_set_;
};
// A non-default key set based on a hash set. Simple, but requires a hash lookup
// for each insertion and deletion.
template <int n>
class SparseKeySet {
public:
// {Dense,Sparse}KeySet stores symmetric keys, symmetry is handled by
// `SlicingStorage`.
using Key = AttrKey<n, NoSymmetry>;
explicit SparseKeySet(const DenseKeySet<n>& dense_set)
: key_set_(dense_set.begin(), dense_set.end()) {}
size_t size() const { return key_set_.size(); }
template <typename F>
// requires std::invocable<F, const Key&>
void ForEach(F f) const {
for (const Key& key : key_set_) {
f(key);
}
}
void Erase(const Key& key) { key_set_.erase(key); }
void Insert(const Key& key) { key_set_.insert(key); }
private:
absl::flat_hash_set<Key> key_set_;
};
// A non-default key set that switches between implementations
// opportunistically: It starts dense, and switches to sparse if there are
// deletions.
template <int n>
class KeySet {
public:
using Key = AttrKey<n, NoSymmetry>;
size_t size() const {
return std::visit([](const auto& impl) { return impl.size(); }, impl_);
}
// We can't do begin/end because the iterator types are not the same.
template <typename F>
// requires std::invocable<F, const Key&>
void ForEach(F f) const {
return std::visit(
[f = std::move(f)](const auto& impl) {
return impl.ForEach(std::move(f));
},
impl_);
}
auto Erase(const Key& key) { return AsSparse().Erase(key); }
void Insert(const Key& key) {
std::visit([&](auto& impl) { impl.Insert(key); }, impl_);
}
private:
SparseKeySet<n>& AsSparse() {
if (auto* sparse = std::get_if<SparseKeySet<n>>(&impl_)) {
return *sparse;
}
// Switch to a sparse representation.
impl_ = SparseKeySet<n>(std::get<DenseKeySet<n>>(impl_));
return std::get<SparseKeySet<n>>(impl_);
}
std::variant<DenseKeySet<n>, SparseKeySet<n>> impl_;
};
// When we have two or more dimensions, we need to store the nondefaults for
// each dimension to support slicing.
template <int n, typename Symmetry, typename = void>
class SlicingSupport {
public:
using Key = AttrKey<n, Symmetry>;
void AddRowsAndColumns(const Key key) {
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
ForEachDimension([this, key]<int i>() {
if (MustInsertNondefault<i>(key, Symmetry{})) {
key_nondefaults_[i][key[i]].Insert(key.template RemoveElement<i>());
}
});
}
// Requires key is currently stored with a non-default value.
void ClearRowsAndColumns(Key key) {
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
ForEachDimension([this, key]<int i>() {
const auto& key_elem = key[i];
auto& nondefaults = key_nondefaults_[i];
if (nondefaults[key_elem].size() == 1) {
nondefaults.erase(key_elem);
} else {
nondefaults[key_elem].Erase(key.template RemoveElement<i>());
}
});
}
void Clear() {
for (auto& key_nondefaults : key_nondefaults_) {
key_nondefaults.clear();
}
}
template <int i>
std::vector<Key> Slice(const int64_t key_elem) const {
return SliceImpl<i>(
key_elem, Symmetry{},
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
[key_elem]<int... is>(KeySetExpansion<is>... expansions) {
std::vector<Key> slice((expansions.key_set.size() + ...));
Key* out = slice.data();
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
const auto append = [key_elem, &out]<int j>(
const KeySetExpansion<j>& expansion) {
expansion.key_set.ForEach(
[key_elem, &out](const AttrKey<n - 1> other_elems) {
*out = other_elems.template AddElement<j, Symmetry>(key_elem);
++out;
});
};
(append(expansions), ...);
return slice;
});
}
template <int i>
int64_t GetSliceSize(const int64_t key_elem) const {
return SliceImpl<i>(key_elem, Symmetry{}, [](const auto... expansions) {
return (expansions.key_set.size() + ...);
});
}
private:
// We store the nondefaults for a given id along a given dimension as a set of
// `AttrKey<n-1, NoSymmetry>` (the current dimension is not stored).
using NonDefaultKeySet = KeySet<n - 1>;
// For each dimension, we store the nondefaults for each id.
using NonDefaultKeySetById = absl::flat_hash_map<int64_t, NonDefaultKeySet>;
// We need one such set per dimension.
using NonDefaultsPerDimension = std::array<NonDefaultKeySetById, n>;
// Represents a NonDefaultKeySet to be expanded by inserting an element on
// dimension `i`.
template <int i>
struct KeySetExpansion {
KeySetExpansion(const NonDefaultsPerDimension& key_nondefaults,
int64_t key_elem)
: key_set(gtl::FindWithDefault(key_nondefaults[i], key_elem)) {}
const NonDefaultKeySet& key_set;
};
// A helper function that calls `F` `i` times with template arguments `n-1` to
// `0`.
template <typename F, int i = n - 1>
static void ForEachDimension(const F& f) {
f.template operator()<i>();
if constexpr (i > 0) {
ForEachDimension<F, i - 1>(f);
}
}
template <int i>
static bool MustInsertNondefault(const Key&, NoSymmetry) {
return true;
}
template <int i, int k, int l>
static bool MustInsertNondefault(const Key& key, ElementSymmetry<k, l>) {
// For attributes that are symmetric on `k` and `l`, elements on the
// diagonal need to be in only one of the nondefaults for `k` or `l`
// (otherwise they would be counted twice in `Slice()`). We arbitrarily pick
// `k`.
if constexpr (i == l) {
const bool is_diagonal = key[k] == key[l];
return !is_diagonal;
}
return true;
}
// `Fn` should be a functor that takes any number of `KeySetExpansion`
// arguments.
template <int i, typename Fn>
auto SliceImpl(const int64_t key_elem, NoSymmetry, const Fn& fn) const {
static_assert(n > 1);
return fn(KeySetExpansion<i>(key_nondefaults_, key_elem));
}
template <int i, int k, int l, typename Fn>
auto SliceImpl(const int64_t key_elem, ElementSymmetry<k, l>,
const Fn& fn) const {
static_assert(n > 1);
if constexpr (i != k && i != l) {
// This is a normal dimension, not a symmetric one.
return SliceImpl<i>(key_elem, NoSymmetry(), fn);
} else {
// For symmetric dimensions, we need to look up the keys on both
// dimensions `l` and `k`.
return fn(KeySetExpansion<k>(key_nondefaults_, key_elem),
KeySetExpansion<l>(key_nondefaults_, key_elem));
}
}
NonDefaultsPerDimension key_nondefaults_;
};
// Without slicing we don't need to track anything.
template <int n, typename Symmetry>
struct SlicingSupport<n, Symmetry, std::enable_if_t<(n < 2), std::void_t<>>> {
using Key = AttrKey<n, Symmetry>;
void AddRowsAndColumns(Key) {}
void ClearRowsAndColumns(Key) {}
void Clear() {}
};
} // namespace detail
// Stores the value of an attribute keyed on n elements (e.g.
// linear_constraint_coefficient is a double valued attribute keyed first on
// LinearConstraint and then on Variable).
//
// Memory usage:
// Storing `k` elements with non-default values in a `AttrStorage<V, n>` uses
// `sizeof(V) * (n^2 + 1) * k / load_factor` (where load_factor is the absl
// hash map load factor, typically 0.8), plus a small allocation overhead of
// `O(k)`.
template <typename V, int n, typename Symmetry>
class AttrStorage {
public:
using Key = AttrKey<n, Symmetry>;
// If this no longer holds, we should sprinkle the code with `move`s and
// return `V`s by ref.
static_assert(std::is_trivially_copyable_v<V>);
// Generally avoid, provided to make working with std::array easier.
explicit AttrStorage() : AttrStorage({}) {}
// The default value of the attribute is its value when the model is created
// (e.g. for linear_constraint_coefficient, 0.0).
explicit AttrStorage(const V default_value) : default_value_(default_value) {}
AttrStorage(const AttrStorage&) = default;
AttrStorage& operator=(const AttrStorage&) = default;
AttrStorage(AttrStorage&&) = default;
AttrStorage& operator=(AttrStorage&&) = default;
// Returns true if the attribute for `key` has a value different from its
// default.
bool IsNonDefault(const Key key) const {
return non_default_values_.contains(key);
}
// Returns the previous value if value has changed, otherwise returns
// `std::nullopt`.
std::optional<V> Set(const Key key, const V value) {
bool is_default = value == default_value_;
if (is_default) {
const auto it = non_default_values_.find(key);
if (it == non_default_values_.end()) {
return std::nullopt;
}
const V prev_value = it->second;
non_default_values_.erase(it);
slicing_support_.ClearRowsAndColumns(key);
return prev_value;
}
const auto [it, inserted] = non_default_values_.try_emplace(key, value);
if (inserted) {
slicing_support_.AddRowsAndColumns(key);
return default_value_;
}
// !is_default and !inserted
if (value == it->second) {
return std::nullopt;
}
return std::exchange(it->second, value);
}
// Returns the value of the attribute for `key` (return the default value if
// the attribute value for `key` is unset).
V Get(const Key key) const {
return GetIfNonDefault(key).value_or(default_value_);
}
// Returns the value of the attribute for `key`, or nullopt.
std::optional<V> GetIfNonDefault(const Key key) const {
auto it = non_default_values_.find(key);
if (it == non_default_values_.end()) {
return std::nullopt;
}
return it->second;
}
// Sets the value of the attribute for `key` to the default value.
void Erase(const Key key) {
if (non_default_values_.erase(key)) {
slicing_support_.ClearRowsAndColumns(key);
}
}
// Returns the keys (ids pairs) the of the elements with a non-default value
// for this attribute.
std::vector<Key> NonDefaults() const {
std::vector<Key> result;
result.reserve(non_default_values_.size());
for (const auto& [key, unused] : non_default_values_) {
result.push_back(key);
}
return result;
}
// Returns the set of all keys `K` such that:
// - There exists `k_{0}..k_{n-1}` such that
// `K == AttrKey(k_{0}, ..., k_{i-1}, key_elem, k_{i+1}, ..., k_{n-1})`, and
// - `K` has a non-default value for this attribute.
template <int i>
std::vector<Key> Slice(const int64_t key_elem) const {
static_assert(n >= 1);
if constexpr (n == 1) {
return non_default_values_.contains(Key(key_elem))
? std::vector<Key>({Key(key_elem)})
: std::vector<Key>();
} else {
return slicing_support_.template Slice<i>(key_elem);
}
}
// Returns the size of the given slice: This is equivalent to
// `Slice(key_elem).size()`, but `O(1)`.
template <int i>
int64_t GetSliceSize(const int64_t key_elem) const {
static_assert(n >= 1);
if constexpr (n == 1) {
return non_default_values_.count(Key(key_elem));
} else {
return slicing_support_.template GetSliceSize<i>(key_elem);
}
}
// Returns the number of keys (element pairs) with non-default values for this
// attribute.
int64_t num_non_defaults() const { return non_default_values_.size(); }
// Restore all elements to their default value for this attribute.
void Clear() {
non_default_values_.clear();
slicing_support_.Clear();
}
private:
V default_value_;
AttrKeyHashMap<Key, V> non_default_values_;
detail::SlicingSupport<n, Symmetry> slicing_support_;
};
} // namespace operations_research::math_opt
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_STORAGE_H_

View File

@@ -0,0 +1,574 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/attr_storage.h"
#include <cstdint>
#include <optional>
#include <vector>
#include "benchmark/benchmark.h"
#include "gtest/gtest.h"
#include "ortools/base/gmock.h"
#include "ortools/math_opt/elemental/attr_key.h"
#include "ortools/math_opt/elemental/symmetry.h"
namespace operations_research::math_opt {
namespace {
using ::testing::IsEmpty;
using ::testing::Optional;
using ::testing::UnorderedElementsAre;
TEST(Attr0StorageTest, EmptyGetters) {
const AttrStorage<double, 0, NoSymmetry> attr_storage(1.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey()), 1.0);
EXPECT_FALSE(attr_storage.IsNonDefault(AttrKey()));
}
TEST(Attr0StorageTest, SetDefaultToDefault) {
AttrStorage<double, 0, NoSymmetry> attr_storage(1.0);
EXPECT_FALSE(attr_storage.Set(AttrKey(), 1.0));
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey()), 1.0);
EXPECT_FALSE(attr_storage.IsNonDefault(AttrKey()));
}
TEST(Attr0StorageTest, SetDefaultToNonDefault) {
AttrStorage<double, 0, NoSymmetry> attr_storage(1.0);
EXPECT_TRUE(attr_storage.Set(AttrKey(), 10.0));
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey()), 10.0);
EXPECT_TRUE(attr_storage.IsNonDefault(AttrKey()));
}
TEST(Attr0StorageTest, SetNonDefaultToDefault) {
AttrStorage<double, 0, NoSymmetry> attr_storage(1.0);
EXPECT_TRUE(attr_storage.Set(AttrKey(), 10.0));
EXPECT_TRUE(attr_storage.Set(AttrKey(), 1.0));
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey()), 1.0);
EXPECT_FALSE(attr_storage.IsNonDefault(AttrKey()));
}
TEST(Attr0StorageTest, SetNonDefaultToNonDefaultDifferent) {
AttrStorage<double, 0, NoSymmetry> attr_storage(1.0);
attr_storage.Set(AttrKey(), 10.0);
EXPECT_TRUE(attr_storage.Set(AttrKey(), 20.0));
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey()), 20.0);
EXPECT_TRUE(attr_storage.IsNonDefault(AttrKey()));
}
TEST(Attr0StorageTest, SetNonDefaultToNonDefaultSame) {
AttrStorage<double, 0, NoSymmetry> attr_storage(1.0);
attr_storage.Set(AttrKey(), 10.0);
EXPECT_FALSE(attr_storage.Set(AttrKey(), 10.0));
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey()), 10.0);
EXPECT_TRUE(attr_storage.IsNonDefault(AttrKey()));
}
////////////////////////////////////////////////////////////////////////////////
// Attr1Storage
////////////////////////////////////////////////////////////////////////////////
TEST(Attr1StorageTest, EmptyGetters) {
const AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(0)), 1.0);
EXPECT_FALSE(attr_storage.IsNonDefault(AttrKey(0)));
EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty());
EXPECT_EQ(attr_storage.num_non_defaults(), 0);
EXPECT_THAT(attr_storage.Slice<0>(0), IsEmpty());
}
TEST(Attr1StorageTest, GettersNonEmpty) {
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
attr_storage.Set(AttrKey(2), 10.0);
attr_storage.Set(AttrKey(3), 11.0);
attr_storage.Set(AttrKey(5), 12.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 10.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(3)), 11.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(4)), 1.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(5)), 12.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(6)), 1.0);
EXPECT_THAT(attr_storage.NonDefaults(),
UnorderedElementsAre(AttrKey(2), AttrKey(3), AttrKey(5)));
EXPECT_EQ(attr_storage.num_non_defaults(), 3);
EXPECT_THAT(attr_storage.Slice<0>(3), UnorderedElementsAre(AttrKey(3)));
}
TEST(Attr1StorageTest, SetDefaultToDefault) {
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
EXPECT_FALSE(attr_storage.Set(AttrKey(2), 1.0));
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 1.0);
EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty());
}
TEST(Attr1StorageTest, SetDefaultToNonDefault) {
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
EXPECT_TRUE(attr_storage.Set(AttrKey(2), 10.0));
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 10.0);
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(2)));
}
TEST(Attr1StorageTest, SetNonDefaultToDefault) {
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
attr_storage.Set(AttrKey(2), 10.0);
EXPECT_TRUE(attr_storage.Set(AttrKey(2), 1.0));
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 1.0);
EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty());
}
TEST(Attr1StorageTest, SetNonDefaultToNonDefaultDifferent) {
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
attr_storage.Set(AttrKey(2), 5.0);
EXPECT_TRUE(attr_storage.Set(AttrKey(2), 10.0));
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 10.0);
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(2)));
}
TEST(Attr1StorageTest, SetNonDefaultToNonDefaultSame) {
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
attr_storage.Set(AttrKey(2), 10.0);
EXPECT_FALSE(attr_storage.Set(AttrKey(2), 10.0));
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 10.0);
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(2)));
}
TEST(Attr1StorageTest, Clear) {
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
attr_storage.Set(AttrKey(2), 10.0);
attr_storage.Set(AttrKey(3), 11.0);
EXPECT_THAT(attr_storage.NonDefaults(),
UnorderedElementsAre(AttrKey(2), AttrKey(3)));
EXPECT_EQ(attr_storage.num_non_defaults(), 2);
attr_storage.Clear();
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 1.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(3)), 1.0);
EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty());
EXPECT_EQ(attr_storage.num_non_defaults(), 0);
}
TEST(Attr1StorageTest, Erase) {
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
attr_storage.Set(AttrKey(2), 10.0);
attr_storage.Set(AttrKey(3), 11.0);
attr_storage.Erase(AttrKey(2));
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 1.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(3)), 11.0);
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(3)));
EXPECT_EQ(attr_storage.num_non_defaults(), 1);
}
////////////////////////////////////////////////////////////////////////////////
// Attr2Storage
////////////////////////////////////////////////////////////////////////////////
TEST(Attr2StorageTest, EmptyGetters) {
const AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(0, 0)), 1.0);
EXPECT_FALSE(attr_storage.IsNonDefault(AttrKey(0, 0)));
EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty());
EXPECT_EQ(attr_storage.num_non_defaults(), 0);
EXPECT_THAT(attr_storage.Slice<1>(0), IsEmpty());
EXPECT_THAT(attr_storage.GetSliceSize<1>(0), 0);
EXPECT_THAT(attr_storage.Slice<0>(0), IsEmpty());
EXPECT_THAT(attr_storage.GetSliceSize<0>(0), 0);
}
TEST(Attr2StorageTest, GettersNonEmpty) {
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
attr_storage.Set(AttrKey(2, 3), 10.0);
attr_storage.Set(AttrKey(2, 5), 11.0);
attr_storage.Set(AttrKey(5, 5), 12.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 10.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 5)), 11.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(5, 5)), 12.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(5, 2)), 1.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 2)), 1.0);
EXPECT_THAT(
attr_storage.NonDefaults(),
UnorderedElementsAre(AttrKey(2, 3), AttrKey(2, 5), AttrKey(5, 5)));
EXPECT_EQ(attr_storage.num_non_defaults(), 3);
EXPECT_THAT(attr_storage.Slice<0>(2),
UnorderedElementsAre(AttrKey(2, 3), AttrKey(2, 5)));
EXPECT_THAT(attr_storage.GetSliceSize<0>(2), 2);
EXPECT_THAT(attr_storage.Slice<0>(3), IsEmpty());
EXPECT_THAT(attr_storage.GetSliceSize<0>(3), 0);
EXPECT_THAT(attr_storage.Slice<0>(5), UnorderedElementsAre(AttrKey(5, 5)));
EXPECT_THAT(attr_storage.GetSliceSize<0>(5), 1);
EXPECT_THAT(attr_storage.Slice<1>(2), IsEmpty());
EXPECT_THAT(attr_storage.GetSliceSize<1>(2), 0);
EXPECT_THAT(attr_storage.Slice<1>(3), UnorderedElementsAre(AttrKey(2, 3)));
EXPECT_THAT(attr_storage.GetSliceSize<1>(3), 1);
EXPECT_THAT(attr_storage.Slice<1>(5),
UnorderedElementsAre(AttrKey(2, 5), AttrKey(5, 5)));
EXPECT_THAT(attr_storage.GetSliceSize<1>(5), 2);
}
TEST(Attr2StorageTest, GettersNonEmptySymmetric) {
// Dim 0
// | 0 1 2 3 4 5
// --+------------------------
// 0 | 0
// D 1 | 0 0
// i 2 | 0 0 0
// m 3 | 0 10 0 0
// 1 4 | 0 0 0 0 0
// 5 | 0 11 0 0 0 12
//
using Storage = AttrStorage<double, 2, ElementSymmetry<0, 1>>;
using Key = Storage::Key;
Storage attr_storage(1.0);
attr_storage.Set(Key(2, 3), 10.0);
attr_storage.Set(Key(2, 5), 123.0);
attr_storage.Set(Key(5, 2), 11.0); // Overwrites 123.0.
attr_storage.Set(Key(5, 5), 12.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(Key(2, 3)), 10.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(Key(2, 5)), 11.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(Key(5, 5)), 12.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(Key(3, 2)), 10.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(Key(5, 2)), 11.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(Key(2, 2)), 1.0);
EXPECT_THAT(attr_storage.NonDefaults(),
UnorderedElementsAre(Key(2, 3), Key(2, 5), Key(5, 5)));
EXPECT_EQ(attr_storage.num_non_defaults(), 3);
EXPECT_THAT(attr_storage.Slice<0>(2),
UnorderedElementsAre(Key(2, 3), Key(2, 5)));
EXPECT_THAT(attr_storage.GetSliceSize<0>(2), 2);
EXPECT_THAT(attr_storage.Slice<0>(3), UnorderedElementsAre(Key(2, 3)));
EXPECT_THAT(attr_storage.GetSliceSize<0>(3), 1);
EXPECT_THAT(attr_storage.Slice<0>(4), IsEmpty());
EXPECT_THAT(attr_storage.GetSliceSize<0>(4), 0);
EXPECT_THAT(attr_storage.Slice<0>(5),
UnorderedElementsAre(Key(2, 5), Key(5, 5)));
EXPECT_THAT(attr_storage.GetSliceSize<0>(5), 2);
EXPECT_THAT(attr_storage.Slice<1>(2),
UnorderedElementsAre(Key(2, 3), Key(2, 5)));
EXPECT_THAT(attr_storage.GetSliceSize<1>(2), 2);
EXPECT_THAT(attr_storage.Slice<1>(3), UnorderedElementsAre(Key(2, 3)));
EXPECT_THAT(attr_storage.GetSliceSize<1>(3), 1);
EXPECT_THAT(attr_storage.Slice<1>(4), IsEmpty());
EXPECT_THAT(attr_storage.GetSliceSize<1>(4), 0);
EXPECT_THAT(attr_storage.Slice<1>(5),
UnorderedElementsAre(Key(2, 5), Key(5, 5)));
EXPECT_THAT(attr_storage.GetSliceSize<1>(5), 2);
}
TEST(Attr2StorageTest, SetDefaultToDefault) {
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
EXPECT_FALSE(attr_storage.Set(AttrKey(2, 3), 1.0).has_value());
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 1.0);
EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty());
}
TEST(Attr2StorageTest, SetDefaultToNonDefault) {
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
EXPECT_THAT(attr_storage.Set(AttrKey(2, 3), 10.0), Optional(1.0));
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 10.0);
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(2, 3)));
}
TEST(Attr2StorageTest, SetNonDefaultToDefault) {
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
attr_storage.Set(AttrKey(2, 3), 10.0);
EXPECT_THAT(attr_storage.Set(AttrKey(2, 3), 1.0), Optional(10.0));
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 1.0);
EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty());
}
TEST(Attr2StorageTest, SetNonDefaultToNonDefaultDifferent) {
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
attr_storage.Set(AttrKey(2, 3), 5.0);
EXPECT_THAT(attr_storage.Set(AttrKey(2, 3), 10.0), Optional(5.0));
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 10.0);
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(2, 3)));
}
TEST(Attr2StorageTest, SetNonDefaultToNonDefaultSame) {
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
attr_storage.Set(AttrKey(2, 3), 10.0);
EXPECT_FALSE(attr_storage.Set(AttrKey(2, 3), 10.0));
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 10.0);
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(2, 3)));
}
TEST(Attr2StorageTest, Clear) {
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
attr_storage.Set(AttrKey(2, 3), 10.0);
attr_storage.Set(AttrKey(3, 4), 11.0);
EXPECT_THAT(attr_storage.NonDefaults(),
UnorderedElementsAre(AttrKey(2, 3), AttrKey(3, 4)));
EXPECT_EQ(attr_storage.num_non_defaults(), 2);
attr_storage.Clear();
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 1.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(3, 4)), 1.0);
EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty());
EXPECT_EQ(attr_storage.num_non_defaults(), 0);
}
TEST(Attr2StorageTest, Erase) {
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
attr_storage.Set(AttrKey(2, 3), 10.0);
attr_storage.Set(AttrKey(3, 4), 11.0);
attr_storage.Erase(AttrKey(2, 3));
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 1.0);
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(3, 4)), 11.0);
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(3, 4)));
EXPECT_EQ(attr_storage.num_non_defaults(), 1);
}
TEST(Attr2StorageTest, EraseColumnLives) {
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
attr_storage.Set(AttrKey(2, 3), 10.0);
attr_storage.Set(AttrKey(5, 3), 11.0);
attr_storage.Erase(AttrKey(2, 3));
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(5, 3)));
EXPECT_THAT(attr_storage.Slice<0>(5), UnorderedElementsAre(AttrKey(5, 3)));
EXPECT_THAT(attr_storage.Slice<1>(3), UnorderedElementsAre(AttrKey(5, 3)));
// Insert again.
attr_storage.Set(AttrKey(2, 3), 12.0);
EXPECT_THAT(attr_storage.NonDefaults(),
UnorderedElementsAre(AttrKey(2, 3), AttrKey(5, 3)));
EXPECT_THAT(attr_storage.Slice<0>(5), UnorderedElementsAre(AttrKey(5, 3)));
EXPECT_THAT(attr_storage.Slice<1>(3),
UnorderedElementsAre(AttrKey(2, 3), AttrKey(5, 3)));
}
TEST(Attr2StorageTest, EraseRowLives) {
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
attr_storage.Set(AttrKey(3, 2), 10.0);
attr_storage.Set(AttrKey(3, 5), 11.0);
attr_storage.Erase(AttrKey(3, 2));
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(3, 5)));
EXPECT_THAT(attr_storage.Slice<0>(3), UnorderedElementsAre(AttrKey(3, 5)));
EXPECT_THAT(attr_storage.Slice<1>(5), UnorderedElementsAre(AttrKey(3, 5)));
}
// Makes a set of `n` 1-dimensional keys.
std::vector<AttrKey<1>> Make1DKeys(int n) {
std::vector<AttrKey<1>> keys;
for (int64_t i = 0; i < n; ++i) {
keys.emplace_back(i);
}
return keys;
}
// Makes a set of `n^2` 2-dimensional keys.
// NOTE: depending in `Symmetry` this might create duplicate keys. This is
// intentional, as we want to have the same number of keys to be able to compare
// the performance of different symmetries.
template <typename Symmetry>
std::vector<AttrKey<2, Symmetry>> Make2DKeys(int n) {
std::vector<AttrKey<2, Symmetry>> keys;
for (int64_t i = 0; i < n; ++i) {
for (int64_t j = 0; j < n; ++j) {
keys.emplace_back(i, j);
}
}
return keys;
}
// A functor that returns true every N calls, false otherwise.
template <int N>
struct TrueEvery {
int n = 0;
bool operator()() {
if (n == N) {
n = 0;
return true;
}
++n;
return false;
}
};
void BM_Attr0StorageSet(benchmark::State& state) {
AttrStorage<double, 0, NoSymmetry> attr_storage(1.0);
for (const auto s : state) {
attr_storage.Set(AttrKey(), 10.0);
benchmark::DoNotOptimize(attr_storage);
}
}
BENCHMARK(BM_Attr0StorageSet);
void BM_Attr1StorageSet(benchmark::State& state) {
const int n = state.range(0);
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
const auto keys = Make1DKeys(n);
for (const auto s : state) {
for (const auto& key : keys) {
attr_storage.Set(key, 10.0);
}
}
}
BENCHMARK(BM_Attr1StorageSet)->Arg(900);
template <typename Symmetry>
void BM_Attr2StorageSet(benchmark::State& state) {
const int n = state.range(0);
const auto keys = Make2DKeys<Symmetry>(n);
std::optional<AttrStorage<double, 2, Symmetry>> attr_storage(1.0);
for (const auto s : state) {
for (const auto& key : keys) {
attr_storage->Set(key, 10.0);
}
state.PauseTiming();
attr_storage.emplace(1.0);
state.ResumeTiming();
}
}
BENCHMARK(BM_Attr2StorageSet<NoSymmetry>)->Arg(30);
BENCHMARK(BM_Attr2StorageSet<ElementSymmetry<0, 1>>)->Arg(30);
void BM_Attr0StorageGet(benchmark::State& state) {
AttrStorage<double, 0, NoSymmetry> attr_storage(1.0);
for (const auto s : state) {
double v = attr_storage.Get(AttrKey());
benchmark::DoNotOptimize(v);
}
}
BENCHMARK(BM_Attr0StorageGet);
void BM_Attr1StorageGet(benchmark::State& state) {
const int n = state.range(0);
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
const auto keys = Make1DKeys(n);
// Insert half the keys.
TrueEvery<2> sample;
for (const auto& key : keys) {
if (sample()) {
attr_storage.Set(key, 10.0);
}
}
for (const auto s : state) {
for (const auto& key : keys) {
double v = attr_storage.Get(key);
benchmark::DoNotOptimize(v);
}
}
}
BENCHMARK(BM_Attr1StorageGet)->Arg(900);
template <typename Symmetry>
void BM_Attr2StorageGet(benchmark::State& state) {
const int n = state.range(0);
AttrStorage<double, 2, Symmetry> attr_storage(1.0);
const auto keys = Make2DKeys<Symmetry>(n);
// Insert half the keys.
TrueEvery<2> sample;
for (const auto& key : keys) {
if (sample()) {
attr_storage.Set(key, 10.0);
}
}
for (const auto s : state) {
for (const auto& key : keys) {
double v = attr_storage.Get(key);
benchmark::DoNotOptimize(v);
}
}
}
BENCHMARK(BM_Attr2StorageGet<NoSymmetry>)->Arg(30);
BENCHMARK(BM_Attr2StorageGet<ElementSymmetry<0, 1>>)->Arg(30);
template <typename Symmetry>
void BM_Attr2StorageSlice(benchmark::State& state) {
const int n = state.range(0);
AttrStorage<double, 2, Symmetry> attr_storage(1.0);
const auto keys = Make2DKeys<Symmetry>(n);
// Insert 5% of the keys.
TrueEvery<20> sample;
for (const auto& key : keys) {
if (sample()) {
attr_storage.Set(key, 10.0);
}
}
for (const auto s : state) {
for (int key_id = 0; key_id < n; ++key_id) {
auto slice0 = attr_storage.template Slice<0>(key_id);
auto slice1 = attr_storage.template Slice<1>(key_id);
benchmark::DoNotOptimize(slice0);
benchmark::DoNotOptimize(slice1);
}
}
}
BENCHMARK(BM_Attr2StorageSlice<NoSymmetry>)->Arg(30);
BENCHMARK(BM_Attr2StorageSlice<ElementSymmetry<0, 1>>)->Arg(30);
} // namespace
} // namespace operations_research::math_opt

View File

@@ -0,0 +1,349 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ATTRIBUTES_H_
#define OR_TOOLS_MATH_OPT_ELEMENTAL_ATTRIBUTES_H_
#include <array>
#include <cstdint>
#include <limits>
#include <ostream>
#include <tuple>
#include <type_traits>
#include "absl/strings/string_view.h"
#include "ortools/base/array.h"
#include "ortools/math_opt/elemental/arrays.h"
#include "ortools/math_opt/elemental/elements.h"
#include "ortools/math_opt/elemental/symmetry.h"
namespace operations_research::math_opt {
// A base class for all attribute type descriptors.
// `ValueTypeT` is the attribute value type, and `n` is the number of key
// elements (e.g. `Double2` attribute has `ValueType` == `double` and `n` == 2).
// This uses
// [CRTP](https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern) in
// `Impl` to deduce common descriptor properties from `Impl`. `Impl` must
// inherit from `AttrTypeDescriptor` and define the following entities:
// - `static constexpr absl::string_view kName`: The name of the attribute
// type.
// - `enum class AttrType`: The attribute type, with `k` enumerators
// corresponding to attributes for this type. Enumerators must be numbered
// `0..(k-1)` (a good way to do this is to leave them unnumbered).
// - `std::array<AttrDescriptor, k> kAttrDescriptors`: A descriptor for each
// of the `k` attributes for this type.
template <typename ValueTypeT, int n, typename SymmetryT, typename Impl>
struct AttrTypeDescriptor {
// The type of attribute values (e.g. `bool`, `int64_t`, `double`).
using ValueType = ValueTypeT;
// The number of key elements.
static constexpr int kNumKeyElements = n;
// The key symmetry. For example, this can be used to enforce that
// quadratic objective coefficients are the same for `(i, j)` and `(j, i)`
// (see `kObjQuadCoef` below).
using Symmetry = SymmetryT;
// A descriptor of an attribute of this attribute type.
// E.g., this could describe the attribute `DoubleAttr1::kVarLb`.
struct AttrDescriptor {
// The name of the attribute value.
absl::string_view name;
// The default value.
ValueType default_value;
// The types of the `n` key elements.
std::array<ElementType, n> key_types;
};
// Returns the number of attributes of this attribute type.
static constexpr int NumAttrs() { return Impl::kAttrDescriptors.size(); }
// Returns an array with all attributes of this attribute type.
static constexpr auto Enumerate() {
std::array<typename Impl::AttrType, NumAttrs()> result;
for (int i = 0; i < NumAttrs(); ++i) {
result[i] = {static_cast<typename Impl::AttrType>(i)};
}
return result;
}
};
struct BoolAttr0TypeDescriptor
: public AttrTypeDescriptor<bool, 0, NoSymmetry, BoolAttr0TypeDescriptor> {
static constexpr absl::string_view kName = "BoolAttr0";
enum class AttrType { kMaximize };
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>(
{{.name = "maximize", .default_value = false, .key_types = {}}});
};
struct BoolAttr1TypeDescriptor
: public AttrTypeDescriptor<bool, 1, NoSymmetry, BoolAttr1TypeDescriptor> {
static constexpr absl::string_view kName = "BoolAttr1";
enum class AttrType {
kVarInteger,
kAuxObjMaximize,
kIndConActivateOnZero,
};
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>(
{{.name = "variable_integer",
.default_value = false,
.key_types = {ElementType::kVariable}},
{.name = "auxiliary_objective_maximize",
.default_value = false,
.key_types = {ElementType::kAuxiliaryObjective}},
{.name = "indicator_constraint_activate_on_zero",
.default_value = false,
.key_types = {ElementType::kIndicatorConstraint}}});
};
struct IntAttr0TypeDescriptor
: public AttrTypeDescriptor<int64_t, 0, NoSymmetry,
IntAttr0TypeDescriptor> {
static constexpr absl::string_view kName = "IntAttr0";
enum class AttrType {
kObjPriority,
};
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>({
{.name = "objective_priority", .default_value = 0, .key_types = {}},
});
};
struct IntAttr1TypeDescriptor
: public AttrTypeDescriptor<int64_t, 1, NoSymmetry,
IntAttr1TypeDescriptor> {
static constexpr absl::string_view kName = "IntAttr1";
enum class AttrType {
kAuxObjPriority,
};
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>({
{.name = "auxiliary_objective_priority",
.default_value = 0,
.key_types = {ElementType::kAuxiliaryObjective}},
});
};
struct DoubleAttr0TypeDescriptor
: public AttrTypeDescriptor<double, 0, NoSymmetry,
DoubleAttr0TypeDescriptor> {
static constexpr absl::string_view kName = "DoubleAttr0";
enum class AttrType { kObjOffset };
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>(
{{.name = "objective_offset", .default_value = 0.0, .key_types = {}}});
};
struct DoubleAttr1TypeDescriptor
: public AttrTypeDescriptor<double, 1, NoSymmetry,
DoubleAttr1TypeDescriptor> {
static constexpr absl::string_view kName = "DoubleAttr1";
enum class AttrType {
kVarLb,
kVarUb,
kObjLinCoef,
kLinConLb,
kLinConUb,
kAuxObjOffset,
kQuadConLb,
kQuadConUb,
kIndConLb,
kIndConUb,
};
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>({
{.name = "variable_lower_bound",
.default_value = -std::numeric_limits<double>::infinity(),
.key_types = {ElementType::kVariable}},
{.name = "variable_upper_bound",
.default_value = std::numeric_limits<double>::infinity(),
.key_types = {ElementType::kVariable}},
{.name = "objective_linear_coefficient",
.default_value = 0.0,
.key_types = {ElementType::kVariable}},
{.name = "linear_constraint_lower_bound",
.default_value = -std::numeric_limits<double>::infinity(),
.key_types = {ElementType::kLinearConstraint}},
{.name = "linear_constraint_upper_bound",
.default_value = std::numeric_limits<double>::infinity(),
.key_types = {ElementType::kLinearConstraint}},
{.name = "auxiliary_objective_offset",
.default_value = 0.0,
.key_types = {ElementType::kAuxiliaryObjective}},
{.name = "quadratic_constraint_lower_bound",
.default_value = -std::numeric_limits<double>::infinity(),
.key_types = {ElementType::kQuadraticConstraint}},
{.name = "quadratic_constraint_upper_bound",
.default_value = std::numeric_limits<double>::infinity(),
.key_types = {ElementType::kQuadraticConstraint}},
{.name = "indicator_constraint_lower_bound",
.default_value = -std::numeric_limits<double>::infinity(),
.key_types = {ElementType::kIndicatorConstraint}},
{.name = "indicator_constraint_upper_bound",
.default_value = std::numeric_limits<double>::infinity(),
.key_types = {ElementType::kIndicatorConstraint}},
});
};
struct DoubleAttr2TypeDescriptor
: public AttrTypeDescriptor<double, 2, NoSymmetry,
DoubleAttr2TypeDescriptor> {
static constexpr absl::string_view kName = "DoubleAttr2";
enum class AttrType {
kLinConCoef,
kAuxObjLinCoef,
kQuadConLinCoef,
kIndConLinCoef
};
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>({
{.name = "linear_constraint_coefficient",
.default_value = 0.0,
.key_types = {ElementType::kLinearConstraint, ElementType::kVariable}},
{.name = "auxiliary_objective_linear_coefficient",
.default_value = 0.0,
.key_types = {ElementType::kAuxiliaryObjective, ElementType::kVariable}},
{.name = "quadratic_constraint_linear_coefficient",
.default_value = 0.0,
.key_types = {ElementType::kQuadraticConstraint,
ElementType::kVariable}},
{.name = "indicator_constraint_linear_coefficient",
.default_value = 0.0,
.key_types = {ElementType::kIndicatorConstraint,
ElementType::kVariable}},
});
};
struct SymmetricDoubleAttr2TypeDescriptor
: public AttrTypeDescriptor<double, 2, ElementSymmetry<0, 1>,
SymmetricDoubleAttr2TypeDescriptor> {
static constexpr absl::string_view kName = "SymmetricDoubleAttr2";
enum class AttrType {
kObjQuadCoef,
};
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>({
{.name = "objective_quadratic_coefficient",
.default_value = 0.0,
.key_types = {ElementType::kVariable, ElementType::kVariable}},
});
};
// Note: For this type, we pick the symmetric elements to be the last 2 elements
// of the key (index 1 and 2).
struct SymmetricDoubleAttr3TypeDescriptor
: public AttrTypeDescriptor<double, 3, ElementSymmetry<1, 2>,
SymmetricDoubleAttr3TypeDescriptor> {
static constexpr absl::string_view kName = "SymmetricDoubleAttr3";
enum class AttrType {
kQuadConQuadCoef,
};
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>({
{.name = "quadratic_constraint_quadratic_coefficient",
.default_value = 0.0,
.key_types = {ElementType::kQuadraticConstraint, ElementType::kVariable,
ElementType::kVariable}},
});
};
struct VariableAttr1TypeDescriptor
: public AttrTypeDescriptor<VariableId, 1, NoSymmetry,
VariableAttr1TypeDescriptor> {
static constexpr absl::string_view kName = "VariableAttr1";
enum class AttrType {
kIndConIndicator,
};
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>({
{.name = "indicator_constraint_indicator",
.default_value = VariableId(),
.key_types = {ElementType::kIndicatorConstraint}},
});
};
// The list of all available attribute descriptors. This is typically
// manipulated using the `AllAttrs` helper in `derived_data.h`.
using AllAttrTypeDescriptors =
std::tuple<BoolAttr0TypeDescriptor, BoolAttr1TypeDescriptor,
IntAttr0TypeDescriptor, IntAttr1TypeDescriptor,
DoubleAttr0TypeDescriptor, DoubleAttr1TypeDescriptor,
DoubleAttr2TypeDescriptor, SymmetricDoubleAttr2TypeDescriptor,
SymmetricDoubleAttr3TypeDescriptor, VariableAttr1TypeDescriptor>;
// Aliases for types.
using BoolAttr0 = BoolAttr0TypeDescriptor::AttrType;
using BoolAttr1 = BoolAttr1TypeDescriptor::AttrType;
using IntAttr0 = IntAttr0TypeDescriptor::AttrType;
using IntAttr1 = IntAttr1TypeDescriptor::AttrType;
using DoubleAttr0 = DoubleAttr0TypeDescriptor::AttrType;
using DoubleAttr1 = DoubleAttr1TypeDescriptor::AttrType;
using DoubleAttr2 = DoubleAttr2TypeDescriptor::AttrType;
using SymmetricDoubleAttr2 = SymmetricDoubleAttr2TypeDescriptor::AttrType;
using SymmetricDoubleAttr3 = SymmetricDoubleAttr3TypeDescriptor::AttrType;
using VariableAttr1 = VariableAttr1TypeDescriptor::AttrType;
// Returns the index of `AttrT` in `AllAttrTypes` if `AttrT` is an attribute
// type, -1 otherwise.
template <typename AttrT>
static constexpr int GetIndexIfAttr() {
using Tuple = AllAttrTypeDescriptors;
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
return ApplyOnIndexRange<std::tuple_size_v<Tuple>>([]<int... i>() {
return ((std::is_same_v<std::remove_cv_t<std::remove_reference_t<AttrT>>,
typename std::tuple_element_t<i, Tuple>::AttrType>
? (i + 1)
: 0) +
... + -1);
});
}
template <typename AttrT,
typename = std::enable_if_t<(GetIndexIfAttr<AttrT>() >= 0)>>
absl::string_view ToString(const AttrT attr) {
using Descriptor =
std::tuple_element_t<GetIndexIfAttr<AttrT>(), AllAttrTypeDescriptors>;
const int attr_index = static_cast<int>(attr);
return Descriptor::kAttrDescriptors[attr_index].name;
}
template <typename Sink, typename AttrT,
typename = std::enable_if_t<(GetIndexIfAttr<AttrT>() >= 0)>>
void AbslStringify(Sink& sink, const AttrT attr_type) {
sink.Append(ToString(attr_type));
}
template <typename AttrT>
std::enable_if_t<(GetIndexIfAttr<AttrT>() >= 0), std::ostream&> operator<<(
std::ostream& ostr, AttrT attr) {
ostr << ToString(attr);
return ostr;
}
} // namespace operations_research::math_opt
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ATTRIBUTES_H_

View File

@@ -0,0 +1,58 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/attributes.h"
#include "absl/strings/str_cat.h"
#include "gtest/gtest.h"
#include "ortools/math_opt/elemental/arrays.h"
#include "ortools/math_opt/testing/stream.h"
namespace operations_research::math_opt {
namespace {
TEST(ToStringTests, EachTypeCanConvert) {
EXPECT_EQ(ToString(BoolAttr0::kMaximize), "maximize");
EXPECT_EQ(ToString(BoolAttr1::kVarInteger), "variable_integer");
EXPECT_EQ(ToString(IntAttr0::kObjPriority), "objective_priority");
EXPECT_EQ(ToString(IntAttr1::kAuxObjPriority),
"auxiliary_objective_priority");
EXPECT_EQ(ToString(DoubleAttr0::kObjOffset), "objective_offset");
EXPECT_EQ(ToString(DoubleAttr1::kVarLb), "variable_lower_bound");
EXPECT_EQ(ToString(DoubleAttr2::kLinConCoef),
"linear_constraint_coefficient");
EXPECT_EQ(ToString(SymmetricDoubleAttr2::kObjQuadCoef),
"objective_quadratic_coefficient");
EXPECT_EQ(ToString(SymmetricDoubleAttr3::kQuadConQuadCoef),
"quadratic_constraint_quadratic_coefficient");
// Now check that absl::Stringify wraps ToString()
EXPECT_EQ(absl::StrCat(BoolAttr0::kMaximize), "maximize");
// Now check that << wraps ToString()
EXPECT_EQ(StreamToString(BoolAttr0::kMaximize), "maximize");
}
// Validate that for all symmetric attribute types, the symmetry is consistent
// with element types.
TEST(SymmetryTest, AllSymmetricTypesAreCorrect) {
ForEach(
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
[]<typename Descriptor>(const Descriptor&) {
for (const auto& attr : Descriptor::kAttrDescriptors) {
Descriptor::Symmetry::CheckElementTypes(attr.key_types);
}
},
AllAttrTypeDescriptors{});
}
} // namespace
} // namespace operations_research::math_opt

View File

@@ -0,0 +1,81 @@
# Copyright 2010-2025 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@rules_cc//cc:cc_binary.bzl", "cc_binary")
load("@rules_cc//cc:cc_library.bzl", "cc_library")
cc_library(
name = "gen",
srcs = ["gen.cc"],
hdrs = ["gen.h"],
deps = [
"//ortools/math_opt/elemental:arrays",
"//ortools/math_opt/elemental:attributes",
"//ortools/math_opt/elemental:elements",
"@abseil-cpp//absl/strings",
"@abseil-cpp//absl/strings:string_view",
"@abseil-cpp//absl/types:span",
],
)
cc_library(
name = "gen_c",
srcs = ["gen_c.cc"],
hdrs = ["gen_c.h"],
deps = [
":gen",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/strings",
"@abseil-cpp//absl/strings:str_format",
"@abseil-cpp//absl/types:span",
],
)
cc_library(
name = "testing",
testonly = 1,
hdrs = ["testing.h"],
deps = [":gen"],
)
cc_library(
name = "gen_python",
srcs = ["gen_python.cc"],
hdrs = ["gen_python.h"],
deps = [
":gen",
"@abseil-cpp//absl/strings",
"@abseil-cpp//absl/strings:str_format",
"@abseil-cpp//absl/strings:string_view",
"@abseil-cpp//absl/types:span",
],
)
cc_binary(
name = "codegen",
srcs = ["codegen.cc"],
visibility = [
"//ortools/math_opt/elemental/c_api:__subpackages__",
"//ortools/math_opt/elemental/python:__subpackages__",
],
deps = [
":gen",
":gen_c",
":gen_python",
"//ortools/base",
"@abseil-cpp//absl/flags:flag",
"@abseil-cpp//absl/log",
"@abseil-cpp//absl/log:check",
"@abseil-cpp//absl/strings:string_view",
],
)

View File

@@ -0,0 +1,52 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <memory>
#include <string>
#include "absl/flags/flag.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "ortools/base/init_google.h"
#include "ortools/math_opt/elemental/codegen/gen.h"
#include "ortools/math_opt/elemental/codegen/gen_c.h"
#include "ortools/math_opt/elemental/codegen/gen_python.h"
ABSL_FLAG(std::string, binding_type, "", "The binding type to generate.");
namespace operations_research::math_opt::codegen {
namespace {
void Main() {
const std::string binding_type = absl::GetFlag(FLAGS_binding_type);
if (binding_type == "c99_h") {
std::cout << C99Declarations()->GenerateCode();
} else if (binding_type == "c99_cc") {
std::cout << C99Definitions()->GenerateCode();
} else if (binding_type == "python_enums") {
std::cout << PythonEnums()->GenerateCode();
} else {
LOG(FATAL) << "unknown binding type: '" << binding_type << "'";
}
}
} // namespace
} // namespace operations_research::math_opt::codegen
int main(int argc, char** argv) {
InitGoogle(argv[0], &argc, &argv, /*remove_flags=*/true);
operations_research::math_opt::codegen::Main();
return 0;
}

View File

@@ -0,0 +1,158 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/codegen/gen.h"
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "ortools/math_opt/elemental/arrays.h"
#include "ortools/math_opt/elemental/attributes.h"
#include "ortools/math_opt/elemental/elements.h"
namespace operations_research::math_opt::codegen {
namespace {
class NamedType : public Type {
public:
explicit NamedType(std::string name) : name_(std::move(name)) {}
void Print(absl::string_view, std::string* out) const final {
absl::StrAppend(out, name_);
}
private:
std::string name_;
};
class PointerType : public Type {
public:
explicit PointerType(std::shared_ptr<Type> pointee)
: pointee_(std::move(pointee)) {}
void Print(absl::string_view attr_value_type, std::string* out) const final {
pointee_->Print(attr_value_type, out);
absl::StrAppend(out, "*");
}
private:
std::shared_ptr<Type> pointee_;
};
class AttrValueTypeType : public Type {
public:
void Print(absl::string_view attr_value_type, std::string* out) const final {
absl::StrAppend(out, attr_value_type);
}
};
} // namespace
std::shared_ptr<Type> Type::Named(std::string name) {
return std::make_shared<NamedType>(std::move(name));
}
std::shared_ptr<Type> Type::Pointer(std::shared_ptr<Type> pointee) {
return std::make_shared<PointerType>(std::move(pointee));
}
std::shared_ptr<Type> Type::AttrValueType() {
return std::make_shared<AttrValueTypeType>();
}
Type::~Type() = default;
CodegenAttrTypeDescriptor::ValueType GetValueType(bool) {
return CodegenAttrTypeDescriptor::ValueType::kBool;
}
CodegenAttrTypeDescriptor::ValueType GetValueType(int64_t) {
return CodegenAttrTypeDescriptor::ValueType::kInt64;
}
CodegenAttrTypeDescriptor::ValueType GetValueType(double) {
return CodegenAttrTypeDescriptor::ValueType::kDouble;
}
template <ElementType element_type>
CodegenAttrTypeDescriptor::ValueType GetValueType(ElementId<element_type>) {
// Element ids are untyped in wrapped APIs.
return CodegenAttrTypeDescriptor::ValueType::kInt64;
}
template <typename Descriptor>
CodegenAttrTypeDescriptor MakeAttrTypeDescriptor() {
CodegenAttrTypeDescriptor descriptor;
descriptor.value_type = GetValueType(typename Descriptor::ValueType{});
descriptor.name = Descriptor::kName;
descriptor.num_key_elements = Descriptor::kNumKeyElements;
descriptor.symmetry = Descriptor::Symmetry::GetName();
descriptor.attribute_names.reserve(Descriptor::NumAttrs());
for (const auto& attr_descriptor : Descriptor::kAttrDescriptors) {
descriptor.attribute_names.push_back(attr_descriptor.name);
}
return descriptor;
}
constexpr absl::string_view kOpNames[static_cast<int>(AttrOp::kNumOps)] = {
"Get", "Set", "IsNonDefault", "NumNonDefaults", "GetNonDefaults"};
void CodeGenerator::EmitAttrType(const CodegenAttrTypeDescriptor& descriptor,
std::string* out) const {
StartAttrType(descriptor, out);
for (int op = 0; op < kNumAttrOps; ++op) {
const AttrOpFunctionInfo& op_info = attr_op_function_infos_[op];
EmitAttrOp(kOpNames[op], descriptor, op_info, out);
}
}
void CodeGenerator::EmitAttributes(
absl::Span<const CodegenAttrTypeDescriptor> descriptors,
std::string* out) const {
for (const auto& descriptor : descriptors) {
StartAttrType(descriptor, out);
for (int i = 0; i < kNumAttrOps; ++i) {
EmitAttrOp(kOpNames[i], descriptor, attr_op_function_infos_[i], out);
}
}
}
std::string CodeGenerator::GenerateCode() const {
std::string out;
EmitHeader(&out);
// Generate elements.
EmitElements(kElementNames, &out);
// Generate attributes.
std::vector<CodegenAttrTypeDescriptor> attr_type_descriptors;
ForEach(
[&attr_type_descriptors](auto type_descriptor) {
attr_type_descriptors.push_back(
MakeAttrTypeDescriptor<decltype(type_descriptor)>());
},
AllAttrTypeDescriptors{});
EmitAttributes(attr_type_descriptors, &out);
return out;
}
} // namespace operations_research::math_opt::codegen

View File

@@ -0,0 +1,143 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Language-agnostic utilities for `Elemental` codegen.
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_H_
#define OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_H_
#include <array>
#include <memory>
#include <string>
#include <vector>
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
namespace operations_research::math_opt::codegen {
// The list of attribute operations supported by `Elemental`.
enum class AttrOp {
kGet,
kSet,
kIsNonDefault,
kNumNonDefaults,
kGetNonDefaults,
// Do not use.
kNumOps,
};
static constexpr int kNumAttrOps = static_cast<int>(AttrOp::kNumOps);
// A struct to represent an attribute type descriptor during codegen.
struct CodegenAttrTypeDescriptor {
// The attribute type name.
absl::string_view name;
// The value type of the attribute.
enum class ValueType {
kBool,
kInt64,
kDouble,
};
ValueType value_type;
// The number of key elements.
int num_key_elements;
// The key symmetry.
std::string symmetry;
// The names of the attributes of this type.
std::vector<absl::string_view> attribute_names;
};
// Representations for types.
class Type {
public:
// A named type, e.g. "double".
static std::shared_ptr<Type> Named(std::string name);
// A pointer type.
static std::shared_ptr<Type> Pointer(std::shared_ptr<Type> pointee);
// A placeholder for the attribute value type, which is yet unknown when types
// are defined. This gets replaced by `attr_value_type` when calling `Print`.
static std::shared_ptr<Type> AttrValueType();
virtual ~Type();
// Prints the type to `out`, replacing `AttrValueType` placeholders with
// `attr_value_type`.
virtual void Print(absl::string_view attr_value_type,
std::string* out) const = 0;
};
// Information about how to codegen a given `AttrOp` in a given language.
struct AttrOpFunctionInfo {
// The return type of the function.
std::shared_ptr<Type> return_type;
// If true, the function has an `AttrKey` parameter.
bool has_key_parameter;
// Extra parameters (e.g. {"double", "value"} for `Set` operations).
struct ExtraParameter {
std::shared_ptr<Type> type;
std::string name;
};
std::vector<ExtraParameter> extra_parameters;
};
using AttrOpFunctionInfos = std::array<AttrOpFunctionInfo, kNumAttrOps>;
// The code generator interface.
class CodeGenerator {
public:
explicit CodeGenerator(const AttrOpFunctionInfos* attr_op_function_infos)
: attr_op_function_infos_(*attr_op_function_infos) {}
virtual ~CodeGenerator() = default;
// Generates code.
std::string GenerateCode() const;
// Emits the header for the generated code.
virtual void EmitHeader(std::string* out) const {}
// Emits code for elements.
virtual void EmitElements(absl::Span<const absl::string_view> elements,
std::string* out) const {}
// Emits code for attributes. By default, this iterates attributes and for
// each attribute:
// - calls `StartAttrType`, and
// - calls `EmitAttrOp` for each operation.
virtual void EmitAttributes(
absl::Span<const CodegenAttrTypeDescriptor> descriptors,
std::string* out) const;
// Called before generating code for an attribute type.
virtual void StartAttrType(const CodegenAttrTypeDescriptor& descriptor,
std::string* out) const {}
// Emits code for operation `info` for attribute described by `descriptor`.
virtual void EmitAttrOp(absl::string_view op_name,
const CodegenAttrTypeDescriptor& descriptor,
const AttrOpFunctionInfo& info,
std::string* out) const {}
private:
void EmitAttrType(const CodegenAttrTypeDescriptor& descriptor,
std::string* out) const;
const AttrOpFunctionInfos& attr_op_function_infos_;
};
} // namespace operations_research::math_opt::codegen
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_H_

View File

@@ -0,0 +1,245 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/codegen/gen_c.h"
#include <memory>
#include <string>
#include "absl/log/check.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "ortools/math_opt/elemental/codegen/gen.h"
namespace operations_research::math_opt::codegen {
namespace {
// A helper to generate parameters to pass `n` key element indices, e.g:
// ", int64_t key_0, int64_t key_1" (parameters)
void AddKeyParams(int n, std::string* out) {
for (int i = 0; i < n; ++i) {
absl::StrAppend(out, ", int64_t key_", i);
}
}
// A helper to generate an AttrKey argument to pass `n` key element indices,
// e.g: "AttrKey<2, NoSymmetry>(key_0, key_1)".
void AddAttrKeyArg(int n, absl::string_view symmetry, std::string* out) {
absl::StrAppendFormat(out, ", AttrKey<%i, %s>(", n, symmetry);
for (int i = 0; i < n; ++i) {
if (i != 0) {
absl::StrAppend(out, ", ");
}
absl::StrAppend(out, "key_", i);
}
absl::StrAppend(out, ")");
}
// Returns the C99 name for the given type.
absl::string_view GetCTypeName(
CodegenAttrTypeDescriptor::ValueType value_type) {
switch (value_type) {
case CodegenAttrTypeDescriptor::ValueType::kBool:
return "_Bool";
case CodegenAttrTypeDescriptor::ValueType::kInt64:
return "int64_t";
case CodegenAttrTypeDescriptor::ValueType::kDouble:
return "double";
}
}
// Turns an element/attribute name (e.g. "some_name") into a camel case name
// (e.g. "SomeName").
std::string NameToCamelCase(absl::string_view attr_name) {
std::string result;
result.reserve(attr_name.size());
CHECK(!attr_name.empty());
const char first = attr_name[0];
CHECK(absl::ascii_isalpha(first) && absl::ascii_islower(first))
<< "invalid attr name: " << attr_name;
result.push_back(absl::ascii_toupper(first));
for (int i = 1; i < attr_name.size(); ++i) {
const char c = attr_name[i];
if (c == '_') {
++i;
CHECK(i < attr_name.size()) << "invalid attr name: " << attr_name;
const char next_c = attr_name[i];
CHECK(absl::ascii_isalnum(next_c)) << "invalid attr name: " << attr_name;
result.push_back(absl::ascii_toupper(next_c));
} else {
CHECK(absl::ascii_isalnum(c)) << "invalid attr name: " << attr_name;
CHECK(absl::ascii_islower(c)) << "invalid attr name: " << attr_name;
result.push_back(c);
}
}
return result;
}
// Returns the type of the C status.
std::shared_ptr<Type> GetStatusType() { return Type::Named("int"); }
const AttrOpFunctionInfos* GetC99FunctionInfos() {
static const auto* const kResult = new AttrOpFunctionInfos({
// Get.
AttrOpFunctionInfo{
.return_type = GetStatusType(),
.has_key_parameter = true,
.extra_parameters = {{.type = Type::Pointer(Type::AttrValueType()),
.name = "value"}}},
// Set.
AttrOpFunctionInfo{.return_type = GetStatusType(),
.has_key_parameter = true,
.extra_parameters = {{.type = Type::AttrValueType(),
.name = "value"}}},
// IsNonDefault.
AttrOpFunctionInfo{
.return_type = GetStatusType(),
.has_key_parameter = true,
.extra_parameters = {{.type = Type::Pointer(Type::Named("_Bool")),
.name = "out_is_non_default"}}},
// NumNonDefaults.
AttrOpFunctionInfo{
.return_type = GetStatusType(),
.has_key_parameter = false,
.extra_parameters = {{.type = Type::Pointer(Type::Named("int64_t")),
.name = "out_num_non_defaults"}}},
// GetNonDefaults.
AttrOpFunctionInfo{
.return_type = GetStatusType(),
.has_key_parameter = false,
.extra_parameters =
{
{.type = Type::Pointer(Type::Named("int64_t")),
.name = "out_num_non_defaults"},
{.type = Type::Pointer(Type::Pointer(Type::Named("int64_t"))),
.name = "out_non_defaults"},
}},
});
return kResult;
}
class C99CodeGeneratorBase : public CodeGenerator {
public:
using CodeGenerator::CodeGenerator;
void EmitHeader(std::string* out) const final {
absl::StrAppend(out, R"(
/* DO NOT EDIT: This file is autogenerated. */
#ifndef MATHOPTH_GENERATED
#error "this file is intended to be included, do not use directly"
#endif
)");
}
};
// Emits the prototype for a function.
void EmitPrototype(absl::string_view op_name,
const CodegenAttrTypeDescriptor& descriptor,
const AttrOpFunctionInfo& info, std::string* out) {
absl::string_view attr_value_type = GetCTypeName(descriptor.value_type);
// Adds the return type, function name and common parameters.
info.return_type->Print(attr_value_type, out);
absl::StrAppendFormat(out,
" MathOpt%s%s(struct "
"MathOptElemental* e, int attr",
descriptor.name, op_name);
// Add the key.
if (info.has_key_parameter) {
AddKeyParams(descriptor.num_key_elements, out);
}
// Add extra parameters.
for (const auto& extra_param : info.extra_parameters) {
absl::StrAppend(out, ", ");
extra_param.type->Print(attr_value_type, out);
absl::StrAppend(out, " ", extra_param.name);
}
// Finish prototype.
absl::StrAppend(out, ")");
}
class C99DeclarationsGenerator : public C99CodeGeneratorBase {
public:
C99DeclarationsGenerator() : C99CodeGeneratorBase(GetC99FunctionInfos()) {}
void EmitElements(absl::Span<const absl::string_view> elements,
std::string* out) const override {
// Generate an enum for the elements.
absl::StrAppend(out,
"// The type of an element in the model.\n"
"enum MathOptElementType {\n");
for (const auto& element_name : elements) {
absl::StrAppendFormat(out, " kMathOpt%s,\n",
NameToCamelCase(element_name));
}
absl::StrAppend(out, "};\n\n");
}
void EmitAttrOp(absl::string_view op_name,
const CodegenAttrTypeDescriptor& descriptor,
const AttrOpFunctionInfo& info,
std::string* out) const override {
// Just emit a prototype.
EmitPrototype(op_name, descriptor, info, out);
absl::StrAppend(out, ";\n");
}
void StartAttrType(const CodegenAttrTypeDescriptor& descriptor,
std::string* out) const override {
// Generate an enum for the attribute type.
absl::StrAppendFormat(out, "typedef enum {\n");
for (absl::string_view attr_name : descriptor.attribute_names) {
absl::StrAppendFormat(out, " kMathOpt%s%s,\n", descriptor.name,
NameToCamelCase(attr_name));
}
absl::StrAppendFormat(out, "} MathOpt%s;\n", descriptor.name);
}
};
class C99DefinitionsGenerator : public C99CodeGeneratorBase {
public:
C99DefinitionsGenerator() : C99CodeGeneratorBase(GetC99FunctionInfos()) {}
void EmitAttrOp(absl::string_view op_name,
const CodegenAttrTypeDescriptor& descriptor,
const AttrOpFunctionInfo& info,
std::string* out) const override {
EmitPrototype(op_name, descriptor, info, out);
// Emit a call to the wrapper (e.g. `CAttrOp<Descriptor>::Op`).
absl::StrAppendFormat(out, " {\n return CAttrOp<%s>::%s(e, attr",
descriptor.name, op_name);
// Add the key argument.
if (info.has_key_parameter) {
AddAttrKeyArg(descriptor.num_key_elements, descriptor.symmetry, out);
}
// Add extra parameter arguments.
for (const auto& extra_param : info.extra_parameters) {
absl::StrAppend(out, ", ", extra_param.name);
}
absl::StrAppend(out, ");\n}\n");
}
};
} // namespace
std::unique_ptr<CodeGenerator> C99Declarations() {
return std::make_unique<C99DeclarationsGenerator>();
}
std::unique_ptr<CodeGenerator> C99Definitions() {
return std::make_unique<C99DefinitionsGenerator>();
}
} // namespace operations_research::math_opt::codegen

View File

@@ -0,0 +1,32 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The C99 code generator.
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_C_H_
#define OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_C_H_
#include <memory>
#include "ortools/math_opt/elemental/codegen/gen.h"
namespace operations_research::math_opt::codegen {
// Returns a generator for C99 declarations.
std::unique_ptr<CodeGenerator> C99Declarations();
// Returns a generator for C99 definitions.
std::unique_ptr<CodeGenerator> C99Definitions();
} // namespace operations_research::math_opt::codegen
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_C_H_

View File

@@ -0,0 +1,100 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/codegen/gen_c.h"
#include <string>
#include "absl/strings/string_view.h"
#include "gtest/gtest.h"
#include "ortools/math_opt/elemental/codegen/gen.h"
#include "ortools/math_opt/elemental/codegen/testing.h"
namespace operations_research::math_opt::codegen {
namespace {
TEST(GenC99DeclarationsTest, EmitElements) {
std::string code;
C99Declarations()->EmitElements({"some_name", "other_name"}, &code);
EXPECT_EQ(code,
R"(// The type of an element in the model.
enum MathOptElementType {
kMathOptSomeName,
kMathOptOtherName,
};
)");
}
TEST(GenC99DeclarationsTest, StartAttrType) {
std::string code;
C99Declarations()->StartAttrType(GetTestDescriptor(), &code);
EXPECT_EQ(code,
R"(typedef enum {
kMathOptTestAttr2AName,
kMathOptTestAttr2BName,
} MathOptTestAttr2;
)");
}
TEST(GenC99DeclarationsTest, WithoutKey) {
std::string code;
C99Declarations()->EmitAttrOp("Op", GetTestDescriptor(),
GetTestFunctionInfo(false), &code);
EXPECT_EQ(
code,
R"(ReturnType MathOptTestAttr2Op(struct MathOptElemental* e, int attr, ExtraParam extra_param);
)");
}
TEST(GenC99DeclarationsTest, WithKey) {
std::string code;
C99Declarations()->EmitAttrOp("Op", GetTestDescriptor(),
GetTestFunctionInfo(true), &code);
EXPECT_EQ(
code,
R"(ReturnType MathOptTestAttr2Op(struct MathOptElemental* e, int attr, int64_t key_0, int64_t key_1, ExtraParam extra_param);
)");
}
TEST(GenC99DefinitionsTest, WithoutKey) {
std::string code;
C99Definitions()->EmitAttrOp("Op", GetTestDescriptor(),
GetTestFunctionInfo(false), &code);
EXPECT_EQ(
code,
R"(ReturnType MathOptTestAttr2Op(struct MathOptElemental* e, int attr, ExtraParam extra_param) {
return CAttrOp<TestAttr2>::Op(e, attr, extra_param);
}
)");
}
TEST(GenC99DefinitionsTest, WithKey) {
std::string code;
C99Definitions()->EmitAttrOp("Op", GetTestDescriptor(),
GetTestFunctionInfo(true), &code);
EXPECT_EQ(
code,
R"(ReturnType MathOptTestAttr2Op(struct MathOptElemental* e, int attr, int64_t key_0, int64_t key_1, ExtraParam extra_param) {
return CAttrOp<TestAttr2>::Op(e, attr, AttrKey<2, SomeSymmetry>(key_0, key_1), extra_param);
}
)");
}
TEST(GenC99DefinitionsTest, StartAttrType) {
std::string code;
C99Definitions()->StartAttrType(GetTestDescriptor(), &code);
EXPECT_EQ(code, "");
}
} // namespace
} // namespace operations_research::math_opt::codegen

View File

@@ -0,0 +1,161 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/codegen/gen_python.h"
#include <memory>
#include <set>
#include <string>
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "ortools/math_opt/elemental/codegen/gen.h"
namespace operations_research::math_opt::codegen {
namespace {
const AttrOpFunctionInfos* GetPythonFunctionInfos() {
// We're not generating functions for python, only enums.
static const auto* const kResult = new AttrOpFunctionInfos();
return kResult;
}
// Emits a set of numbered python enumerators for the given range.
void EmitEnumerators(const absl::Span<const absl::string_view> names,
std::string* out) {
for (int i = 0; i < names.size(); ++i) {
absl::StrAppendFormat(out, " %s = %i\n", absl::AsciiStrToUpper(names[i]),
i);
}
}
// Returns the python type for the given value type.
absl::string_view GetAttrPyValueType(
const CodegenAttrTypeDescriptor::ValueType& value_type) {
switch (value_type) {
case CodegenAttrTypeDescriptor::ValueType::kBool:
return "bool";
case CodegenAttrTypeDescriptor::ValueType::kInt64:
return "int";
case CodegenAttrTypeDescriptor::ValueType::kDouble:
return "float";
}
}
// Returns the python type for the given value type.
absl::string_view GetAttrNumpyValueType(
const CodegenAttrTypeDescriptor::ValueType& value_type) {
switch (value_type) {
case CodegenAttrTypeDescriptor::ValueType::kBool:
return "np.bool_";
case CodegenAttrTypeDescriptor::ValueType::kInt64:
return "np.int64";
case CodegenAttrTypeDescriptor::ValueType::kDouble:
return "np.float64";
}
}
class PythonEnumsGenerator : public CodeGenerator {
public:
PythonEnumsGenerator() : CodeGenerator(GetPythonFunctionInfos()) {}
void EmitHeader(std::string* out) const override {
absl::StrAppend(out, R"(
'''DO NOT EDIT: This file is autogenerated.'''
import enum
from typing import Generic, TypeVar, Union
import numpy as np
)");
}
void EmitElements(absl::Span<const absl::string_view> elements,
std::string* out) const override {
// Generate an enum for the elements.
absl::StrAppend(out, "class ElementType(enum.Enum):\n");
EmitEnumerators(elements, out);
absl::StrAppend(out, "\n");
}
void EmitAttributes(absl::Span<const CodegenAttrTypeDescriptor> descriptors,
std::string* out) const override {
absl::StrAppend(out, "\n");
{
// Collect the list of unique types:
std::set<absl::string_view> value_types;
for (const auto& descriptor : descriptors) {
value_types.insert(GetAttrNumpyValueType(descriptor.value_type));
}
// Emit `AttrValueType`, a type variable for all attribute value types.
absl::StrAppend(out, "AttrValueType = TypeVar('AttrValueType', ",
absl::StrJoin(value_types, ", "), ")\n");
}
absl::StrAppend(out, "\n");
{
std::set<absl::string_view> py_value_types;
for (const auto& descriptor : descriptors) {
py_value_types.insert(GetAttrPyValueType(descriptor.value_type));
}
absl::StrAppend(out, "AttrPyValueType = TypeVar('AttrPyValueType', ",
absl::StrJoin(py_value_types, ", "), ")\n");
}
// `Attr` is an attribute with any value type.
absl::StrAppend(out, R"(
class Attr(Generic[AttrValueType]):
pass
)");
// `PyAttr` is an attribute with any value type.
absl::StrAppend(out, R"(
class PyAttr(Generic[AttrPyValueType]):
pass
)");
// Generate an enum for the attribute type.
for (const auto& descriptor : descriptors) {
absl::StrAppendFormat(
out, "\nclass %s(Attr[%s], PyAttr[%s], int, enum.Enum):\n",
descriptor.name, GetAttrNumpyValueType(descriptor.value_type),
GetAttrPyValueType(descriptor.value_type));
EmitEnumerators(descriptor.attribute_names, out);
absl::StrAppend(out, "\n");
}
// Add a type alias for the union of all attribute types.
absl::StrAppend(
out, "AnyAttr = Union[",
absl::StrJoin(
descriptors, ", ",
[](std::string* out, const CodegenAttrTypeDescriptor& descriptor) {
absl::StrAppend(out, descriptor.name);
}),
"]\n");
}
};
} // namespace
std::unique_ptr<CodeGenerator> PythonEnums() {
return std::make_unique<PythonEnumsGenerator>();
}
} // namespace operations_research::math_opt::codegen

View File

@@ -0,0 +1,30 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The python code generator.
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_PYTHON_H_
#define OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_PYTHON_H_
#include <memory>
#include "ortools/math_opt/elemental/codegen/gen.h"
namespace operations_research::math_opt::codegen {
// Returns a generator for python enums, independent of the actual
// implementation. These are used by the protocol.
std::unique_ptr<CodeGenerator> PythonEnums();
} // namespace operations_research::math_opt::codegen
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_PYTHON_H_

View File

@@ -0,0 +1,60 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/codegen/gen_python.h"
#include <string>
#include "gtest/gtest.h"
#include "ortools/math_opt/elemental/codegen/testing.h"
namespace operations_research::math_opt::codegen {
namespace {
TEST(GenPythonEnumsTest, EmitElements) {
std::string code;
PythonEnums()->EmitElements({"some_name", "other_name"}, &code);
EXPECT_EQ(code,
R"(class ElementType(enum.Enum):
SOME_NAME = 0
OTHER_NAME = 1
)");
}
TEST(GenPythonEnumsTest, EmitAttributes) {
std::string code;
PythonEnums()->EmitAttributes({GetTestDescriptor()}, &code);
EXPECT_EQ(code,
R"(
AttrValueType = TypeVar('AttrValueType', np.float64)
AttrPyValueType = TypeVar('AttrPyValueType', float)
class Attr(Generic[AttrValueType]):
pass
class PyAttr(Generic[AttrPyValueType]):
pass
class TestAttr2(Attr[np.float64], PyAttr[float], int, enum.Enum):
A_NAME = 0
B_NAME = 1
AnyAttr = Union[TestAttr2]
)");
}
} // namespace
} // namespace operations_research::math_opt::codegen

View File

@@ -0,0 +1,96 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/codegen/gen.h"
#include <memory>
#include <string>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "gtest/gtest.h"
#include "ortools/base/gmock.h"
namespace operations_research::math_opt::codegen {
namespace {
using testing::HasSubstr;
using testing::StartsWith;
const AttrOpFunctionInfos* GetFunctionInfos() {
static const auto* const kResult = new AttrOpFunctionInfos({
{.return_type = Type::Named("TypeForGet"),
.has_key_parameter = false,
.extra_parameters = {}},
{.return_type = Type::Pointer(Type::AttrValueType()),
.has_key_parameter = true,
.extra_parameters = {}},
{.return_type = Type::Named("T"),
.has_key_parameter = false,
.extra_parameters = {}},
{.return_type = Type::Named("T"),
.has_key_parameter = false,
.extra_parameters = {}},
{.return_type = Type::Named("T"),
.has_key_parameter = false,
.extra_parameters = {}},
});
return kResult;
}
class TestCodeGenerator : public CodeGenerator {
public:
TestCodeGenerator() : CodeGenerator(GetFunctionInfos()) {}
void EmitHeader(std::string* out) const override {
absl::StrAppend(out, "# DO NOT EDIT: Test\n");
}
void EmitElements(absl::Span<const absl::string_view> elements,
std::string* out) const override {
absl::StrAppend(out, "Elements: ", absl::StrJoin(elements, ", "), "\n");
}
void StartAttrType(const CodegenAttrTypeDescriptor&,
std::string* out) const override {
absl::StrAppend(out, "\n");
}
void EmitAttrOp(absl::string_view op_name,
const CodegenAttrTypeDescriptor& descriptor,
const AttrOpFunctionInfo& info,
std::string* out) const override {
info.return_type->Print("fake_type", out);
absl::StrAppend(out, " ", descriptor.name, op_name, "\n");
}
};
TEST(GenerateCodeTest, Attrs) {
const std::string code = TestCodeGenerator().GenerateCode();
EXPECT_THAT(code, StartsWith("# DO NOT EDIT: Test\n"));
EXPECT_THAT(code, HasSubstr("Elements: variable, linear_constraint, "));
EXPECT_THAT(code, HasSubstr("TypeForGet BoolAttr0Get\n"
"fake_type* BoolAttr0Set\n"
"T BoolAttr0IsNonDefault\n"
"T BoolAttr0NumNonDefaults\n"
"T BoolAttr0GetNonDefaults\n"));
EXPECT_THAT(code, HasSubstr("TypeForGet DoubleAttr1Get\n"
"fake_type* DoubleAttr1Set\n"
"T DoubleAttr1IsNonDefault\n"
"T DoubleAttr1NumNonDefaults\n"
"T DoubleAttr1GetNonDefaults\n"));
}
} // namespace
} // namespace operations_research::math_opt::codegen

View File

@@ -0,0 +1,40 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Test descriptors. This avoids depending on attributes from `attributes.h`
// in the tests to decouple the codegen tests from `attributes.h`.
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_TESTING_H_
#define OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_TESTING_H_
#include "ortools/math_opt/elemental/codegen/gen.h"
namespace operations_research::math_opt::codegen {
inline CodegenAttrTypeDescriptor GetTestDescriptor() {
return {.name = "TestAttr2",
.value_type = CodegenAttrTypeDescriptor::ValueType::kDouble,
.num_key_elements = 2,
.symmetry = "SomeSymmetry",
.attribute_names = {"a_name", "b_name"}};
}
inline AttrOpFunctionInfo GetTestFunctionInfo(bool with_key_parameter) {
return {.return_type = Type::Named("ReturnType"),
.has_key_parameter = with_key_parameter,
.extra_parameters = {
{{.type = Type::Named("ExtraParam"), .name = "extra_param"}}}};
}
} // namespace operations_research::math_opt::codegen
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_TESTING_H_

View File

@@ -0,0 +1,218 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_DERIVED_DATA_H_
#define OR_TOOLS_MATH_OPT_ELEMENTAL_DERIVED_DATA_H_
#include <array>
#include <string>
#include <tuple>
#include <type_traits>
#include "absl/log/log.h"
#include "absl/strings/str_cat.h"
#include "ortools/math_opt/elemental/arrays.h"
#include "ortools/math_opt/elemental/attr_key.h"
#include "ortools/math_opt/elemental/attributes.h"
#include "ortools/math_opt/elemental/elements.h"
#include "ortools/util/fp_roundtrip_conv.h"
namespace operations_research::math_opt {
// A helper to manipulate the list of attributes.
struct AllAttrs {
// The number of available attribute types.
static constexpr int kNumAttrTypes =
std::tuple_size_v<AllAttrTypeDescriptors>;
// Returns the descriptor of the `i-th` attribute type in the list.
template <int i>
using TypeDescriptor = std::tuple_element_t<i, AllAttrTypeDescriptors>;
// Returns the `i-th` attribute type in the list.
template <int i>
using Type = typename TypeDescriptor<i>::AttrType;
// Returns the index of attribute type `AttrT`.
// Fails to compile if `AttrT` is not an attribute.
template <typename AttrT>
static constexpr int GetIndex() {
constexpr int index = GetIndexIfAttr<AttrT>();
// This weird construct is to show `AttrT` explicitly instead of letting
// the user fish it out of the stack trace when the static_assert fails.
static_assert(
std::is_const_v<std::conditional_t<(index >= 0), const AttrT, AttrT>>,
"no such attribute");
return index;
}
// Applies `fn` on each value for each attribute type. `fn` must have a
// overload set of `operator(AttrType)` that accepts a `AttrType` for
// each attribute type.
template <typename Fn>
static void ForEachAttr(Fn&& fn) {
ForEach(
[&fn](const auto& descriptor) {
for (auto attr : descriptor.Enumerate()) {
fn(attr);
}
},
AllAttrTypeDescriptors{});
}
};
// Returns the descriptor for attribute `AttrT`.
template <typename AttrT>
using AttrTypeDescriptorT =
AllAttrs::TypeDescriptor<AllAttrs::GetIndex<AttrT>()>;
// Returns the default value for the attribute type `attr`.
//
// For example GetAttrDefaultValue<DoubleAttr2::kLinConCoef>() returns 0.0.
template <auto attr>
constexpr typename AttrTypeDescriptorT<decltype(attr)>::ValueType
GetAttrDefaultValue() {
return AttrTypeDescriptorT<decltype(attr)>::kAttrDescriptors[static_cast<int>(
attr)]
.default_value;
}
// Returns the number of elements in a key for the attribute type `AttrType`.
//
// For example `GetAttrKeySize<DoubleAttr2>()` returns 2.
template <typename AttrType>
constexpr int GetAttrKeySize() {
return AttrTypeDescriptorT<AttrType>::kNumKeyElements;
}
template <auto attr>
constexpr int GetAttrKeySize() {
return GetAttrKeySize<decltype(attr)>();
}
// The type of the `AttrKey` for attribute type `AttrType`.
template <typename AttrType>
using AttrKeyFor = AttrKey<AttrTypeDescriptorT<AttrType>::kNumKeyElements,
typename AttrTypeDescriptorT<AttrType>::Symmetry>;
// The value type for attribute type `AttrType`.
template <typename AttrType>
using ValueTypeFor = typename AttrTypeDescriptorT<AttrType>::ValueType;
// Returns the array of elements for the key for the attribute type `attr`.
//
// For example, GetElementTypes<DoubleAttr2>() returns the array
// {ElementType::kLinearConstraint, ElementType::kVariable}.
template <typename AttrType>
constexpr std::array<ElementType, GetAttrKeySize<AttrType>()> GetElementTypes(
const AttrType attr) {
return AttrTypeDescriptorT<AttrType>::kAttrDescriptors[static_cast<int>(attr)]
.key_types;
}
template <auto attr>
constexpr std::array<ElementType, GetAttrKeySize<attr>()> GetElementTypes() {
return GetElementTypes(attr);
}
// After C++20, this can be replaced by a lambda. C++17 does not allow lambdas
// in unevaluated contexts.
template <template <int i> typename ValueType>
struct EnumeratedTupleCpp17Helper {
template <int... i>
auto operator()() const {
return std::make_tuple(ValueType<i>()...);
}
};
// A tuple of `ValueType<i>` for `i` in `0..n`.
template <int n, template <int i> typename ValueType>
using EnumeratedTuple =
decltype(ApplyOnIndexRange<n>(EnumeratedTupleCpp17Helper<ValueType>{}));
// A map of attribute to `ValueType<i>`, where `i` is the index of the attribute
// type.
// See `AttrMapTest` for example usage.
// NOTE: this is formally a map (it maps attributes to values), but internally
// uses dense storage.
template <template <int i> typename ValueType>
class AttrMap {
public:
template <typename AttrT>
ValueType<AllAttrs::GetIndex<AttrT>()>& operator[](AttrT a) {
// TODO(b/365997645): post C++ 23, prefer `std::to_underlying(a)`.
return std::get<AllAttrs::GetIndex<AttrT>()>(
array_tuple_)[static_cast<int>(a)];
}
template <typename AttrT>
const ValueType<AllAttrs::GetIndex<AttrT>()>& operator[](AttrT a) const {
// The `const_cast` is fine because non-const `operator[]` does not mutate
// anything and we're casting the return type back to const.
return (*const_cast<AttrMap*>(this))[a];
}
// Applies `fn` on each value for each attribute type. `fn` must have a
// overload set of `operator()` that accepts a `ValueType<i>` for `i` in
// `0..AllAttrs::kSize`.
// This cannot be an iterator because value types are not homogeneous.
template <typename Fn>
void ForEachAttrValue(Fn&& fn) {
ForEach(
[&fn](auto& array) {
for (auto& value : array) {
fn(value);
}
},
array_tuple_);
}
private:
template <int i>
using ArrayType =
std::array<ValueType<i>, AllAttrs::TypeDescriptor<i>::NumAttrs()>;
EnumeratedTuple<AllAttrs::kNumAttrTypes, ArrayType> array_tuple_;
};
// Calls `fn<attr>()`.
template <typename AttrType, typename Fn, int n = 0>
decltype(auto) CallForAttr(AttrType attr, Fn&& fn) {
using Descriptor = AttrTypeDescriptorT<AttrType>;
if constexpr (n < Descriptor::NumAttrs()) {
constexpr AttrType a = static_cast<AttrType>(n);
if (a == attr) {
return fn.template operator()<a>();
}
return CallForAttr<AttrType, Fn, n + 1>(attr, std::forward<Fn>(fn));
} else {
LOG(FATAL) << "impossible";
return decltype(fn.template operator()<AttrType{}>()) {};
}
}
template <typename ValueType>
std::string FormatAttrValue(const ValueType v) {
return absl::StrCat(v);
}
template <>
inline std::string FormatAttrValue(const double v) {
return RoundTripDoubleFormat::ToString(v);
}
template <>
inline std::string FormatAttrValue(const bool v) {
return v ? "true" : "false";
}
} // namespace operations_research::math_opt
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_DERIVED_DATA_H_

View File

@@ -0,0 +1,172 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/derived_data.h"
#include <cstdint>
#include <limits>
#include <string>
#include <type_traits>
#include <vector>
#include "gtest/gtest.h"
#include "ortools/base/gmock.h"
#include "ortools/math_opt/elemental/arrays.h"
#include "ortools/math_opt/elemental/attr_key.h"
#include "ortools/math_opt/elemental/attributes.h"
#include "ortools/math_opt/elemental/elements.h"
#include "ortools/math_opt/elemental/symmetry.h"
#include "ortools/math_opt/testing/stream.h"
namespace operations_research::math_opt {
namespace {
using ::testing::ElementsAreArray;
using ::testing::IsSupersetOf;
TEST(GetAttrDefaultValueTest, HasRightDefault) {
EXPECT_EQ(GetAttrDefaultValue<DoubleAttr0::kObjOffset>(), 0.0);
EXPECT_EQ(GetAttrDefaultValue<BoolAttr1::kVarInteger>(), false);
EXPECT_EQ(GetAttrDefaultValue<DoubleAttr1::kVarUb>(),
std::numeric_limits<double>::infinity());
EXPECT_EQ(GetAttrDefaultValue<DoubleAttr2::kLinConCoef>(), 0.0);
}
TEST(AttrKeyForTest, Works) {
EXPECT_TRUE((std::is_same_v<AttrKeyFor<BoolAttr0>, AttrKey<0>>));
EXPECT_TRUE((std::is_same_v<AttrKeyFor<DoubleAttr0>, AttrKey<0>>));
EXPECT_TRUE((std::is_same_v<AttrKeyFor<DoubleAttr1>, AttrKey<1>>));
EXPECT_TRUE((std::is_same_v<AttrKeyFor<DoubleAttr2>, AttrKey<2>>));
EXPECT_TRUE((std::is_same_v<AttrKeyFor<SymmetricDoubleAttr2>,
AttrKey<2, ElementSymmetry<0, 1>>>));
}
TEST(ValueTypeForTest, Works) {
EXPECT_TRUE((std::is_same_v<ValueTypeFor<BoolAttr0>, bool>));
EXPECT_TRUE((std::is_same_v<ValueTypeFor<DoubleAttr0>, double>));
EXPECT_TRUE((std::is_same_v<ValueTypeFor<DoubleAttr1>, double>));
EXPECT_TRUE((std::is_same_v<ValueTypeFor<DoubleAttr2>, double>));
}
TEST(GetAttrKeySizeTest, IsRightSize) {
EXPECT_EQ(GetAttrKeySize<DoubleAttr0>(), 0);
EXPECT_EQ(GetAttrKeySize<BoolAttr1>(), 1);
EXPECT_EQ(GetAttrKeySize<DoubleAttr2>(), 2);
EXPECT_EQ(GetAttrKeySize<DoubleAttr0::kObjOffset>(), 0);
EXPECT_EQ(GetAttrKeySize<BoolAttr1::kVarInteger>(), 1);
EXPECT_EQ(GetAttrKeySize<DoubleAttr2::kLinConCoef>(), 2);
}
TEST(GetElementTypesTest, Attr1HasElement) {
EXPECT_EQ(GetElementTypes<BoolAttr1::kVarInteger>()[0],
ElementType::kVariable);
}
TEST(GetElementTypesTest, Attr2HasElements) {
EXPECT_EQ(GetElementTypes<DoubleAttr2::kLinConCoef>()[0],
ElementType::kLinearConstraint);
EXPECT_EQ(GetElementTypes<DoubleAttr2::kLinConCoef>()[1],
ElementType::kVariable);
}
TEST(AllAttrsTest, Indexing) {
ForEachIndex<AllAttrs::kNumAttrTypes>(
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
[]<int i>() { EXPECT_EQ((AllAttrs::GetIndex<AllAttrs::Type<i>>()), i); });
}
TEST(AllAttrsTest, ForEachAttribute) {
std::vector<std::string> invoked;
AllAttrs::ForEachAttr(
[&invoked](auto attr) { invoked.push_back(StreamToString(attr)); });
EXPECT_THAT(invoked, IsSupersetOf({"objective_offset", "maximize",
"variable_integer", "variable_lower_bound",
"linear_constraint_coefficient"}));
}
template <int i>
struct Value {
Value() = default;
explicit Value(int v) : value(v) {}
int value = i;
};
TEST(AttrMapTest, GetSet) {
AttrMap<Value> attr_map;
constexpr int kBoolAttr0Index = AllAttrs::GetIndex<BoolAttr0>();
constexpr int kBoolAttr1Index = AllAttrs::GetIndex<BoolAttr1>();
constexpr int kDoubleAttr1Index = AllAttrs::GetIndex<DoubleAttr1>();
constexpr int kDoubleAttr2Index = AllAttrs::GetIndex<DoubleAttr2>();
// Default initialization.
EXPECT_EQ(attr_map[BoolAttr0::kMaximize].value, kBoolAttr0Index);
EXPECT_EQ(attr_map[BoolAttr1::kVarInteger].value, kBoolAttr1Index);
EXPECT_EQ(attr_map[DoubleAttr1::kVarLb].value, kDoubleAttr1Index);
EXPECT_EQ(attr_map[DoubleAttr1::kVarUb].value, kDoubleAttr1Index);
EXPECT_EQ(attr_map[DoubleAttr2::kLinConCoef].value, kDoubleAttr2Index);
// Mutation (typed).
attr_map[BoolAttr0::kMaximize] = Value<kBoolAttr0Index>(42);
attr_map[BoolAttr1::kVarInteger] = Value<kBoolAttr1Index>(43);
attr_map[DoubleAttr1::kVarLb] = Value<kDoubleAttr1Index>(44);
attr_map[DoubleAttr1::kVarUb] = Value<kDoubleAttr1Index>(45);
attr_map[DoubleAttr2::kLinConCoef] = Value<kDoubleAttr2Index>(46);
EXPECT_EQ(attr_map[BoolAttr0::kMaximize].value, 42);
EXPECT_EQ(attr_map[BoolAttr1::kVarInteger].value, 43);
EXPECT_EQ(attr_map[DoubleAttr1::kVarLb].value, 44);
EXPECT_EQ(attr_map[DoubleAttr1::kVarUb].value, 45);
EXPECT_EQ(attr_map[DoubleAttr2::kLinConCoef].value, 46);
}
TEST(AttrMapTest, Iteration) {
AttrMap<Value> attr_map;
// Collect all values in the default-initialized map.s
std::vector<int> values;
attr_map.ForEachAttrValue(
[&values](const auto& v) { values.emplace_back(v.value); });
// We should have `NumAttrs()` values `i` per attribute.
std::vector<int> expected_values;
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
ForEachIndex<AllAttrs::kNumAttrTypes>([&expected_values]<int i>() {
for (int k = 0; k < AllAttrs::TypeDescriptor<i>::NumAttrs(); ++k) {
expected_values.push_back(i);
}
});
EXPECT_THAT(values, ElementsAreArray(expected_values));
}
TEST(CallForAttrTest, Works) {
EXPECT_EQ(CallForAttr(DoubleAttr1::kVarUb,
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
[]<DoubleAttr1 a>() { return static_cast<int>(a); }),
static_cast<int>(DoubleAttr1::kVarUb));
}
TEST(FormatAttrValueTest, FormatsBool) {
EXPECT_EQ(FormatAttrValue(true), "true");
}
TEST(FormatAttrValueTest, FormatsInt64) {
EXPECT_EQ(FormatAttrValue(int64_t{12}), "12");
}
TEST(FormatAttrValueTest, FormatsDouble) {
EXPECT_EQ(FormatAttrValue(4.2), "4.2");
}
} // namespace
} // namespace operations_research::math_opt

View File

@@ -0,0 +1,30 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/diff.h"
#include <array>
#include <cstdint>
#include "ortools/math_opt/elemental/elements.h"
namespace operations_research::math_opt {
void Diff::Advance(const std::array<int64_t, kNumElements>& checkpoints) {
for (int i = 0; i < kNumElements; ++i) {
element_diffs_[i].Advance(checkpoints[i]);
}
attr_diffs_.ForEachAttrValue([](auto& diff) { diff.Advance(); });
}
} // namespace operations_research::math_opt

View File

@@ -0,0 +1,157 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_DIFF_H_
#define OR_TOOLS_MATH_OPT_ELEMENTAL_DIFF_H_
#include <array>
#include <cstdint>
#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "ortools/math_opt/elemental/attr_diff.h"
#include "ortools/math_opt/elemental/attr_key.h"
#include "ortools/math_opt/elemental/attributes.h"
#include "ortools/math_opt/elemental/derived_data.h"
#include "ortools/math_opt/elemental/element_diff.h"
#include "ortools/math_opt/elemental/elements.h"
namespace operations_research::math_opt {
// Stores the modifications to the model since the previous checkpoint (or since
// creation of the Diff if Advance() has never been called).
//
// Only the following modifications are tracked explicitly:
// * elements before the checkpoint
// * attributes with all elements in the key before the checkpoint
// as all changes involving an element after the checkpoint are implied to be in
// the difference.
//
// Note: users of ElementalImpl can only access a const Diff.
//
// When a element is deleted from the model, the creator of the Diff is
// responsible both for:
// 1. Calling Diff::DeleteElement() on the element,
// 2. For each Attr with a key element on the element type, calling
// Diff::EraseKeysForAttr()
// We cannot do this all at once for the user, as we do not have access to the
// relevant related keys in steps 3/4 above.
class Diff {
public:
Diff() = default;
// Discards all tracked modifications, and in the future, track only
// modifications where all elements are at most checkpoint.
//
// Generally, checkpoints should be component-wise non-decreasing with each
// invocation of Advance(), but this is not checked here.
void Advance(const std::array<int64_t, kNumElements>& checkpoints);
//////////////////////////////////////////////////////////////////////////////
// Elements
//////////////////////////////////////////////////////////////////////////////
// The current checkpoint for the element type `e`.
//
// This equals the next element id for the element type `e` when Advance() was
// last called (or at creation time if advance was never called).
int64_t checkpoint(const ElementType e) const {
return element_diff(e).checkpoint();
}
// The elements of element type `e` that have been deleted since the last call
// to Advance() with id less than the checkpoint.
const absl::flat_hash_set<int64_t>& deleted_elements(
const ElementType e) const {
return element_diff(e).deleted();
}
// Tracks the element `id` of element type `e` as deleted if it is less than
// the checkpoint.
//
// WARNING: this does not update any related attributes.
void DeleteElement(const ElementType e, int64_t id) {
mutable_element_diff(e).Delete(id);
}
//////////////////////////////////////////////////////////////////////////////
// Attributes
//////////////////////////////////////////////////////////////////////////////
// Returns the keys with all elements below the checkpoint where the Attr2 `a`
// was modified since the last call to Advance().
template <typename AttrType>
const AttrKeyHashSet<AttrKeyFor<AttrType>>& modified_keys(
const AttrType a) const {
return attr_diffs_[a].modified_keys();
}
// Marks that the attribute `a` has been modified for `attr_key`.
template <typename AttrType>
void SetModified(const AttrType a, const AttrKeyFor<AttrType> attr_key) {
if (IsBeforeCheckpoint(a, attr_key)) {
attr_diffs_[a].SetModified(attr_key);
}
}
// Discard any tracked modifications for attribute `a` on `keys`.
//
// Typically invoke when the element with id `keys[i]` is deleted from the
// model, and where `keys` are the keys `k` for all elements `e` where the `e`
// has a non-default value for `a`.
template <typename AttrType>
void EraseKeysForAttr(const AttrType a,
absl::Span<const AttrKeyFor<AttrType>> keys) {
if (!attr_diffs_[a].has_modified_keys()) {
return;
}
for (const auto& attr_key : keys) {
if (IsBeforeCheckpoint(a, attr_key)) {
attr_diffs_[a].Erase(attr_key);
}
}
}
private:
const ElementDiff& element_diff(const ElementType e) const {
// TODO(b/365997645): post C++ 23, prefer std::to_underlying(e).
return element_diffs_[static_cast<int>(e)];
}
ElementDiff& mutable_element_diff(const ElementType e) {
return element_diffs_[static_cast<int>(e)];
}
// Returns true if all elements if `key` are before their respective
// checkpoints.
template <typename AttrType>
bool IsBeforeCheckpoint(const AttrType a, const AttrKeyFor<AttrType> key) {
for (int i = 0; i < GetAttrKeySize<AttrType>(); ++i) {
if (key[i] >= element_diff(GetElementTypes(a)[i]).checkpoint()) {
return false;
}
}
return true;
}
std::array<ElementDiff, kNumElements> element_diffs_;
template <int i>
using DiffForAttr = AttrDiff<AllAttrs::TypeDescriptor<i>::kNumKeyElements,
typename AllAttrs::TypeDescriptor<i>::Symmetry>;
AttrMap<DiffForAttr> attr_diffs_;
};
} // namespace operations_research::math_opt
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_DIFF_H_

View File

@@ -0,0 +1,289 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/diff.h"
#include <array>
#include <cstdint>
#include "absl/types/span.h"
#include "gtest/gtest.h"
#include "ortools/base/gmock.h"
#include "ortools/math_opt/elemental/attr_key.h"
#include "ortools/math_opt/elemental/attributes.h"
#include "ortools/math_opt/elemental/elements.h"
namespace operations_research::math_opt {
namespace {
using ::testing::IsEmpty;
using ::testing::UnorderedElementsAre;
std::array<int64_t, kNumElements> MakeUniformCheckpoint(int64_t id) {
std::array<int64_t, kNumElements> result;
result.fill(id);
return result;
}
////////////////////////////////////////////////////////////////////////////////
// Element tests
////////////////////////////////////////////////////////////////////////////////
TEST(DiffTest, InitDiffElementsEmpty) {
Diff diff;
EXPECT_EQ(diff.checkpoint(ElementType::kVariable), 0);
EXPECT_THAT(diff.deleted_elements(ElementType::kVariable), IsEmpty());
}
TEST(DiffTest, DeleteElementAfterCheckpointNoEffect) {
Diff diff;
diff.DeleteElement(ElementType::kVariable, 2);
EXPECT_THAT(diff.deleted_elements(ElementType::kVariable), IsEmpty());
}
TEST(DiffTest, DeletesTrackedBelowCheckpoint) {
Diff diff;
diff.Advance(MakeUniformCheckpoint(5));
EXPECT_EQ(diff.checkpoint(ElementType::kVariable), 5);
diff.DeleteElement(ElementType::kVariable, 3);
diff.DeleteElement(ElementType::kVariable, 1);
diff.DeleteElement(ElementType::kVariable, 8);
diff.DeleteElement(ElementType::kVariable, 5);
EXPECT_THAT(diff.deleted_elements(ElementType::kVariable),
UnorderedElementsAre(3, 1));
}
TEST(DiffTest, AdvanceClearsDeletedElements) {
Diff diff;
diff.Advance(MakeUniformCheckpoint(5));
diff.DeleteElement(ElementType::kVariable, 3);
EXPECT_THAT(diff.deleted_elements(ElementType::kVariable),
UnorderedElementsAre(3));
diff.Advance(MakeUniformCheckpoint(5));
EXPECT_THAT(diff.deleted_elements(ElementType::kVariable), IsEmpty());
}
////////////////////////////////////////////////////////////////////////////////
// Attr0 Tests
////////////////////////////////////////////////////////////////////////////////
TEST(DiffTest, InitBoolAttr0Empty) {
Diff diff;
EXPECT_THAT(diff.modified_keys(BoolAttr0::kMaximize), IsEmpty());
}
TEST(DiffTest, SetBoolAttr0ModifiedIsModified) {
Diff diff;
diff.SetModified(BoolAttr0::kMaximize, AttrKey());
EXPECT_THAT(diff.modified_keys(BoolAttr0::kMaximize),
UnorderedElementsAre(AttrKey()));
}
TEST(DiffTest, BoolAttr0AdvanceClearsModification) {
Diff diff;
diff.SetModified(BoolAttr0::kMaximize, AttrKey());
diff.Advance(MakeUniformCheckpoint(0));
EXPECT_THAT(diff.modified_keys(BoolAttr0::kMaximize), IsEmpty());
}
// Repeat all tests for double, they are short enough.
TEST(DiffTest, InitDoubleAttr0Empty) {
Diff diff;
EXPECT_THAT(diff.modified_keys(DoubleAttr0::kObjOffset), IsEmpty());
}
TEST(DiffTest, SetDoubleAttr0ModifiedIsModified) {
Diff diff;
diff.SetModified(DoubleAttr0::kObjOffset, AttrKey());
EXPECT_THAT(diff.modified_keys(DoubleAttr0::kObjOffset),
UnorderedElementsAre(AttrKey()));
}
TEST(DiffTest, DoubleAttr0AdvanceClearsModification) {
Diff diff;
diff.SetModified(DoubleAttr0::kObjOffset, AttrKey());
diff.Advance(MakeUniformCheckpoint(0));
EXPECT_THAT(diff.modified_keys(DoubleAttr0::kObjOffset), IsEmpty());
}
////////////////////////////////////////////////////////////////////////////////
// Attr1 Tests
////////////////////////////////////////////////////////////////////////////////
TEST(DiffTest, InitBoolAttr1Empty) {
Diff diff;
EXPECT_THAT(diff.modified_keys(BoolAttr1::kVarInteger), IsEmpty());
}
TEST(DiffTest, SetBoolAttr1ModifiedBeforeCheckpointIsModified) {
Diff diff;
diff.Advance(MakeUniformCheckpoint(1));
diff.SetModified(BoolAttr1::kVarInteger, AttrKey(0));
EXPECT_THAT(diff.modified_keys(BoolAttr1::kVarInteger),
UnorderedElementsAre(AttrKey(0)));
}
TEST(DiffTest, SetBoolAttr1ModifiedAtleastCheckpointNotTracked) {
Diff diff;
diff.SetModified(BoolAttr1::kVarInteger, AttrKey(0));
EXPECT_THAT(diff.modified_keys(BoolAttr1::kVarInteger), IsEmpty());
}
TEST(DiffTest, BoolAttr1AdvanceClearsModification) {
Diff diff;
diff.Advance(MakeUniformCheckpoint(1));
diff.SetModified(BoolAttr1::kVarInteger, AttrKey(0));
diff.Advance(MakeUniformCheckpoint(1));
EXPECT_THAT(diff.modified_keys(BoolAttr1::kVarInteger), IsEmpty());
}
TEST(DiffTest, EraseElementForBoolAttr1IsNoLongerTracked) {
Diff diff;
diff.Advance(MakeUniformCheckpoint(1));
diff.SetModified(BoolAttr1::kVarInteger, AttrKey(0));
EXPECT_THAT(diff.modified_keys(BoolAttr1::kVarInteger),
UnorderedElementsAre(AttrKey(0)));
diff.EraseKeysForAttr(BoolAttr1::kVarInteger, {AttrKey(0)});
EXPECT_THAT(diff.modified_keys(BoolAttr1::kVarInteger), IsEmpty());
}
////////////////////////////////////////////////////////////////////////////////
// Repeat all tests for DoubleAttr1, not ideal
////////////////////////////////////////////////////////////////////////////////
TEST(DiffTest, InitDoubleAttr1Empty) {
Diff diff;
EXPECT_THAT(diff.modified_keys(DoubleAttr1::kLinConUb), IsEmpty());
}
TEST(DiffTest, SetDoubleAttr1ModifiedBeforeCheckpointIsModified) {
Diff diff;
diff.Advance(MakeUniformCheckpoint(1));
diff.SetModified(DoubleAttr1::kLinConUb, AttrKey(0));
EXPECT_THAT(diff.modified_keys(DoubleAttr1::kLinConUb),
UnorderedElementsAre(AttrKey(0)));
}
TEST(DiffTest, SetDoubleAttr1ModifiedAtleastCheckpointNotTracked) {
Diff diff;
diff.SetModified(DoubleAttr1::kLinConUb, AttrKey(0));
EXPECT_THAT(diff.modified_keys(DoubleAttr1::kLinConUb), IsEmpty());
}
TEST(DiffTest, DoubleAttr1AdvanceClearsModification) {
Diff diff;
diff.Advance(MakeUniformCheckpoint(1));
diff.SetModified(DoubleAttr1::kLinConUb, AttrKey(0));
diff.Advance(MakeUniformCheckpoint(1));
EXPECT_THAT(diff.modified_keys(DoubleAttr1::kLinConUb), IsEmpty());
}
TEST(DiffTest, EraseElementForDoubleAttr1IsNoLongerTracked) {
Diff diff;
diff.Advance(MakeUniformCheckpoint(1));
diff.SetModified(DoubleAttr1::kLinConUb, AttrKey(0));
EXPECT_THAT(diff.modified_keys(DoubleAttr1::kLinConUb),
UnorderedElementsAre(AttrKey(0)));
diff.EraseKeysForAttr(DoubleAttr1::kLinConUb, {AttrKey(0)});
EXPECT_THAT(diff.modified_keys(DoubleAttr1::kLinConUb), IsEmpty());
}
////////////////////////////////////////////////////////////////////////////////
// Attr2 Tests
//
// UpdateAttr2OnFirstElementDeleted and UpdateAttr2OnSecondElementDeleted are a
// bit under tested.
////////////////////////////////////////////////////////////////////////////////
TEST(DiffTest, InitDoubleAttr2Empty) {
Diff diff;
EXPECT_THAT(diff.modified_keys(DoubleAttr2::kLinConCoef), IsEmpty());
}
TEST(DiffTest, SetDoubleAttr2ModifiedBothKeysBeforeCheckpointIsModified) {
Diff diff;
diff.Advance(MakeUniformCheckpoint(2));
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(1, 0));
EXPECT_THAT(diff.modified_keys(DoubleAttr2::kLinConCoef),
UnorderedElementsAre(AttrKey(1, 0)));
}
TEST(DiffTest, SetDoubleAttr2ModifiedFirstKeyAtleastCheckpointNotTracked) {
Diff diff;
diff.Advance(MakeUniformCheckpoint(2));
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(4, 0));
EXPECT_THAT(diff.modified_keys(DoubleAttr2::kLinConCoef), IsEmpty());
}
TEST(DiffTest, SetDoubleAttr2ModifiedSecondtKeyAtleastCheckpointNotTracked) {
Diff diff;
diff.Advance(MakeUniformCheckpoint(2));
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(0, 4));
EXPECT_THAT(diff.modified_keys(DoubleAttr2::kLinConCoef), IsEmpty());
}
TEST(DiffTest, DoubleAttr2AdvanceClearsModification) {
Diff diff;
diff.Advance(MakeUniformCheckpoint(1));
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(0, 0));
diff.Advance(MakeUniformCheckpoint(1));
EXPECT_THAT(diff.modified_keys(DoubleAttr2::kLinConCoef), IsEmpty());
}
TEST(DiffTest, EraseFirstElementForDoubleAttr2IsNoLongerTracked) {
Diff diff;
diff.Advance(MakeUniformCheckpoint(5));
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(1, 0));
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(1, 2));
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(1, 4));
EXPECT_THAT(
diff.modified_keys(DoubleAttr2::kLinConCoef),
UnorderedElementsAre(AttrKey(1, 0), AttrKey(1, 2), AttrKey(1, 4)));
diff.EraseKeysForAttr(
DoubleAttr2::kLinConCoef,
absl::Span<const AttrKey<2>>({AttrKey(1, 0), AttrKey(1, 4)}));
EXPECT_THAT(diff.modified_keys(DoubleAttr2::kLinConCoef),
UnorderedElementsAre(AttrKey(1, 2)));
}
TEST(DiffTest, EraseSecondElementForDoubleAttr2IsNoLongerTracked) {
Diff diff;
diff.Advance(MakeUniformCheckpoint(5));
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(0, 1));
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(2, 1));
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(4, 1));
EXPECT_THAT(
diff.modified_keys(DoubleAttr2::kLinConCoef),
UnorderedElementsAre(AttrKey(0, 1), AttrKey(2, 1), AttrKey(4, 1)));
diff.EraseKeysForAttr(
DoubleAttr2::kLinConCoef,
absl::Span<const AttrKey<2>>({AttrKey(0, 1), AttrKey(4, 1)}));
EXPECT_THAT(diff.modified_keys(DoubleAttr2::kLinConCoef),
UnorderedElementsAre(AttrKey(2, 1)));
}
} // namespace
} // namespace operations_research::math_opt

View File

@@ -0,0 +1,64 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENT_DIFF_H_
#define OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENT_DIFF_H_
#include <cstdint>
#include "absl/container/flat_hash_set.h"
namespace operations_research::math_opt {
// Tracks the ids of the elements in a model that:
// 1. Are less than the checkpoint for this element.
// 2. Have been deleted since the most recent time the checkpoint was advanced
// (or creation of the ElementDiff if advance was never called).
//
// Generally:
// * Element ids should be nonnegative.
// * Each element should be deleted at most once.
// * Sequential calls to Advance() should be called on non-decreasing
// checkpoints.
// However, these are enforced higher up the stack, not in this class.
class ElementDiff {
public:
// The current checkpoint for this element, generally the next_id for this
// element when Advance() was last called (or at creation time if advance was
// never called).
int64_t checkpoint() const { return checkpoint_; }
// The elements that have been deleted before the checkpoint.
const absl::flat_hash_set<int64_t>& deleted() const { return deleted_; }
// Tracks the element `id` as deleted if it is less than the checkpoint.
void Delete(int64_t id) {
if (id < checkpoint_) {
deleted_.insert(id);
}
}
// Update the checkpoint and clears all tracked deletions.
void Advance(int64_t checkpoint) {
checkpoint_ = checkpoint;
deleted_.clear();
}
private:
int64_t checkpoint_ = 0;
absl::flat_hash_set<int64_t> deleted_;
};
} // namespace operations_research::math_opt
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENT_DIFF_H_

View File

@@ -0,0 +1,59 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/element_diff.h"
#include "gtest/gtest.h"
#include "ortools/base/gmock.h"
namespace operations_research::math_opt {
namespace {
using ::testing::IsEmpty;
using ::testing::UnorderedElementsAre;
TEST(ElementDiffTest, EmptyDiff) {
ElementDiff diff;
EXPECT_EQ(diff.checkpoint(), 0);
EXPECT_THAT(diff.deleted(), IsEmpty());
diff.Delete(4);
EXPECT_THAT(diff.deleted(), IsEmpty());
}
TEST(ElementDiffTest, AddsPointsBelowCheckpoint) {
ElementDiff diff;
diff.Advance(4);
EXPECT_EQ(diff.checkpoint(), 4);
diff.Delete(1);
diff.Delete(3);
diff.Delete(4);
diff.Delete(5);
EXPECT_THAT(diff.deleted(), UnorderedElementsAre(1, 3));
}
TEST(ElementDiffTest, AdvanceClearsDiff) {
ElementDiff diff;
diff.Advance(4);
diff.Delete(1);
diff.Delete(3);
diff.Advance(5);
EXPECT_THAT(diff.deleted(), IsEmpty());
EXPECT_EQ(diff.checkpoint(), 5);
}
} // namespace
} // namespace operations_research::math_opt

View File

@@ -0,0 +1,81 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENT_REF_TRACKER_H_
#define OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENT_REF_TRACKER_H_
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "ortools/base/map_util.h"
#include "ortools/math_opt/elemental/attr_key.h"
#include "ortools/math_opt/elemental/elements.h"
namespace operations_research::math_opt {
template <typename ValueType, int n, typename Symmetry>
class ElementRefTracker;
// The `ElementRefTracker` for a given attribute type descriptor.
template <typename Descriptor>
using ElementRefTrackerForAttrTypeDescriptor =
ElementRefTracker<typename Descriptor::ValueType,
Descriptor::kNumKeyElements,
typename Descriptor::Symmetry>;
// A tracker for values that reference elements.
//
// This is used to delete attributes when the elements they reference are
// deleted.
template <ElementType element_type, int n, typename Symmetry>
class ElementRefTracker<ElementId<element_type>, n, Symmetry> {
public:
using ElemId = ElementId<element_type>;
using Key = AttrKey<n, Symmetry>;
// Returns the set of keys that reference element `id`.
const absl::flat_hash_set<Key>& GetKeysReferencing(const ElemId id) const {
return gtl::FindWithDefault(element_id_to_attr_keys_, id);
}
// Tracks the fact that attribute with key `key` has a value that references
// elements.
void Track(const Key key, const ElemId id) {
element_id_to_attr_keys_[id].insert(key);
}
void Untrack(const Key key, const ElemId id) {
const auto it = element_id_to_attr_keys_.find(id);
if (it != element_id_to_attr_keys_.end()) {
it->second.erase(key);
if (it->second.empty()) {
element_id_to_attr_keys_.erase(it);
}
}
}
void Clear() { element_id_to_attr_keys_.clear(); }
private:
// A map of element id to the list of attribute keys that have a non-default
// value for this element.
absl::flat_hash_map<ElemId, absl::flat_hash_set<Key>>
element_id_to_attr_keys_;
};
// Other value types do not need tracking.
template <typename ValueType, int n, typename Symmetry>
class ElementRefTracker {};
} // namespace operations_research::math_opt
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENT_REF_TRACKER_H_

View File

@@ -0,0 +1,54 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/element_ref_tracker.h"
#include "gtest/gtest.h"
#include "ortools/base/gmock.h"
#include "ortools/math_opt/elemental/attr_key.h"
#include "ortools/math_opt/elemental/elements.h"
#include "ortools/math_opt/elemental/symmetry.h"
namespace operations_research::math_opt {
namespace {
using ::testing::IsEmpty;
using ::testing::UnorderedElementsAre;
TEST(ElementRefTrackerTest, ElementValued) {
const VariableId x(0);
const VariableId y(1);
ElementRefTracker<VariableId, 1, NoSymmetry> tracker;
tracker.Track(AttrKey(1), x);
tracker.Track(AttrKey(2), x);
tracker.Track(AttrKey(3), y);
EXPECT_THAT(tracker.GetKeysReferencing(x),
UnorderedElementsAre(AttrKey(1), AttrKey(2)));
EXPECT_THAT(tracker.GetKeysReferencing(y), UnorderedElementsAre(AttrKey(3)));
tracker.Untrack(AttrKey(1), x);
EXPECT_THAT(tracker.GetKeysReferencing(x), UnorderedElementsAre(AttrKey(2)));
EXPECT_THAT(tracker.GetKeysReferencing(y), UnorderedElementsAre(AttrKey(3)));
tracker.Untrack(AttrKey(2), x);
EXPECT_THAT(tracker.GetKeysReferencing(x), IsEmpty());
EXPECT_THAT(tracker.GetKeysReferencing(y), UnorderedElementsAre(AttrKey(3)));
tracker.Untrack(AttrKey(3), y);
EXPECT_THAT(tracker.GetKeysReferencing(x), IsEmpty());
EXPECT_THAT(tracker.GetKeysReferencing(y), IsEmpty());
}
} // namespace
} // namespace operations_research::math_opt

View File

@@ -0,0 +1,56 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/element_storage.h"
#include <cstdint>
#include <utility>
#include <variant>
#include <vector>
#include "absl/algorithm/container.h"
namespace operations_research::math_opt {
namespace detail {
std::vector<int64_t> SparseElementStorage::AllIds() const {
std::vector<int64_t> result;
result.reserve(elements_.size());
for (const auto& [id, unused] : elements_) {
result.push_back(id);
}
return result;
}
std::vector<int64_t> DenseElementStorage::AllIds() const {
std::vector<int64_t> result(elements_.size());
absl::c_iota(result, 0);
return result;
}
SparseElementStorage::SparseElementStorage(DenseElementStorage&& dense)
: next_id_(dense.size()) {
elements_.reserve(next_id_);
for (int i = 0; i < next_id_; ++i) {
elements_.emplace(i, std::move(dense.elements_[i]));
}
}
} // namespace detail
std::vector<int64_t> ElementStorage::AllIds() const {
return std::visit([](const auto& impl) { return impl.AllIds(); }, impl_);
}
} // namespace operations_research::math_opt

View File

@@ -0,0 +1,192 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENT_STORAGE_H_
#define OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENT_STORAGE_H_
#include <algorithm>
#include <cstdint>
#include <string>
#include <utility>
#include <variant>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "ortools/base/status_builder.h"
namespace operations_research::math_opt {
namespace detail {
// A dense element storage, for use when no elements have been erased.
// Same API as `ElementStorage`, but no deletion.
// TODO(b/369972336): We should stay in dense mode if we have a small percentage
// of deletions.
class DenseElementStorage {
public:
int64_t Add(const absl::string_view name) {
const int64_t id = elements_.size();
elements_.emplace_back(name);
return id;
}
bool exists(const int64_t id) const {
return 0 <= id && id < elements_.size();
}
absl::StatusOr<absl::string_view> GetName(const int64_t id) const {
if (exists(id)) {
return elements_[id];
}
return util::InvalidArgumentErrorBuilder() << "no element with id " << id;
}
int64_t next_id() const { return size(); }
std::vector<int64_t> AllIds() const;
int64_t size() const { return elements_.size(); }
private:
friend class SparseElementStorage;
std::vector<std::string> elements_;
};
// A sparse element storage, which supports deletion.
class SparseElementStorage {
public:
explicit SparseElementStorage(DenseElementStorage&& dense);
int64_t Add(const absl::string_view name) {
const int64_t id = next_id_;
elements_.try_emplace(id, name);
++next_id_;
return id;
}
bool Erase(int64_t id) { return elements_.erase(id) > 0; }
bool exists(int64_t id) const { return elements_.contains(id); }
absl::StatusOr<absl::string_view> GetName(int64_t id) const {
if (const auto it = elements_.find(id); it != elements_.end()) {
return it->second;
}
return util::InvalidArgumentErrorBuilder() << "no element with id " << id;
}
int64_t next_id() const { return next_id_; }
std::vector<int64_t> AllIds() const;
int64_t size() const { return elements_.size(); }
void ensure_next_id_at_least(int64_t id) {
next_id_ = std::max(next_id_, id);
}
private:
absl::flat_hash_map<int64_t, std::string> elements_;
int64_t next_id_ = 0;
};
} // namespace detail
class ElementStorage {
// Functions with deduced return must be defined before they are used.
private:
// std::visit is very slow, see
// 5253596299885805568
//
// This function is static, taking Self as template argument, to avoid
// having const and non-const versions. Post C++ 23, prefer:
// https://en.cppreference.com/w/cpp/language/member_functions#Explicit_object_member_functions
template <typename Self, typename Fn>
static auto Visit(Self& self, Fn fn) {
if (std::holds_alternative<detail::DenseElementStorage>(self.impl_)) {
return fn(std::get<detail::DenseElementStorage>(self.impl_));
} else {
return fn(std::get<detail::SparseElementStorage>(self.impl_));
}
}
public:
// We start with a dense storage, which is more efficient, and switch to a
// sparse storage when an element is erased.
ElementStorage() : impl_(detail::DenseElementStorage()) {}
// Creates a new element and returns its id.
int64_t Add(const absl::string_view name) {
return Visit(*this, [name](auto& impl) { return impl.Add(name); });
}
// Deletes an element by id, returning true on success and false if no element
// was deleted (it was already deleted or the id was not from any existing
// element).
bool Erase(const int64_t id) { return AsSparse().Erase(id); }
// Returns true an element with this id was created and not yet erased.
bool Exists(const int64_t id) const {
return Visit(*this, [id](auto& impl) { return impl.exists(id); });
}
// Returns the name of this element, or CHECK fails if no element with this id
// exists.
absl::StatusOr<absl::string_view> GetName(const int64_t id) const {
return Visit(*this, [id](auto& impl) { return impl.GetName(id); });
}
// Returns the id that will be used for the next element added.
//
// NOTE: when no elements have been erased, this equals size().
int64_t next_id() const {
return Visit(*this, [](auto& impl) { return impl.next_id(); });
}
// Returns all ids of all elements in the model in an unsorted,
// non-deterministic order.
std::vector<int64_t> AllIds() const;
// Returns the number of elements added and not erased.
int64_t size() const {
return Visit(*this, [](auto& impl) { return impl.size(); });
}
// Increases next_id() to `id` if it is currently less than `id`.
//
// Useful for reading a model back from proto, most users should not need to
// call this directly.
void EnsureNextIdAtLeast(const int64_t id) {
if (id > next_id()) {
AsSparse().ensure_next_id_at_least(id);
}
}
private:
detail::SparseElementStorage& AsSparse() {
if (auto* sparse = std::get_if<detail::SparseElementStorage>(&impl_)) {
return *sparse;
}
impl_ = detail::SparseElementStorage(
std::move(std::get<detail::DenseElementStorage>(impl_)));
return std::get<detail::SparseElementStorage>(impl_);
}
std::variant<detail::DenseElementStorage, detail::SparseElementStorage> impl_;
};
} // namespace operations_research::math_opt
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENT_STORAGE_H_

View File

@@ -0,0 +1,170 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/element_storage.h"
#include <cstdint>
#include "absl/status/status.h"
#include "benchmark/benchmark.h"
#include "gtest/gtest.h"
#include "ortools/base/gmock.h"
namespace operations_research::math_opt {
namespace {
using ::testing::IsEmpty;
using ::testing::UnorderedElementsAre;
using ::testing::status::IsOkAndHolds;
using ::testing::status::StatusIs;
TEST(ElementStorageTest, EmptyModelGetters) {
ElementStorage elements;
EXPECT_EQ(elements.size(), 0);
EXPECT_EQ(elements.next_id(), 0);
EXPECT_THAT(elements.AllIds(), IsEmpty());
EXPECT_FALSE(elements.Exists(0));
EXPECT_FALSE(elements.Exists(1));
EXPECT_FALSE(elements.Exists(-1));
}
TEST(ElementStorageTest, AddElement) {
ElementStorage elements;
EXPECT_EQ(elements.Add("x"), 0);
EXPECT_EQ(elements.Add("y"), 1);
EXPECT_EQ(elements.size(), 2);
EXPECT_EQ(elements.next_id(), 2);
EXPECT_THAT(elements.AllIds(), UnorderedElementsAre(0, 1));
EXPECT_FALSE(elements.Exists(2));
EXPECT_FALSE(elements.Exists(-1));
ASSERT_TRUE(elements.Exists(0));
ASSERT_TRUE(elements.Exists(1));
EXPECT_THAT(elements.GetName(0), IsOkAndHolds("x"));
EXPECT_THAT(elements.GetName(1), IsOkAndHolds("y"));
EXPECT_THAT(elements.GetName(-42),
StatusIs(absl::StatusCode::kInvalidArgument));
}
TEST(ElementStorageTest, AddElementDuplicateName) {
ElementStorage elements;
EXPECT_EQ(elements.Add("xyx"), 0);
EXPECT_EQ(elements.Add("xyx"), 1);
EXPECT_EQ(elements.size(), 2);
EXPECT_EQ(elements.next_id(), 2);
EXPECT_THAT(elements.AllIds(), UnorderedElementsAre(0, 1));
ASSERT_TRUE(elements.Exists(0));
ASSERT_TRUE(elements.Exists(1));
EXPECT_THAT(elements.GetName(0), IsOkAndHolds("xyx"));
EXPECT_THAT(elements.GetName(1), IsOkAndHolds("xyx"));
}
TEST(ElementStorageTest, DeleteElement) {
ElementStorage elements;
const int64_t x = elements.Add("x");
const int64_t y = elements.Add("y");
const int64_t z = elements.Add("z");
EXPECT_TRUE(elements.Erase(x));
EXPECT_TRUE(elements.Erase(z));
EXPECT_EQ(elements.size(), 1);
EXPECT_EQ(elements.next_id(), 3);
EXPECT_THAT(elements.AllIds(), UnorderedElementsAre(y));
EXPECT_FALSE(elements.Exists(x));
EXPECT_FALSE(elements.Exists(z));
EXPECT_FALSE(elements.Exists(3));
EXPECT_FALSE(elements.Exists(-1));
ASSERT_TRUE(elements.Exists(y));
EXPECT_THAT(elements.GetName(y), IsOkAndHolds("y"));
EXPECT_THAT(elements.GetName(x),
StatusIs(absl::StatusCode::kInvalidArgument));
}
TEST(ElementStorageTest, DeleteInvalidIdNoEffect) {
ElementStorage elements;
const int64_t x = elements.Add("x");
const int64_t y = elements.Add("y");
const int64_t z = elements.Add("z");
EXPECT_FALSE(elements.Erase(-2));
EXPECT_FALSE(elements.Erase(5));
EXPECT_EQ(elements.size(), 3);
EXPECT_EQ(elements.next_id(), 3);
EXPECT_THAT(elements.AllIds(), UnorderedElementsAre(x, y, z));
}
TEST(ElementStorageTest, DeleteTwiceNoAdditionalEffect) {
ElementStorage elements;
const int64_t x = elements.Add("x");
const int64_t y = elements.Add("y");
const int64_t z = elements.Add("z");
EXPECT_TRUE(elements.Erase(y));
EXPECT_FALSE(elements.Erase(y));
EXPECT_EQ(elements.size(), 2);
EXPECT_EQ(elements.next_id(), 3);
EXPECT_THAT(elements.AllIds(), UnorderedElementsAre(x, z));
}
TEST(ElementStorageTest, EnsureNextIdAtLeastIncreasesNextId) {
ElementStorage elements;
elements.EnsureNextIdAtLeast(5);
EXPECT_EQ(elements.next_id(), 5);
EXPECT_EQ(elements.Add("x"), 5);
EXPECT_THAT(elements.AllIds(), UnorderedElementsAre(5));
}
TEST(ElementStorageTest, EnsureNextIdAtLeastNoEffect) {
ElementStorage elements;
elements.Add("x");
elements.EnsureNextIdAtLeast(0);
EXPECT_EQ(elements.next_id(), 1);
EXPECT_EQ(elements.Add("y"), 1);
EXPECT_THAT(elements.AllIds(), UnorderedElementsAre(0, 1));
}
void BM_AddElements(benchmark::State& state) {
const int n = state.range(0);
for (auto s : state) {
ElementStorage storage;
for (int i = 0; i < n; ++i) {
storage.Add("");
}
benchmark::DoNotOptimize(storage);
}
}
BENCHMARK(BM_AddElements)->Arg(100)->Arg(10000);
void BM_Exists(benchmark::State& state) {
const int n = state.range(0);
ElementStorage storage;
for (int i = 0; i < n; ++i) {
storage.Add("");
}
for (auto s : state) {
for (int i = 0; i < 2 * n; ++i) {
bool e = storage.Exists(i);
benchmark::DoNotOptimize(e);
}
}
}
BENCHMARK(BM_Exists)->Arg(100)->Arg(10000);
} // namespace
} // namespace operations_research::math_opt

View File

@@ -0,0 +1,169 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ortools/math_opt/elemental/elemental.h"
#include <array>
#include <cstdint>
#include <memory>
#include <optional>
#include <ostream>
#include <string>
#include <utility>
#include <vector>
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "ortools/math_opt/elemental/arrays.h"
#include "ortools/math_opt/elemental/attr_key.h"
#include "ortools/math_opt/elemental/derived_data.h"
#include "ortools/math_opt/elemental/diff.h"
#include "ortools/math_opt/elemental/elements.h"
namespace operations_research::math_opt {
Elemental::Elemental(std::string model_name, std::string primary_objective_name)
: model_name_(std::move(model_name)),
primary_objective_name_(std::move(primary_objective_name)) {
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
ForEachIndex<AllAttrs::kNumAttrTypes>([this]<int attr_type_index>() {
using Descriptor = AllAttrs::TypeDescriptor<attr_type_index>;
for (const auto a : Descriptor::Enumerate()) {
const int index = static_cast<int>(a);
attrs_[a] = StorageForAttrType<attr_type_index>(
Descriptor::kAttrDescriptors[index].default_value);
}
});
}
std::optional<Elemental::DiffHandle> Elemental::GetDiffHandle(
const int64_t id) const {
if (diffs_->Get(id) == nullptr) {
return std::nullopt;
}
return DiffHandle(id, diffs_.get());
}
Elemental::DiffHandle Elemental::AddDiff() {
auto diff = std::make_unique<Diff>();
diff->Advance(CurrentCheckpoint());
const int64_t diff_id = diffs_->Insert(std::move(diff));
return DiffHandle(diff_id, diffs_.get());
}
bool Elemental::DeleteDiff(const DiffHandle diff) {
if (&diff.diffs_ != diffs_.get()) {
return false;
}
return diffs_->Erase(diff.diff_id_);
}
bool Elemental::Advance(const DiffHandle diff) {
if (diffs_.get() != &diff.diffs_) {
return false;
}
Diff* d = diffs_->UpdateAndGet(diff.diff_id_);
if (d == nullptr) {
return false;
}
d->Advance(CurrentCheckpoint());
return true;
}
bool Elemental::DeleteElementUntyped(const ElementType e, int64_t id) {
if (!mutable_element_storage(e).Erase(id)) {
return false;
}
for (auto& [unused, diff] : diffs_->UpdateAndGetAll()) {
diff->DeleteElement(e, id);
}
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
AllAttrs::ForEachAttr([this, e, id]<typename AttrType>(AttrType a) {
ForEachIndex<GetAttrKeySize<AttrType>()>(
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
[&]<int i>() {
if (GetElementTypes(a)[i] == e) {
UpdateAttrOnElementDeleted<AttrType, i>(a, id);
}
});
// If `a` is element-valued, we need to remove all keys that refer to the
// deleted element.
if constexpr (is_element_id_v<ValueTypeFor<AttrType>>) {
if (e == ValueTypeFor<AttrType>::type()) {
const auto keys = element_ref_trackers_[a].GetKeysReferencing(
ValueTypeFor<AttrType>(id));
for (const auto key : keys) {
SetAttr(a, key, ValueTypeFor<AttrType>());
}
}
}
});
return true;
}
template <typename AttrType, int i>
void Elemental::UpdateAttrOnElementDeleted(const AttrType a, const int64_t id) {
auto& attr_storage = attrs_[a];
// We consider the case of n == 1 separately so that we can ensure that
// for any attribute with a key size of one, the AttrDiff has no deleted
// elements. (If we did not specialize this code, we would need to check for
// deleted elements when building our ModelUpdateProto, see
// README.md#checkpoints-and-model-updates for an explanation.)
if constexpr (GetAttrKeySize<AttrType>() == 1) {
for (auto& [unused, diff] : diffs_->UpdateAndGetAll()) {
diff->EraseKeysForAttr(a, {AttrKey(id)});
}
attr_storage.Erase(AttrKey(id));
} else {
// NOTE: We explicitly spell out the type here, so that if `Slice` ever
// returns a reference in `attr_storage` instead of a copy, we are forced
// to update this code to make a copy of the slice (otherwise the slice
// would be invalidated by calls to `Erase()` below).
const std::vector<AttrKeyFor<AttrType>> keys =
attr_storage.template Slice<i>(id);
for (auto& [unused, diff] : diffs_->UpdateAndGetAll()) {
diff->EraseKeysForAttr(a, keys);
}
for (const auto& key : keys) {
attr_storage.Erase(key);
}
}
}
std::array<int64_t, kNumElements> Elemental::CurrentCheckpoint() const {
std::array<int64_t, kNumElements> result;
for (int i = 0; i < kNumElements; ++i) {
result[i] = elements_[i].next_id();
}
return result;
}
Elemental Elemental::Clone(
std::optional<absl::string_view> new_model_name) const {
Elemental result(std::string(new_model_name.value_or(model_name_)),
primary_objective_name_);
result.elements_ = elements_;
result.attrs_ = attrs_;
return result;
}
std::ostream& operator<<(std::ostream& ostr, const Elemental& elemental) {
ostr << elemental.DebugString();
return ostr;
}
} // namespace operations_research::math_opt

View File

@@ -0,0 +1,545 @@
// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENTAL_H_
#define OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENTAL_H_
#include <array>
#include <cstdint>
#include <memory>
#include <optional>
#include <ostream>
#include <string>
#include <vector>
#include "absl/log/check.h"
#include "absl/log/die_if_null.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "ortools/base/status_builder.h"
#include "ortools/math_opt/elemental/attr_key.h"
#include "ortools/math_opt/elemental/attr_storage.h"
#include "ortools/math_opt/elemental/derived_data.h"
#include "ortools/math_opt/elemental/diff.h"
#include "ortools/math_opt/elemental/element_ref_tracker.h"
#include "ortools/math_opt/elemental/element_storage.h"
#include "ortools/math_opt/elemental/elements.h"
#include "ortools/math_opt/elemental/thread_safe_id_map.h"
#include "ortools/math_opt/model.pb.h"
#include "ortools/math_opt/model_update.pb.h"
namespace operations_research::math_opt {
// A MathOpt optimization model and modification trackers.
//
// Holds the elements, the attribute values, and tracks modifications to the
// model by `Diff` objects, and keeps them all in sync. See README.md for
// details.
class Elemental {
public:
// An opaque value type for a reference to an underlying `Diff` (change
// tracker).
class DiffHandle {
public:
int64_t id() const { return diff_id_; }
private:
explicit DiffHandle(int64_t diff_id, ThreadSafeIdMap<Diff>* diffs)
: diff_id_(diff_id), diffs_(*ABSL_DIE_IF_NULL(diffs)) {}
int64_t diff_id_;
ThreadSafeIdMap<Diff>& diffs_;
friend class Elemental;
friend class ElementalTestPeer;
};
explicit Elemental(std::string model_name = "",
std::string primary_objective_name = "");
// The name of this optimization model.
const std::string& model_name() const { return model_name_; }
// The name of the primary objective of this optimization model.
const std::string& primary_objective_name() const {
return primary_objective_name_;
}
//////////////////////////////////////////////////////////////////////////////
// Elements
//////////////////////////////////////////////////////////////////////////////
// Creates and returns the id of a new element for the element type `e`.
template <ElementType e>
ElementId<e> AddElement(const absl::string_view name) {
return ElementId<e>(AddElementUntyped(e, name));
}
// Type-erased version of `AddElement`. Prefer the latter.
int64_t AddElementUntyped(const ElementType e, const absl::string_view name) {
return mutable_element_storage(e).Add(name);
}
// Deletes the element with `id` for element type `e`, returning true on
// success and false if no element was deleted (it was already deleted or the
// id was not from any existing element).
template <ElementType e>
bool DeleteElement(const ElementId<e> id) {
return DeleteElementUntyped(e, id.value());
}
// Type-erased version of `DeleteElement`. Prefer the latter.
bool DeleteElementUntyped(ElementType e, int64_t id);
// Returns true the element with `id` for element type `e` exists (it was
// created and not yet deleted).
template <ElementType e>
bool ElementExists(const ElementId<e> id) const {
return ElementExistsUntyped(e, id.value());
}
// Type-erased version of `ElementExists`. Prefer the latter.
bool ElementExistsUntyped(const ElementType e, const int64_t id) const {
return element_storage(e).Exists(id);
}
// Returns the name of the element with `id` for element type `e`, or an error
// if this element does not exist.
template <ElementType e>
absl::StatusOr<absl::string_view> GetElementName(
const ElementId<e> id) const {
return GetElementNameUntyped(e, id.value());
}
// Type-erased version of `GetElementName`. Prefer the latter.
absl::StatusOr<absl::string_view> GetElementNameUntyped(
const ElementType e, const int64_t id) const {
return element_storage(e).GetName(id);
}
// Returns the ids of all elements of element type `e` in the model in an
// unsorted, non-deterministic order.
template <ElementType e>
ElementIdsVector<e> AllElements() const {
return ElementIdsVector<e>(AllElementsUntyped(e));
}
// Type-erased version of `AllElements`. Prefer the latter.
std::vector<int64_t> AllElementsUntyped(const ElementType e) const {
return element_storage(e).AllIds();
}
// Returns the id of next element created for element type `e`.
//
// Equal to one plus the number of elements that were created for element type
// `e`. When no elements have been deleted, this equals num_elements(e).
int64_t NextElementId(const ElementType e) const {
return element_storage(e).next_id();
}
// Returns the number of elements in the model for element type `e`.
//
// Equal to the number of elements that were created minus the number
// deleted for element type `e`.
int64_t NumElements(const ElementType e) const {
return element_storage(e).size();
}
// Increases next_element_id(e) to `id` if it is currently less than `id`.
//
// Useful for reading a model back from proto, most users should not need to
// call this directly.
template <ElementType e>
void EnsureNextElementIdAtLeast(const ElementId<e> id) {
EnsureNextElementIdAtLeastUntyped(e, id.value());
}
// Type-erased version of `EnsureNextElementIdAtLeast`. Prefer the latter.
void EnsureNextElementIdAtLeastUntyped(const ElementType e, int64_t id) {
mutable_element_storage(e).EnsureNextIdAtLeast(id);
}
//////////////////////////////////////////////////////////////////////////////
// Attributes
//////////////////////////////////////////////////////////////////////////////
struct AlwaysOk {};
// In all the following functions, `key` must be a valid key for attribute `a`
// (i.e. elements must exists for all elements ids of `key`). When this is not
// the case, the behavior is defined by one of the following policies:
// - Checks whether each element of the key exists, and dies if not.
struct DiePolicy {
using CheckResultT = AlwaysOk;
template <typename T>
using Wrapped = T;
};
// - Checks whether each element of the key exists, and sets `*status` if
// not. When `*status` is not OK, the model is not modified and is still
// valid.
struct StatusPolicy {
using CheckResultT = absl::Status;
template <typename T>
using Wrapped = absl::StatusOr<T>;
};
// - Does not check whether key elements exists. UB if the key does not exist
// (DCHECK-fails in debug mode). Use if you know that the key exists and
// you care about performance.
struct UBPolicy {
using CheckResultT = AlwaysOk;
template <typename T>
using Wrapped = T;
};
// Restores the attribute `a` to its default value for all AttrKeys (or for
// an Attr0, its only value).
template <typename AttrType>
void AttrClear(AttrType a);
// Return the vector of attribute keys where a is non-default.
template <typename AttrType>
std::vector<AttrKeyFor<AttrType>> AttrNonDefaults(const AttrType a) const {
return attrs_[a].NonDefaults();
}
// Returns the number of keys where `a` is non-default.
template <typename AttrType>
int64_t AttrNumNonDefaults(const AttrType a) const {
return attrs_[a].num_non_defaults();
}
// Returns the value of the attr `a` for `key`:
// - `get_attr(DoubleAttr1::kVarUb, AttrKey(x))` returns a `double` value if
// element id `x` exists, and crashes otherwise. The returned value is
// the default value if the attribute has not been set for `x`.
// - `get_attr<StatusPolicy>(DoubleAttr1::kVarUb, AttrKey(x))` returns a
// valid `StatusOr<double>` if element id `x` exists, and an error
// otherwise.
template <typename Policy = DiePolicy, typename AttrType>
typename Policy::template Wrapped<ValueTypeFor<AttrType>> GetAttr(
AttrType a, AttrKeyFor<AttrType> key) const;
// Returns true if the attr `a` for `key` has a value different from its
// default.
template <typename Policy = DiePolicy, typename AttrType>
typename Policy::template Wrapped<bool> AttrIsNonDefault(
AttrType a, AttrKeyFor<AttrType> key) const;
// Sets the value of the attr `a` for the element `key` to `value`, and
// returns true if the value of the attribute has changed.
template <typename Policy = DiePolicy, typename AttrType>
typename Policy::CheckResultT SetAttr(AttrType a, AttrKeyFor<AttrType> key,
ValueTypeFor<AttrType> value);
// Returns the set of all keys `k` such that `k[i] == key_elem` and `k` has a
// non-default value for the attribute `a`.
template <int i, typename Policy = DiePolicy, typename AttrType = void>
typename Policy::template Wrapped<std::vector<AttrKeyFor<AttrType>>> Slice(
AttrType a, int64_t key_elem) const;
// Returns the size of the given slice: This is equivalent to `Slice(a,
// key_elem).size()`, but `O(1)`.
template <int i, typename Policy = DiePolicy, typename AttrType = void>
typename Policy::template Wrapped<int64_t> GetSliceSize(
AttrType a, int64_t key_elem) const;
// Returns a copy of this, but with no diffs. The name of the model can
// optionally be replaced by `new_model_name`.
Elemental Clone(
std::optional<absl::string_view> new_model_name = std::nullopt) const;
//////////////////////////////////////////////////////////////////////////////
// Working with proto
//////////////////////////////////////////////////////////////////////////////
// Returns an equivalent protocol buffer. Fails if the model is too big
// to fit in the in-memory representation of the proto (it has more than
// 2**31-1 elements of a type or non-defaults for an attribute).
absl::StatusOr<ModelProto> ExportModel(bool remove_names = false) const;
// Creates an equivalent Elemental to `proto`.
static absl::StatusOr<Elemental> FromModelProto(const ModelProto& proto);
// Applies the changes to the model in `update_proto`.
absl::Status ApplyUpdateProto(const ModelUpdateProto& update_proto);
//////////////////////////////////////////////////////////////////////////////
// Diffs
//////////////////////////////////////////////////////////////////////////////
// Returns the DiffHandle for `id`, if one exists, or nullopt otherwise.
std::optional<DiffHandle> GetDiffHandle(int64_t id) const;
// Returned handle is valid until passed `DeleteDiff` or `*this` is
// destructed.
DiffHandle AddDiff();
// Deletes `diff` & invalidates it. Returns false if the handle was invalid or
// from the wrong elemental). On success, invalidates `diff`.
bool DeleteDiff(DiffHandle diff);
// The number of diffs currently tracking this.
int64_t NumDiffs() const { return diffs_->Size(); }
// Returns true on success (fails if diff was null, deleted or from the wrong
// elemental). Warning: diff is modified (owned by this).
bool Advance(DiffHandle diff);
// Internal use only (users of Elemental cannot access Diff directly), but
// prefer to invoking Diff::modified_keys() directly..
//
// Returns the modified keys in a Diff for an attribute, filtering out the
// keys referring to an element that has been deleted.
//
// This is needed because in some situations where a variable is deleted
// we cannot clean up the diff, see README.md.
template <typename AttrType>
std::vector<AttrKeyFor<AttrType>> ModifiedKeysThatExist(
AttrType attr, const Diff& diff) const;
// Returns a proto describing all changes to the model for `diff` since the
// most recent call to `Advance(diff)` (or the creation of `diff` if
// `Advance()` was never called).
//
// Returns std::nullopt the resulting ModelUpdateProto would be the empty
// message (there have been no changes to the model to report).
//
// Fails if the update is too big to fit in the in-memory representation of
// the proto (it has more than 2**31-1 elements in a RepeatedField).
absl::StatusOr<std::optional<ModelUpdateProto>> ExportModelUpdate(
DiffHandle diff, bool remove_names = false) const;
// Prints out the model by element and attribute. If print_diffs is true, also
// prints out the deleted elements and modified keys for each attribute for
// each DiffHandle tracked.
//
// This is a debug format. Do not assume the output is consistent across CLs
// and do not parse this format.
std::string DebugString(bool print_diffs = true) const;
private:
DiePolicy::CheckResultT CheckElementExists(ElementType elem_type,
int64_t elem_id, DiePolicy) const;
StatusPolicy::CheckResultT CheckElementExists(ElementType elem_type,
int64_t elem_id,
StatusPolicy) const;
UBPolicy::CheckResultT CheckElementExists(ElementType elem_type,
int64_t elem_id, UBPolicy) const;
template <typename AttrType>
bool AttrKeyExists(AttrType attr, AttrKeyFor<AttrType> key) const;
template <typename Policy, typename AttrType>
typename Policy::CheckResultT CheckAttrKeyExists(
AttrType a, AttrKeyFor<AttrType> key) const;
template <typename AttrType, int i>
void UpdateAttrOnElementDeleted(AttrType a, int64_t id);
std::array<int64_t, kNumElements> CurrentCheckpoint() const;
const ElementStorage& element_storage(const ElementType e) const {
// TODO(b/365997645): post C++ 23, prefer std::to_underlying(e).
return elements_[static_cast<int>(e)];
}
ElementStorage& mutable_element_storage(const ElementType e) {
return elements_[static_cast<int>(e)];
}
std::string model_name_;
std::string primary_objective_name_;
std::array<ElementStorage, kNumElements> elements_;
template <int i>
using StorageForAttrType =
AttrStorage<typename AllAttrs::TypeDescriptor<i>::ValueType,
AllAttrs::TypeDescriptor<i>::kNumKeyElements,
typename AllAttrs::TypeDescriptor<i>::Symmetry>;
AttrMap<StorageForAttrType> attrs_;
// For each attribute whose value is an element, we need to keep a map of
// element to the set of keys whose value refers to that element. This
// is used to erase the attribute when the element is deleted.
// This is kept outside of `attrs_` so that we can update the diffs when
// element deletions trigger attribute deletions.
template <int i>
using ElementRefTrackerForAttrType =
ElementRefTrackerForAttrTypeDescriptor<AllAttrs::TypeDescriptor<i>>;
AttrMap<ElementRefTrackerForAttrType> element_ref_trackers_;
// Note: it is important that this is a unique_ptr for two reasons:
// 1. We need a stable memory address for diffs_ to refer to in DiffHandle,
// and the Elemental type is moveable.
// 2. We want Elemental to be moveable, but ThreadSafeIdMap<Diff> is not.
std::unique_ptr<ThreadSafeIdMap<Diff>> diffs_ =
std::make_unique<ThreadSafeIdMap<Diff>>();
};
template <typename Sink>
void AbslStringify(Sink& sink, const Elemental& elemental) {
sink.Append(elemental.DebugString());
}
std::ostream& operator<<(std::ostream& ostr, const Elemental& elemental);
///////////////////////////////////////////////////////////////////////////////
// Inline and template implementation
template <typename AttrType>
void Elemental::AttrClear(const AttrType a) {
// Note: this is slightly faster than setting each non-default back to the
// default value.
const std::vector<AttrKeyFor<AttrType>> non_defaults = AttrNonDefaults(a);
if (!non_defaults.empty()) {
for (auto& [unused, diff] : diffs_->UpdateAndGetAll()) {
for (const auto key : non_defaults) {
diff->SetModified(a, key);
}
}
}
attrs_[a].Clear();
if constexpr (is_element_id_v<ValueTypeFor<AttrType>>) {
element_ref_trackers_[a].Clear();
}
}
#define ELEMENTAL_RETURN_IF_ERROR(expr) \
{ \
auto error = (expr); \
if constexpr (!std::is_same_v<decltype(error), AlwaysOk>) { \
if (!error.ok()) { \
return error; \
} \
} \
}
template <typename Policy, typename AttrType>
typename Policy::template Wrapped<ValueTypeFor<AttrType>> Elemental::GetAttr(
const AttrType a, const AttrKeyFor<AttrType> key) const {
ELEMENTAL_RETURN_IF_ERROR(CheckAttrKeyExists<Policy>(a, key));
return attrs_[a].Get(key);
}
template <typename Policy, typename AttrType>
typename Policy::template Wrapped<bool> Elemental::AttrIsNonDefault(
const AttrType a, const AttrKeyFor<AttrType> key) const {
ELEMENTAL_RETURN_IF_ERROR(CheckAttrKeyExists<Policy>(a, key));
return attrs_[a].IsNonDefault(key);
}
template <typename Policy, typename AttrType>
typename Policy::CheckResultT Elemental::SetAttr(
const AttrType a, const AttrKeyFor<AttrType> key,
const ValueTypeFor<AttrType> value) {
ELEMENTAL_RETURN_IF_ERROR(CheckAttrKeyExists<Policy>(a, key));
const std::optional<ValueTypeFor<AttrType>> prev_value =
attrs_[a].Set(key, value);
if (prev_value.has_value()) {
if constexpr (is_element_id_v<ValueTypeFor<AttrType>>) {
element_ref_trackers_[a].Untrack(key, *prev_value);
}
for (auto& [unused, diff] : diffs_->UpdateAndGetAll()) {
diff->SetModified(a, key);
}
if constexpr (is_element_id_v<ValueTypeFor<AttrType>>) {
element_ref_trackers_[a].Track(key, value);
}
}
return {};
}
template <int i, typename Policy, typename AttrType>
typename Policy::template Wrapped<std::vector<AttrKeyFor<AttrType>>>
Elemental::Slice(const AttrType a, const int64_t key_elem) const {
ELEMENTAL_RETURN_IF_ERROR(
CheckElementExists(GetElementTypes<AttrType>(a)[i], key_elem, Policy{}));
return attrs_[a].template Slice<i>(key_elem);
}
template <int i, typename Policy, typename AttrType>
typename Policy::template Wrapped<int64_t> Elemental::GetSliceSize(
const AttrType a, const int64_t key_elem) const {
ELEMENTAL_RETURN_IF_ERROR(
CheckElementExists(GetElementTypes<AttrType>(a)[i], key_elem, Policy{}));
return attrs_[a].template GetSliceSize<i>(key_elem);
}
inline Elemental::DiePolicy::CheckResultT Elemental::CheckElementExists(
const ElementType elem_type, const int64_t elem_id, DiePolicy) const {
// This is ~30% faster than:
// CHECK_OK(CheckElementExistsUntyped(elem_type, elem_id, StatusPolicy{}));
CHECK(ElementExistsUntyped(elem_type, elem_id))
<< "no element with id " << elem_id << " for element type " << elem_type;
return {};
}
inline Elemental::StatusPolicy::CheckResultT Elemental::CheckElementExists(
const ElementType elem_type, const int64_t elem_id, StatusPolicy) const {
if (!ElementExistsUntyped(elem_type, elem_id)) {
return util::InvalidArgumentErrorBuilder()
<< "no element with id " << elem_id << " for element type "
<< elem_type;
}
return absl::OkStatus();
}
inline Elemental::UBPolicy::CheckResultT Elemental::CheckElementExists(
const ElementType elem_type, const int64_t elem_id, UBPolicy) const {
// Try to be useful in debug mode.
DCHECK_OK(CheckElementExists(elem_type, elem_id, StatusPolicy{}));
return {};
}
template <typename AttrType>
bool Elemental::AttrKeyExists(const AttrType attr,
const AttrKeyFor<AttrType> key) const {
for (int i = 0; i < key.size(); ++i) {
if (!ElementExistsUntyped(GetElementTypes<AttrType>(attr)[i], key[i])) {
return false;
}
}
return true;
}
template <typename Policy, typename AttrType>
typename Policy::CheckResultT Elemental::CheckAttrKeyExists(
const AttrType a, const AttrKeyFor<AttrType> key) const {
for (int i = 0; i < key.size(); ++i) {
ELEMENTAL_RETURN_IF_ERROR(
CheckElementExists(GetElementTypes<AttrType>(a)[i], key[i], Policy{}));
}
return {};
}
template <typename AttrType>
std::vector<AttrKeyFor<AttrType>> Elemental::ModifiedKeysThatExist(
AttrType attr, const Diff& diff) const {
using Key = AttrKeyFor<AttrType>;
std::vector<Key> keys;
// Can be a slight overestimate.
keys.reserve(diff.modified_keys(attr).size());
for (const Key key : diff.modified_keys(attr)) {
if constexpr (Key::size() > 1) {
if (AttrKeyExists(attr, key)) {
keys.push_back(key);
}
} else {
keys.push_back(key);
}
}
return keys;
}
#undef ELEMENTAL_RETURN_IF_ERROR
} // namespace operations_research::math_opt
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENTAL_H_

Some files were not shown because too many files have changed in this diff Show More