math_opt: export from google3
* CMake has not been updated yet * bazel was compiling at least last week bazel: disable math opt facility_location.py missing some dependencies...
This commit is contained in:
@@ -298,6 +298,8 @@ endforeach()
|
||||
|
||||
if(BUILD_MATH_OPT)
|
||||
add_subdirectory(ortools/math_opt/core/python)
|
||||
add_subdirectory(ortools/math_opt/elemental/python)
|
||||
add_subdirectory(ortools/math_opt/io/python)
|
||||
add_subdirectory(ortools/math_opt/python)
|
||||
endif()
|
||||
|
||||
@@ -331,7 +333,12 @@ if(BUILD_MATH_OPT)
|
||||
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/__init__.py CONTENT "")
|
||||
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/core/__init__.py CONTENT "")
|
||||
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/core/python/__init__.py CONTENT "")
|
||||
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/elemental/__init__.py CONTENT "")
|
||||
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/elemental/python/__init__.py CONTENT "")
|
||||
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/io/__init__.py CONTENT "")
|
||||
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/io/python/__init__.py CONTENT "")
|
||||
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/python/__init__.py CONTENT "")
|
||||
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/python/elemental/__init__.py CONTENT "")
|
||||
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/python/ipc/__init__.py CONTENT "")
|
||||
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/python/testing/__init__.py CONTENT "")
|
||||
file(GENERATE OUTPUT ${PYTHON_PROJECT_DIR}/math_opt/solvers/__init__.py CONTENT "")
|
||||
@@ -366,27 +373,42 @@ file(COPY
|
||||
ortools/linear_solver/python/model_builder_numbers.py
|
||||
DESTINATION ${PYTHON_PROJECT_DIR}/linear_solver/python)
|
||||
if(BUILD_MATH_OPT)
|
||||
configure_file(
|
||||
ortools/math_opt/elemental/python/enums.py.in
|
||||
${PYTHON_PROJECT_DIR}/math_opt/elemental/python/enums.py
|
||||
COPYONLY)
|
||||
file(COPY
|
||||
ortools/math_opt/python/bounded_expressions.py
|
||||
ortools/math_opt/python/callback.py
|
||||
ortools/math_opt/python/compute_infeasible_subsystem_result.py
|
||||
ortools/math_opt/python/errors.py
|
||||
ortools/math_opt/python/expressions.py
|
||||
ortools/math_opt/python/from_model.py
|
||||
ortools/math_opt/python/hash_model_storage.py
|
||||
ortools/math_opt/python/indicator_constraints.py
|
||||
ortools/math_opt/python/init_arguments.py
|
||||
ortools/math_opt/python/linear_constraints.py
|
||||
ortools/math_opt/python/mathopt.py
|
||||
ortools/math_opt/python/message_callback.py
|
||||
ortools/math_opt/python/model.py
|
||||
ortools/math_opt/python/model_parameters.py
|
||||
ortools/math_opt/python/model_storage.py
|
||||
ortools/math_opt/python/normalized_inequality.py
|
||||
ortools/math_opt/python/normalize.py
|
||||
ortools/math_opt/python/objectives.py
|
||||
ortools/math_opt/python/parameters.py
|
||||
ortools/math_opt/python/quadratic_constraints.py
|
||||
ortools/math_opt/python/result.py
|
||||
ortools/math_opt/python/solution.py
|
||||
ortools/math_opt/python/solve.py
|
||||
ortools/math_opt/python/solver_resources.py
|
||||
ortools/math_opt/python/sparse_containers.py
|
||||
ortools/math_opt/python/statistics.py
|
||||
ortools/math_opt/python/variables.py
|
||||
DESTINATION ${PYTHON_PROJECT_DIR}/math_opt/python)
|
||||
file(COPY
|
||||
ortools/math_opt/python/elemental/elemental.py
|
||||
DESTINATION ${PYTHON_PROJECT_DIR}/math_opt/python/elemental)
|
||||
file(COPY
|
||||
ortools/math_opt/python/ipc/proto_converter.py
|
||||
ortools/math_opt/python/ipc/remote_http_solve.py
|
||||
@@ -663,7 +685,13 @@ add_custom_command(
|
||||
$<TARGET_FILE:model_builder_helper_pybind11> ${PYTHON_PROJECT}/linear_solver/python
|
||||
COMMAND ${CMAKE_COMMAND} -E
|
||||
$<IF:$<BOOL:${BUILD_MATH_OPT}>,copy,true>
|
||||
$<TARGET_FILE:math_opt_pybind11> ${PYTHON_PROJECT}/math_opt/core/python
|
||||
$<TARGET_FILE:math_opt_core_pybind11> ${PYTHON_PROJECT}/math_opt/core/python
|
||||
COMMAND ${CMAKE_COMMAND} -E
|
||||
$<IF:$<BOOL:${BUILD_MATH_OPT}>,copy,true>
|
||||
$<TARGET_FILE:math_opt_elemental_pybind11> ${PYTHON_PROJECT}/math_opt/elemental/python
|
||||
COMMAND ${CMAKE_COMMAND} -E
|
||||
$<IF:$<BOOL:${BUILD_MATH_OPT}>,copy,true>
|
||||
$<TARGET_FILE:math_opt_io_pybind11> ${PYTHON_PROJECT}/math_opt/io/python
|
||||
COMMAND ${CMAKE_COMMAND} -E
|
||||
$<IF:$<BOOL:${BUILD_MATH_OPT}>,copy,true>
|
||||
$<TARGET_FILE:status_py_extension_stub> ${PYTHON_PROJECT}/../pybind11_abseil
|
||||
@@ -697,7 +725,9 @@ add_custom_command(
|
||||
routing_pybind11
|
||||
pywraplp
|
||||
model_builder_helper_pybind11
|
||||
math_opt_pybind11
|
||||
$<$<BOOL:${BUILD_MATH_OPT}>:math_opt_core_pybind11>
|
||||
$<$<BOOL:${BUILD_MATH_OPT}>:math_opt_elemental_pybind11>
|
||||
$<$<BOOL:${BUILD_MATH_OPT}>:math_opt_io_pybind11>
|
||||
$<TARGET_NAME_IF_EXISTS:pdlp_pybind11>
|
||||
cp_model_helper_pybind11
|
||||
rcpsp_pybind11
|
||||
@@ -760,6 +790,10 @@ search_python_module(
|
||||
search_python_module(
|
||||
NAME wheel
|
||||
PACKAGE wheel)
|
||||
search_python_module(
|
||||
NAME typing_extensions
|
||||
PACKAGE typing-extensions
|
||||
NO_VERSION)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT python/dist_timestamp
|
||||
|
||||
@@ -18,6 +18,7 @@ endif()
|
||||
add_subdirectory(core)
|
||||
add_subdirectory(constraints)
|
||||
add_subdirectory(cpp)
|
||||
add_subdirectory(elemental)
|
||||
add_subdirectory(io)
|
||||
add_subdirectory(labs)
|
||||
add_subdirectory(solver_tests)
|
||||
@@ -31,6 +32,7 @@ target_sources(${NAME} PUBLIC
|
||||
$<TARGET_OBJECTS:${NAME}_core>
|
||||
$<TARGET_OBJECTS:${NAME}_core_c_api>
|
||||
$<TARGET_OBJECTS:${NAME}_cpp>
|
||||
$<TARGET_OBJECTS:${NAME}_elemental>
|
||||
$<TARGET_OBJECTS:${NAME}_io>
|
||||
$<TARGET_OBJECTS:${NAME}_labs>
|
||||
$<TARGET_OBJECTS:${NAME}_solvers>
|
||||
|
||||
18
ortools/math_opt/README.md
Normal file
18
ortools/math_opt/README.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# math_opt
|
||||
|
||||
The code in this directory provides a generic way of accessing mathematical
|
||||
optimization solvers (sometimes called mathematical programming solvers), such
|
||||
as GLOP, CP-SAT, SCIP and Gurobi. In particular, a single API is provided to
|
||||
make these solvers largely interoperable.
|
||||
|
||||
New code should prefer MathOpt to `MPSolver`, as defined in
|
||||
[linear_solver.h](../linear_solver/linear_solver.h)
|
||||
when possible.
|
||||
|
||||
MathOpt has client libraries in C++, Python, and Java that most users should use
|
||||
to build and solve their models. A proto API is also provided, but this is not
|
||||
recommended for most users.
|
||||
|
||||
See
|
||||
[parameters.proto](../math_opt/parameters.proto?q=SolverTypeProto)
|
||||
for the list of supported solvers.
|
||||
@@ -158,7 +158,17 @@ message CallbackResultProto {
|
||||
bool is_lazy = 4;
|
||||
}
|
||||
|
||||
// Ends the solve early.
|
||||
// When true it tells the solver to interrupt the solve as soon as possible.
|
||||
//
|
||||
// It can be set from any event. This is equivalent to using a
|
||||
// SolveInterrupter and triggering it from the callback.
|
||||
//
|
||||
// Some solvers don't support interruption, in that case this is simply
|
||||
// ignored and the solve terminates as usual. On top of that solvers may not
|
||||
// immediately stop the solve. Thus the user should expect the callback to
|
||||
// still be called after they set `terminate` to true in a previous
|
||||
// call. Returning with `terminate` false after having previously returned
|
||||
// true won't cancel the interruption.
|
||||
bool terminate = 1;
|
||||
|
||||
// TODO(b/172214608): SCIP allows to reject a feasible solution without
|
||||
|
||||
@@ -18,10 +18,11 @@ cc_library(
|
||||
srcs = ["indicator_constraint.cc"],
|
||||
hdrs = ["indicator_constraint.h"],
|
||||
deps = [
|
||||
"//ortools/base:intops",
|
||||
"//ortools/math_opt/constraints/util:model_util",
|
||||
"//ortools/math_opt/cpp:variable_and_expressions",
|
||||
"//ortools/math_opt/elemental:elements",
|
||||
"//ortools/math_opt/storage:model_storage",
|
||||
"//ortools/math_opt/storage:model_storage_item",
|
||||
"@abseil-cpp//absl/strings",
|
||||
],
|
||||
)
|
||||
@@ -51,6 +52,7 @@ cc_library(
|
||||
"//ortools/math_opt:model_update_cc_proto",
|
||||
"//ortools/math_opt:sparse_containers_cc_proto",
|
||||
"//ortools/math_opt/core:sorted",
|
||||
"//ortools/math_opt/elemental:elements",
|
||||
"//ortools/math_opt/storage:atomic_constraint_storage",
|
||||
"//ortools/math_opt/storage:sparse_coefficient_map",
|
||||
"@abseil-cpp//absl/container:flat_hash_set",
|
||||
|
||||
@@ -18,8 +18,6 @@
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "ortools/base/strong_int.h"
|
||||
#include "ortools/math_opt/constraints/util/model_util.h"
|
||||
#include "ortools/math_opt/cpp/variable_and_expressions.h"
|
||||
#include "ortools/math_opt/storage/model_storage.h"
|
||||
@@ -27,22 +25,22 @@
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
BoundedLinearExpression IndicatorConstraint::ImpliedConstraint() const {
|
||||
const IndicatorConstraintData& data = storage()->constraint_data(id_);
|
||||
const IndicatorConstraintData& data = storage()->constraint_data(typed_id());
|
||||
// NOTE: The following makes a copy of `data.linear_terms`. This can be made
|
||||
// more efficient if the need arises.
|
||||
LinearExpression expr = ToLinearExpression(
|
||||
*storage_, {.coeffs = data.linear_terms, .offset = 0.0});
|
||||
*storage(), {.coeffs = data.linear_terms, .offset = 0.0});
|
||||
return data.lower_bound <= std::move(expr) <= data.upper_bound;
|
||||
}
|
||||
|
||||
std::string IndicatorConstraint::ToString() const {
|
||||
if (!storage()->has_constraint(id_)) {
|
||||
if (!storage()->has_constraint(typed_id())) {
|
||||
return std::string(kDeletedConstraintDefaultDescription);
|
||||
}
|
||||
const IndicatorConstraintData& data = storage()->constraint_data(id_);
|
||||
const IndicatorConstraintData& data = storage()->constraint_data(typed_id());
|
||||
std::stringstream str;
|
||||
if (data.indicator.has_value()) {
|
||||
str << Variable(storage_, *data.indicator)
|
||||
str << Variable(storage(), *data.indicator)
|
||||
<< (data.activate_on_zero ? " = 0" : " = 1");
|
||||
} else {
|
||||
str << "[unset indicator variable]";
|
||||
|
||||
@@ -16,38 +16,28 @@
|
||||
#ifndef OR_TOOLS_MATH_OPT_CONSTRAINTS_INDICATOR_INDICATOR_CONSTRAINT_H_
|
||||
#define OR_TOOLS_MATH_OPT_CONSTRAINTS_INDICATOR_INDICATOR_CONSTRAINT_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "ortools/base/strong_int.h"
|
||||
#include "ortools/math_opt/constraints/util/model_util.h"
|
||||
#include "ortools/math_opt/cpp/variable_and_expressions.h"
|
||||
#include "ortools/math_opt/elemental/elements.h"
|
||||
#include "ortools/math_opt/storage/model_storage.h"
|
||||
#include "ortools/math_opt/storage/model_storage_item.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
// A value type that references an indicator constraint from ModelStorage.
|
||||
// Usually this type is passed by copy.
|
||||
//
|
||||
// This type implements https://abseil.io/docs/cpp/guides/hash.
|
||||
class IndicatorConstraint {
|
||||
class IndicatorConstraint final
|
||||
: public ModelStorageElement<ElementType::kIndicatorConstraint,
|
||||
IndicatorConstraint> {
|
||||
public:
|
||||
// The typed integer used for ids.
|
||||
using IdType = IndicatorConstraintId;
|
||||
using ModelStorageElement::ModelStorageElement;
|
||||
|
||||
inline IndicatorConstraint(const ModelStorage* storage,
|
||||
IndicatorConstraintId id);
|
||||
|
||||
inline int64_t id() const;
|
||||
|
||||
inline IndicatorConstraintId typed_id() const;
|
||||
inline const ModelStorage* storage() const;
|
||||
|
||||
inline absl::string_view name() const;
|
||||
absl::string_view name() const;
|
||||
|
||||
// Returns nullopt if the indicator variable is unset (this is a valid state,
|
||||
// in which the constraint is functionally ignored).
|
||||
@@ -65,91 +55,36 @@ class IndicatorConstraint {
|
||||
// Returns a detailed string description of the contents of the constraint
|
||||
// (not its name, use `<<` for that instead).
|
||||
std::string ToString() const;
|
||||
|
||||
friend inline bool operator==(const IndicatorConstraint& lhs,
|
||||
const IndicatorConstraint& rhs);
|
||||
friend inline bool operator!=(const IndicatorConstraint& lhs,
|
||||
const IndicatorConstraint& rhs);
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const IndicatorConstraint& constraint);
|
||||
friend std::ostream& operator<<(std::ostream& ostr,
|
||||
const IndicatorConstraint& constraint);
|
||||
|
||||
private:
|
||||
const ModelStorage* storage_;
|
||||
IndicatorConstraintId id_;
|
||||
};
|
||||
|
||||
// Streams the name of the constraint, as registered upon constraint creation,
|
||||
// or a short default if none was provided.
|
||||
inline std::ostream& operator<<(std::ostream& ostr,
|
||||
const IndicatorConstraint& constraint);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Inline function implementations
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int64_t IndicatorConstraint::id() const { return id_.value(); }
|
||||
|
||||
IndicatorConstraintId IndicatorConstraint::typed_id() const { return id_; }
|
||||
|
||||
const ModelStorage* IndicatorConstraint::storage() const { return storage_; }
|
||||
|
||||
absl::string_view IndicatorConstraint::name() const {
|
||||
if (storage_->has_constraint(id_)) {
|
||||
return storage_->constraint_data(id_).name;
|
||||
inline absl::string_view IndicatorConstraint::name() const {
|
||||
if (storage()->has_constraint(typed_id())) {
|
||||
return storage()->constraint_data(typed_id()).name;
|
||||
}
|
||||
return kDeletedConstraintDefaultDescription;
|
||||
}
|
||||
|
||||
std::optional<Variable> IndicatorConstraint::indicator_variable() const {
|
||||
const std::optional<VariableId> maybe_indicator =
|
||||
storage_->constraint_data(id_).indicator;
|
||||
storage()->constraint_data(typed_id()).indicator;
|
||||
if (!maybe_indicator.has_value()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return Variable(storage_, *maybe_indicator);
|
||||
return Variable(storage(), *maybe_indicator);
|
||||
}
|
||||
|
||||
bool IndicatorConstraint::activate_on_zero() const {
|
||||
return storage_->constraint_data(id_).activate_on_zero;
|
||||
return storage()->constraint_data(typed_id()).activate_on_zero;
|
||||
}
|
||||
|
||||
std::vector<Variable> IndicatorConstraint::NonzeroVariables() const {
|
||||
return AtomicConstraintNonzeroVariables(*storage_, id_);
|
||||
return AtomicConstraintNonzeroVariables(*storage(), typed_id());
|
||||
}
|
||||
|
||||
bool operator==(const IndicatorConstraint& lhs,
|
||||
const IndicatorConstraint& rhs) {
|
||||
return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_;
|
||||
}
|
||||
|
||||
bool operator!=(const IndicatorConstraint& lhs,
|
||||
const IndicatorConstraint& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
template <typename H>
|
||||
H AbslHashValue(H h, const IndicatorConstraint& constraint) {
|
||||
return H::combine(std::move(h), constraint.id_.value(), constraint.storage_);
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& ostr,
|
||||
const IndicatorConstraint& constraint) {
|
||||
// TODO(b/170992529): handle quoting of invalid characters in the name.
|
||||
const absl::string_view name = constraint.name();
|
||||
if (name.empty()) {
|
||||
ostr << "__indic_con#" << constraint.id() << "__";
|
||||
} else {
|
||||
ostr << name;
|
||||
}
|
||||
return ostr;
|
||||
}
|
||||
|
||||
IndicatorConstraint::IndicatorConstraint(const ModelStorage* const storage,
|
||||
const IndicatorConstraintId id)
|
||||
: storage_(storage), id_(id) {}
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
#endif // OR_TOOLS_MATH_OPT_CONSTRAINTS_INDICATOR_INDICATOR_CONSTRAINT_H_
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ortools/math_opt/elemental/elements.h"
|
||||
#include "ortools/math_opt/model.pb.h"
|
||||
#include "ortools/math_opt/model_update.pb.h"
|
||||
#include "ortools/math_opt/storage/atomic_constraint_storage.h"
|
||||
@@ -34,6 +35,8 @@ struct IndicatorConstraintData {
|
||||
using IdType = IndicatorConstraintId;
|
||||
using ProtoType = IndicatorConstraintProto;
|
||||
using UpdatesProtoType = IndicatorConstraintUpdatesProto;
|
||||
static constexpr ElementType kElementType = ElementType::kIndicatorConstraint;
|
||||
static constexpr bool kSupportsElemental = true;
|
||||
|
||||
// The `in_proto` must be in a valid state; see the inline comments on
|
||||
// `IndicatorConstraintProto` for details.
|
||||
@@ -55,6 +58,8 @@ struct IndicatorConstraintData {
|
||||
template <>
|
||||
struct AtomicConstraintTraits<IndicatorConstraintId> {
|
||||
using ConstraintData = IndicatorConstraintData;
|
||||
static constexpr ElementType kElementType = ElementType::kIndicatorConstraint;
|
||||
static constexpr bool kSupportsElemental = true;
|
||||
};
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
@@ -22,7 +22,9 @@ cc_library(
|
||||
"//ortools/math_opt/constraints/util:model_util",
|
||||
"//ortools/math_opt/cpp:key_types",
|
||||
"//ortools/math_opt/cpp:variable_and_expressions",
|
||||
"//ortools/math_opt/elemental:elements",
|
||||
"//ortools/math_opt/storage:model_storage",
|
||||
"//ortools/math_opt/storage:model_storage_item",
|
||||
"//ortools/math_opt/storage:sparse_coefficient_map",
|
||||
"//ortools/math_opt/storage:sparse_matrix",
|
||||
"@abseil-cpp//absl/log:check",
|
||||
@@ -54,6 +56,7 @@ cc_library(
|
||||
"//ortools/math_opt:model_cc_proto",
|
||||
"//ortools/math_opt:model_update_cc_proto",
|
||||
"//ortools/math_opt:sparse_containers_cc_proto",
|
||||
"//ortools/math_opt/elemental:elements",
|
||||
"//ortools/math_opt/storage:atomic_constraint_storage",
|
||||
"//ortools/math_opt/storage:model_storage_types",
|
||||
"//ortools/math_opt/storage:sparse_coefficient_map",
|
||||
|
||||
@@ -26,7 +26,7 @@ namespace operations_research::math_opt {
|
||||
BoundedQuadraticExpression QuadraticConstraint::AsBoundedQuadraticExpression()
|
||||
const {
|
||||
QuadraticExpression expression;
|
||||
const QuadraticConstraintData& data = storage()->constraint_data(id_);
|
||||
const QuadraticConstraintData& data = storage()->constraint_data(typed_id());
|
||||
for (const auto [var, coeff] : data.linear_terms.terms()) {
|
||||
expression += coeff * Variable(storage(), var);
|
||||
}
|
||||
|
||||
@@ -18,19 +18,18 @@
|
||||
#ifndef OR_TOOLS_MATH_OPT_CONSTRAINTS_QUADRATIC_QUADRATIC_CONSTRAINT_H_
|
||||
#define OR_TOOLS_MATH_OPT_CONSTRAINTS_QUADRATIC_QUADRATIC_CONSTRAINT_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "ortools/base/strong_int.h"
|
||||
#include "ortools/math_opt/constraints/util/model_util.h"
|
||||
#include "ortools/math_opt/cpp/key_types.h"
|
||||
#include "ortools/math_opt/cpp/variable_and_expressions.h"
|
||||
#include "ortools/math_opt/elemental/elements.h"
|
||||
#include "ortools/math_opt/storage/model_storage.h"
|
||||
#include "ortools/math_opt/storage/model_storage_item.h"
|
||||
#include "ortools/math_opt/storage/sparse_coefficient_map.h"
|
||||
#include "ortools/math_opt/storage/sparse_matrix.h"
|
||||
|
||||
@@ -38,20 +37,11 @@ namespace operations_research::math_opt {
|
||||
|
||||
// A value type that references a quadratic constraint from ModelStorage.
|
||||
// Usually this type is passed by copy.
|
||||
//
|
||||
// This type implements https://abseil.io/docs/cpp/guides/hash.
|
||||
class QuadraticConstraint {
|
||||
class QuadraticConstraint final
|
||||
: public ModelStorageElement<ElementType::kQuadraticConstraint,
|
||||
QuadraticConstraint> {
|
||||
public:
|
||||
// The typed integer used for ids.
|
||||
using IdType = QuadraticConstraintId;
|
||||
|
||||
inline QuadraticConstraint(const ModelStorage* storage,
|
||||
QuadraticConstraintId id);
|
||||
|
||||
inline int64_t id() const;
|
||||
|
||||
inline QuadraticConstraintId typed_id() const;
|
||||
inline const ModelStorage* storage() const;
|
||||
using ModelStorageElement::ModelStorageElement;
|
||||
|
||||
inline double lower_bound() const;
|
||||
inline double upper_bound() const;
|
||||
@@ -89,47 +79,23 @@ class QuadraticConstraint {
|
||||
// Returns a detailed string description of the contents of the constraint
|
||||
// (not its name, use `<<` for that instead).
|
||||
inline std::string ToString() const;
|
||||
|
||||
friend inline bool operator==(const QuadraticConstraint& lhs,
|
||||
const QuadraticConstraint& rhs);
|
||||
friend inline bool operator!=(const QuadraticConstraint& lhs,
|
||||
const QuadraticConstraint& rhs);
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const QuadraticConstraint& quadratic_constraint);
|
||||
friend std::ostream& operator<<(
|
||||
std::ostream& ostr, const QuadraticConstraint& quadratic_constraint);
|
||||
|
||||
private:
|
||||
const ModelStorage* storage_;
|
||||
QuadraticConstraintId id_;
|
||||
};
|
||||
|
||||
// Streams the name of the constraint, as registered upon constraint creation,
|
||||
// or a short default if none was provided.
|
||||
inline std::ostream& operator<<(std::ostream& ostr,
|
||||
const QuadraticConstraint& constraint);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Inline function implementations
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int64_t QuadraticConstraint::id() const { return id_.value(); }
|
||||
|
||||
QuadraticConstraintId QuadraticConstraint::typed_id() const { return id_; }
|
||||
|
||||
const ModelStorage* QuadraticConstraint::storage() const { return storage_; }
|
||||
|
||||
double QuadraticConstraint::lower_bound() const {
|
||||
return storage_->constraint_data(id_).lower_bound;
|
||||
return storage()->constraint_data(typed_id()).lower_bound;
|
||||
}
|
||||
|
||||
double QuadraticConstraint::upper_bound() const {
|
||||
return storage_->constraint_data(id_).upper_bound;
|
||||
return storage()->constraint_data(typed_id()).upper_bound;
|
||||
}
|
||||
|
||||
absl::string_view QuadraticConstraint::name() const {
|
||||
if (storage_->has_constraint(id_)) {
|
||||
return storage_->constraint_data(id_).name;
|
||||
if (storage()->has_constraint(typed_id())) {
|
||||
return storage()->constraint_data(typed_id()).name;
|
||||
}
|
||||
return kDeletedConstraintDefaultDescription;
|
||||
}
|
||||
@@ -145,27 +111,31 @@ bool QuadraticConstraint::is_quadratic_coefficient_nonzero(
|
||||
}
|
||||
|
||||
double QuadraticConstraint::linear_coefficient(const Variable variable) const {
|
||||
CHECK_EQ(variable.storage(), storage_)
|
||||
CHECK_EQ(variable.storage(), storage())
|
||||
<< internal::kObjectsFromOtherModelStorage;
|
||||
return storage_->constraint_data(id_).linear_terms.get(variable.typed_id());
|
||||
return storage()
|
||||
->constraint_data(typed_id())
|
||||
.linear_terms.get(variable.typed_id());
|
||||
}
|
||||
|
||||
double QuadraticConstraint::quadratic_coefficient(
|
||||
const Variable first_variable, const Variable second_variable) const {
|
||||
CHECK_EQ(first_variable.storage(), storage_)
|
||||
CHECK_EQ(first_variable.storage(), storage())
|
||||
<< internal::kObjectsFromOtherModelStorage;
|
||||
CHECK_EQ(second_variable.storage(), storage_)
|
||||
CHECK_EQ(second_variable.storage(), storage())
|
||||
<< internal::kObjectsFromOtherModelStorage;
|
||||
return storage_->constraint_data(id_).quadratic_terms.get(
|
||||
first_variable.typed_id(), second_variable.typed_id());
|
||||
return storage()
|
||||
->constraint_data(typed_id())
|
||||
.quadratic_terms.get(first_variable.typed_id(),
|
||||
second_variable.typed_id());
|
||||
}
|
||||
|
||||
std::vector<Variable> QuadraticConstraint::NonzeroVariables() const {
|
||||
return AtomicConstraintNonzeroVariables(*storage_, id_);
|
||||
return AtomicConstraintNonzeroVariables(*storage(), typed_id());
|
||||
}
|
||||
|
||||
std::string QuadraticConstraint::ToString() const {
|
||||
if (!storage()->has_constraint(id_)) {
|
||||
if (!storage()->has_constraint(typed_id())) {
|
||||
return std::string(kDeletedConstraintDefaultDescription);
|
||||
}
|
||||
std::stringstream str;
|
||||
@@ -173,38 +143,6 @@ std::string QuadraticConstraint::ToString() const {
|
||||
return str.str();
|
||||
}
|
||||
|
||||
bool operator==(const QuadraticConstraint& lhs,
|
||||
const QuadraticConstraint& rhs) {
|
||||
return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_;
|
||||
}
|
||||
|
||||
bool operator!=(const QuadraticConstraint& lhs,
|
||||
const QuadraticConstraint& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
template <typename H>
|
||||
H AbslHashValue(H h, const QuadraticConstraint& quadratic_constraint) {
|
||||
return H::combine(std::move(h), quadratic_constraint.id_.value(),
|
||||
quadratic_constraint.storage_);
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& ostr,
|
||||
const QuadraticConstraint& constraint) {
|
||||
// TODO(b/170992529): handle quoting of invalid characters in the name.
|
||||
const absl::string_view name = constraint.name();
|
||||
if (name.empty()) {
|
||||
ostr << "__quad_con#" << constraint.id() << "__";
|
||||
} else {
|
||||
ostr << name;
|
||||
}
|
||||
return ostr;
|
||||
}
|
||||
|
||||
QuadraticConstraint::QuadraticConstraint(const ModelStorage* const storage,
|
||||
const QuadraticConstraintId id)
|
||||
: storage_(storage), id_(id) {}
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
#endif // OR_TOOLS_MATH_OPT_CONSTRAINTS_QUADRATIC_QUADRATIC_CONSTRAINT_H_
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ortools/math_opt/elemental/elements.h"
|
||||
#include "ortools/math_opt/model.pb.h"
|
||||
#include "ortools/math_opt/model_update.pb.h"
|
||||
#include "ortools/math_opt/storage/atomic_constraint_storage.h"
|
||||
@@ -36,6 +37,9 @@ struct QuadraticConstraintData {
|
||||
using ProtoType = QuadraticConstraintProto;
|
||||
using UpdatesProtoType = QuadraticConstraintUpdatesProto;
|
||||
|
||||
static constexpr ElementType kElementType = ElementType::kQuadraticConstraint;
|
||||
static constexpr bool kSupportsElemental = true;
|
||||
|
||||
// The `in_proto` must be in a valid state; see the inline comments on
|
||||
// `QuadraticConstraintProto` for details.
|
||||
static QuadraticConstraintData FromProto(const ProtoType& in_proto);
|
||||
@@ -53,6 +57,8 @@ struct QuadraticConstraintData {
|
||||
template <>
|
||||
struct AtomicConstraintTraits<QuadraticConstraintId> {
|
||||
using ConstraintData = QuadraticConstraintData;
|
||||
static constexpr ElementType kElementType = ElementType::kQuadraticConstraint;
|
||||
static constexpr bool kSupportsElemental = true;
|
||||
};
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
@@ -24,6 +24,7 @@ cc_library(
|
||||
"//ortools/math_opt/cpp:variable_and_expressions",
|
||||
"//ortools/math_opt/storage:linear_expression_data",
|
||||
"//ortools/math_opt/storage:model_storage",
|
||||
"//ortools/math_opt/storage:model_storage_item",
|
||||
"//ortools/math_opt/storage:model_storage_types",
|
||||
"@abseil-cpp//absl/strings",
|
||||
],
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
LinearExpression SecondOrderConeConstraint::UpperBound() const {
|
||||
return ToLinearExpression(*storage_,
|
||||
return ToLinearExpression(*storage(),
|
||||
storage()->constraint_data(id_).upper_bound);
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ std::vector<LinearExpression> SecondOrderConeConstraint::ArgumentsToNorm()
|
||||
std::vector<LinearExpression> args;
|
||||
args.reserve(data.arguments_to_norm.size());
|
||||
for (const LinearExpressionData& arg_data : data.arguments_to_norm) {
|
||||
args.push_back(ToLinearExpression(*storage_, arg_data));
|
||||
args.push_back(ToLinearExpression(*storage(), arg_data));
|
||||
}
|
||||
return args;
|
||||
}
|
||||
@@ -58,9 +58,9 @@ std::string SecondOrderConeConstraint::ToString() const {
|
||||
str << ", ";
|
||||
}
|
||||
leading_comma = true;
|
||||
str << ToLinearExpression(*storage_, arg_data);
|
||||
str << ToLinearExpression(*storage(), arg_data);
|
||||
}
|
||||
str << "}||₂ ≤ " << ToLinearExpression(*storage_, data.upper_bound);
|
||||
str << "}||₂ ≤ " << ToLinearExpression(*storage(), data.upper_bound);
|
||||
return str.str();
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
#define OR_TOOLS_MATH_OPT_CONSTRAINTS_SECOND_ORDER_CONE_SECOND_ORDER_CONE_CONSTRAINT_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@@ -28,6 +27,7 @@
|
||||
#include "ortools/math_opt/constraints/util/model_util.h"
|
||||
#include "ortools/math_opt/cpp/variable_and_expressions.h"
|
||||
#include "ortools/math_opt/storage/model_storage.h"
|
||||
#include "ortools/math_opt/storage/model_storage_item.h"
|
||||
#include "ortools/math_opt/storage/model_storage_types.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
@@ -36,18 +36,17 @@ namespace operations_research::math_opt {
|
||||
// ModelStorage. Usually this type is passed by copy.
|
||||
//
|
||||
// This type implements https://abseil.io/docs/cpp/guides/hash.
|
||||
class SecondOrderConeConstraint {
|
||||
class SecondOrderConeConstraint final : public ModelStorageItem {
|
||||
public:
|
||||
// The typed integer used for ids.
|
||||
using IdType = SecondOrderConeConstraintId;
|
||||
|
||||
inline SecondOrderConeConstraint(const ModelStorage* storage,
|
||||
inline SecondOrderConeConstraint(ModelStorageCPtr storage,
|
||||
SecondOrderConeConstraintId id);
|
||||
|
||||
inline int64_t id() const;
|
||||
|
||||
inline SecondOrderConeConstraintId typed_id() const;
|
||||
inline const ModelStorage* storage() const;
|
||||
|
||||
inline absl::string_view name() const;
|
||||
|
||||
@@ -77,7 +76,6 @@ class SecondOrderConeConstraint {
|
||||
const SecondOrderConeConstraint& constraint);
|
||||
|
||||
private:
|
||||
const ModelStorage* storage_;
|
||||
SecondOrderConeConstraintId id_;
|
||||
};
|
||||
|
||||
@@ -96,24 +94,20 @@ SecondOrderConeConstraintId SecondOrderConeConstraint::typed_id() const {
|
||||
return id_;
|
||||
}
|
||||
|
||||
const ModelStorage* SecondOrderConeConstraint::storage() const {
|
||||
return storage_;
|
||||
}
|
||||
|
||||
absl::string_view SecondOrderConeConstraint::name() const {
|
||||
if (storage_->has_constraint(id_)) {
|
||||
return storage_->constraint_data(id_).name;
|
||||
if (storage()->has_constraint(id_)) {
|
||||
return storage()->constraint_data(id_).name;
|
||||
}
|
||||
return kDeletedConstraintDefaultDescription;
|
||||
}
|
||||
|
||||
std::vector<Variable> SecondOrderConeConstraint::NonzeroVariables() const {
|
||||
return AtomicConstraintNonzeroVariables(*storage_, id_);
|
||||
return AtomicConstraintNonzeroVariables(*storage(), id_);
|
||||
}
|
||||
|
||||
bool operator==(const SecondOrderConeConstraint& lhs,
|
||||
const SecondOrderConeConstraint& rhs) {
|
||||
return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_;
|
||||
return lhs.id_ == rhs.id_ && lhs.storage() == rhs.storage();
|
||||
}
|
||||
|
||||
bool operator!=(const SecondOrderConeConstraint& lhs,
|
||||
@@ -123,7 +117,7 @@ bool operator!=(const SecondOrderConeConstraint& lhs,
|
||||
|
||||
template <typename H>
|
||||
H AbslHashValue(H h, const SecondOrderConeConstraint& constraint) {
|
||||
return H::combine(std::move(h), constraint.id_.value(), constraint.storage_);
|
||||
return H::combine(std::move(h), constraint.id_.value(), constraint.storage());
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& ostr,
|
||||
@@ -139,8 +133,8 @@ std::ostream& operator<<(std::ostream& ostr,
|
||||
}
|
||||
|
||||
SecondOrderConeConstraint::SecondOrderConeConstraint(
|
||||
const ModelStorage* const storage, const SecondOrderConeConstraintId id)
|
||||
: storage_(storage), id_(id) {}
|
||||
const ModelStorageCPtr storage, const SecondOrderConeConstraintId id)
|
||||
: ModelStorageItem(storage), id_(id) {}
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ struct SecondOrderConeConstraintData {
|
||||
using IdType = SecondOrderConeConstraintId;
|
||||
using ProtoType = SecondOrderConeConstraintProto;
|
||||
using UpdatesProtoType = SecondOrderConeConstraintUpdatesProto;
|
||||
static constexpr bool kSupportsElemental = false;
|
||||
|
||||
// The `in_proto` must be in a valid state; see the inline comments on
|
||||
// `SecondOrderConeConstraintProto` for details.
|
||||
@@ -50,6 +51,7 @@ struct SecondOrderConeConstraintData {
|
||||
template <>
|
||||
struct AtomicConstraintTraits<SecondOrderConeConstraintId> {
|
||||
using ConstraintData = SecondOrderConeConstraintData;
|
||||
static constexpr bool kSupportsElemental = false;
|
||||
};
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
@@ -24,7 +24,9 @@ cc_library(
|
||||
"//ortools/math_opt/cpp:variable_and_expressions",
|
||||
"//ortools/math_opt/storage:linear_expression_data",
|
||||
"//ortools/math_opt/storage:model_storage",
|
||||
"//ortools/math_opt/storage:model_storage_item",
|
||||
"//ortools/math_opt/storage:sparse_coefficient_map",
|
||||
"@abseil-cpp//absl/base:nullability",
|
||||
"@abseil-cpp//absl/strings",
|
||||
],
|
||||
)
|
||||
@@ -57,6 +59,7 @@ cc_library(
|
||||
"//ortools/math_opt/cpp:variable_and_expressions",
|
||||
"//ortools/math_opt/storage:linear_expression_data",
|
||||
"//ortools/math_opt/storage:model_storage",
|
||||
"//ortools/math_opt/storage:model_storage_item",
|
||||
"//ortools/math_opt/storage:sparse_coefficient_map",
|
||||
"@abseil-cpp//absl/strings",
|
||||
],
|
||||
|
||||
@@ -23,10 +23,10 @@ namespace operations_research::math_opt {
|
||||
|
||||
LinearExpression Sos1Constraint::Expression(int index) const {
|
||||
const LinearExpressionData& storage_expr =
|
||||
storage_->constraint_data(id_).expression(index);
|
||||
storage()->constraint_data(id_).expression(index);
|
||||
LinearExpression out_expr = storage_expr.offset;
|
||||
for (const auto [var_id, coeff] : storage_expr.coeffs.terms()) {
|
||||
out_expr += coeff * Variable(storage_, var_id);
|
||||
out_expr += coeff * Variable(storage(), var_id);
|
||||
}
|
||||
return out_expr;
|
||||
}
|
||||
|
||||
@@ -21,12 +21,14 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/nullability.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "ortools/base/strong_int.h"
|
||||
#include "ortools/math_opt/constraints/sos/util.h"
|
||||
#include "ortools/math_opt/constraints/util/model_util.h"
|
||||
#include "ortools/math_opt/cpp/variable_and_expressions.h"
|
||||
#include "ortools/math_opt/storage/model_storage.h"
|
||||
#include "ortools/math_opt/storage/model_storage_item.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
@@ -34,17 +36,16 @@ namespace operations_research::math_opt {
|
||||
// Usually this type is passed by copy.
|
||||
//
|
||||
// This type implements https://abseil.io/docs/cpp/guides/hash.
|
||||
class Sos1Constraint {
|
||||
class Sos1Constraint final : public ModelStorageItem {
|
||||
public:
|
||||
// The typed integer used for ids.
|
||||
using IdType = Sos1ConstraintId;
|
||||
|
||||
inline Sos1Constraint(const ModelStorage* storage, Sos1ConstraintId id);
|
||||
inline Sos1Constraint(ModelStorageCPtr storage, Sos1ConstraintId id);
|
||||
|
||||
inline int64_t id() const;
|
||||
|
||||
inline Sos1ConstraintId typed_id() const;
|
||||
inline const ModelStorage* storage() const;
|
||||
|
||||
inline int64_t num_expressions() const;
|
||||
LinearExpression Expression(int index) const;
|
||||
@@ -69,7 +70,6 @@ class Sos1Constraint {
|
||||
const Sos1Constraint& constraint);
|
||||
|
||||
private:
|
||||
const ModelStorage* storage_;
|
||||
Sos1ConstraintId id_;
|
||||
};
|
||||
|
||||
@@ -86,33 +86,31 @@ int64_t Sos1Constraint::id() const { return id_.value(); }
|
||||
|
||||
Sos1ConstraintId Sos1Constraint::typed_id() const { return id_; }
|
||||
|
||||
const ModelStorage* Sos1Constraint::storage() const { return storage_; }
|
||||
|
||||
int64_t Sos1Constraint::num_expressions() const {
|
||||
return storage_->constraint_data(id_).num_expressions();
|
||||
return storage()->constraint_data(id_).num_expressions();
|
||||
}
|
||||
|
||||
bool Sos1Constraint::has_weights() const {
|
||||
return storage_->constraint_data(id_).has_weights();
|
||||
return storage()->constraint_data(id_).has_weights();
|
||||
}
|
||||
|
||||
double Sos1Constraint::weight(int index) const {
|
||||
return storage_->constraint_data(id_).weight(index);
|
||||
return storage()->constraint_data(id_).weight(index);
|
||||
}
|
||||
|
||||
absl::string_view Sos1Constraint::name() const {
|
||||
if (storage_->has_constraint(id_)) {
|
||||
return storage_->constraint_data(id_).name();
|
||||
if (storage()->has_constraint(id_)) {
|
||||
return storage()->constraint_data(id_).name();
|
||||
}
|
||||
return kDeletedConstraintDefaultDescription;
|
||||
}
|
||||
|
||||
std::vector<Variable> Sos1Constraint::NonzeroVariables() const {
|
||||
return AtomicConstraintNonzeroVariables(*storage_, id_);
|
||||
return AtomicConstraintNonzeroVariables(*storage(), id_);
|
||||
}
|
||||
|
||||
bool operator==(const Sos1Constraint& lhs, const Sos1Constraint& rhs) {
|
||||
return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_;
|
||||
return lhs.id_ == rhs.id_ && lhs.storage() == rhs.storage();
|
||||
}
|
||||
|
||||
bool operator!=(const Sos1Constraint& lhs, const Sos1Constraint& rhs) {
|
||||
@@ -121,7 +119,7 @@ bool operator!=(const Sos1Constraint& lhs, const Sos1Constraint& rhs) {
|
||||
|
||||
template <typename H>
|
||||
H AbslHashValue(H h, const Sos1Constraint& constraint) {
|
||||
return H::combine(std::move(h), constraint.id_.value(), constraint.storage_);
|
||||
return H::combine(std::move(h), constraint.id_.value(), constraint.storage());
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& ostr, const Sos1Constraint& constraint) {
|
||||
@@ -136,15 +134,15 @@ std::ostream& operator<<(std::ostream& ostr, const Sos1Constraint& constraint) {
|
||||
}
|
||||
|
||||
std::string Sos1Constraint::ToString() const {
|
||||
if (storage_->has_constraint(id_)) {
|
||||
if (storage()->has_constraint(id_)) {
|
||||
return internal::SosConstraintToString(*this, "SOS1");
|
||||
}
|
||||
return std::string(kDeletedConstraintDefaultDescription);
|
||||
}
|
||||
|
||||
Sos1Constraint::Sos1Constraint(const ModelStorage* const storage,
|
||||
Sos1Constraint::Sos1Constraint(const ModelStorageCPtr storage,
|
||||
const Sos1ConstraintId id)
|
||||
: storage_(storage), id_(id) {}
|
||||
: ModelStorageItem(storage), id_(id) {}
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
|
||||
@@ -23,10 +23,10 @@ namespace operations_research::math_opt {
|
||||
|
||||
LinearExpression Sos2Constraint::Expression(int index) const {
|
||||
const LinearExpressionData& storage_expr =
|
||||
storage_->constraint_data(id_).expression(index);
|
||||
storage()->constraint_data(id_).expression(index);
|
||||
LinearExpression out_expr = storage_expr.offset;
|
||||
for (const auto [var_id, coeff] : storage_expr.coeffs.terms()) {
|
||||
out_expr += coeff * Variable(storage_, var_id);
|
||||
out_expr += coeff * Variable(storage(), var_id);
|
||||
}
|
||||
return out_expr;
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
#include "ortools/math_opt/constraints/util/model_util.h"
|
||||
#include "ortools/math_opt/cpp/variable_and_expressions.h"
|
||||
#include "ortools/math_opt/storage/model_storage.h"
|
||||
#include "ortools/math_opt/storage/model_storage_item.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
@@ -34,17 +35,16 @@ namespace operations_research::math_opt {
|
||||
// Usually this type is passed by copy.
|
||||
//
|
||||
// This type implements https://abseil.io/docs/cpp/guides/hash.
|
||||
class Sos2Constraint {
|
||||
class Sos2Constraint final : public ModelStorageItem {
|
||||
public:
|
||||
// The typed integer used for ids.
|
||||
using IdType = Sos2ConstraintId;
|
||||
|
||||
inline Sos2Constraint(const ModelStorage* storage, Sos2ConstraintId id);
|
||||
inline Sos2Constraint(ModelStorageCPtr storage, Sos2ConstraintId id);
|
||||
|
||||
inline int64_t id() const;
|
||||
|
||||
inline Sos2ConstraintId typed_id() const;
|
||||
inline const ModelStorage* storage() const;
|
||||
|
||||
inline int64_t num_expressions() const;
|
||||
LinearExpression Expression(int index) const;
|
||||
@@ -70,7 +70,6 @@ class Sos2Constraint {
|
||||
const Sos2Constraint& constraint);
|
||||
|
||||
private:
|
||||
const ModelStorage* storage_;
|
||||
Sos2ConstraintId id_;
|
||||
};
|
||||
|
||||
@@ -87,33 +86,31 @@ int64_t Sos2Constraint::id() const { return id_.value(); }
|
||||
|
||||
Sos2ConstraintId Sos2Constraint::typed_id() const { return id_; }
|
||||
|
||||
const ModelStorage* Sos2Constraint::storage() const { return storage_; }
|
||||
|
||||
int64_t Sos2Constraint::num_expressions() const {
|
||||
return storage_->constraint_data(id_).num_expressions();
|
||||
return storage()->constraint_data(id_).num_expressions();
|
||||
}
|
||||
|
||||
bool Sos2Constraint::has_weights() const {
|
||||
return storage_->constraint_data(id_).has_weights();
|
||||
return storage()->constraint_data(id_).has_weights();
|
||||
}
|
||||
|
||||
double Sos2Constraint::weight(int index) const {
|
||||
return storage_->constraint_data(id_).weight(index);
|
||||
return storage()->constraint_data(id_).weight(index);
|
||||
}
|
||||
|
||||
absl::string_view Sos2Constraint::name() const {
|
||||
if (storage_->has_constraint(id_)) {
|
||||
return storage_->constraint_data(id_).name();
|
||||
if (storage()->has_constraint(id_)) {
|
||||
return storage()->constraint_data(id_).name();
|
||||
}
|
||||
return kDeletedConstraintDefaultDescription;
|
||||
}
|
||||
|
||||
std::vector<Variable> Sos2Constraint::NonzeroVariables() const {
|
||||
return AtomicConstraintNonzeroVariables(*storage_, id_);
|
||||
return AtomicConstraintNonzeroVariables(*storage(), id_);
|
||||
}
|
||||
|
||||
bool operator==(const Sos2Constraint& lhs, const Sos2Constraint& rhs) {
|
||||
return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_;
|
||||
return lhs.id_ == rhs.id_ && lhs.storage() == rhs.storage();
|
||||
}
|
||||
|
||||
bool operator!=(const Sos2Constraint& lhs, const Sos2Constraint& rhs) {
|
||||
@@ -122,7 +119,7 @@ bool operator!=(const Sos2Constraint& lhs, const Sos2Constraint& rhs) {
|
||||
|
||||
template <typename H>
|
||||
H AbslHashValue(H h, const Sos2Constraint& constraint) {
|
||||
return H::combine(std::move(h), constraint.id_.value(), constraint.storage_);
|
||||
return H::combine(std::move(h), constraint.id_.value(), constraint.storage());
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& ostr, const Sos2Constraint& constraint) {
|
||||
@@ -137,15 +134,15 @@ std::ostream& operator<<(std::ostream& ostr, const Sos2Constraint& constraint) {
|
||||
}
|
||||
|
||||
std::string Sos2Constraint::ToString() const {
|
||||
if (storage_->has_constraint(id_)) {
|
||||
if (storage()->has_constraint(id_)) {
|
||||
return internal::SosConstraintToString(*this, "SOS2");
|
||||
}
|
||||
return std::string(kDeletedConstraintDefaultDescription);
|
||||
}
|
||||
|
||||
Sos2Constraint::Sos2Constraint(const ModelStorage* const storage,
|
||||
Sos2Constraint::Sos2Constraint(const ModelStorageCPtr storage,
|
||||
const Sos2ConstraintId id)
|
||||
: storage_(storage), id_(id) {}
|
||||
: ModelStorageItem(storage), id_(id) {}
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ class SosConstraintData {
|
||||
using IdType = ConstraintId;
|
||||
using ProtoType = SosConstraintProto;
|
||||
using UpdatesProtoType = SosConstraintUpdatesProto;
|
||||
static constexpr bool kSupportsElemental = false;
|
||||
|
||||
static_assert(
|
||||
std::disjunction_v<std::is_same<ConstraintId, Sos1ConstraintId>,
|
||||
@@ -101,11 +102,13 @@ using Sos2ConstraintData = internal::SosConstraintData<Sos2ConstraintId>;
|
||||
template <>
|
||||
struct AtomicConstraintTraits<Sos1ConstraintId> {
|
||||
using ConstraintData = Sos1ConstraintData;
|
||||
static constexpr bool kSupportsElemental = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct AtomicConstraintTraits<Sos2ConstraintId> {
|
||||
using ConstraintData = Sos2ConstraintData;
|
||||
static constexpr bool kSupportsElemental = false;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@@ -52,7 +52,7 @@ std::vector<Variable> AtomicConstraintNonzeroVariables(
|
||||
}
|
||||
|
||||
// Duck-types on `ConstraintType` having a typedef for the associated `IdType`,
|
||||
// and having a `(const ModelStorage*, IdType)` constructor.
|
||||
// and having a `(ModelStorageCPtr, IdType)` constructor.
|
||||
template <typename ConstraintType>
|
||||
std::vector<ConstraintType> AtomicConstraints(const ModelStorage& storage) {
|
||||
using IdType = typename ConstraintType::IdType;
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#ifndef OR_TOOLS_MATH_OPT_CORE_INVALID_INDICATORS_H_
|
||||
#define OR_TOOLS_MATH_OPT_CORE_INVALID_INDICATORS_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#ifndef OR_TOOLS_MATH_OPT_CORE_INVERTED_BOUNDS_H_
|
||||
#define OR_TOOLS_MATH_OPT_CORE_INVERTED_BOUNDS_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "ortools/base/logging.h"
|
||||
|
||||
@@ -11,29 +11,29 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
pybind11_add_module(math_opt_pybind11 MODULE solver.cc)
|
||||
set_target_properties(math_opt_pybind11 PROPERTIES
|
||||
pybind11_add_module(math_opt_core_pybind11 MODULE solver.cc)
|
||||
set_target_properties(math_opt_core_pybind11 PROPERTIES
|
||||
LIBRARY_OUTPUT_NAME "solver")
|
||||
|
||||
# note: macOS is APPLE and also UNIX !
|
||||
if(APPLE)
|
||||
set_target_properties(math_opt_pybind11 PROPERTIES
|
||||
set_target_properties(math_opt_core_pybind11 PROPERTIES
|
||||
SUFFIX ".so"
|
||||
INSTALL_RPATH
|
||||
"@loader_path;@loader_path/../../../../${PYTHON_PROJECT}/.libs;@loader_path/../../../../pybind11_abseil")
|
||||
elseif(UNIX)
|
||||
set_target_properties(math_opt_pybind11 PROPERTIES
|
||||
set_target_properties(math_opt_core_pybind11 PROPERTIES
|
||||
INSTALL_RPATH
|
||||
"$ORIGIN:$ORIGIN/../../../../${PYTHON_PROJECT}/.libs:$ORIGIN/../../../../pybind11_abseil")
|
||||
endif()
|
||||
|
||||
target_link_libraries(math_opt_pybind11 PRIVATE
|
||||
target_link_libraries(math_opt_core_pybind11 PRIVATE
|
||||
${PROJECT_NAMESPACE}::ortools
|
||||
pybind11_abseil::absl_casters
|
||||
pybind11_abseil::status_casters
|
||||
pybind11_native_proto_caster
|
||||
protobuf::libprotobuf)
|
||||
add_library(${PROJECT_NAMESPACE}::math_opt_pybind11 ALIAS math_opt_pybind11)
|
||||
add_library(${PROJECT_NAMESPACE}::math_opt_core_pybind11 ALIAS math_opt_core_pybind11)
|
||||
|
||||
if(BUILD_TESTING)
|
||||
file(GLOB PYTHON_SRCS "*_test.py")
|
||||
|
||||
@@ -17,7 +17,7 @@ import threading
|
||||
from typing import Callable, Optional, Sequence
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
from pybind11_abseil.status import StatusNotOk
|
||||
from pybind11_abseil import status
|
||||
from ortools.math_opt import callback_pb2
|
||||
from ortools.math_opt import model_parameters_pb2
|
||||
from ortools.math_opt import model_pb2
|
||||
@@ -116,7 +116,7 @@ class PybindSolverTest(parameterized.TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "id 7 not found"):
|
||||
_solve_model(model, use_solver_class=use_solver_class)
|
||||
else:
|
||||
with self.assertRaisesRegex(StatusNotOk, "id 7 not found"):
|
||||
with self.assertRaisesRegex(status.StatusNotOk, "id 7 not found"):
|
||||
_solve_model(model, use_solver_class=use_solver_class)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
|
||||
@@ -49,7 +49,7 @@ namespace {
|
||||
|
||||
// Returns an InternalError with the input status message if the input status is
|
||||
// not OK.
|
||||
absl::Status ToInternalError(const absl::Status original) {
|
||||
absl::Status ToInternalError(absl::Status original) {
|
||||
if (original.ok()) {
|
||||
return original;
|
||||
}
|
||||
@@ -201,7 +201,7 @@ Solver::ComputeInfeasibleSubsystem(
|
||||
RETURN_IF_ERROR(ValidateSolveParameters(arguments.parameters))
|
||||
<< "invalid parameters";
|
||||
|
||||
ASSIGN_OR_RETURN(const ComputeInfeasibleSubsystemResultProto result,
|
||||
ASSIGN_OR_RETURN(ComputeInfeasibleSubsystemResultProto result,
|
||||
underlying_solver_->ComputeInfeasibleSubsystem(
|
||||
arguments.parameters, arguments.message_callback,
|
||||
arguments.interrupter));
|
||||
|
||||
@@ -28,7 +28,8 @@ namespace internal {
|
||||
// This variable is intended to be used by MathOpt unit tests in other languages
|
||||
// to test the proper garbage collection. It should never be used in any other
|
||||
// context.
|
||||
OR_DLL extern std::atomic<int64_t> debug_num_solver;
|
||||
OR_DLL
|
||||
extern std::atomic<int64_t> debug_num_solver;
|
||||
|
||||
} // namespace internal
|
||||
} // namespace math_opt
|
||||
|
||||
@@ -94,7 +94,9 @@ cc_library(
|
||||
"//ortools/math_opt/storage:sparse_coefficient_map",
|
||||
"//ortools/math_opt/storage:sparse_matrix",
|
||||
"//ortools/util:fp_roundtrip_conv",
|
||||
"@abseil-cpp//absl/base:nullability",
|
||||
"@abseil-cpp//absl/log:check",
|
||||
"@abseil-cpp//absl/log:die_if_null",
|
||||
"@abseil-cpp//absl/status",
|
||||
"@abseil-cpp//absl/status:statusor",
|
||||
"@abseil-cpp//absl/strings",
|
||||
@@ -113,6 +115,7 @@ cc_library(
|
||||
"//ortools/base:intops",
|
||||
"//ortools/base:map_util",
|
||||
"//ortools/math_opt/storage:model_storage",
|
||||
"//ortools/math_opt/storage:model_storage_item",
|
||||
"//ortools/math_opt/storage:model_storage_types",
|
||||
"//ortools/util:fp_roundtrip_conv",
|
||||
"@abseil-cpp//absl/base:core_headers",
|
||||
@@ -130,8 +133,8 @@ cc_library(
|
||||
deps = [
|
||||
":key_types",
|
||||
":variable_and_expressions",
|
||||
"//ortools/base:intops",
|
||||
"//ortools/math_opt/storage:model_storage",
|
||||
"//ortools/math_opt/storage:model_storage_item",
|
||||
"//ortools/math_opt/storage:model_storage_types",
|
||||
"@abseil-cpp//absl/log:check",
|
||||
"@abseil-cpp//absl/strings",
|
||||
@@ -144,10 +147,11 @@ cc_library(
|
||||
deps = [
|
||||
":key_types",
|
||||
":variable_and_expressions",
|
||||
"//ortools/base:intops",
|
||||
"//ortools/math_opt/constraints/util:model_util",
|
||||
"//ortools/math_opt/storage:model_storage",
|
||||
"//ortools/math_opt/storage:model_storage_item",
|
||||
"//ortools/math_opt/storage:model_storage_types",
|
||||
"@abseil-cpp//absl/container:flat_hash_map",
|
||||
"@abseil-cpp//absl/log:check",
|
||||
"@abseil-cpp//absl/strings",
|
||||
],
|
||||
@@ -259,6 +263,7 @@ cc_library(
|
||||
hdrs = ["key_types.h"],
|
||||
deps = [
|
||||
"//ortools/math_opt/storage:model_storage",
|
||||
"//ortools/math_opt/storage:model_storage_item",
|
||||
"@abseil-cpp//absl/algorithm:container",
|
||||
"@abseil-cpp//absl/container:flat_hash_map",
|
||||
"@abseil-cpp//absl/status",
|
||||
|
||||
@@ -18,9 +18,7 @@
|
||||
#define OR_TOOLS_MATH_OPT_CPP_BASIS_STATUS_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "ortools/math_opt/cpp/enums.h" // IWYU pragma: export
|
||||
#include "ortools/math_opt/solution.pb.h"
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ CallbackData::CallbackData(const CallbackEvent event,
|
||||
const absl::Duration runtime)
|
||||
: event(event), runtime(runtime) {}
|
||||
|
||||
CallbackData::CallbackData(const ModelStorage* storage,
|
||||
CallbackData::CallbackData(const ModelStorageCPtr storage,
|
||||
const CallbackDataProto& proto)
|
||||
// iOS 11 does not support .value() hence we use operator* here and CHECK
|
||||
// below that we have a value.
|
||||
@@ -90,7 +90,7 @@ CallbackData::CallbackData(const ModelStorage* storage,
|
||||
}
|
||||
|
||||
absl::Status CallbackRegistration::CheckModelStorage(
|
||||
const ModelStorage* const expected_storage) const {
|
||||
const ModelStorageCPtr expected_storage) const {
|
||||
RETURN_IF_ERROR(mip_node_filter.CheckModelStorage(expected_storage))
|
||||
<< "invalid mip_node_filter";
|
||||
RETURN_IF_ERROR(mip_solution_filter.CheckModelStorage(expected_storage))
|
||||
@@ -113,7 +113,7 @@ CallbackRegistrationProto CallbackRegistration::Proto() const {
|
||||
}
|
||||
|
||||
absl::Status CallbackResult::CheckModelStorage(
|
||||
const ModelStorage* const expected_storage) const {
|
||||
const ModelStorageCPtr expected_storage) const {
|
||||
for (const GeneratedLinearConstraint& constraint : new_constraints) {
|
||||
RETURN_IF_ERROR(
|
||||
internal::CheckModelStorage(/*storage=*/constraint.storage(),
|
||||
|
||||
@@ -144,7 +144,7 @@ MATH_OPT_DEFINE_ENUM(CallbackEvent, CALLBACK_EVENT_UNSPECIFIED);
|
||||
struct CallbackRegistration {
|
||||
// Returns a failure if the referenced variables don't belong to the input
|
||||
// expected_storage (which must not be nullptr).
|
||||
absl::Status CheckModelStorage(const ModelStorage* expected_storage) const;
|
||||
absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const;
|
||||
|
||||
// Returns the proto equivalent of this object.
|
||||
//
|
||||
@@ -189,7 +189,7 @@ struct CallbackData {
|
||||
|
||||
// Users will typically not need this function.
|
||||
// Will CHECK fail if proto is not valid.
|
||||
CallbackData(const ModelStorage* storage, const CallbackDataProto& proto);
|
||||
CallbackData(ModelStorageCPtr storage, const CallbackDataProto& proto);
|
||||
|
||||
// The current state of the underlying solver.
|
||||
CallbackEvent event;
|
||||
@@ -229,7 +229,7 @@ struct CallbackResult {
|
||||
BoundedLinearExpression linear_constraint;
|
||||
bool is_lazy = false;
|
||||
|
||||
const ModelStorage* storage() const {
|
||||
NullableModelStorageCPtr storage() const {
|
||||
return linear_constraint.expression.storage();
|
||||
}
|
||||
};
|
||||
@@ -249,7 +249,7 @@ struct CallbackResult {
|
||||
|
||||
// Returns a failure if the referenced variables don't belong to the input
|
||||
// expected_storage (which must not be nullptr).
|
||||
absl::Status CheckModelStorage(const ModelStorage* expected_storage) const;
|
||||
absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const;
|
||||
|
||||
// Returns the proto equivalent of this object.
|
||||
//
|
||||
@@ -257,7 +257,17 @@ struct CallbackResult {
|
||||
// internal consistency of the referenced variables.
|
||||
CallbackResultProto Proto() const;
|
||||
|
||||
// Stop the solve process and return early. Can be called from any event.
|
||||
// When true it tells the solver to interrupt the solve as soon as possible.
|
||||
//
|
||||
// It can be set from any event. This is equivalent to using a
|
||||
// SolveInterrupter and triggering it from the callback.
|
||||
//
|
||||
// Some solvers don't support interruption, in that case this is simply
|
||||
// ignored and the solve terminates as usual. On top of that solvers may not
|
||||
// immediately stop the solve. Thus the user should expect the callback to
|
||||
// still be called after they set `terminate` to true in a previous
|
||||
// call. Returning with `terminate` false after having previously returned
|
||||
// true won't cancel the interruption.
|
||||
bool terminate = false;
|
||||
|
||||
// The user cuts and lazy constraints added. Prefer AddUserCut() and
|
||||
|
||||
@@ -76,7 +76,7 @@ template <typename K>
|
||||
absl::Status BoundsMapProtoToCpp(
|
||||
const google::protobuf::Map<int64_t, ModelSubsetProto::Bounds>& source,
|
||||
absl::flat_hash_map<K, ModelSubset::Bounds>& target,
|
||||
const ModelStorage* const model,
|
||||
const ModelStorageCPtr model,
|
||||
bool (ModelStorage::* const contains_strong_id)(typename K::IdType id)
|
||||
const,
|
||||
const absl::string_view object_name) {
|
||||
@@ -95,7 +95,7 @@ absl::Status BoundsMapProtoToCpp(
|
||||
template <typename K>
|
||||
absl::Status RepeatedIdsProtoToCpp(
|
||||
const google::protobuf::RepeatedField<int64_t>& source,
|
||||
absl::flat_hash_set<K>& target, const ModelStorage* const model,
|
||||
absl::flat_hash_set<K>& target, const ModelStorageCPtr model,
|
||||
bool (ModelStorage::* const contains_strong_id)(typename K::IdType id)
|
||||
const,
|
||||
const absl::string_view object_name) {
|
||||
@@ -134,7 +134,7 @@ google::protobuf::RepeatedField<int64_t> RepeatedIdsCppToProto(
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<ModelSubset> ModelSubset::FromProto(
|
||||
const ModelStorage* const model, const ModelSubsetProto& proto) {
|
||||
const ModelStorageCPtr model, const ModelSubsetProto& proto) {
|
||||
ModelSubset model_subset;
|
||||
RETURN_IF_ERROR(BoundsMapProtoToCpp(proto.variable_bounds(),
|
||||
model_subset.variable_bounds, model,
|
||||
@@ -184,7 +184,7 @@ ModelSubsetProto ModelSubset::Proto() const {
|
||||
}
|
||||
|
||||
absl::Status ModelSubset::CheckModelStorage(
|
||||
const ModelStorage* const expected_storage) const {
|
||||
const ModelStorageCPtr expected_storage) const {
|
||||
const auto validate_map_keys =
|
||||
[expected_storage](const auto& map,
|
||||
const absl::string_view name) -> absl::Status {
|
||||
@@ -348,7 +348,7 @@ std::ostream& operator<<(std::ostream& out, const ModelSubset& model_subset) {
|
||||
|
||||
absl::StatusOr<ComputeInfeasibleSubsystemResult>
|
||||
ComputeInfeasibleSubsystemResult::FromProto(
|
||||
const ModelStorage* const model,
|
||||
const ModelStorageCPtr model,
|
||||
const ComputeInfeasibleSubsystemResultProto& result_proto) {
|
||||
ComputeInfeasibleSubsystemResult result;
|
||||
const std::optional<FeasibilityStatus> feasibility =
|
||||
@@ -383,7 +383,7 @@ ComputeInfeasibleSubsystemResultProto ComputeInfeasibleSubsystemResult::Proto()
|
||||
}
|
||||
|
||||
absl::Status ComputeInfeasibleSubsystemResult::CheckModelStorage(
|
||||
const ModelStorage* const expected_storage) const {
|
||||
const ModelStorageCPtr expected_storage) const {
|
||||
return infeasible_subsystem.CheckModelStorage(expected_storage);
|
||||
}
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ struct ModelSubset {
|
||||
//
|
||||
// Returns an error when `model` does not contain a variable or constraint
|
||||
// associated with an index present in `proto`.
|
||||
static absl::StatusOr<ModelSubset> FromProto(const ModelStorage* model,
|
||||
static absl::StatusOr<ModelSubset> FromProto(ModelStorageCPtr model,
|
||||
const ModelSubsetProto& proto);
|
||||
|
||||
// Returns the proto equivalent of this object.
|
||||
@@ -74,8 +74,8 @@ struct ModelSubset {
|
||||
ModelSubsetProto Proto() const;
|
||||
|
||||
// Returns a failure if the `Variable` and Constraints contained in the fields
|
||||
// do not belong to the input expected_storage (which must not be nullptr).
|
||||
absl::Status CheckModelStorage(const ModelStorage* expected_storage) const;
|
||||
// do not belong to the input expected_storage.
|
||||
absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const;
|
||||
|
||||
// True if this object corresponds to the empty subset.
|
||||
bool empty() const;
|
||||
@@ -105,7 +105,7 @@ struct ComputeInfeasibleSubsystemResult {
|
||||
// index present in `proto.infeasible_subsystem`.
|
||||
// * ValidateComputeInfeasibleSubsystemResultNoModel(result_proto) fails.
|
||||
static absl::StatusOr<ComputeInfeasibleSubsystemResult> FromProto(
|
||||
const ModelStorage* model,
|
||||
ModelStorageCPtr model,
|
||||
const ComputeInfeasibleSubsystemResultProto& result_proto);
|
||||
|
||||
// Returns the proto equivalent of this object.
|
||||
@@ -116,8 +116,8 @@ struct ComputeInfeasibleSubsystemResult {
|
||||
ComputeInfeasibleSubsystemResultProto Proto() const;
|
||||
|
||||
// Returns a failure if this object contains references to a model other than
|
||||
// `expected_storage` (which must not be nullptr).
|
||||
absl::Status CheckModelStorage(const ModelStorage* expected_storage) const;
|
||||
// `expected_storage`.
|
||||
absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const;
|
||||
|
||||
// The primal feasibility status of the model, as determined by the solver.
|
||||
FeasibilityStatus feasibility = FeasibilityStatus::kUndetermined;
|
||||
|
||||
@@ -25,14 +25,16 @@
|
||||
//
|
||||
// A key type K must match the following requirements:
|
||||
// - K::IdType is a value type used for indices.
|
||||
// - K has a constructor K(const ModelStorage*, K::IdType).
|
||||
// - K has a constructor K(ModelStorageCPtr, K::IdType).
|
||||
// - K is a value-semantic type.
|
||||
// - K has a function with signature `K::IdType K::typed_id() const`.
|
||||
// - K has a function with signature `const ModelStorage* K::storage() const`.
|
||||
// It must return a non-null pointer.
|
||||
// - K has a function with signature `ModelStorageCPtr K::storage() const`.
|
||||
// - K::IdType is a valid key for absl::flat_hash_map or absl::flat_hash_set
|
||||
// (supports hash and ==).
|
||||
// - the is_key_type_v<> below should include them.
|
||||
// TODO(b/396580721): Those requirements are those of `ModelStorageElement`.
|
||||
// Once we've migrated most key types to `ModelStorageElement`, we should be
|
||||
// able to simplify this code.
|
||||
#ifndef OR_TOOLS_MATH_OPT_CPP_KEY_TYPES_H_
|
||||
#define OR_TOOLS_MATH_OPT_CPP_KEY_TYPES_H_
|
||||
|
||||
@@ -45,6 +47,7 @@
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "ortools/math_opt/storage/model_storage.h"
|
||||
#include "ortools/math_opt/storage/model_storage_item.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
@@ -70,11 +73,9 @@ class Objective;
|
||||
// the values in the hash map are in the math_opt namespace.
|
||||
template <typename T>
|
||||
constexpr inline bool is_key_type_v =
|
||||
(std::is_same_v<T, Variable> || std::is_same_v<T, LinearConstraint> ||
|
||||
std::is_same_v<T, QuadraticConstraint> ||
|
||||
(is_model_storage_element<T>::value ||
|
||||
std::is_same_v<T, SecondOrderConeConstraint> ||
|
||||
std::is_same_v<T, Sos1Constraint> || std::is_same_v<T, Sos2Constraint> ||
|
||||
std::is_same_v<T, IndicatorConstraint> ||
|
||||
std::is_same_v<T, QuadraticTermKey> || std::is_same_v<T, Objective>);
|
||||
|
||||
// Returns the keys of the map sorted by their (storage(), type_id()).
|
||||
@@ -162,12 +163,12 @@ inline constexpr absl::string_view kInputFromInvalidModelStorage =
|
||||
"the input does not belong to the same model";
|
||||
|
||||
// Returns a failure when the input pointer is not nullptr and points to a
|
||||
// different model storage than expected_storage (which must not be nullptr).
|
||||
// different model storage than expected_storage.
|
||||
//
|
||||
// Failure message is kInputFromInvalidModelStorage.
|
||||
inline absl::Status CheckModelStorage(
|
||||
const ModelStorage* const storage,
|
||||
const ModelStorage* const expected_storage) {
|
||||
inline absl::Status CheckModelStorage(const NullableModelStorageCPtr storage,
|
||||
const ModelStorageCPtr expected_storage) {
|
||||
// This is not allowed by the contract, but let's be safe.
|
||||
if (expected_storage == nullptr) {
|
||||
return absl::InternalError("expected_storage is nullptr");
|
||||
}
|
||||
|
||||
@@ -18,19 +18,18 @@
|
||||
#ifndef OR_TOOLS_MATH_OPT_CPP_LINEAR_CONSTRAINT_H_
|
||||
#define OR_TOOLS_MATH_OPT_CPP_LINEAR_CONSTRAINT_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "ortools/base/strong_int.h"
|
||||
#include "ortools/math_opt/constraints/util/model_util.h"
|
||||
#include "ortools/math_opt/cpp/key_types.h"
|
||||
#include "ortools/math_opt/cpp/variable_and_expressions.h"
|
||||
#include "ortools/math_opt/storage/model_storage.h"
|
||||
#include "ortools/math_opt/storage/model_storage_item.h"
|
||||
#include "ortools/math_opt/storage/model_storage_types.h"
|
||||
|
||||
namespace operations_research {
|
||||
@@ -38,19 +37,11 @@ namespace math_opt {
|
||||
|
||||
// A value type that references a linear constraint from ModelStorage. Usually
|
||||
// this type is passed by copy.
|
||||
//
|
||||
// This type implements https://abseil.io/docs/cpp/guides/hash.
|
||||
class LinearConstraint {
|
||||
class LinearConstraint final
|
||||
: public ModelStorageElement<ElementType::kLinearConstraint,
|
||||
LinearConstraint> {
|
||||
public:
|
||||
// The typed integer used for ids.
|
||||
using IdType = LinearConstraintId;
|
||||
|
||||
inline LinearConstraint(const ModelStorage* storage, LinearConstraintId id);
|
||||
|
||||
inline int64_t id() const;
|
||||
|
||||
inline LinearConstraintId typed_id() const;
|
||||
inline const ModelStorage* storage() const;
|
||||
using ModelStorageElement::ModelStorageElement;
|
||||
|
||||
inline double lower_bound() const;
|
||||
inline double upper_bound() const;
|
||||
@@ -76,65 +67,42 @@ class LinearConstraint {
|
||||
// Returns a detailed string description of the contents of the constraint
|
||||
// (not its name, use `<<` for that instead).
|
||||
inline std::string ToString() const;
|
||||
|
||||
friend inline bool operator==(const LinearConstraint& lhs,
|
||||
const LinearConstraint& rhs);
|
||||
friend inline bool operator!=(const LinearConstraint& lhs,
|
||||
const LinearConstraint& rhs);
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const LinearConstraint& linear_constraint);
|
||||
friend std::ostream& operator<<(std::ostream& ostr,
|
||||
const LinearConstraint& linear_constraint);
|
||||
|
||||
private:
|
||||
const ModelStorage* storage_;
|
||||
LinearConstraintId id_;
|
||||
};
|
||||
|
||||
template <typename V>
|
||||
using LinearConstraintMap = absl::flat_hash_map<LinearConstraint, V>;
|
||||
|
||||
// Streams the name of the constraint, as registered upon constraint creation,
|
||||
// or a short default if none was provided.
|
||||
inline std::ostream& operator<<(std::ostream& ostr,
|
||||
const LinearConstraint& linear_constraint);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Inline function implementations
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
int64_t LinearConstraint::id() const { return id_.value(); }
|
||||
|
||||
LinearConstraintId LinearConstraint::typed_id() const { return id_; }
|
||||
|
||||
const ModelStorage* LinearConstraint::storage() const { return storage_; }
|
||||
|
||||
double LinearConstraint::lower_bound() const {
|
||||
return storage_->linear_constraint_lower_bound(id_);
|
||||
return storage()->linear_constraint_lower_bound(typed_id());
|
||||
}
|
||||
|
||||
double LinearConstraint::upper_bound() const {
|
||||
return storage_->linear_constraint_upper_bound(id_);
|
||||
return storage()->linear_constraint_upper_bound(typed_id());
|
||||
}
|
||||
|
||||
absl::string_view LinearConstraint::name() const {
|
||||
if (storage()->has_linear_constraint(id_)) {
|
||||
return storage_->linear_constraint_name(id_);
|
||||
if (storage()->has_linear_constraint(typed_id())) {
|
||||
return storage()->linear_constraint_name(typed_id());
|
||||
}
|
||||
return kDeletedConstraintDefaultDescription;
|
||||
}
|
||||
|
||||
bool LinearConstraint::is_coefficient_nonzero(const Variable variable) const {
|
||||
CHECK_EQ(variable.storage(), storage_)
|
||||
CHECK_EQ(variable.storage(), storage())
|
||||
<< internal::kObjectsFromOtherModelStorage;
|
||||
return storage_->is_linear_constraint_coefficient_nonzero(
|
||||
id_, variable.typed_id());
|
||||
return storage()->is_linear_constraint_coefficient_nonzero(
|
||||
typed_id(), variable.typed_id());
|
||||
}
|
||||
|
||||
double LinearConstraint::coefficient(const Variable variable) const {
|
||||
CHECK_EQ(variable.storage(), storage_)
|
||||
CHECK_EQ(variable.storage(), storage())
|
||||
<< internal::kObjectsFromOtherModelStorage;
|
||||
return storage_->linear_constraint_coefficient(id_, variable.typed_id());
|
||||
return storage()->linear_constraint_coefficient(typed_id(),
|
||||
variable.typed_id());
|
||||
}
|
||||
|
||||
BoundedLinearExpression LinearConstraint::AsBoundedLinearExpression() const {
|
||||
@@ -150,7 +118,7 @@ BoundedLinearExpression LinearConstraint::AsBoundedLinearExpression() const {
|
||||
}
|
||||
|
||||
std::string LinearConstraint::ToString() const {
|
||||
if (!storage()->has_linear_constraint(id_)) {
|
||||
if (!storage()->has_linear_constraint(typed_id())) {
|
||||
return std::string(kDeletedConstraintDefaultDescription);
|
||||
}
|
||||
std::stringstream str;
|
||||
@@ -158,36 +126,6 @@ std::string LinearConstraint::ToString() const {
|
||||
return str.str();
|
||||
}
|
||||
|
||||
bool operator==(const LinearConstraint& lhs, const LinearConstraint& rhs) {
|
||||
return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_;
|
||||
}
|
||||
|
||||
bool operator!=(const LinearConstraint& lhs, const LinearConstraint& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
template <typename H>
|
||||
H AbslHashValue(H h, const LinearConstraint& linear_constraint) {
|
||||
return H::combine(std::move(h), linear_constraint.id_.value(),
|
||||
linear_constraint.storage_);
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& ostr,
|
||||
const LinearConstraint& linear_constraint) {
|
||||
// TODO(b/170992529): handle quoting of invalid characters in the name.
|
||||
const absl::string_view name = linear_constraint.name();
|
||||
if (name.empty()) {
|
||||
ostr << "__lin_con#" << linear_constraint.id() << "__";
|
||||
} else {
|
||||
ostr << name;
|
||||
}
|
||||
return ostr;
|
||||
}
|
||||
|
||||
LinearConstraint::LinearConstraint(const ModelStorage* const storage,
|
||||
const LinearConstraintId id)
|
||||
: storage_(storage), id_(id) {}
|
||||
|
||||
} // namespace math_opt
|
||||
} // namespace operations_research
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ struct MapFilter {
|
||||
// Returns a failure if the keys don't belong to the input expected_storage
|
||||
// (which must not be nullptr).
|
||||
inline absl::Status CheckModelStorage(
|
||||
const ModelStorage* expected_storage) const;
|
||||
ModelStorageCPtr expected_storage) const;
|
||||
|
||||
// Returns the proto corresponding to this filter.
|
||||
//
|
||||
@@ -192,7 +192,7 @@ MapFilter<KeyType> MakeKeepKeysFilter(std::initializer_list<KeyType> keys) {
|
||||
|
||||
template <typename KeyType>
|
||||
absl::Status MapFilter<KeyType>::CheckModelStorage(
|
||||
const ModelStorage* expected_storage) const {
|
||||
const ModelStorageCPtr expected_storage) const {
|
||||
if (!filtered_keys.has_value()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "gtest/gtest.h"
|
||||
@@ -871,7 +872,7 @@ std::vector<TerminationReason> CompatibleReasons(
|
||||
}
|
||||
|
||||
Matcher<std::vector<Solution>> CheckSolutions(
|
||||
const std::vector<Solution>& expected_solutions,
|
||||
absl::Span<const Solution> expected_solutions,
|
||||
const SolveResultMatcherOptions& options) {
|
||||
if (options.first_solution_only && !expected_solutions.empty()) {
|
||||
return FirstElementIs(
|
||||
|
||||
@@ -37,7 +37,7 @@ class PrinterMessageCallbackImpl {
|
||||
const absl::string_view prefix)
|
||||
: output_stream_(output_stream), prefix_(prefix) {}
|
||||
|
||||
void Call(const std::vector<std::string>& messages) {
|
||||
void Call(absl::Span<const std::string> messages) {
|
||||
const absl::MutexLock lock(&mutex_);
|
||||
for (const std::string& message : messages) {
|
||||
output_stream_ << prefix_ << message << '\n';
|
||||
@@ -56,7 +56,7 @@ void PushBack(absl::Span<const std::string> messages,
|
||||
sink->insert(sink->end(), messages.begin(), messages.end());
|
||||
}
|
||||
|
||||
void PushBack(const std::vector<std::string>& messages,
|
||||
void PushBack(absl::Span<const std::string> messages,
|
||||
google::protobuf::RepeatedPtrField<std::string>* const sink) {
|
||||
std::copy(messages.begin(), messages.end(),
|
||||
google::protobuf::RepeatedFieldBackInserter(sink));
|
||||
@@ -68,7 +68,7 @@ class VectorLikeMessageCallbackImpl {
|
||||
explicit VectorLikeMessageCallbackImpl(Sink* const sink)
|
||||
: sink_(ABSL_DIE_IF_NULL(sink)) {}
|
||||
|
||||
void Call(const std::vector<std::string>& messages) {
|
||||
void Call(absl::Span<const std::string> messages) {
|
||||
const absl::MutexLock lock(&mutex_);
|
||||
PushBack(messages, sink_);
|
||||
}
|
||||
@@ -102,7 +102,7 @@ MessageCallback InfoLoggerMessageCallback(const absl::string_view prefix,
|
||||
|
||||
MessageCallback VLoggerMessageCallback(int level, absl::string_view prefix,
|
||||
absl::SourceLocation loc) {
|
||||
return [=](const std::vector<std::string>& messages) {
|
||||
return [=](absl::Span<const std::string> messages) {
|
||||
for (const std::string& message : messages) {
|
||||
VLOG(level).AtLocation(loc.file_name(), loc.line()) << prefix << message;
|
||||
}
|
||||
@@ -117,8 +117,7 @@ MessageCallback VectorMessageCallback(std::vector<std::string>* sink) {
|
||||
const auto impl =
|
||||
std::make_shared<VectorLikeMessageCallbackImpl<std::vector<std::string>>>(
|
||||
sink);
|
||||
return
|
||||
[=](const std::vector<std::string>& messages) { impl->Call(messages); };
|
||||
return [=](absl::Span<const std::string> messages) { impl->Call(messages); };
|
||||
}
|
||||
|
||||
MessageCallback RepeatedPtrFieldMessageCallback(
|
||||
|
||||
@@ -22,7 +22,9 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/nullability.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/die_if_null.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
@@ -53,7 +55,7 @@ constexpr double kInf = std::numeric_limits<double>::infinity();
|
||||
|
||||
absl::StatusOr<std::unique_ptr<Model>> Model::FromModelProto(
|
||||
const ModelProto& model_proto) {
|
||||
ASSIGN_OR_RETURN(std::unique_ptr<ModelStorage> storage,
|
||||
ASSIGN_OR_RETURN(absl::Nonnull<std::unique_ptr<ModelStorage>> storage,
|
||||
ModelStorage::FromModelProto(model_proto));
|
||||
return std::make_unique<Model>(std::move(storage));
|
||||
}
|
||||
@@ -61,10 +63,10 @@ absl::StatusOr<std::unique_ptr<Model>> Model::FromModelProto(
|
||||
Model::Model(const absl::string_view name)
|
||||
: storage_(std::make_shared<ModelStorage>(name)) {}
|
||||
|
||||
Model::Model(std::unique_ptr<ModelStorage> storage)
|
||||
: storage_(std::move(storage)) {}
|
||||
Model::Model(absl::Nonnull<std::unique_ptr<ModelStorage>> storage)
|
||||
: storage_(ABSL_DIE_IF_NULL(std::move(storage))) {}
|
||||
|
||||
std::unique_ptr<Model> Model::Clone(
|
||||
absl::Nonnull<std::unique_ptr<Model>> Model::Clone(
|
||||
const std::optional<absl::string_view> new_name) const {
|
||||
return std::make_unique<Model>(storage_->Clone(new_name));
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
#include <ostream>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/nullability.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
@@ -136,7 +137,7 @@ class Model {
|
||||
// This constructor is used when loading a model, for example from a
|
||||
// ModelProto or an MPS file. Note that in those cases the FromModelProto()
|
||||
// should be used.
|
||||
explicit Model(std::unique_ptr<ModelStorage> storage);
|
||||
explicit Model(absl::Nonnull<std::unique_ptr<ModelStorage>> storage);
|
||||
|
||||
Model(const Model&) = delete;
|
||||
Model& operator=(const Model&) = delete;
|
||||
@@ -158,7 +159,7 @@ class Model {
|
||||
// * in an arbitrary order using Variables() and LinearConstraints().
|
||||
//
|
||||
// Note that the returned model does not have any update tracker.
|
||||
std::unique_ptr<Model> Clone(
|
||||
absl::Nonnull<std::unique_ptr<Model>> Clone(
|
||||
std::optional<absl::string_view> new_name = std::nullopt) const;
|
||||
|
||||
inline absl::string_view name() const;
|
||||
@@ -893,13 +894,13 @@ class Model {
|
||||
//
|
||||
// This API is for internal use only and regular users should have no need for
|
||||
// it.
|
||||
const ModelStorage* storage() const { return storage_.get(); }
|
||||
ModelStorageCPtr storage() const { return storage_.get(); }
|
||||
|
||||
// Returns a pointer to the underlying model storage.
|
||||
//
|
||||
// This API is for internal use only and regular users should have no need for
|
||||
// it.
|
||||
ModelStorage* storage() { return storage_.get(); }
|
||||
ModelStoragePtr storage() { return storage_.get(); }
|
||||
|
||||
// Prints the objective, the constraints and the variables of the model over
|
||||
// several lines in a human-readable way. Includes a new line at the end of
|
||||
@@ -911,12 +912,12 @@ class Model {
|
||||
// points to the same model as storage_.
|
||||
//
|
||||
// Use CheckModel() when nullptr is not a valid value.
|
||||
inline void CheckOptionalModel(const ModelStorage* other_storage) const;
|
||||
inline void CheckOptionalModel(NullableModelStorageCPtr other_storage) const;
|
||||
|
||||
// Asserts (with CHECK) that the input pointer is the same as storage_.
|
||||
//
|
||||
// Use CheckOptionalModel() if nullptr is a valid value too.
|
||||
inline void CheckModel(const ModelStorage* other_storage) const;
|
||||
inline void CheckModel(ModelStorageCPtr other_storage) const;
|
||||
|
||||
// Don't use storage_ directly; prefer to use storage() so that const member
|
||||
// functions don't have modifying access to the underlying storage.
|
||||
@@ -924,7 +925,7 @@ class Model {
|
||||
// We use a shared_ptr here so that the UpdateTracker class can have a
|
||||
// weak_ptr on the ModelStorage. This let it have a destructor that don't
|
||||
// crash when called after the destruction of the associated Model.
|
||||
const std::shared_ptr<ModelStorage> storage_;
|
||||
const absl::Nonnull<std::shared_ptr<ModelStorage>> storage_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -973,7 +974,7 @@ int64_t Model::next_variable_id() const {
|
||||
}
|
||||
|
||||
bool Model::has_variable(const int64_t id) const {
|
||||
return has_variable(VariableId(id));
|
||||
return id < 0 ? false : has_variable(VariableId(id));
|
||||
}
|
||||
|
||||
bool Model::has_variable(const VariableId id) const {
|
||||
@@ -1074,7 +1075,7 @@ int64_t Model::next_linear_constraint_id() const {
|
||||
}
|
||||
|
||||
bool Model::has_linear_constraint(const int64_t id) const {
|
||||
return has_linear_constraint(LinearConstraintId(id));
|
||||
return id < 0 ? false : has_linear_constraint(LinearConstraintId(id));
|
||||
}
|
||||
|
||||
bool Model::has_linear_constraint(const LinearConstraintId id) const {
|
||||
@@ -1082,6 +1083,7 @@ bool Model::has_linear_constraint(const LinearConstraintId id) const {
|
||||
}
|
||||
|
||||
LinearConstraint Model::linear_constraint(const int64_t id) const {
|
||||
CHECK_GE(id, 0) << "negative linear constraint id: " << id;
|
||||
return linear_constraint(LinearConstraintId(id));
|
||||
}
|
||||
|
||||
@@ -1176,7 +1178,7 @@ int64_t Model::next_quadratic_constraint_id() const {
|
||||
}
|
||||
|
||||
bool Model::has_quadratic_constraint(const int64_t id) const {
|
||||
return has_quadratic_constraint(QuadraticConstraintId(id));
|
||||
return id < 0 ? false : has_quadratic_constraint(QuadraticConstraintId(id));
|
||||
}
|
||||
|
||||
bool Model::has_quadratic_constraint(const QuadraticConstraintId id) const {
|
||||
@@ -1184,6 +1186,7 @@ bool Model::has_quadratic_constraint(const QuadraticConstraintId id) const {
|
||||
}
|
||||
|
||||
QuadraticConstraint Model::quadratic_constraint(const int64_t id) const {
|
||||
CHECK_GE(id, 0) << "negative quadratic constraint id: " << id;
|
||||
return quadratic_constraint(QuadraticConstraintId(id));
|
||||
}
|
||||
|
||||
@@ -1219,7 +1222,9 @@ int64_t Model::next_second_order_cone_constraint_id() const {
|
||||
}
|
||||
|
||||
bool Model::has_second_order_cone_constraint(const int64_t id) const {
|
||||
return has_second_order_cone_constraint(SecondOrderConeConstraintId(id));
|
||||
return id < 0 ? false
|
||||
: has_second_order_cone_constraint(
|
||||
SecondOrderConeConstraintId(id));
|
||||
}
|
||||
|
||||
bool Model::has_second_order_cone_constraint(
|
||||
@@ -1265,7 +1270,7 @@ int64_t Model::next_sos1_constraint_id() const {
|
||||
}
|
||||
|
||||
bool Model::has_sos1_constraint(const int64_t id) const {
|
||||
return has_sos1_constraint(Sos1ConstraintId(id));
|
||||
return id < 0 ? false : has_sos1_constraint(Sos1ConstraintId(id));
|
||||
}
|
||||
|
||||
bool Model::has_sos1_constraint(const Sos1ConstraintId id) const {
|
||||
@@ -1306,7 +1311,7 @@ int64_t Model::next_sos2_constraint_id() const {
|
||||
}
|
||||
|
||||
bool Model::has_sos2_constraint(const int64_t id) const {
|
||||
return has_sos2_constraint(Sos2ConstraintId(id));
|
||||
return id < 0 ? false : has_sos2_constraint(Sos2ConstraintId(id));
|
||||
}
|
||||
|
||||
bool Model::has_sos2_constraint(const Sos2ConstraintId id) const {
|
||||
@@ -1347,7 +1352,7 @@ int64_t Model::next_indicator_constraint_id() const {
|
||||
}
|
||||
|
||||
bool Model::has_indicator_constraint(const int64_t id) const {
|
||||
return has_indicator_constraint(IndicatorConstraintId(id));
|
||||
return id < 0 ? false : has_indicator_constraint(IndicatorConstraintId(id));
|
||||
}
|
||||
|
||||
bool Model::has_indicator_constraint(const IndicatorConstraintId id) const {
|
||||
@@ -1555,7 +1560,7 @@ int64_t Model::next_auxiliary_objective_id() const {
|
||||
}
|
||||
|
||||
bool Model::has_auxiliary_objective(const int64_t id) const {
|
||||
return has_auxiliary_objective(AuxiliaryObjectiveId(id));
|
||||
return id < 0 ? false : has_auxiliary_objective(AuxiliaryObjectiveId(id));
|
||||
}
|
||||
|
||||
bool Model::has_auxiliary_objective(const AuxiliaryObjectiveId id) const {
|
||||
@@ -1563,6 +1568,7 @@ bool Model::has_auxiliary_objective(const AuxiliaryObjectiveId id) const {
|
||||
}
|
||||
|
||||
Objective Model::auxiliary_objective(const int64_t id) const {
|
||||
CHECK_GE(id, 0) << "negative auxiliary objective id: " << id;
|
||||
return auxiliary_objective(AuxiliaryObjectiveId(id));
|
||||
}
|
||||
|
||||
@@ -1618,14 +1624,15 @@ void Model::set_is_maximize(const Objective objective, const bool is_maximize) {
|
||||
storage()->set_is_maximize(objective.typed_id(), is_maximize);
|
||||
}
|
||||
|
||||
void Model::CheckOptionalModel(const ModelStorage* const other_storage) const {
|
||||
void Model::CheckOptionalModel(
|
||||
const NullableModelStorageCPtr other_storage) const {
|
||||
if (other_storage != nullptr) {
|
||||
CHECK_EQ(other_storage, storage())
|
||||
<< internal::kObjectsFromOtherModelStorage;
|
||||
}
|
||||
}
|
||||
|
||||
void Model::CheckModel(const ModelStorage* const other_storage) const {
|
||||
void Model::CheckModel(const ModelStorageCPtr other_storage) const {
|
||||
CHECK_EQ(other_storage, storage()) << internal::kObjectsFromOtherModelStorage;
|
||||
}
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ ModelSolveParameters ModelSolveParameters::OnlySomePrimalVariables(
|
||||
}
|
||||
|
||||
absl::Status ModelSolveParameters::CheckModelStorage(
|
||||
const ModelStorage* const expected_storage) const {
|
||||
const ModelStorageCPtr expected_storage) const {
|
||||
for (const SolutionHint& hint : solution_hints) {
|
||||
RETURN_IF_ERROR(hint.CheckModelStorage(expected_storage))
|
||||
<< "invalid hint in solution_hints";
|
||||
@@ -100,7 +100,7 @@ absl::Status ModelSolveParameters::CheckModelStorage(
|
||||
}
|
||||
|
||||
absl::Status ModelSolveParameters::SolutionHint::CheckModelStorage(
|
||||
const ModelStorage* expected_storage) const {
|
||||
const ModelStorageCPtr expected_storage) const {
|
||||
for (const auto& [v, _] : variable_values) {
|
||||
RETURN_IF_ERROR(internal::CheckModelStorage(
|
||||
/*storage=*/v.storage(),
|
||||
|
||||
@@ -128,7 +128,7 @@ struct ModelSolveParameters {
|
||||
|
||||
// Returns a failure if the referenced variables and constraints don't
|
||||
// belong to the input expected_storage (which must not be nullptr).
|
||||
absl::Status CheckModelStorage(const ModelStorage* expected_storage) const;
|
||||
absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const;
|
||||
|
||||
// Returns the proto equivalent of this object.
|
||||
//
|
||||
@@ -215,7 +215,7 @@ struct ModelSolveParameters {
|
||||
|
||||
// Returns a failure if the referenced variables and constraints do not belong
|
||||
// to the input expected_storage (which must not be nullptr).
|
||||
absl::Status CheckModelStorage(const ModelStorage* expected_storage) const;
|
||||
absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const;
|
||||
|
||||
// Returns the proto equivalent of this object.
|
||||
//
|
||||
|
||||
@@ -31,18 +31,18 @@ LinearExpression Objective::AsLinearExpression() const {
|
||||
<< "The objective function contains quadratic terms and cannot be "
|
||||
"represented as a LinearExpression";
|
||||
LinearExpression objective = offset();
|
||||
for (const auto [raw_var_id, coeff] : storage_->linear_objective(id_)) {
|
||||
objective += coeff * Variable(storage_, raw_var_id);
|
||||
for (const auto [raw_var_id, coeff] : storage()->linear_objective(id_)) {
|
||||
objective += coeff * Variable(storage(), raw_var_id);
|
||||
}
|
||||
return objective;
|
||||
}
|
||||
|
||||
QuadraticExpression Objective::AsQuadraticExpression() const {
|
||||
QuadraticExpression result = offset();
|
||||
for (const auto& [v, coef] : storage_->linear_objective(id_)) {
|
||||
for (const auto& [v, coef] : storage()->linear_objective(id_)) {
|
||||
result += coef * Variable(storage(), v);
|
||||
}
|
||||
for (const auto& [v1, v2, coef] : storage_->quadratic_objective_terms(id_)) {
|
||||
for (const auto& [v1, v2, coef] : storage()->quadratic_objective_terms(id_)) {
|
||||
result +=
|
||||
QuadraticTerm(Variable(storage(), v1), Variable(storage(), v2), coef);
|
||||
}
|
||||
|
||||
@@ -25,10 +25,10 @@
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "ortools/base/strong_int.h"
|
||||
#include "ortools/math_opt/cpp/key_types.h"
|
||||
#include "ortools/math_opt/cpp/variable_and_expressions.h"
|
||||
#include "ortools/math_opt/storage/model_storage.h"
|
||||
#include "ortools/math_opt/storage/model_storage_item.h"
|
||||
#include "ortools/math_opt/storage/model_storage_types.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
@@ -40,15 +40,15 @@ constexpr absl::string_view kDeletedObjectiveDefaultDescription =
|
||||
// ModelStorage. Usually this type is passed by copy.
|
||||
//
|
||||
// This type implements https://abseil.io/docs/cpp/guides/hash.
|
||||
class Objective {
|
||||
class Objective final : public ModelStorageItem {
|
||||
public:
|
||||
// The type used for ids.
|
||||
using IdType = AuxiliaryObjectiveId;
|
||||
|
||||
// Returns an object that refers to the primary objective of the model.
|
||||
inline static Objective Primary(const ModelStorage* storage);
|
||||
inline static Objective Primary(ModelStorageCPtr storage);
|
||||
// Returns an object that refers to an auxiliary objective of the model.
|
||||
inline static Objective Auxiliary(const ModelStorage* storage,
|
||||
inline static Objective Auxiliary(ModelStorageCPtr storage,
|
||||
AuxiliaryObjectiveId id);
|
||||
|
||||
// Returns the raw integer ID associated with the objective: nullopt for the
|
||||
@@ -57,8 +57,6 @@ class Objective {
|
||||
// Returns the strong int ID associated with the objective: nullopt for the
|
||||
// primary objective, an AuxiliaryObjectiveId for an auxiliary objective.
|
||||
inline ObjectiveId typed_id() const;
|
||||
// Returns a const-pointer to the underlying storage object for the model.
|
||||
inline const ModelStorage* storage() const;
|
||||
|
||||
// Returns true if the ID corresponds to the primary objective, and false if
|
||||
// it is an auxiliary objective.
|
||||
@@ -113,9 +111,8 @@ class Objective {
|
||||
const Objective& objective);
|
||||
|
||||
private:
|
||||
inline Objective(const ModelStorage* storage, ObjectiveId id);
|
||||
inline Objective(ModelStorageCPtr storage, ObjectiveId id);
|
||||
|
||||
const ModelStorage* storage_;
|
||||
ObjectiveId id_;
|
||||
};
|
||||
|
||||
@@ -139,68 +136,66 @@ std::optional<int64_t> Objective::id() const {
|
||||
|
||||
ObjectiveId Objective::typed_id() const { return id_; }
|
||||
|
||||
const ModelStorage* Objective::storage() const { return storage_; }
|
||||
|
||||
bool Objective::is_primary() const { return id_ == kPrimaryObjectiveId; }
|
||||
|
||||
int64_t Objective::priority() const {
|
||||
return storage_->objective_priority(id_);
|
||||
return storage()->objective_priority(id_);
|
||||
}
|
||||
|
||||
bool Objective::maximize() const { return storage_->is_maximize(id_); }
|
||||
bool Objective::maximize() const { return storage()->is_maximize(id_); }
|
||||
|
||||
absl::string_view Objective::name() const {
|
||||
if (is_primary() || storage_->has_auxiliary_objective(*id_)) {
|
||||
return storage_->objective_name(id_);
|
||||
if (is_primary() || storage()->has_auxiliary_objective(*id_)) {
|
||||
return storage()->objective_name(id_);
|
||||
}
|
||||
return kDeletedObjectiveDefaultDescription;
|
||||
}
|
||||
|
||||
double Objective::offset() const { return storage_->objective_offset(id_); }
|
||||
double Objective::offset() const { return storage()->objective_offset(id_); }
|
||||
|
||||
int64_t Objective::num_quadratic_terms() const {
|
||||
return storage_->num_quadratic_objective_terms(id_);
|
||||
return storage()->num_quadratic_objective_terms(id_);
|
||||
}
|
||||
|
||||
int64_t Objective::num_linear_terms() const {
|
||||
return storage_->num_linear_objective_terms(id_);
|
||||
return storage()->num_linear_objective_terms(id_);
|
||||
}
|
||||
|
||||
double Objective::coefficient(const Variable variable) const {
|
||||
CHECK_EQ(variable.storage(), storage_)
|
||||
CHECK_EQ(variable.storage(), storage())
|
||||
<< internal::kObjectsFromOtherModelStorage;
|
||||
return storage_->linear_objective_coefficient(id_, variable.typed_id());
|
||||
return storage()->linear_objective_coefficient(id_, variable.typed_id());
|
||||
}
|
||||
|
||||
double Objective::coefficient(const Variable first_variable,
|
||||
const Variable second_variable) const {
|
||||
CHECK_EQ(first_variable.storage(), storage_)
|
||||
CHECK_EQ(first_variable.storage(), storage())
|
||||
<< internal::kObjectsFromOtherModelStorage;
|
||||
CHECK_EQ(second_variable.storage(), storage_)
|
||||
CHECK_EQ(second_variable.storage(), storage())
|
||||
<< internal::kObjectsFromOtherModelStorage;
|
||||
return storage_->quadratic_objective_coefficient(
|
||||
return storage()->quadratic_objective_coefficient(
|
||||
id_, first_variable.typed_id(), second_variable.typed_id());
|
||||
}
|
||||
|
||||
bool Objective::is_coefficient_nonzero(const Variable variable) const {
|
||||
CHECK_EQ(variable.storage(), storage_)
|
||||
CHECK_EQ(variable.storage(), storage())
|
||||
<< internal::kObjectsFromOtherModelStorage;
|
||||
return storage_->is_linear_objective_coefficient_nonzero(id_,
|
||||
variable.typed_id());
|
||||
return storage()->is_linear_objective_coefficient_nonzero(
|
||||
id_, variable.typed_id());
|
||||
}
|
||||
|
||||
bool Objective::is_coefficient_nonzero(const Variable first_variable,
|
||||
const Variable second_variable) const {
|
||||
CHECK_EQ(first_variable.storage(), storage_)
|
||||
CHECK_EQ(first_variable.storage(), storage())
|
||||
<< internal::kObjectsFromOtherModelStorage;
|
||||
CHECK_EQ(second_variable.storage(), storage_)
|
||||
CHECK_EQ(second_variable.storage(), storage())
|
||||
<< internal::kObjectsFromOtherModelStorage;
|
||||
return storage_->is_quadratic_objective_coefficient_nonzero(
|
||||
return storage()->is_quadratic_objective_coefficient_nonzero(
|
||||
id_, first_variable.typed_id(), second_variable.typed_id());
|
||||
}
|
||||
|
||||
bool operator==(const Objective& lhs, const Objective& rhs) {
|
||||
return lhs.id_ == rhs.id_ && lhs.storage_ == rhs.storage_;
|
||||
return lhs.id_ == rhs.id_ && lhs.storage() == rhs.storage();
|
||||
}
|
||||
|
||||
bool operator!=(const Objective& lhs, const Objective& rhs) {
|
||||
@@ -209,17 +204,17 @@ bool operator!=(const Objective& lhs, const Objective& rhs) {
|
||||
|
||||
template <typename H>
|
||||
H AbslHashValue(H h, const Objective& objective) {
|
||||
return H::combine(std::move(h), objective.id_, objective.storage_);
|
||||
return H::combine(std::move(h), objective.id_, objective.storage());
|
||||
}
|
||||
|
||||
Objective::Objective(const ModelStorage* const storage, const ObjectiveId id)
|
||||
: storage_(storage), id_(id) {}
|
||||
Objective::Objective(const ModelStorageCPtr storage, const ObjectiveId id)
|
||||
: ModelStorageItem(storage), id_(id) {}
|
||||
|
||||
Objective Objective::Primary(const ModelStorage* const storage) {
|
||||
Objective Objective::Primary(const ModelStorageCPtr storage) {
|
||||
return Objective(storage, kPrimaryObjectiveId);
|
||||
}
|
||||
|
||||
Objective Objective::Auxiliary(const ModelStorage* const storage,
|
||||
Objective Objective::Auxiliary(const ModelStorageCPtr storage,
|
||||
const AuxiliaryObjectiveId id) {
|
||||
return Objective(storage, id);
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
|
||||
#include "ortools/math_opt/cpp/parameters.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@@ -86,7 +85,7 @@ std::optional<absl::string_view> Enum<SolverType>::ToOptString(
|
||||
case SolverType::kSantorini:
|
||||
return "santorini";
|
||||
case SolverType::kXpress:
|
||||
return "xpress";
|
||||
return "xpress";
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
@@ -96,7 +95,7 @@ absl::Span<const SolverType> Enum<SolverType>::AllValues() {
|
||||
SolverType::kGscip, SolverType::kGurobi, SolverType::kGlop,
|
||||
SolverType::kCpSat, SolverType::kPdlp, SolverType::kGlpk,
|
||||
SolverType::kEcos, SolverType::kScs, SolverType::kHighs,
|
||||
SolverType::kSantorini,
|
||||
SolverType::kSantorini, SolverType::kXpress,
|
||||
};
|
||||
return absl::MakeConstSpan(kSolverTypeValues);
|
||||
}
|
||||
|
||||
@@ -24,7 +24,6 @@
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "ortools/base/linked_hash_map.h"
|
||||
#include "ortools/glop/parameters.pb.h" // IWYU pragma: export
|
||||
#include "ortools/gscip/gscip.pb.h" // IWYU pragma: export
|
||||
@@ -114,7 +113,7 @@ enum class SolverType {
|
||||
//
|
||||
// Supports LP, MIP, and nonconvex integer quadratic problems.
|
||||
// A fast option, but has special licensing.
|
||||
kXpress = SOLVER_TYPE_XPRESS
|
||||
kXpress = SOLVER_TYPE_XPRESS,
|
||||
};
|
||||
|
||||
MATH_OPT_DEFINE_ENUM(SolverType, SOLVER_TYPE_UNSPECIFIED);
|
||||
|
||||
@@ -18,10 +18,10 @@
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "ortools/base/logging.h"
|
||||
#include "ortools/base/status_builder.h"
|
||||
#include "ortools/base/status_macros.h"
|
||||
#include "ortools/math_opt/cpp/sparse_containers.h"
|
||||
@@ -57,7 +57,7 @@ absl::Span<const SolutionStatus> Enum<SolutionStatus>::AllValues() {
|
||||
}
|
||||
|
||||
absl::StatusOr<PrimalSolution> PrimalSolution::FromProto(
|
||||
const ModelStorage* model,
|
||||
const ModelStorageCPtr model,
|
||||
const PrimalSolutionProto& primal_solution_proto) {
|
||||
PrimalSolution primal_solution;
|
||||
OR_ASSIGN_OR_RETURN3(
|
||||
@@ -104,7 +104,7 @@ double PrimalSolution::get_objective_value(const Objective objective) const {
|
||||
}
|
||||
|
||||
absl::StatusOr<PrimalRay> PrimalRay::FromProto(
|
||||
const ModelStorage* model, const PrimalRayProto& primal_ray_proto) {
|
||||
const ModelStorageCPtr model, const PrimalRayProto& primal_ray_proto) {
|
||||
PrimalRay result;
|
||||
OR_ASSIGN_OR_RETURN3(
|
||||
result.variable_values,
|
||||
@@ -120,7 +120,8 @@ PrimalRayProto PrimalRay::Proto() const {
|
||||
}
|
||||
|
||||
absl::StatusOr<DualSolution> DualSolution::FromProto(
|
||||
const ModelStorage* model, const DualSolutionProto& dual_solution_proto) {
|
||||
const ModelStorageCPtr model,
|
||||
const DualSolutionProto& dual_solution_proto) {
|
||||
DualSolution dual_solution;
|
||||
OR_ASSIGN_OR_RETURN3(
|
||||
dual_solution.dual_values,
|
||||
@@ -159,7 +160,7 @@ DualSolutionProto DualSolution::Proto() const {
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::StatusOr<DualRay> DualRay::FromProto(const ModelStorage* model,
|
||||
absl::StatusOr<DualRay> DualRay::FromProto(const ModelStorageCPtr model,
|
||||
const DualRayProto& dual_ray_proto) {
|
||||
DualRay result;
|
||||
OR_ASSIGN_OR_RETURN3(
|
||||
@@ -180,7 +181,7 @@ DualRayProto DualRay::Proto() const {
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::StatusOr<Basis> Basis::FromProto(const ModelStorage* model,
|
||||
absl::StatusOr<Basis> Basis::FromProto(const ModelStorageCPtr model,
|
||||
const BasisProto& basis_proto) {
|
||||
Basis basis;
|
||||
OR_ASSIGN_OR_RETURN3(
|
||||
@@ -197,7 +198,7 @@ absl::StatusOr<Basis> Basis::FromProto(const ModelStorage* model,
|
||||
}
|
||||
|
||||
absl::Status Basis::CheckModelStorage(
|
||||
const ModelStorage* const expected_storage) const {
|
||||
const ModelStorageCPtr expected_storage) const {
|
||||
for (const auto& [v, _] : variable_status) {
|
||||
RETURN_IF_ERROR(
|
||||
internal::CheckModelStorage(/*storage=*/v.storage(),
|
||||
@@ -223,7 +224,7 @@ BasisProto Basis::Proto() const {
|
||||
}
|
||||
|
||||
absl::StatusOr<Solution> Solution::FromProto(
|
||||
const ModelStorage* model, const SolutionProto& solution_proto) {
|
||||
const ModelStorageCPtr model, const SolutionProto& solution_proto) {
|
||||
Solution solution;
|
||||
if (solution_proto.has_primal_solution()) {
|
||||
OR_ASSIGN_OR_RETURN3(
|
||||
|
||||
@@ -67,8 +67,7 @@ struct PrimalSolution {
|
||||
// * VariableValuesFromProto(primal_solution_proto.variable_values) fails.
|
||||
// * the feasibility_status is not specified.
|
||||
static absl::StatusOr<PrimalSolution> FromProto(
|
||||
const ModelStorage* model,
|
||||
const PrimalSolutionProto& primal_solution_proto);
|
||||
ModelStorageCPtr model, const PrimalSolutionProto& primal_solution_proto);
|
||||
|
||||
// Returns the proto equivalent of this.
|
||||
PrimalSolutionProto Proto() const;
|
||||
@@ -112,7 +111,7 @@ struct PrimalRay {
|
||||
// Returns an error when
|
||||
// VariableValuesFromProto(primal_ray_proto.variable_values) fails.
|
||||
static absl::StatusOr<PrimalRay> FromProto(
|
||||
const ModelStorage* model, const PrimalRayProto& primal_ray_proto);
|
||||
ModelStorageCPtr model, const PrimalRayProto& primal_ray_proto);
|
||||
|
||||
// Returns the proto equivalent of this.
|
||||
PrimalRayProto Proto() const;
|
||||
@@ -139,7 +138,7 @@ struct DualSolution {
|
||||
// * LinearConstraintValuesFromProto(dual_solution_proto.dual_values) fails.
|
||||
// * dual_solution_proto.feasibility_status is not specified.
|
||||
static absl::StatusOr<DualSolution> FromProto(
|
||||
const ModelStorage* model, const DualSolutionProto& dual_solution_proto);
|
||||
ModelStorageCPtr model, const DualSolutionProto& dual_solution_proto);
|
||||
|
||||
// Returns the proto equivalent of this.
|
||||
DualSolutionProto Proto() const;
|
||||
@@ -179,7 +178,7 @@ struct DualRay {
|
||||
// Returns an error when either of:
|
||||
// * VariableValuesFromProto(dual_ray_proto.reduced_costs) fails.
|
||||
// * LinearConstraintValuesFromProto(dual_ray_proto.dual_values) fails.
|
||||
static absl::StatusOr<DualRay> FromProto(const ModelStorage* model,
|
||||
static absl::StatusOr<DualRay> FromProto(ModelStorageCPtr model,
|
||||
const DualRayProto& dual_ray_proto);
|
||||
|
||||
// Returns the proto equivalent of this.
|
||||
@@ -218,12 +217,12 @@ struct Basis {
|
||||
// Returns an error if:
|
||||
// * VariableBasisFromProto(basis_proto.variable_status) fails.
|
||||
// * LinearConstraintBasisFromProto(basis_proto.constraint_status) fails.
|
||||
static absl::StatusOr<Basis> FromProto(const ModelStorage* model,
|
||||
static absl::StatusOr<Basis> FromProto(ModelStorageCPtr model,
|
||||
const BasisProto& basis_proto);
|
||||
|
||||
// Returns a failure if the referenced variables don't belong to the input
|
||||
// expected_storage (which must not be nullptr).
|
||||
absl::Status CheckModelStorage(const ModelStorage* expected_storage) const;
|
||||
absl::Status CheckModelStorage(ModelStorageCPtr expected_storage) const;
|
||||
|
||||
// Returns the proto equivalent of this object.
|
||||
//
|
||||
@@ -262,7 +261,7 @@ struct Solution {
|
||||
// Returns an error if FromProto() fails on any field that is not std::nullopt
|
||||
// (see the static FromProto() functions for each field type for details).
|
||||
static absl::StatusOr<Solution> FromProto(
|
||||
const ModelStorage* model, const SolutionProto& solution_proto);
|
||||
ModelStorageCPtr model, const SolutionProto& solution_proto);
|
||||
|
||||
// Returns the proto equivalent of this.
|
||||
SolutionProto Proto() const;
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
absl::Status SolveArguments::CheckModelStorageAndCallback(
|
||||
const ModelStorage* const expected_storage) const {
|
||||
const ModelStorageCPtr expected_storage) const {
|
||||
RETURN_IF_ERROR(model_parameters.CheckModelStorage(expected_storage))
|
||||
<< "invalid model_parameters";
|
||||
RETURN_IF_ERROR(callback_registration.CheckModelStorage(expected_storage))
|
||||
|
||||
@@ -88,7 +88,7 @@ struct SolveArguments {
|
||||
// to the input expected_storage (which must not be nullptr). Also returns a
|
||||
// failure if callback events are registered but no callback is provided.
|
||||
absl::Status CheckModelStorageAndCallback(
|
||||
const ModelStorage* expected_storage) const;
|
||||
ModelStorageCPtr expected_storage) const;
|
||||
};
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
@@ -40,9 +40,10 @@
|
||||
namespace operations_research::math_opt::internal {
|
||||
namespace {
|
||||
|
||||
absl::StatusOr<SolveResult> CallSolve(
|
||||
BaseSolver& solver, const ModelStorage* const expected_storage,
|
||||
const SolveArguments& arguments, SolveInterrupter& local_canceller) {
|
||||
absl::StatusOr<SolveResult> CallSolve(BaseSolver& solver,
|
||||
const ModelStorageCPtr expected_storage,
|
||||
const SolveArguments& arguments,
|
||||
SolveInterrupter& local_canceller) {
|
||||
RETURN_IF_ERROR(arguments.CheckModelStorageAndCallback(expected_storage));
|
||||
|
||||
BaseSolver::Callback cb = nullptr;
|
||||
@@ -104,7 +105,7 @@ absl::StatusOr<SolveResult> CallSolve(
|
||||
}
|
||||
|
||||
absl::StatusOr<ComputeInfeasibleSubsystemResult> CallComputeInfeasibleSubsystem(
|
||||
BaseSolver& solver, const ModelStorage* const expected_storage,
|
||||
BaseSolver& solver, const ModelStorageCPtr expected_storage,
|
||||
const ComputeInfeasibleSubsystemArguments& arguments,
|
||||
SolveInterrupter& local_canceller) {
|
||||
ASSIGN_OR_RETURN(
|
||||
@@ -180,7 +181,7 @@ IncrementalSolverImpl::IncrementalSolverImpl(
|
||||
BaseSolverFactory solver_factory, SolverType solver_type,
|
||||
const bool remove_names, std::shared_ptr<SolveInterrupter> local_canceller,
|
||||
std::unique_ptr<const ScopedSolveInterrupterCallback> user_canceller_cb,
|
||||
const ModelStorage* const expected_storage,
|
||||
const ModelStorageCPtr expected_storage,
|
||||
std::unique_ptr<UpdateTracker> update_tracker,
|
||||
std::unique_ptr<BaseSolver> solver)
|
||||
: solver_factory_(std::move(solver_factory)),
|
||||
|
||||
@@ -102,7 +102,7 @@ class IncrementalSolverImpl : public IncrementalSolver {
|
||||
BaseSolverFactory solver_factory, SolverType solver_type,
|
||||
bool remove_names, std::shared_ptr<SolveInterrupter> local_canceller,
|
||||
std::unique_ptr<const ScopedSolveInterrupterCallback> user_canceller_cb,
|
||||
const ModelStorage* expected_storage,
|
||||
ModelStorageCPtr expected_storage,
|
||||
std::unique_ptr<UpdateTracker> update_tracker,
|
||||
std::unique_ptr<BaseSolver> solver);
|
||||
|
||||
@@ -114,7 +114,7 @@ class IncrementalSolverImpl : public IncrementalSolver {
|
||||
// can be destroyed after local_canceller_ without risk.
|
||||
std::shared_ptr<SolveInterrupter> local_canceller_;
|
||||
std::unique_ptr<const ScopedSolveInterrupterCallback> user_canceller_cb_;
|
||||
const ModelStorage* const expected_storage_;
|
||||
const ModelStorageCPtr expected_storage_;
|
||||
const std::unique_ptr<UpdateTracker> update_tracker_;
|
||||
std::unique_ptr<BaseSolver> solver_;
|
||||
};
|
||||
|
||||
@@ -536,7 +536,7 @@ TerminationProto UpgradedTerminationProtoForStatsMigration(
|
||||
|
||||
} // namespace
|
||||
absl::StatusOr<SolveResult> SolveResult::FromProto(
|
||||
const ModelStorage* model, const SolveResultProto& solve_result_proto) {
|
||||
const ModelStorageCPtr model, const SolveResultProto& solve_result_proto) {
|
||||
OR_ASSIGN_OR_RETURN3(
|
||||
auto termination,
|
||||
Termination::FromProto(
|
||||
|
||||
@@ -518,7 +518,7 @@ struct SolveResult {
|
||||
// validation, or not rely on the strong guarantees of ValidateResult()
|
||||
// and just treat SolveResult as a simple struct.
|
||||
static absl::StatusOr<SolveResult> FromProto(
|
||||
const ModelStorage* model, const SolveResultProto& solve_result_proto);
|
||||
ModelStorageCPtr model, const SolveResultProto& solve_result_proto);
|
||||
|
||||
// Returns the proto equivalent of this.
|
||||
//
|
||||
|
||||
@@ -49,8 +49,7 @@ absl::Status CheckSparseVectorProto(const SparseVectorProtoType& vec) {
|
||||
|
||||
template <typename Key>
|
||||
absl::StatusOr<absl::flat_hash_map<Key, BasisStatus>> BasisVectorFromProto(
|
||||
const ModelStorage* const model,
|
||||
const SparseBasisStatusVector& basis_proto) {
|
||||
const ModelStorageCPtr model, const SparseBasisStatusVector& basis_proto) {
|
||||
using IdType = typename Key::IdType;
|
||||
absl::flat_hash_map<Key, BasisStatus> map;
|
||||
map.reserve(basis_proto.ids_size());
|
||||
@@ -104,7 +103,7 @@ SparseBasisStatusVector BasisMapToProto(
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::Status VariableIdsExist(const ModelStorage* const model,
|
||||
absl::Status VariableIdsExist(const ModelStorageCPtr model,
|
||||
const absl::Span<const int64_t> ids) {
|
||||
for (const int64_t id : ids) {
|
||||
if (!model->has_variable(VariableId(id))) {
|
||||
@@ -115,7 +114,7 @@ absl::Status VariableIdsExist(const ModelStorage* const model,
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status LinearConstraintIdsExist(const ModelStorage* const model,
|
||||
absl::Status LinearConstraintIdsExist(const ModelStorageCPtr model,
|
||||
const absl::Span<const int64_t> ids) {
|
||||
for (const int64_t id : ids) {
|
||||
if (!model->has_linear_constraint(LinearConstraintId(id))) {
|
||||
@@ -126,7 +125,7 @@ absl::Status LinearConstraintIdsExist(const ModelStorage* const model,
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status QuadraticConstraintIdsExist(const ModelStorage* const model,
|
||||
absl::Status QuadraticConstraintIdsExist(const ModelStorageCPtr model,
|
||||
const absl::Span<const int64_t> ids) {
|
||||
for (const int64_t id : ids) {
|
||||
if (!model->has_constraint(QuadraticConstraintId(id))) {
|
||||
@@ -140,15 +139,14 @@ absl::Status QuadraticConstraintIdsExist(const ModelStorage* const model,
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<VariableMap<double>> VariableValuesFromProto(
|
||||
const ModelStorage* const model,
|
||||
const SparseDoubleVectorProto& vars_proto) {
|
||||
const ModelStorageCPtr model, const SparseDoubleVectorProto& vars_proto) {
|
||||
RETURN_IF_ERROR(CheckSparseVectorProto(vars_proto));
|
||||
RETURN_IF_ERROR(VariableIdsExist(model, vars_proto.ids()));
|
||||
return MakeView(vars_proto).as_map<Variable>(model);
|
||||
}
|
||||
|
||||
absl::StatusOr<VariableMap<int32_t>> VariableValuesFromProto(
|
||||
const ModelStorage* model, const SparseInt32VectorProto& vars_proto) {
|
||||
const ModelStorageCPtr model, const SparseInt32VectorProto& vars_proto) {
|
||||
RETURN_IF_ERROR(CheckSparseVectorProto(vars_proto));
|
||||
RETURN_IF_ERROR(VariableIdsExist(model, vars_proto.ids()));
|
||||
return MakeView(vars_proto).as_map<Variable>(model);
|
||||
@@ -161,7 +159,7 @@ SparseDoubleVectorProto VariableValuesToProto(
|
||||
|
||||
absl::StatusOr<absl::flat_hash_map<Objective, double>>
|
||||
AuxiliaryObjectiveValuesFromProto(
|
||||
const ModelStorage* const model,
|
||||
const ModelStorageCPtr model,
|
||||
const google::protobuf::Map<int64_t, double>& aux_obj_proto) {
|
||||
absl::flat_hash_map<Objective, double> result;
|
||||
for (const auto [raw_id, value] : aux_obj_proto) {
|
||||
@@ -187,7 +185,7 @@ google::protobuf::Map<int64_t, double> AuxiliaryObjectiveValuesToProto(
|
||||
}
|
||||
|
||||
absl::StatusOr<LinearConstraintMap<double>> LinearConstraintValuesFromProto(
|
||||
const ModelStorage* const model,
|
||||
const ModelStorageCPtr model,
|
||||
const SparseDoubleVectorProto& lin_cons_proto) {
|
||||
RETURN_IF_ERROR(CheckSparseVectorProto(lin_cons_proto));
|
||||
RETURN_IF_ERROR(LinearConstraintIdsExist(model, lin_cons_proto.ids()));
|
||||
@@ -201,7 +199,7 @@ SparseDoubleVectorProto LinearConstraintValuesToProto(
|
||||
|
||||
absl::StatusOr<absl::flat_hash_map<QuadraticConstraint, double>>
|
||||
QuadraticConstraintValuesFromProto(
|
||||
const ModelStorage* const model,
|
||||
const ModelStorageCPtr model,
|
||||
const SparseDoubleVectorProto& quad_cons_proto) {
|
||||
RETURN_IF_ERROR(CheckSparseVectorProto(quad_cons_proto));
|
||||
RETURN_IF_ERROR(QuadraticConstraintIdsExist(model, quad_cons_proto.ids()));
|
||||
@@ -215,8 +213,7 @@ SparseDoubleVectorProto QuadraticConstraintValuesToProto(
|
||||
}
|
||||
|
||||
absl::StatusOr<VariableMap<BasisStatus>> VariableBasisFromProto(
|
||||
const ModelStorage* const model,
|
||||
const SparseBasisStatusVector& basis_proto) {
|
||||
const ModelStorageCPtr model, const SparseBasisStatusVector& basis_proto) {
|
||||
RETURN_IF_ERROR(CheckSparseVectorProto(basis_proto));
|
||||
RETURN_IF_ERROR(VariableIdsExist(model, basis_proto.ids()));
|
||||
return BasisVectorFromProto<Variable>(model, basis_proto);
|
||||
@@ -228,8 +225,7 @@ SparseBasisStatusVector VariableBasisToProto(
|
||||
}
|
||||
|
||||
absl::StatusOr<LinearConstraintMap<BasisStatus>> LinearConstraintBasisFromProto(
|
||||
const ModelStorage* const model,
|
||||
const SparseBasisStatusVector& basis_proto) {
|
||||
const ModelStorageCPtr model, const SparseBasisStatusVector& basis_proto) {
|
||||
RETURN_IF_ERROR(CheckSparseVectorProto(basis_proto));
|
||||
RETURN_IF_ERROR(LinearConstraintIdsExist(model, basis_proto.ids()));
|
||||
return BasisVectorFromProto<LinearConstraint>(model, basis_proto);
|
||||
|
||||
@@ -42,7 +42,7 @@ namespace operations_research::math_opt {
|
||||
//
|
||||
// Note that the values of vars_proto.values are not checked (it may have NaNs).
|
||||
absl::StatusOr<VariableMap<double>> VariableValuesFromProto(
|
||||
const ModelStorage* model, const SparseDoubleVectorProto& vars_proto);
|
||||
ModelStorageCPtr model, const SparseDoubleVectorProto& vars_proto);
|
||||
|
||||
// Returns the VariableMap<int32_t> equivalent to `vars_proto`.
|
||||
//
|
||||
@@ -52,7 +52,7 @@ absl::StatusOr<VariableMap<double>> VariableValuesFromProto(
|
||||
// * vars_proto.ids has elements that are variables in `model` (this implies
|
||||
// that each id is in [0, max(int64_t))).
|
||||
absl::StatusOr<VariableMap<int32_t>> VariableValuesFromProto(
|
||||
const ModelStorage* model, const SparseInt32VectorProto& vars_proto);
|
||||
ModelStorageCPtr model, const SparseInt32VectorProto& vars_proto);
|
||||
|
||||
// Returns the proto equivalent of variable_values.
|
||||
SparseDoubleVectorProto VariableValuesToProto(
|
||||
@@ -67,7 +67,7 @@ SparseDoubleVectorProto VariableValuesToProto(
|
||||
// Note that the values of `aux_obj_proto` are not checked (it may have NaNs).
|
||||
absl::StatusOr<absl::flat_hash_map<Objective, double>>
|
||||
AuxiliaryObjectiveValuesFromProto(
|
||||
const ModelStorage* model,
|
||||
ModelStorageCPtr model,
|
||||
const google::protobuf::Map<int64_t, double>& aux_obj_proto);
|
||||
|
||||
// Returns the proto equivalent of auxiliary_obj_values.
|
||||
@@ -88,7 +88,7 @@ google::protobuf::Map<int64_t, double> AuxiliaryObjectiveValuesToProto(
|
||||
// Note that the values of lin_cons_proto.values are not checked (it may have
|
||||
// NaNs).
|
||||
absl::StatusOr<LinearConstraintMap<double>> LinearConstraintValuesFromProto(
|
||||
const ModelStorage* model, const SparseDoubleVectorProto& lin_cons_proto);
|
||||
ModelStorageCPtr model, const SparseDoubleVectorProto& lin_cons_proto);
|
||||
|
||||
// Returns the proto equivalent of linear_constraint_values.
|
||||
SparseDoubleVectorProto LinearConstraintValuesToProto(
|
||||
@@ -107,7 +107,7 @@ SparseDoubleVectorProto LinearConstraintValuesToProto(
|
||||
// NaNs).
|
||||
absl::StatusOr<absl::flat_hash_map<QuadraticConstraint, double>>
|
||||
QuadraticConstraintValuesFromProto(
|
||||
const ModelStorage* model, const SparseDoubleVectorProto& quad_cons_proto);
|
||||
ModelStorageCPtr model, const SparseDoubleVectorProto& quad_cons_proto);
|
||||
|
||||
// Returns the proto equivalent of quadratic_constraint_values.
|
||||
SparseDoubleVectorProto QuadraticConstraintValuesToProto(
|
||||
@@ -123,7 +123,7 @@ SparseDoubleVectorProto QuadraticConstraintValuesToProto(
|
||||
// that each id is in [0, max(int64_t))).
|
||||
// * basis_proto.values does not contain UNSPECIFIED and has valid enum values.
|
||||
absl::StatusOr<VariableMap<BasisStatus>> VariableBasisFromProto(
|
||||
const ModelStorage* model, const SparseBasisStatusVector& basis_proto);
|
||||
ModelStorageCPtr model, const SparseBasisStatusVector& basis_proto);
|
||||
|
||||
// Returns the proto equivalent of basis_values.
|
||||
SparseBasisStatusVector VariableBasisToProto(
|
||||
@@ -138,7 +138,7 @@ SparseBasisStatusVector VariableBasisToProto(
|
||||
// implies that each id is in [0, max(int64_t))).
|
||||
// * basis_proto.values does not contain UNSPECIFIED and has valid enum values.
|
||||
absl::StatusOr<LinearConstraintMap<BasisStatus>> LinearConstraintBasisFromProto(
|
||||
const ModelStorage* model, const SparseBasisStatusVector& basis_proto);
|
||||
ModelStorageCPtr model, const SparseBasisStatusVector& basis_proto);
|
||||
|
||||
// Returns the proto equivalent of basis_values.
|
||||
SparseBasisStatusVector LinearConstraintBasisToProto(
|
||||
|
||||
@@ -24,6 +24,9 @@
|
||||
#include "ortools/base/map_util.h"
|
||||
#include "ortools/base/strong_int.h"
|
||||
#include "ortools/math_opt/cpp/formatters.h"
|
||||
#ifdef MATH_OPT_USE_EXPRESSION_COUNTERS
|
||||
#include "ortools/math_opt/storage/model_storage_item.h"
|
||||
#endif // MATH_OPT_USE_EXPRESSION_COUNTERS
|
||||
#include "ortools/util/fp_roundtrip_conv.h"
|
||||
|
||||
namespace operations_research {
|
||||
@@ -35,7 +38,9 @@ constexpr double kInf = std::numeric_limits<double>::infinity();
|
||||
LinearExpression::LinearExpression() { ++num_calls_default_constructor_; }
|
||||
|
||||
LinearExpression::LinearExpression(const LinearExpression& other)
|
||||
: storage_(other.storage_), terms_(other.terms_), offset_(other.offset_) {
|
||||
: ModelStorageItemContainer(other.storage()),
|
||||
terms_(other.terms_),
|
||||
offset_(other.offset_) {
|
||||
++num_calls_copy_constructor_;
|
||||
}
|
||||
|
||||
@@ -203,7 +208,7 @@ std::ostream& operator<<(std::ostream& ostr,
|
||||
QuadraticExpression::QuadraticExpression() { ++num_calls_default_constructor_; }
|
||||
|
||||
QuadraticExpression::QuadraticExpression(const QuadraticExpression& other)
|
||||
: storage_(other.storage_),
|
||||
: ModelStorageItemContainer(other),
|
||||
quadratic_terms_(other.quadratic_terms_),
|
||||
linear_terms_(other.linear_terms_),
|
||||
offset_(other.offset_) {
|
||||
|
||||
@@ -103,11 +103,11 @@
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "ortools/base/logging.h"
|
||||
#include "ortools/base/strong_int.h"
|
||||
#include "ortools/math_opt/cpp/key_types.h" // IWYU pragma: export
|
||||
#include "ortools/math_opt/storage/model_storage.h"
|
||||
#include "ortools/math_opt/storage/model_storage_item.h"
|
||||
#include "ortools/math_opt/storage/model_storage_types.h"
|
||||
|
||||
namespace operations_research {
|
||||
@@ -118,40 +118,20 @@ class LinearExpression;
|
||||
|
||||
// A value type that references a variable from ModelStorage. Usually this type
|
||||
// is passed by copy.
|
||||
//
|
||||
// This type implements https://abseil.io/docs/cpp/guides/hash (see
|
||||
// VariablesEquality for details about how operator== works).
|
||||
class Variable {
|
||||
class Variable final : public ModelStorageElement<
|
||||
ElementType::kVariable, Variable,
|
||||
// This type has a special equality operator
|
||||
// (see `VariablesEquality` below).
|
||||
ModelStorageElementEquality::kWithoutEquality> {
|
||||
public:
|
||||
// The typed integer used for ids.
|
||||
using IdType = VariableId;
|
||||
|
||||
// Usually users will obtain variables using Model::AddVariable(). There
|
||||
// should be little for users to build this object from an ModelStorage.
|
||||
inline Variable(const ModelStorage* storage, VariableId id);
|
||||
|
||||
// Each call to AddVariable will produce Variables id() increasing by one,
|
||||
// starting at zero. Deleted ids are NOT reused. Thus, if no variables are
|
||||
// deleted, the ids in the model will be consecutive.
|
||||
inline int64_t id() const;
|
||||
|
||||
inline VariableId typed_id() const;
|
||||
inline const ModelStorage* storage() const;
|
||||
using ModelStorageElement::ModelStorageElement;
|
||||
|
||||
inline double lower_bound() const;
|
||||
inline double upper_bound() const;
|
||||
inline bool is_integer() const;
|
||||
inline absl::string_view name() const;
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const Variable& variable);
|
||||
friend std::ostream& operator<<(std::ostream& ostr, const Variable& variable);
|
||||
|
||||
inline LinearExpression operator-() const;
|
||||
|
||||
private:
|
||||
const ModelStorage* storage_;
|
||||
VariableId id_;
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
@@ -186,8 +166,6 @@ inline bool operator!=(const Variable& lhs, const Variable& rhs);
|
||||
template <typename V>
|
||||
using VariableMap = absl::flat_hash_map<Variable, V>;
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& ostr, const Variable& variable);
|
||||
|
||||
// A term in an sum of variables multiplied by coefficients.
|
||||
struct LinearTerm {
|
||||
// Usually this constructor is never called explicitly by users. Instead it
|
||||
@@ -226,7 +204,7 @@ class QuadraticExpression;
|
||||
// TODO(b/169415098): add a function to remove zero terms.
|
||||
// TODO(b/169415834): study if exact zeros should be automatically removed.
|
||||
// TODO(b/169415103): add tests that some expressions don't compile.
|
||||
class LinearExpression {
|
||||
class LinearExpression final : public ModelStorageItemContainer {
|
||||
public:
|
||||
// For unit testing purpose, we define optional counters. We have to
|
||||
// explicitly define the default constructor, copy constructor and assignment
|
||||
@@ -238,9 +216,6 @@ class LinearExpression {
|
||||
LinearExpression();
|
||||
LinearExpression(const LinearExpression& other);
|
||||
#endif // MATH_OPT_USE_EXPRESSION_COUNTERS
|
||||
// We have to define a custom move constructor as we need to reset storage_ to
|
||||
// nullptr.
|
||||
inline LinearExpression(LinearExpression&& other) noexcept;
|
||||
// Usually users should use the overloads of operators to build linear
|
||||
// expressions. For example, assuming `x` and `y` are Variable, then `x + 2*y
|
||||
// + 5` will build a LinearExpression automatically.
|
||||
@@ -250,8 +225,9 @@ class LinearExpression {
|
||||
inline LinearExpression(Variable variable); // NOLINT
|
||||
inline LinearExpression(const LinearTerm& term); // NOLINT
|
||||
LinearExpression& operator=(const LinearExpression& other) = default;
|
||||
// We have to define a custom move assignment operator as we need to reset
|
||||
// storage_ to nullptr.
|
||||
// A moved-from `LinearExpression` is the zero expression: it's not associated
|
||||
// to a storage, has no terms and its offset is zero.
|
||||
inline LinearExpression(LinearExpression&& other) noexcept;
|
||||
inline LinearExpression& operator=(LinearExpression&& other) noexcept;
|
||||
|
||||
inline LinearExpression& operator+=(const LinearExpression& other);
|
||||
@@ -367,8 +343,6 @@ class LinearExpression {
|
||||
double EvaluateWithDefaultZero(
|
||||
const VariableMap<double>& variable_values) const;
|
||||
|
||||
inline const ModelStorage* storage() const;
|
||||
|
||||
#ifdef MATH_OPT_USE_EXPRESSION_COUNTERS
|
||||
static thread_local int num_calls_default_constructor_;
|
||||
static thread_local int num_calls_copy_constructor_;
|
||||
@@ -384,14 +358,9 @@ class LinearExpression {
|
||||
const LinearExpression& expression);
|
||||
friend QuadraticExpression;
|
||||
|
||||
// Sets the storage_ to the input value if nullptr, else CHECKs that it is
|
||||
// equal. Also CHECKs that the input value is not nullptr.
|
||||
inline void SetOrCheckStorage(const ModelStorage* storage);
|
||||
|
||||
// Invariants:
|
||||
// * nullptr, if terms_ is empty
|
||||
// * equal to Variable::storage() of each key of terms_, else
|
||||
const ModelStorage* storage_ = nullptr;
|
||||
// * storage() == v.storage()for each v in terms_;
|
||||
// * storage() == nullptr, if terms_ is empty.
|
||||
VariableMap<double> terms_;
|
||||
double offset_ = 0.0;
|
||||
};
|
||||
@@ -647,14 +616,14 @@ using QuadraticProductId = std::pair<VariableId, VariableId>;
|
||||
// silently correct this if not satisfied by the inputs.
|
||||
//
|
||||
// This type can be used as a key in ABSL hash containers.
|
||||
class QuadraticTermKey {
|
||||
class QuadraticTermKey final : public ModelStorageItem {
|
||||
public:
|
||||
// NOTE: this definition is for use by IdMap; clients should not rely upon it.
|
||||
using IdType = QuadraticProductId;
|
||||
|
||||
// NOTE: This constructor will silently re-order the passed id so that, upon
|
||||
// exiting the constructor, variable_ids_.first <= variable_ids_.second.
|
||||
inline QuadraticTermKey(const ModelStorage* storage, QuadraticProductId id);
|
||||
inline QuadraticTermKey(ModelStorageCPtr storage, QuadraticProductId id);
|
||||
// NOTE: This constructor will CHECK fail if the variable models do not agree,
|
||||
// i.e. first_variable.storage() != second_variable.storage(). It will also
|
||||
// silently re-order the passed id so that, upon exiting the constructor,
|
||||
@@ -662,19 +631,17 @@ class QuadraticTermKey {
|
||||
inline QuadraticTermKey(Variable first_variable, Variable second_variable);
|
||||
|
||||
inline QuadraticProductId typed_id() const;
|
||||
inline const ModelStorage* storage() const;
|
||||
|
||||
// Returns the Variable with the smallest id.
|
||||
Variable first() const { return Variable(storage_, variable_ids_.first); }
|
||||
Variable first() const { return Variable(storage(), variable_ids_.first); }
|
||||
|
||||
// Returns the Variable the largest id.
|
||||
Variable second() const { return Variable(storage_, variable_ids_.second); }
|
||||
Variable second() const { return Variable(storage(), variable_ids_.second); }
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const QuadraticTermKey& key);
|
||||
|
||||
private:
|
||||
const ModelStorage* storage_;
|
||||
QuadraticProductId variable_ids_;
|
||||
};
|
||||
|
||||
@@ -744,7 +711,7 @@ using QuadraticTermMap = absl::flat_hash_map<QuadraticTermKey, V>;
|
||||
// is, it is forbidden that both are non-null and not equal. Use
|
||||
// CheckModelsAgree() and the initializer_list constructor to enforce this
|
||||
// invariant in any class or friend method.
|
||||
class QuadraticExpression {
|
||||
class QuadraticExpression final : public ModelStorageItemContainer {
|
||||
public:
|
||||
// For unit testing purpose, we define optional counters. We have to
|
||||
// explicitly define the default constructor, copy constructor and assignment
|
||||
@@ -756,9 +723,6 @@ class QuadraticExpression {
|
||||
QuadraticExpression();
|
||||
QuadraticExpression(const QuadraticExpression& other);
|
||||
#endif // MATH_OPT_USE_EXPRESSION_COUNTERS
|
||||
// We have to define a custom move constructor as we need to reset storage_ to
|
||||
// nullptr.
|
||||
inline QuadraticExpression(QuadraticExpression&& other) noexcept;
|
||||
// Users should prefer the default constructor and operator overloads to build
|
||||
// expressions.
|
||||
inline QuadraticExpression(
|
||||
@@ -770,8 +734,9 @@ class QuadraticExpression {
|
||||
inline QuadraticExpression(LinearExpression expr); // NOLINT
|
||||
inline QuadraticExpression(const QuadraticTerm& term); // NOLINT
|
||||
QuadraticExpression& operator=(const QuadraticExpression& other) = default;
|
||||
// We have to define a custom move assignment operator as we need to reset
|
||||
// storage_ to nullptr.
|
||||
// A moved-from `LinearExpression` is the zero expression: it's not associated
|
||||
// to a storage, has no terms and its offset is zero.
|
||||
inline QuadraticExpression(QuadraticExpression&& other) noexcept;
|
||||
inline QuadraticExpression& operator=(QuadraticExpression&& other) noexcept;
|
||||
|
||||
inline double offset() const;
|
||||
@@ -933,8 +898,6 @@ class QuadraticExpression {
|
||||
double EvaluateWithDefaultZero(
|
||||
const VariableMap<double>& variable_values) const;
|
||||
|
||||
inline const ModelStorage* storage() const;
|
||||
|
||||
#ifdef MATH_OPT_USE_EXPRESSION_COUNTERS
|
||||
static thread_local int num_calls_default_constructor_;
|
||||
static thread_local int num_calls_copy_constructor_;
|
||||
@@ -950,15 +913,10 @@ class QuadraticExpression {
|
||||
friend std::ostream& operator<<(std::ostream& ostr,
|
||||
const QuadraticExpression& expr);
|
||||
|
||||
// Sets the storage_ to the input value if nullptr, else CHECKs that it is
|
||||
// equal. Also CHECKs that the input value is not nullptr.
|
||||
inline void SetOrCheckStorage(const ModelStorage* storage);
|
||||
|
||||
// Invariants:
|
||||
// * nullptr, if both quadratic_terms_ and linear_terms_ are empty
|
||||
// * equal to Variable::storage() of each key of linear_terms_ and
|
||||
// QuadraticTermKey::storage() of each key of quadratic_terms_, else
|
||||
const ModelStorage* storage_ = nullptr;
|
||||
// * storage() == v.storage() for each v in linear_terms_;
|
||||
// * storage() == v.storage() for each v in quadratic_terms_;
|
||||
// * storage() == nullptr, if both terms_ and quadratic_terms_ are empty.
|
||||
QuadraticTermMap<double> quadratic_terms_;
|
||||
VariableMap<double> linear_terms_;
|
||||
double offset_ = 0.0;
|
||||
@@ -1268,50 +1226,25 @@ inline BoundedQuadraticExpression operator==(double lhs,
|
||||
// Variable
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Variable::Variable(const ModelStorage* const storage, const VariableId id)
|
||||
: storage_(storage), id_(id) {
|
||||
DCHECK(storage != nullptr);
|
||||
}
|
||||
|
||||
int64_t Variable::id() const { return id_.value(); }
|
||||
|
||||
VariableId Variable::typed_id() const { return id_; }
|
||||
|
||||
const ModelStorage* Variable::storage() const { return storage_; }
|
||||
|
||||
double Variable::lower_bound() const {
|
||||
return storage_->variable_lower_bound(id_);
|
||||
return storage()->variable_lower_bound(typed_id());
|
||||
}
|
||||
|
||||
double Variable::upper_bound() const {
|
||||
return storage_->variable_upper_bound(id_);
|
||||
return storage()->variable_upper_bound(typed_id());
|
||||
}
|
||||
|
||||
bool Variable::is_integer() const { return storage_->is_variable_integer(id_); }
|
||||
bool Variable::is_integer() const {
|
||||
return storage()->is_variable_integer(typed_id());
|
||||
}
|
||||
|
||||
absl::string_view Variable::name() const {
|
||||
if (storage()->has_variable(id_)) {
|
||||
return storage_->variable_name(id_);
|
||||
if (storage()->has_variable(typed_id())) {
|
||||
return storage()->variable_name(typed_id());
|
||||
}
|
||||
return "[variable deleted from model]";
|
||||
}
|
||||
|
||||
template <typename H>
|
||||
H AbslHashValue(H h, const Variable& variable) {
|
||||
return H::combine(std::move(h), variable.id_.value(), variable.storage_);
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& ostr, const Variable& variable) {
|
||||
// TODO(b/170992529): handle quoting of invalid characters in the name.
|
||||
const absl::string_view name = variable.name();
|
||||
if (name.empty()) {
|
||||
ostr << "__var#" << variable.id() << "__";
|
||||
} else {
|
||||
ostr << name;
|
||||
}
|
||||
return ostr;
|
||||
}
|
||||
|
||||
LinearExpression Variable::operator-() const {
|
||||
return LinearExpression({LinearTerm(*this, -1.0)}, 0.0);
|
||||
}
|
||||
@@ -1368,17 +1301,9 @@ LinearTerm operator/(Variable variable, const double coefficient) {
|
||||
// LinearExpression
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void LinearExpression::SetOrCheckStorage(const ModelStorage* const storage) {
|
||||
CHECK(storage != nullptr) << internal::kKeyHasNullModelStorage;
|
||||
if (storage_ == nullptr) {
|
||||
storage_ = storage;
|
||||
return;
|
||||
}
|
||||
CHECK_EQ(storage, storage_) << internal::kObjectsFromOtherModelStorage;
|
||||
}
|
||||
|
||||
LinearExpression::LinearExpression(LinearExpression&& other) noexcept
|
||||
: storage_(std::exchange(other.storage_, nullptr)),
|
||||
: ModelStorageItemContainer(
|
||||
static_cast<ModelStorageItemContainer&&>(other)),
|
||||
terms_(std::move(other.terms_)),
|
||||
offset_(std::exchange(other.offset_, 0.0)) {
|
||||
other.terms_.clear();
|
||||
@@ -1389,7 +1314,8 @@ LinearExpression::LinearExpression(LinearExpression&& other) noexcept
|
||||
|
||||
LinearExpression& LinearExpression::operator=(
|
||||
LinearExpression&& other) noexcept {
|
||||
storage_ = std::exchange(other.storage_, nullptr);
|
||||
ModelStorageItemContainer::operator=(
|
||||
static_cast<ModelStorageItemContainer&&>(other));
|
||||
terms_ = std::move(other.terms_);
|
||||
other.terms_.clear();
|
||||
offset_ = std::exchange(other.offset_, 0.0);
|
||||
@@ -1403,7 +1329,7 @@ LinearExpression::LinearExpression(std::initializer_list<LinearTerm> terms,
|
||||
++num_calls_initializer_list_constructor_;
|
||||
#endif // MATH_OPT_USE_EXPRESSION_COUNTERS
|
||||
for (const auto& term : terms) {
|
||||
SetOrCheckStorage(term.variable.storage());
|
||||
SetOrCheckStorage(term.variable);
|
||||
// The same variable may appear multiple times in the input list; we must
|
||||
// accumulate the coefficients.
|
||||
terms_[term.variable] += term.coefficient;
|
||||
@@ -1579,7 +1505,7 @@ LinearExpression& LinearExpression::operator+=(const LinearExpression& other) {
|
||||
// thus we don't need to compare in the loop. Of course this only applies if
|
||||
// the other has terms.
|
||||
if (!other.terms_.empty()) {
|
||||
SetOrCheckStorage(other.storage());
|
||||
SetOrCheckStorage(other);
|
||||
for (const auto& [v, coeff] : other.terms_) {
|
||||
terms_[v] += coeff;
|
||||
}
|
||||
@@ -1589,13 +1515,13 @@ LinearExpression& LinearExpression::operator+=(const LinearExpression& other) {
|
||||
}
|
||||
|
||||
LinearExpression& LinearExpression::operator+=(const LinearTerm& term) {
|
||||
SetOrCheckStorage(term.variable.storage());
|
||||
SetOrCheckStorage(term.variable);
|
||||
terms_[term.variable] += term.coefficient;
|
||||
return *this;
|
||||
}
|
||||
|
||||
LinearExpression& LinearExpression::operator+=(const Variable variable) {
|
||||
SetOrCheckStorage(variable.storage());
|
||||
SetOrCheckStorage(variable);
|
||||
return *this += LinearTerm(variable, 1.0);
|
||||
}
|
||||
|
||||
@@ -1607,7 +1533,7 @@ LinearExpression& LinearExpression::operator+=(const double value) {
|
||||
LinearExpression& LinearExpression::operator-=(const LinearExpression& other) {
|
||||
// See operator+=.
|
||||
if (!other.terms_.empty()) {
|
||||
SetOrCheckStorage(other.storage());
|
||||
SetOrCheckStorage(other);
|
||||
for (const auto& [v, coeff] : other.terms_) {
|
||||
terms_[v] -= coeff;
|
||||
}
|
||||
@@ -1617,13 +1543,13 @@ LinearExpression& LinearExpression::operator-=(const LinearExpression& other) {
|
||||
}
|
||||
|
||||
LinearExpression& LinearExpression::operator-=(const LinearTerm& term) {
|
||||
SetOrCheckStorage(term.variable.storage());
|
||||
SetOrCheckStorage(term.variable);
|
||||
terms_[term.variable] -= term.coefficient;
|
||||
return *this;
|
||||
}
|
||||
|
||||
LinearExpression& LinearExpression::operator-=(const Variable variable) {
|
||||
SetOrCheckStorage(variable.storage());
|
||||
SetOrCheckStorage(variable);
|
||||
return *this -= LinearTerm(variable, 1.0);
|
||||
}
|
||||
|
||||
@@ -1713,8 +1639,6 @@ const VariableMap<double>& LinearExpression::terms() const { return terms_; }
|
||||
|
||||
double LinearExpression::offset() const { return offset_; }
|
||||
|
||||
const ModelStorage* LinearExpression::storage() const { return storage_; }
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// VariablesEquality
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@@ -2055,9 +1979,9 @@ BoundedLinearExpression operator==(const double lhs, const Variable rhs) {
|
||||
// QuadraticTermKey
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
QuadraticTermKey::QuadraticTermKey(const ModelStorage* storage,
|
||||
QuadraticTermKey::QuadraticTermKey(const ModelStorageCPtr storage,
|
||||
const QuadraticProductId id)
|
||||
: storage_(storage), variable_ids_(id) {
|
||||
: ModelStorageItem(storage), variable_ids_(id) {
|
||||
if (variable_ids_.first > variable_ids_.second) {
|
||||
// See https://en.cppreference.com/w/cpp/named_req/Swappable for details.
|
||||
using std::swap;
|
||||
@@ -2075,8 +1999,6 @@ QuadraticTermKey::QuadraticTermKey(const Variable first_variable,
|
||||
|
||||
QuadraticProductId QuadraticTermKey::typed_id() const { return variable_ids_; }
|
||||
|
||||
const ModelStorage* QuadraticTermKey::storage() const { return storage_; }
|
||||
|
||||
template <typename H>
|
||||
H AbslHashValue(H h, const QuadraticTermKey& key) {
|
||||
return H::combine(std::move(h), key.typed_id().first.value(),
|
||||
@@ -2124,17 +2046,9 @@ QuadraticTermKey QuadraticTerm::GetKey() const {
|
||||
// QuadraticExpression (no arithmetic)
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void QuadraticExpression::SetOrCheckStorage(const ModelStorage* const storage) {
|
||||
CHECK(storage != nullptr) << internal::kKeyHasNullModelStorage;
|
||||
if (storage_ == nullptr) {
|
||||
storage_ = storage;
|
||||
return;
|
||||
}
|
||||
CHECK_EQ(storage, storage_) << internal::kObjectsFromOtherModelStorage;
|
||||
}
|
||||
|
||||
QuadraticExpression::QuadraticExpression(QuadraticExpression&& other) noexcept
|
||||
: storage_(std::exchange(other.storage_, nullptr)),
|
||||
: ModelStorageItemContainer(
|
||||
static_cast<ModelStorageItemContainer&&>(other)),
|
||||
quadratic_terms_(std::move(other.quadratic_terms_)),
|
||||
linear_terms_(std::move(other.linear_terms_)),
|
||||
offset_(std::exchange(other.offset_, 0.0)) {
|
||||
@@ -2147,7 +2061,8 @@ QuadraticExpression::QuadraticExpression(QuadraticExpression&& other) noexcept
|
||||
|
||||
QuadraticExpression& QuadraticExpression::operator=(
|
||||
QuadraticExpression&& other) noexcept {
|
||||
storage_ = std::exchange(other.storage_, nullptr);
|
||||
ModelStorageItemContainer::operator=(
|
||||
static_cast<ModelStorageItemContainer&&>(other));
|
||||
quadratic_terms_ = std::move(other.quadratic_terms_);
|
||||
other.quadratic_terms_.clear();
|
||||
linear_terms_ = std::move(other.linear_terms_);
|
||||
@@ -2164,12 +2079,12 @@ QuadraticExpression::QuadraticExpression(
|
||||
++num_calls_initializer_list_constructor_;
|
||||
#endif // MATH_OPT_USE_EXPRESSION_COUNTERS
|
||||
for (const LinearTerm& term : linear_terms) {
|
||||
SetOrCheckStorage(term.variable.storage());
|
||||
SetOrCheckStorage(term.variable);
|
||||
linear_terms_[term.variable] += term.coefficient;
|
||||
}
|
||||
for (const QuadraticTerm& term : quadratic_terms) {
|
||||
const QuadraticTermKey key = term.GetKey();
|
||||
SetOrCheckStorage(key.storage());
|
||||
SetOrCheckStorage(key);
|
||||
quadratic_terms_[key] += term.coefficient();
|
||||
}
|
||||
}
|
||||
@@ -2184,9 +2099,9 @@ QuadraticExpression::QuadraticExpression(const LinearTerm& term)
|
||||
: QuadraticExpression({}, {term}, 0.0) {}
|
||||
|
||||
QuadraticExpression::QuadraticExpression(LinearExpression expr)
|
||||
: storage_(std::exchange(expr.storage_, nullptr)),
|
||||
: ModelStorageItemContainer(expr.storage()),
|
||||
linear_terms_(std::move(expr.terms_)),
|
||||
offset_(std::exchange(expr.offset_, 0.0)) {
|
||||
offset_(expr.offset_) {
|
||||
#ifdef MATH_OPT_USE_EXPRESSION_COUNTERS
|
||||
++num_calls_linear_expression_constructor_;
|
||||
#endif // MATH_OPT_USE_EXPRESSION_COUNTERS
|
||||
@@ -2195,8 +2110,6 @@ QuadraticExpression::QuadraticExpression(LinearExpression expr)
|
||||
QuadraticExpression::QuadraticExpression(const QuadraticTerm& term)
|
||||
: QuadraticExpression({term}, {}, 0.0) {}
|
||||
|
||||
const ModelStorage* QuadraticExpression::storage() const { return storage_; }
|
||||
|
||||
double QuadraticExpression::offset() const { return offset_; }
|
||||
|
||||
const VariableMap<double>& QuadraticExpression::linear_terms() const {
|
||||
@@ -2582,13 +2495,13 @@ QuadraticExpression& QuadraticExpression::operator+=(const double value) {
|
||||
}
|
||||
|
||||
QuadraticExpression& QuadraticExpression::operator+=(const Variable variable) {
|
||||
SetOrCheckStorage(variable.storage());
|
||||
SetOrCheckStorage(variable);
|
||||
linear_terms_[variable] += 1;
|
||||
return *this;
|
||||
}
|
||||
|
||||
QuadraticExpression& QuadraticExpression::operator+=(const LinearTerm& term) {
|
||||
SetOrCheckStorage(term.variable.storage());
|
||||
SetOrCheckStorage(term.variable);
|
||||
linear_terms_[term.variable] += term.coefficient;
|
||||
return *this;
|
||||
}
|
||||
@@ -2598,7 +2511,7 @@ QuadraticExpression& QuadraticExpression::operator+=(
|
||||
offset_ += expr.offset();
|
||||
// See comment in LinearExpression::operator+=.
|
||||
if (!expr.terms().empty()) {
|
||||
SetOrCheckStorage(expr.storage());
|
||||
SetOrCheckStorage(expr);
|
||||
for (const auto& [v, coeff] : expr.terms()) {
|
||||
linear_terms_[v] += coeff;
|
||||
}
|
||||
@@ -2609,7 +2522,7 @@ QuadraticExpression& QuadraticExpression::operator+=(
|
||||
QuadraticExpression& QuadraticExpression::operator+=(
|
||||
const QuadraticTerm& term) {
|
||||
const QuadraticTermKey key = term.GetKey();
|
||||
SetOrCheckStorage(key.storage());
|
||||
SetOrCheckStorage(key);
|
||||
quadratic_terms_[key] += term.coefficient();
|
||||
return *this;
|
||||
}
|
||||
@@ -2619,7 +2532,7 @@ QuadraticExpression& QuadraticExpression::operator+=(
|
||||
offset_ += expr.offset();
|
||||
// See comment in LinearExpression::operator+=.
|
||||
if (!expr.linear_terms().empty() || !expr.quadratic_terms().empty()) {
|
||||
SetOrCheckStorage(expr.storage());
|
||||
SetOrCheckStorage(expr);
|
||||
for (const auto& [v, coeff] : expr.linear_terms()) {
|
||||
linear_terms_[v] += coeff;
|
||||
}
|
||||
@@ -2637,13 +2550,13 @@ QuadraticExpression& QuadraticExpression::operator-=(const double value) {
|
||||
}
|
||||
|
||||
QuadraticExpression& QuadraticExpression::operator-=(const Variable variable) {
|
||||
SetOrCheckStorage(variable.storage());
|
||||
SetOrCheckStorage(variable);
|
||||
linear_terms_[variable] -= 1;
|
||||
return *this;
|
||||
}
|
||||
|
||||
QuadraticExpression& QuadraticExpression::operator-=(const LinearTerm& term) {
|
||||
SetOrCheckStorage(term.variable.storage());
|
||||
SetOrCheckStorage(term.variable);
|
||||
linear_terms_[term.variable] -= term.coefficient;
|
||||
return *this;
|
||||
}
|
||||
@@ -2653,7 +2566,7 @@ QuadraticExpression& QuadraticExpression::operator-=(
|
||||
offset_ -= expr.offset();
|
||||
// See comment in LinearExpression::operator+=.
|
||||
if (!expr.terms().empty()) {
|
||||
SetOrCheckStorage(expr.storage());
|
||||
SetOrCheckStorage(expr);
|
||||
for (const auto& [v, coeff] : expr.terms()) {
|
||||
linear_terms_[v] -= coeff;
|
||||
}
|
||||
@@ -2664,7 +2577,7 @@ QuadraticExpression& QuadraticExpression::operator-=(
|
||||
QuadraticExpression& QuadraticExpression::operator-=(
|
||||
const QuadraticTerm& term) {
|
||||
const QuadraticTermKey key = term.GetKey();
|
||||
SetOrCheckStorage(key.storage());
|
||||
SetOrCheckStorage(key);
|
||||
quadratic_terms_[key] -= term.coefficient();
|
||||
return *this;
|
||||
}
|
||||
@@ -2674,7 +2587,7 @@ QuadraticExpression& QuadraticExpression::operator-=(
|
||||
offset_ -= expr.offset();
|
||||
// See comment in LinearExpression::operator+=.
|
||||
if (!expr.linear_terms().empty() || !expr.quadratic_terms().empty()) {
|
||||
SetOrCheckStorage(expr.storage());
|
||||
SetOrCheckStorage(expr);
|
||||
for (const auto& [v, coeff] : expr.linear_terms()) {
|
||||
linear_terms_[v] -= coeff;
|
||||
}
|
||||
|
||||
553
ortools/math_opt/elemental/BUILD.bazel
Normal file
553
ortools/math_opt/elemental/BUILD.bazel
Normal file
@@ -0,0 +1,553 @@
|
||||
# Copyright 2010-2025 Google LLC
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@rules_cc//cc:cc_library.bzl", "cc_library")
|
||||
load("@rules_cc//cc:cc_test.bzl", "cc_test")
|
||||
|
||||
cc_library(
|
||||
name = "attributes",
|
||||
hdrs = ["attributes.h"],
|
||||
visibility = ["//ortools/math_opt:__subpackages__"],
|
||||
deps = [
|
||||
":arrays",
|
||||
":elements",
|
||||
":symmetry",
|
||||
"//ortools/base:array",
|
||||
"@abseil-cpp//absl/strings:string_view",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "attributes_test",
|
||||
srcs = ["attributes_test.cc"],
|
||||
deps = [
|
||||
":arrays",
|
||||
":attributes",
|
||||
"//ortools/base:gmock_main",
|
||||
"//ortools/math_opt/testing:stream",
|
||||
"@abseil-cpp//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "elemental",
|
||||
srcs = [
|
||||
"elemental.cc",
|
||||
"elemental_export_model.cc",
|
||||
"elemental_from_proto.cc",
|
||||
"elemental_to_string.cc",
|
||||
],
|
||||
hdrs = ["elemental.h"],
|
||||
visibility = ["//ortools/math_opt:__subpackages__"],
|
||||
deps = [
|
||||
":arrays",
|
||||
":attr_key",
|
||||
":attr_storage",
|
||||
":attributes",
|
||||
":derived_data",
|
||||
":diff",
|
||||
":element_ref_tracker",
|
||||
":element_storage",
|
||||
":elements",
|
||||
":symmetry",
|
||||
":thread_safe_id_map",
|
||||
"//ortools/base:status_macros",
|
||||
"//ortools/math_opt:model_cc_proto",
|
||||
"//ortools/math_opt:model_update_cc_proto",
|
||||
"//ortools/math_opt:sparse_containers_cc_proto",
|
||||
"//ortools/math_opt/core:model_summary",
|
||||
"//ortools/math_opt/validators:model_validator",
|
||||
"@abseil-cpp//absl/algorithm:container",
|
||||
"@abseil-cpp//absl/container:flat_hash_map",
|
||||
"@abseil-cpp//absl/container:flat_hash_set",
|
||||
"@abseil-cpp//absl/log",
|
||||
"@abseil-cpp//absl/log:check",
|
||||
"@abseil-cpp//absl/log:die_if_null",
|
||||
"@abseil-cpp//absl/status",
|
||||
"@abseil-cpp//absl/status:statusor",
|
||||
"@abseil-cpp//absl/strings",
|
||||
"@abseil-cpp//absl/strings:string_view",
|
||||
"@abseil-cpp//absl/types:span",
|
||||
"@com_google_protobuf//:protobuf",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "elemental_test",
|
||||
srcs = ["elemental_test.cc"],
|
||||
deps = [
|
||||
":attr_key",
|
||||
":attributes",
|
||||
":derived_data",
|
||||
":diff",
|
||||
":elemental",
|
||||
":elemental_matcher",
|
||||
":elements",
|
||||
":symmetry",
|
||||
":testing",
|
||||
"//ortools/base:gmock_main",
|
||||
"@abseil-cpp//absl/log:check",
|
||||
"@abseil-cpp//absl/status",
|
||||
"@com_google_benchmark//:benchmark",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "derived_data",
|
||||
hdrs = ["derived_data.h"],
|
||||
visibility = ["//ortools/math_opt:__subpackages__"],
|
||||
deps = [
|
||||
":arrays",
|
||||
":attr_key",
|
||||
":attributes",
|
||||
":elements",
|
||||
"//ortools/util:fp_roundtrip_conv",
|
||||
"@abseil-cpp//absl/log",
|
||||
"@abseil-cpp//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "derived_data_test",
|
||||
srcs = ["derived_data_test.cc"],
|
||||
deps = [
|
||||
":arrays",
|
||||
":attr_key",
|
||||
":attributes",
|
||||
":derived_data",
|
||||
":elements",
|
||||
":symmetry",
|
||||
"//ortools/base:gmock_main",
|
||||
"//ortools/math_opt/testing:stream",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "element_storage",
|
||||
srcs = ["element_storage.cc"],
|
||||
hdrs = ["element_storage.h"],
|
||||
deps = [
|
||||
"//ortools/base:status_macros",
|
||||
"@abseil-cpp//absl/algorithm:container",
|
||||
"@abseil-cpp//absl/container:flat_hash_map",
|
||||
"@abseil-cpp//absl/status:statusor",
|
||||
"@abseil-cpp//absl/strings:string_view",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "element_storage_test",
|
||||
srcs = ["element_storage_test.cc"],
|
||||
deps = [
|
||||
":element_storage",
|
||||
"//ortools/base:gmock_main",
|
||||
"@abseil-cpp//absl/status",
|
||||
"@com_google_benchmark//:benchmark",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "element_diff",
|
||||
hdrs = ["element_diff.h"],
|
||||
deps = ["@abseil-cpp//absl/container:flat_hash_set"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "element_diff_test",
|
||||
srcs = ["element_diff_test.cc"],
|
||||
deps = [
|
||||
":element_diff",
|
||||
"//ortools/base:gmock_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "diff",
|
||||
srcs = ["diff.cc"],
|
||||
hdrs = ["diff.h"],
|
||||
deps = [
|
||||
"derived_data",
|
||||
":attr_diff",
|
||||
":attr_key",
|
||||
":attributes",
|
||||
":element_diff",
|
||||
":elements",
|
||||
"@abseil-cpp//absl/container:flat_hash_set",
|
||||
"@abseil-cpp//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "diff_test",
|
||||
srcs = ["diff_test.cc"],
|
||||
deps = [
|
||||
":attr_key",
|
||||
":attributes",
|
||||
":diff",
|
||||
":elements",
|
||||
"//ortools/base:gmock_main",
|
||||
"@abseil-cpp//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "attr_storage",
|
||||
hdrs = ["attr_storage.h"],
|
||||
deps = [
|
||||
":attr_key",
|
||||
":symmetry",
|
||||
"//ortools/base:map_util",
|
||||
"@abseil-cpp//absl/container:flat_hash_map",
|
||||
"@abseil-cpp//absl/container:flat_hash_set",
|
||||
"@abseil-cpp//absl/functional:function_ref",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "attr_storage_test",
|
||||
srcs = ["attr_storage_test.cc"],
|
||||
deps = [
|
||||
":attr_key",
|
||||
":attr_storage",
|
||||
":symmetry",
|
||||
"//ortools/base:gmock_main",
|
||||
"@com_google_benchmark//:benchmark",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "attr_diff",
|
||||
hdrs = ["attr_diff.h"],
|
||||
deps = [
|
||||
":attr_key",
|
||||
"@abseil-cpp//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "attr_diff_test",
|
||||
srcs = ["attr_diff_test.cc"],
|
||||
deps = [
|
||||
":attr_diff",
|
||||
":attr_key",
|
||||
":symmetry",
|
||||
"//ortools/base:gmock_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "attr_key",
|
||||
hdrs = ["attr_key.h"],
|
||||
visibility = ["//ortools/math_opt:__subpackages__"],
|
||||
deps = [
|
||||
":elements",
|
||||
":symmetry",
|
||||
"//ortools/base:status_macros",
|
||||
"@abseil-cpp//absl/container:flat_hash_map",
|
||||
"@abseil-cpp//absl/container:flat_hash_set",
|
||||
"@abseil-cpp//absl/log:check",
|
||||
"@abseil-cpp//absl/strings",
|
||||
"@abseil-cpp//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "attr_key_test",
|
||||
srcs = ["attr_key_test.cc"],
|
||||
deps = [
|
||||
":attr_key",
|
||||
":elements",
|
||||
":symmetry",
|
||||
":testing",
|
||||
"//ortools/base:gmock_main",
|
||||
"//ortools/math_opt/testing:stream",
|
||||
"@abseil-cpp//absl/algorithm:container",
|
||||
"@abseil-cpp//absl/container:flat_hash_map",
|
||||
"@abseil-cpp//absl/container:flat_hash_set",
|
||||
"@abseil-cpp//absl/hash:hash_testing",
|
||||
"@abseil-cpp//absl/meta:type_traits",
|
||||
"@abseil-cpp//absl/status",
|
||||
"@abseil-cpp//absl/strings",
|
||||
"@com_google_benchmark//:benchmark",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "arrays",
|
||||
hdrs = ["arrays.h"],
|
||||
visibility = ["//ortools/math_opt/elemental:__subpackages__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "elemental_differencer",
|
||||
srcs = ["elemental_differencer.cc"],
|
||||
hdrs = ["elemental_differencer.h"],
|
||||
deps = [
|
||||
":attr_key",
|
||||
":attributes",
|
||||
":derived_data",
|
||||
":elemental",
|
||||
":elements",
|
||||
"@abseil-cpp//absl/container:flat_hash_set",
|
||||
"@abseil-cpp//absl/log:check",
|
||||
"@abseil-cpp//absl/status:statusor",
|
||||
"@abseil-cpp//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "elemental_differencer_test",
|
||||
srcs = ["elemental_differencer_test.cc"],
|
||||
deps = [
|
||||
":attr_key",
|
||||
":attributes",
|
||||
":elemental",
|
||||
":elemental_differencer",
|
||||
":elements",
|
||||
"//ortools/base:gmock_main",
|
||||
"@abseil-cpp//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "arrays_test",
|
||||
srcs = ["arrays_test.cc"],
|
||||
deps = [
|
||||
":arrays",
|
||||
"//ortools/base:array",
|
||||
"//ortools/base:gmock_main",
|
||||
"@abseil-cpp//absl/status",
|
||||
"@abseil-cpp//absl/strings",
|
||||
"@abseil-cpp//absl/strings:string_view",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "elemental_export_model_test",
|
||||
srcs = ["elemental_export_model_test.cc"],
|
||||
deps = [
|
||||
":attr_key",
|
||||
":attributes",
|
||||
":derived_data",
|
||||
":elemental",
|
||||
":elements",
|
||||
"//ortools/base:gmock_main",
|
||||
"//ortools/math_opt:model_cc_proto",
|
||||
"//ortools/math_opt:sparse_containers_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "elemental_to_string_test",
|
||||
srcs = ["elemental_to_string_test.cc"],
|
||||
deps = [
|
||||
":attr_key",
|
||||
":attributes",
|
||||
":elemental",
|
||||
":elements",
|
||||
"//ortools/base:gmock_main",
|
||||
"//ortools/math_opt/testing:stream",
|
||||
"@abseil-cpp//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "safe_attr_ops_test",
|
||||
srcs = ["safe_attr_ops_test.cc"],
|
||||
deps = [
|
||||
":attr_key",
|
||||
":attributes",
|
||||
":elemental",
|
||||
":elements",
|
||||
":safe_attr_ops",
|
||||
"//ortools/base:gmock_main",
|
||||
"@abseil-cpp//absl/status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "safe_attr_ops",
|
||||
hdrs = ["safe_attr_ops.h"],
|
||||
visibility = ["//ortools/math_opt/elemental/c_api:__subpackages__"],
|
||||
deps = [
|
||||
":derived_data",
|
||||
":elemental",
|
||||
"//ortools/base:status_macros",
|
||||
"@abseil-cpp//absl/status",
|
||||
"@abseil-cpp//absl/status:statusor",
|
||||
"@abseil-cpp//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "testing",
|
||||
testonly = 1,
|
||||
hdrs = ["testing.h"],
|
||||
deps = [":attr_key"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "symmetry",
|
||||
hdrs = ["symmetry.h"],
|
||||
deps = [
|
||||
"@abseil-cpp//absl/log:check",
|
||||
"@abseil-cpp//absl/strings:str_format",
|
||||
"@abseil-cpp//absl/strings:string_view",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "elemental_matcher",
|
||||
testonly = 1,
|
||||
srcs = ["elemental_matcher.cc"],
|
||||
hdrs = ["elemental_matcher.h"],
|
||||
deps = [
|
||||
":elemental",
|
||||
":elemental_differencer",
|
||||
"//ortools/base:gmock",
|
||||
"@abseil-cpp//absl/base:core_headers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "element_ref_tracker",
|
||||
hdrs = ["element_ref_tracker.h"],
|
||||
deps = [
|
||||
":attr_key",
|
||||
":elements",
|
||||
"//ortools/base:map_util",
|
||||
"@abseil-cpp//absl/container:flat_hash_map",
|
||||
"@abseil-cpp//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "elements",
|
||||
srcs = ["elements.cc"],
|
||||
hdrs = ["elements.h"],
|
||||
visibility = ["//ortools/math_opt:__subpackages__"],
|
||||
deps = [
|
||||
"//ortools/base:array",
|
||||
"@abseil-cpp//absl/base:core_headers",
|
||||
"@abseil-cpp//absl/log:check",
|
||||
"@abseil-cpp//absl/strings",
|
||||
"@abseil-cpp//absl/strings:str_format",
|
||||
"@abseil-cpp//absl/strings:string_view",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "elements_test",
|
||||
srcs = ["elements_test.cc"],
|
||||
deps = [
|
||||
":elements",
|
||||
"//ortools/base:gmock_main",
|
||||
"//ortools/math_opt/testing:stream",
|
||||
"@abseil-cpp//absl/hash:hash_testing",
|
||||
"@abseil-cpp//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "elemental_matcher_test",
|
||||
srcs = ["elemental_matcher_test.cc"],
|
||||
deps = [
|
||||
":elemental",
|
||||
":elemental_differencer",
|
||||
":elemental_matcher",
|
||||
":elements",
|
||||
"//ortools/base:gmock_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "element_ref_tracker_test",
|
||||
srcs = ["element_ref_tracker_test.cc"],
|
||||
deps = [
|
||||
":attr_key",
|
||||
":element_ref_tracker",
|
||||
":elements",
|
||||
":symmetry",
|
||||
"//ortools/base:gmock_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "thread_safe_id_map",
|
||||
hdrs = ["thread_safe_id_map.h"],
|
||||
deps = [
|
||||
"//ortools/base:stl_util",
|
||||
"@abseil-cpp//absl/base:core_headers",
|
||||
"@abseil-cpp//absl/container:flat_hash_set",
|
||||
"@abseil-cpp//absl/synchronization",
|
||||
"@abseil-cpp//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "thread_safe_id_map_test",
|
||||
srcs = ["thread_safe_id_map_test.cc"],
|
||||
deps = [
|
||||
":thread_safe_id_map",
|
||||
"//ortools/base:gmock_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "elemental_from_proto_test",
|
||||
srcs = ["elemental_from_proto_test.cc"],
|
||||
deps = [
|
||||
":attr_key",
|
||||
":attributes",
|
||||
":derived_data",
|
||||
":elemental",
|
||||
":elemental_matcher",
|
||||
":elements",
|
||||
"//ortools/base:gmock_main",
|
||||
"//ortools/math_opt:model_cc_proto",
|
||||
"//ortools/math_opt:sparse_containers_cc_proto",
|
||||
"@abseil-cpp//absl/status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "elemental_from_proto_fuzz_test",
|
||||
srcs = ["elemental_from_proto_fuzz_test.cc"],
|
||||
tags = ["componentid:1147829"],
|
||||
deps = [
|
||||
":elemental",
|
||||
":elemental_matcher",
|
||||
"//ortools/base:fuzztest",
|
||||
"//ortools/base:gmock_main",
|
||||
"//ortools/math_opt:model_update_cc_proto",
|
||||
"@abseil-cpp//absl/status:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "elemental_update_from_proto_test",
|
||||
srcs = ["elemental_update_from_proto_test.cc"],
|
||||
deps = [
|
||||
":attr_key",
|
||||
":attributes",
|
||||
":derived_data",
|
||||
":elemental",
|
||||
":elemental_matcher",
|
||||
":elements",
|
||||
"//ortools/base:gmock_main",
|
||||
"//ortools/math_opt:model_cc_proto",
|
||||
"//ortools/math_opt:model_update_cc_proto",
|
||||
"//ortools/math_opt:sparse_containers_cc_proto",
|
||||
"@abseil-cpp//absl/status",
|
||||
],
|
||||
)
|
||||
30
ortools/math_opt/elemental/CMakeLists.txt
Normal file
30
ortools/math_opt/elemental/CMakeLists.txt
Normal file
@@ -0,0 +1,30 @@
|
||||
# Copyright 2010-2025 Google LLC
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
set(NAME ${PROJECT_NAME}_math_opt_elemental)
|
||||
add_library(${NAME} OBJECT)
|
||||
|
||||
file(GLOB_RECURSE _SRCS "*.h" "*.cc")
|
||||
list(FILTER _SRCS EXCLUDE REGEX ".*/.*_test.cc")
|
||||
list(FILTER _SRCS EXCLUDE REGEX "/elemental_matcher.*")
|
||||
list(FILTER _SRCS EXCLUDE REGEX "/python/.*")
|
||||
list(FILTER _SRCS EXCLUDE REGEX "/codegen/codegen.cc")
|
||||
target_sources(${NAME} PRIVATE ${_SRCS})
|
||||
set_target_properties(${NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_include_directories(${NAME} PUBLIC
|
||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}>
|
||||
$<BUILD_INTERFACE:${PROJECT_BINARY_DIR}>)
|
||||
target_link_libraries(${NAME} PRIVATE
|
||||
${PROJECT_NAMESPACE}::math_opt_proto
|
||||
absl::strings
|
||||
)
|
||||
3
ortools/math_opt/elemental/README.md
Normal file
3
ortools/math_opt/elemental/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Elemental
|
||||
|
||||
See go/math-opt-elemental and g/math-opt-dev/c/0cgOO6qkoWM.
|
||||
74
ortools/math_opt/elemental/arrays.h
Normal file
74
ortools/math_opt/elemental/arrays.h
Normal file
@@ -0,0 +1,74 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Utilities to apply template functors on index ranges.
|
||||
// See tests for examples.
|
||||
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ARRAYS_H_
|
||||
#define OR_TOOLS_MATH_OPT_ELEMENTAL_ARRAYS_H_
|
||||
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
// Calls `fn<0, ..., n-1>()`, and returns the result. Typically used for
|
||||
// simple reduce operations that can be expressed as a fold.
|
||||
//
|
||||
// Examples:
|
||||
// - Sum of elements from 0 to 5 (result is 15):
|
||||
// `ApplyOnIndexRange<6>([]<int... i>() { return (i + ... + 0); });`
|
||||
//
|
||||
// - Sum of elements of array `a`:
|
||||
// ```
|
||||
// ApplyOnIndexRange<a.size()>([&a]<int... i>() {
|
||||
// return (a[i] + ... + 0);
|
||||
// });
|
||||
// ```
|
||||
template <int n, typename Fn>
|
||||
constexpr decltype(auto) ApplyOnIndexRange(Fn&& fn) {
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
return [&fn]<int... is>(std::integer_sequence<int, is...>) mutable {
|
||||
return fn.template operator()<is...>();
|
||||
}(std::make_integer_sequence<int, n>());
|
||||
}
|
||||
|
||||
// Calls (fn<0>(), ..., fn<n-1>()).
|
||||
// Typically used for independent operations on elements, or more complex reduce
|
||||
// operations that cannot be expressed with a fold.
|
||||
//
|
||||
// Example (independent operations): Log each array element for some array `a`:
|
||||
// `ForEachIndex<a.size()>([&a]<int i>() { LOG(ERROR) << a[i]; });`
|
||||
//
|
||||
// NOTE: this returns the result of the last call, which allows returning some
|
||||
// internal state (and avoids capturing an external variable by reference) for
|
||||
// complex fold operations. See `CollectTest` for an example.
|
||||
template <int n, typename Fn>
|
||||
constexpr decltype(auto) ForEachIndex(Fn&& fn) {
|
||||
return ApplyOnIndexRange<n>(
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
[&fn]<int... is>() { return (fn.template operator()<is>(), ...); });
|
||||
}
|
||||
|
||||
// Calls `fn` of each element of `tuple`, and returns the result of the
|
||||
// last invocation.
|
||||
template <typename Fn, typename Tuple>
|
||||
constexpr decltype(auto) ForEach(Fn&& fn, Tuple&& tuple) {
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
return std::apply([&fn]<typename... Ts>(
|
||||
Ts&&... ts) { return (fn(std::forward<Ts>(ts)), ...); },
|
||||
std::forward<Tuple>(tuple));
|
||||
}
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ARRAYS_H_
|
||||
179
ortools/math_opt/elemental/arrays_test.cc
Normal file
179
ortools/math_opt/elemental/arrays_test.cc
Normal file
@@ -0,0 +1,179 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/arrays.h"
|
||||
|
||||
#include <array>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/base/array.h"
|
||||
#include "ortools/base/gmock.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
|
||||
// Sums the elements of an array-like object `a`.
|
||||
template <auto a>
|
||||
constexpr int Sum() {
|
||||
return ApplyOnIndexRange<std::size(a)>(
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
[]<int... i>() { return (a[i] + ... + 0); });
|
||||
}
|
||||
|
||||
// Same as `Sum`, but starts at 1.
|
||||
template <auto a>
|
||||
constexpr int SumPlusOne() {
|
||||
return ApplyOnIndexRange<std::size(a)>(
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
[]<int... i>() { return (a[i] + ... + 1); });
|
||||
}
|
||||
|
||||
#if __cplusplus >= 202002L
|
||||
// NOLINTBEGIN(clang-diagnostic-pre-c++20-compat)
|
||||
TEST(ApplyOnIndexRangeTest, Sum) {
|
||||
EXPECT_EQ(Sum<gtl::to_array({5, 3, 1})>(), 9);
|
||||
EXPECT_EQ(SumPlusOne<gtl::to_array({5, 3, 1})>(), 10);
|
||||
}
|
||||
// NOLINTEND(clang-diagnostic-pre-c++20-compat)
|
||||
#endif
|
||||
|
||||
// Returns the weighted sum of the elements of an array-like object `a`, where
|
||||
// weights are indices.
|
||||
template <auto a>
|
||||
constexpr double ScaledSum() {
|
||||
return ApplyOnIndexRange<std::size(a)>(
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
[]<int... i>() { return ((i * a[i]) + ... + 0.0); });
|
||||
}
|
||||
|
||||
#if __cplusplus >= 202002L
|
||||
// NOLINTBEGIN(clang-diagnostic-pre-c++20-compat)
|
||||
TEST(ApplyOnIndexRangeTest, ScaledSum) {
|
||||
EXPECT_EQ(ScaledSum<gtl::to_array({5, 3, 1})>(), 5.0);
|
||||
}
|
||||
// NOLINTEND(clang-diagnostic-pre-c++20-compat)
|
||||
#endif
|
||||
|
||||
// Returns the number of even elements in an array-like object `a`.
|
||||
template <auto a>
|
||||
constexpr int CountEven() {
|
||||
return ApplyOnIndexRange<std::size(a)>(
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
[]<int... i>() { return ((a[i] % 2 == 0 ? 1 : 0) + ... + 0); });
|
||||
}
|
||||
|
||||
#if __cplusplus >= 202002L
|
||||
// NOLINTBEGIN(clang-diagnostic-pre-c++20-compat)
|
||||
TEST(ApplyOnIndexRangeTest, CountEven) {
|
||||
EXPECT_EQ(CountEven<gtl::to_array({5, 4, 8, 1, 10})>(), 3);
|
||||
}
|
||||
// NOLINTEND(clang-diagnostic-pre-c++20-compat)
|
||||
#endif
|
||||
|
||||
// Returns array of doubles of the same size as `a`, where each element has been
|
||||
// halved.
|
||||
template <auto a>
|
||||
constexpr std::array<double, std::size(a)> Half() {
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
return ApplyOnIndexRange<std::size(a)>([]<int... i>() {
|
||||
return std::array<double, std::size(a)>(
|
||||
{(static_cast<double>(a[i]) / 2.0)...});
|
||||
});
|
||||
}
|
||||
|
||||
#if __cplusplus >= 202002L
|
||||
// NOLINTBEGIN(clang-diagnostic-pre-c++20-compat)
|
||||
TEST(ApplyOnIndexRangeTest, Half) {
|
||||
EXPECT_THAT(Half<gtl::to_array({5, 4, 8, 1, 10})>(),
|
||||
ElementsAre(2.5, 2.0, 4.0, 0.5, 5.0));
|
||||
}
|
||||
// NOLINTEND(clang-diagnostic-pre-c++20-compat)
|
||||
#endif
|
||||
|
||||
// Returns true of all elements of `a` are even.
|
||||
template <auto a>
|
||||
constexpr int AllEven() {
|
||||
return ApplyOnIndexRange<std::size(a)>(
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
[]<int... i>() { return (((a[i] % 2) == 0) && ...); });
|
||||
}
|
||||
|
||||
// Returns true of any element of `a` is even.
|
||||
template <auto a>
|
||||
constexpr int AnyEven() {
|
||||
return ApplyOnIndexRange<std::size(a)>(
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
[]<int... i>() { return (((a[i] % 2) == 0) || ...); });
|
||||
}
|
||||
|
||||
#if __cplusplus >= 202002L
|
||||
// NOLINTBEGIN(clang-diagnostic-pre-c++20-compat)
|
||||
TEST(ApplyOnIndexRangeTest, Even) {
|
||||
EXPECT_FALSE(AllEven<gtl::to_array({5, 4, 8, 1, 10})>());
|
||||
EXPECT_TRUE(AnyEven<gtl::to_array({5, 4, 8, 1, 10})>());
|
||||
|
||||
EXPECT_TRUE(AllEven<gtl::to_array({8, 2, 6})>());
|
||||
EXPECT_TRUE(AnyEven<gtl::to_array({8, 2, 6})>());
|
||||
|
||||
EXPECT_FALSE(AllEven<gtl::to_array({3, 7, 1})>());
|
||||
EXPECT_FALSE(AnyEven<gtl::to_array({3, 7, 1})>());
|
||||
}
|
||||
// NOLINTEND(clang-diagnostic-pre-c++20-compat)
|
||||
#endif
|
||||
|
||||
// A example of a more complex reduce operation using `ForEachIndex`. Here, we
|
||||
// want to collect a list of integers for which an operation (`may_fail`)
|
||||
// failed.
|
||||
TEST(ForEachIndexTest, CollectTest) {
|
||||
constexpr auto may_fail = [](int i) {
|
||||
if (i == 3 || i == 7 || i == 42) {
|
||||
return absl::InvalidArgumentError("bad number");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
};
|
||||
|
||||
EXPECT_THAT(
|
||||
ForEachIndex<21>(
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
[&may_fail, failed_indices = std::vector<int>()]<int i>() mutable
|
||||
-> const std::vector<int>& {
|
||||
if (!may_fail(i).ok()) {
|
||||
failed_indices.push_back(i);
|
||||
}
|
||||
return failed_indices;
|
||||
}),
|
||||
ElementsAre(3, 7));
|
||||
}
|
||||
|
||||
TEST(ForEachTest, StrCatHeterogeneousTypes) {
|
||||
EXPECT_EQ(
|
||||
ForEach(
|
||||
[r = std::string()](const auto& v) mutable -> absl::string_view {
|
||||
absl::StrAppend(&r, " ", v);
|
||||
return r;
|
||||
},
|
||||
std::make_tuple("a", 1, 0.5)),
|
||||
" a 1 0.5");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace operations_research::math_opt
|
||||
58
ortools/math_opt/elemental/attr_diff.h
Normal file
58
ortools/math_opt/elemental/attr_diff.h
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_DIFF_H_
|
||||
#define OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_DIFF_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "ortools/math_opt/elemental/attr_key.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
// Tracks modifications to an Attribute with a key size of n (e.g., variable
|
||||
// lower bound has a key size of 1).
|
||||
template <int n, typename Symmetry>
|
||||
class AttrDiff {
|
||||
public:
|
||||
using Key = AttrKey<n, Symmetry>;
|
||||
|
||||
// On creation, the attribute is not modified for any key.
|
||||
AttrDiff() = default;
|
||||
|
||||
// Clear all tracked modifications.
|
||||
void Advance() { modified_keys_.clear(); }
|
||||
|
||||
// Mark the attribute as modified for `key`.
|
||||
void SetModified(const Key key) { modified_keys_.insert(key); }
|
||||
|
||||
// Returns the attribute keys that have been modified for this attribute (the
|
||||
// elements where set_modified() was called without a subsequent call to
|
||||
// Advance()).
|
||||
const AttrKeyHashSet<Key>& modified_keys() const { return modified_keys_; }
|
||||
|
||||
bool has_modified_keys() const { return !modified_keys_.empty(); }
|
||||
|
||||
// Stop tracking modifications for this attribute key. (Typically invoked when
|
||||
// an element in the key was deleted from the model.)
|
||||
void Erase(const Key key) { modified_keys_.erase(key); }
|
||||
|
||||
private:
|
||||
AttrKeyHashSet<Key> modified_keys_;
|
||||
};
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_DIFF_H_
|
||||
168
ortools/math_opt/elemental/attr_diff_test.cc
Normal file
168
ortools/math_opt/elemental/attr_diff_test.cc
Normal file
@@ -0,0 +1,168 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/attr_diff.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/base/gmock.h"
|
||||
#include "ortools/math_opt/elemental/attr_key.h"
|
||||
#include "ortools/math_opt/elemental/symmetry.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
using ::testing::IsEmpty;
|
||||
using ::testing::UnorderedElementsAre;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// AttrDiff<0>
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(AttrDiff0Test, InitNotModified) {
|
||||
AttrDiff<0, NoSymmetry> diff;
|
||||
EXPECT_THAT(diff.modified_keys(), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(AttrDiff0Test, SetModified) {
|
||||
AttrDiff<0, NoSymmetry> diff;
|
||||
diff.SetModified(AttrKey());
|
||||
EXPECT_THAT(diff.modified_keys(), UnorderedElementsAre(AttrKey()));
|
||||
}
|
||||
|
||||
TEST(AttrDiff0Test, Advance) {
|
||||
AttrDiff<0, NoSymmetry> diff;
|
||||
diff.SetModified(AttrKey());
|
||||
diff.Advance();
|
||||
EXPECT_THAT(diff.modified_keys(), IsEmpty());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Attr1Diff
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(AttrDiff1Test, InitNotModified) {
|
||||
AttrDiff<1, NoSymmetry> diff;
|
||||
EXPECT_THAT(diff.modified_keys(), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(AttrDiff1Test, SetModified) {
|
||||
AttrDiff<1, NoSymmetry> diff;
|
||||
diff.SetModified(AttrKey(2));
|
||||
diff.SetModified(AttrKey(5));
|
||||
diff.SetModified(AttrKey(6));
|
||||
EXPECT_THAT(diff.modified_keys(),
|
||||
UnorderedElementsAre(AttrKey(2), AttrKey(5), AttrKey(6)));
|
||||
}
|
||||
|
||||
TEST(AttrDiff1Test, Advance) {
|
||||
AttrDiff<1, NoSymmetry> diff;
|
||||
diff.SetModified(AttrKey(2));
|
||||
diff.SetModified(AttrKey(5));
|
||||
|
||||
diff.Advance();
|
||||
EXPECT_THAT(diff.modified_keys(), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(AttrDiff1Test, EraseIsModifiedGetsRemoved) {
|
||||
AttrDiff<1, NoSymmetry> diff;
|
||||
diff.SetModified(AttrKey(2));
|
||||
diff.SetModified(AttrKey(5));
|
||||
diff.SetModified(AttrKey(6));
|
||||
|
||||
diff.Erase(AttrKey(5));
|
||||
EXPECT_THAT(diff.modified_keys(),
|
||||
UnorderedElementsAre(AttrKey(2), AttrKey(6)));
|
||||
}
|
||||
|
||||
TEST(AttrDiff1Test, EraseNotModifiedNoEffect) {
|
||||
AttrDiff<1, NoSymmetry> diff;
|
||||
diff.SetModified(AttrKey(2));
|
||||
diff.SetModified(AttrKey(5));
|
||||
|
||||
diff.Erase(AttrKey(1));
|
||||
EXPECT_THAT(diff.modified_keys(),
|
||||
UnorderedElementsAre(AttrKey(2), AttrKey(5)));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Attr2Diff
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(AttrDiffTest2, InitNotModified) {
|
||||
AttrDiff<2, NoSymmetry> diff;
|
||||
EXPECT_THAT(diff.modified_keys(), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(AttrDiffTest2, SetModified) {
|
||||
AttrDiff<2, NoSymmetry> diff;
|
||||
diff.SetModified(AttrKey(2, 4));
|
||||
diff.SetModified(AttrKey(5, 2));
|
||||
diff.SetModified(AttrKey(2, 5));
|
||||
diff.SetModified(AttrKey(6, 6));
|
||||
EXPECT_THAT(diff.modified_keys(),
|
||||
UnorderedElementsAre(AttrKey(2, 4), AttrKey(5, 2), AttrKey(2, 5),
|
||||
AttrKey(6, 6)));
|
||||
}
|
||||
|
||||
TEST(AttrDiffTest2, Advance) {
|
||||
AttrDiff<2, NoSymmetry> diff;
|
||||
diff.SetModified(AttrKey(2, 3));
|
||||
diff.SetModified(AttrKey(2, 8));
|
||||
|
||||
diff.Advance();
|
||||
EXPECT_THAT(diff.modified_keys(), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(AttrDiffTest2, EraseIsModifiedGetsRemoved) {
|
||||
AttrDiff<2, NoSymmetry> diff;
|
||||
diff.SetModified(AttrKey(2, 5));
|
||||
diff.SetModified(AttrKey(4, 3));
|
||||
diff.SetModified(AttrKey(3, 4));
|
||||
diff.SetModified(AttrKey(6, 6));
|
||||
|
||||
EXPECT_THAT(diff.modified_keys(),
|
||||
UnorderedElementsAre(AttrKey(2, 5), AttrKey(3, 4), AttrKey(4, 3),
|
||||
AttrKey(6, 6)));
|
||||
|
||||
diff.Erase(AttrKey(4, 3));
|
||||
EXPECT_THAT(
|
||||
diff.modified_keys(),
|
||||
UnorderedElementsAre(AttrKey(2, 5), AttrKey(3, 4), AttrKey(6, 6)));
|
||||
}
|
||||
|
||||
TEST(AttrDiffTest2, EraseIsModifiedGetsRemovedSymmetric) {
|
||||
using Diff = AttrDiff<2, ElementSymmetry<0, 1>>;
|
||||
using Key = Diff::Key;
|
||||
Diff diff;
|
||||
diff.SetModified(Key(2, 5));
|
||||
diff.SetModified(Key(4, 3));
|
||||
diff.SetModified(Key(3, 4)); // Noop, same as (4,3).
|
||||
diff.SetModified(Key(6, 6));
|
||||
|
||||
EXPECT_THAT(diff.modified_keys(),
|
||||
UnorderedElementsAre(Key(2, 5), Key(3, 4), Key(6, 6)));
|
||||
|
||||
diff.Erase(Key(4, 3));
|
||||
EXPECT_THAT(diff.modified_keys(), UnorderedElementsAre(Key(2, 5), Key(6, 6)));
|
||||
}
|
||||
|
||||
TEST(AttrDiffTest2, EraseNotModifiedNoEffect) {
|
||||
AttrDiff<2, NoSymmetry> diff;
|
||||
diff.SetModified(AttrKey(2, 5));
|
||||
diff.SetModified(AttrKey(6, 6));
|
||||
|
||||
diff.Erase(AttrKey(1, 3));
|
||||
EXPECT_THAT(diff.modified_keys(),
|
||||
UnorderedElementsAre(AttrKey(2, 5), AttrKey(6, 6)));
|
||||
}
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
360
ortools/math_opt/elemental/attr_key.h
Normal file
360
ortools/math_opt/elemental/attr_key.h
Normal file
@@ -0,0 +1,360 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_KEY_H_
|
||||
#define OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_KEY_H_
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <ostream>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "ortools/base/status_builder.h"
|
||||
#include "ortools/math_opt/elemental/elements.h"
|
||||
#include "ortools/math_opt/elemental/symmetry.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
// An attribute key for an attribute keyed on `n` elements.
|
||||
// `AttrKey` is a value type.
|
||||
template <int n, typename Symmetry = NoSymmetry>
|
||||
class AttrKey {
|
||||
public:
|
||||
using value_type = int64_t;
|
||||
using SymmetryT = Symmetry;
|
||||
|
||||
// Default constructor: values are uninitialized.
|
||||
constexpr AttrKey() {} // NOLINT: uninitialized on purpose.
|
||||
|
||||
template <typename... Ints,
|
||||
typename = std::enable_if_t<(sizeof...(Ints) == n &&
|
||||
(std::is_integral_v<Ints> && ...))>>
|
||||
explicit constexpr AttrKey(const Ints... ids) {
|
||||
auto push_back = [this, i = 0](auto e) mutable { element_ids_[i++] = e; };
|
||||
(push_back(ids), ...);
|
||||
Symmetry::Enforce(element_ids_);
|
||||
}
|
||||
|
||||
template <ElementType... element_types,
|
||||
typename = std::enable_if_t<(sizeof...(element_types) == n)>>
|
||||
explicit constexpr AttrKey(const ElementId<element_types>... ids)
|
||||
: AttrKey(ids.value()...) {}
|
||||
|
||||
constexpr AttrKey(std::array<value_type, n> ids) // NOLINT: pybind11.
|
||||
: element_ids_(ids) {
|
||||
Symmetry::Enforce(element_ids_);
|
||||
}
|
||||
|
||||
// Canonicalizes a non-canonical key, i.e., enforces the symmetries
|
||||
static constexpr AttrKey Canonicalize(AttrKey<n, NoSymmetry> key) {
|
||||
return AttrKey(key.element_ids_);
|
||||
}
|
||||
|
||||
// Creates a key from a range of `n` elements.
|
||||
static absl::StatusOr<AttrKey> FromRange(absl::Span<const int64_t> range) {
|
||||
if (range.size() != n) {
|
||||
return ::util::InvalidArgumentErrorBuilder()
|
||||
<< "cannot build AttrKey<" << n << "> from a range of size "
|
||||
<< range.size();
|
||||
}
|
||||
AttrKey result;
|
||||
std::copy(range.begin(), range.end(), result.element_ids_.begin());
|
||||
Symmetry::Enforce(result.element_ids_);
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr AttrKey(const AttrKey&) = default;
|
||||
constexpr AttrKey& operator=(const AttrKey&) = default;
|
||||
|
||||
static constexpr int size() { return n; }
|
||||
|
||||
// Element access.
|
||||
constexpr value_type operator[](const int dim) const {
|
||||
DCHECK_LT(dim, n);
|
||||
DCHECK_GE(dim, 0);
|
||||
return element_ids_[dim];
|
||||
}
|
||||
constexpr value_type& operator[](const int dim) {
|
||||
DCHECK_LT(dim, n);
|
||||
DCHECK_GE(dim, 0);
|
||||
return element_ids_[dim];
|
||||
}
|
||||
|
||||
// Element iteration.
|
||||
constexpr const value_type* begin() const { return element_ids_.begin(); }
|
||||
constexpr const value_type* end() const { return element_ids_.end(); }
|
||||
|
||||
// `AttrKey` is comparable (ordering is lexicographic) and hashable.
|
||||
//
|
||||
// TODO(b/365998156): post C++ 20, replace by spaceship operator (with all
|
||||
// comparison operators below). Do NOT use the default generated operator (see
|
||||
// below).
|
||||
friend constexpr bool operator==(const AttrKey& l, const AttrKey& r) {
|
||||
// This is much faster than using the default generated `operator==`.
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (l.element_ids_[i] != r.element_ids_[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
friend constexpr bool operator<(const AttrKey& l, const AttrKey& r) {
|
||||
// This is much faster than using the default generated `operator<`.
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (l.element_ids_[i] < r.element_ids_[i]) {
|
||||
return true;
|
||||
}
|
||||
if (l.element_ids_[i] > r.element_ids_[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
friend constexpr bool operator<=(const AttrKey& l, const AttrKey& r) {
|
||||
// This is much faster than using the default generated `operator<`.
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (l.element_ids_[i] < r.element_ids_[i]) {
|
||||
return true;
|
||||
}
|
||||
if (l.element_ids_[i] > r.element_ids_[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
friend constexpr bool operator>(const AttrKey& l, const AttrKey& r) {
|
||||
return r < l;
|
||||
}
|
||||
|
||||
friend constexpr bool operator>=(const AttrKey& l, const AttrKey& r) {
|
||||
return r <= l;
|
||||
}
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const AttrKey& a) {
|
||||
return H::combine_contiguous(std::move(h), a.element_ids_.data(), n);
|
||||
}
|
||||
|
||||
// `AttrKey` is printable for logging and tests.
|
||||
template <typename Sink>
|
||||
friend void AbslStringify(Sink& sink, const AttrKey& key) {
|
||||
sink.Append(absl::StrCat(
|
||||
"AttrKey(", absl::StrJoin(absl::MakeSpan(key.element_ids_), ", "),
|
||||
")"));
|
||||
}
|
||||
|
||||
// Removes the element at dimension `dim` from the key and returns a key with
|
||||
// only remaining dimensions.
|
||||
template <int dim>
|
||||
AttrKey<n - 1, NoSymmetry> RemoveElement() const {
|
||||
static_assert(dim >= 0);
|
||||
static_assert(dim < n);
|
||||
AttrKey<n - 1, NoSymmetry> result;
|
||||
for (int i = 0; i < dim; ++i) {
|
||||
result.element_ids_[i] = element_ids_[i];
|
||||
}
|
||||
for (int i = dim + 1; i < n; ++i) {
|
||||
result.element_ids_[i - 1] = element_ids_[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Adds element `elem` at dimension `dim` and returns the result.
|
||||
// The result must respect `NewSymmetry` (we `DCHECK` this).
|
||||
template <int dim, typename NewSymmetry>
|
||||
AttrKey<n + 1, NewSymmetry> AddElement(const value_type elem) const {
|
||||
static_assert(dim >= 0);
|
||||
static_assert(dim < n + 1);
|
||||
AttrKey<n + 1, NewSymmetry> result;
|
||||
for (int i = 0; i < dim; ++i) {
|
||||
result.element_ids_[i] = element_ids_[i];
|
||||
}
|
||||
result.element_ids_[dim] = elem;
|
||||
for (int i = dim + 1; i < n + 1; ++i) {
|
||||
result.element_ids_[i] = element_ids_[i - 1];
|
||||
}
|
||||
DCHECK(NewSymmetry::Validate(result.element_ids_))
|
||||
<< result << " does not have `" << NewSymmetry::GetName()
|
||||
<< "` symmetry";
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
template <int other_n, typename OtherSymmetry>
|
||||
friend class AttrKey;
|
||||
std::array<value_type, n> element_ids_;
|
||||
};
|
||||
|
||||
// CTAD for `AttrKey(1,2)`.
|
||||
template <typename... Ints>
|
||||
AttrKey(Ints... dims) -> AttrKey<sizeof...(Ints), NoSymmetry>;
|
||||
|
||||
// Traits to detect whether `T` is an `AttrKey`.
|
||||
template <typename T>
|
||||
struct is_attr_key : public std::false_type {};
|
||||
|
||||
template <int n, typename Symmetry>
|
||||
struct is_attr_key<AttrKey<n, Symmetry>> : public std::true_type {};
|
||||
|
||||
template <typename T>
|
||||
static constexpr inline bool is_attr_key_v = is_attr_key<T>::value;
|
||||
|
||||
// Required for open-source `StatusBuilder` support.
|
||||
template <int n, typename Symmetry>
|
||||
std::ostream& operator<<(std::ostream& ostr, const AttrKey<n, Symmetry>& key) {
|
||||
ostr << absl::StrCat(key);
|
||||
return ostr;
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
// A set of zero or one `AttrKey<0, Symmetry>, V`. This is used to make
|
||||
// implementations of `AttrDiff` and `AttrStorage` uniform.
|
||||
// `V` must by default constructible, trivially destructible and copyable
|
||||
// (we'll fail to compile otherwise).
|
||||
// After c++26, optional is a sequence container, so this can pretty much become
|
||||
// `std::optional<AttrKey<0, Symmetry>>` + `find()`.
|
||||
template <typename Symmetry, typename V>
|
||||
class AttrKey0RawSet {
|
||||
public:
|
||||
using value_type = V;
|
||||
using Key = AttrKey<0, Symmetry>;
|
||||
|
||||
template <typename ValueT>
|
||||
class IteratorImpl {
|
||||
public:
|
||||
IteratorImpl() = default;
|
||||
// `iterator` converts to `const_iterator`.
|
||||
IteratorImpl(const IteratorImpl<std::remove_cv_t<ValueT>>& other) // NOLINT
|
||||
: value_(other.value_) {}
|
||||
|
||||
// Dereference.
|
||||
ValueT& operator*() const {
|
||||
DCHECK_NE(value_, nullptr);
|
||||
return *value_;
|
||||
}
|
||||
ValueT* operator->() const {
|
||||
DCHECK_NE(value_, nullptr);
|
||||
return value_;
|
||||
}
|
||||
|
||||
// Increment.
|
||||
IteratorImpl& operator++() {
|
||||
DCHECK_NE(value_, nullptr);
|
||||
value_ = nullptr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Equality.
|
||||
friend bool operator==(const IteratorImpl& l, const IteratorImpl& r) {
|
||||
return l.value_ == r.value_;
|
||||
}
|
||||
friend bool operator!=(const IteratorImpl& l, const IteratorImpl& r) {
|
||||
return !(l == r);
|
||||
}
|
||||
|
||||
private:
|
||||
friend class AttrKey0RawSet;
|
||||
explicit IteratorImpl(ValueT& value) : value_(&value) {}
|
||||
|
||||
ValueT* value_ = nullptr;
|
||||
};
|
||||
|
||||
using iterator = IteratorImpl<value_type>;
|
||||
using const_iterator = IteratorImpl<const value_type>;
|
||||
|
||||
AttrKey0RawSet() = default;
|
||||
|
||||
bool empty() const { return !engaged_; }
|
||||
size_t size() const { return engaged_ ? 1 : 0; }
|
||||
|
||||
const_iterator begin() const {
|
||||
return engaged_ ? const_iterator(value_) : const_iterator();
|
||||
}
|
||||
const_iterator end() const { return const_iterator(); }
|
||||
iterator begin() { return engaged_ ? iterator(value_) : iterator(); }
|
||||
iterator end() { return iterator(); }
|
||||
|
||||
bool contains(Key) const { return engaged_; }
|
||||
const_iterator find(Key) const { return begin(); }
|
||||
iterator find(Key) { return begin(); }
|
||||
|
||||
void clear() { engaged_ = false; }
|
||||
size_t erase(Key) {
|
||||
if (engaged_) {
|
||||
engaged_ = false;
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
size_t erase(const_iterator) { return erase(Key()); }
|
||||
|
||||
template <typename... Args>
|
||||
std::pair<iterator, bool> try_emplace(Key, Args&&... args) {
|
||||
if (engaged_) {
|
||||
return std::make_pair(iterator(value_), false);
|
||||
}
|
||||
value_ = value_type(Key(), std::forward<Args>(args)...);
|
||||
engaged_ = true;
|
||||
return std::make_pair(iterator(value_), true);
|
||||
}
|
||||
|
||||
std::pair<iterator, bool> insert(const value_type& v) {
|
||||
if (engaged_) {
|
||||
return std::make_pair(iterator(value_), false);
|
||||
}
|
||||
value_ = v;
|
||||
engaged_ = true;
|
||||
return std::make_pair(iterator(value_), true);
|
||||
}
|
||||
|
||||
private:
|
||||
// The following greatly simplifies the implementation because we don't have
|
||||
// to worry about side effects of the dtor (see e.g. `clear()`).
|
||||
static_assert(std::is_trivially_destructible_v<value_type>);
|
||||
|
||||
bool engaged_ = false;
|
||||
value_type value_;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// A hash set of `AttrKeyT`, where `AttrKeyT` is an `AttrKey<n, Symmetry>`.
|
||||
template <typename AttrKeyT,
|
||||
typename = std::enable_if_t<is_attr_key_v<AttrKeyT>>>
|
||||
using AttrKeyHashSet = std::conditional_t<
|
||||
(AttrKeyT::size() > 0), absl::flat_hash_set<AttrKeyT>,
|
||||
detail::AttrKey0RawSet<typename AttrKeyT::SymmetryT, AttrKeyT>>;
|
||||
|
||||
// A hash map of `AttrKeyT` to `V`, where `AttrKeyT` is an
|
||||
// `AttrKey<n, Symmetry>`.
|
||||
template <typename AttrKeyT, typename V,
|
||||
typename = std::enable_if_t<is_attr_key_v<AttrKeyT>>>
|
||||
using AttrKeyHashMap =
|
||||
std::conditional_t<(AttrKeyT::size() > 0), absl::flat_hash_map<AttrKeyT, V>,
|
||||
detail::AttrKey0RawSet<typename AttrKeyT::SymmetryT,
|
||||
std::pair<AttrKeyT, V>>>;
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_KEY_H_
|
||||
335
ortools/math_opt/elemental/attr_key_test.cc
Normal file
335
ortools/math_opt/elemental/attr_key_test.cc
Normal file
@@ -0,0 +1,335 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/attr_key.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/hash/hash_testing.h"
|
||||
#include "absl/meta/type_traits.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "benchmark/benchmark.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/base/gmock.h"
|
||||
#include "ortools/math_opt/elemental/elements.h"
|
||||
#include "ortools/math_opt/elemental/symmetry.h"
|
||||
#include "ortools/math_opt/elemental/testing.h"
|
||||
#include "ortools/math_opt/testing/stream.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
namespace {
|
||||
using testing::ElementsAre;
|
||||
using testing::HasSubstr;
|
||||
using testing::IsEmpty;
|
||||
using testing::Pair;
|
||||
using testing::SizeIs;
|
||||
using testing::UnorderedElementsAre;
|
||||
using testing::status::IsOkAndHolds;
|
||||
using testing::status::StatusIs;
|
||||
|
||||
static_assert(sizeof(AttrKey<0>) <= sizeof(uint64_t));
|
||||
static_assert(sizeof(AttrKey<1>) == sizeof(uint64_t));
|
||||
static_assert(sizeof(AttrKey<2>) == 2 * sizeof(uint64_t));
|
||||
static_assert(sizeof(AttrKey<2, ElementSymmetry<0, 1>>) ==
|
||||
2 * sizeof(uint64_t));
|
||||
|
||||
// Make sure that passing AttrKey by value really puts it in registers rather
|
||||
// than leaving it in the caller's frame (see
|
||||
// https://itanium-cxx-abi.github.io/cxx-abi/abi.html#non-trivial).
|
||||
static_assert(absl::is_trivially_relocatable<AttrKey<0>>());
|
||||
static_assert(absl::is_trivially_relocatable<AttrKey<1>>());
|
||||
static_assert(absl::is_trivially_relocatable<AttrKey<2>>());
|
||||
static_assert(
|
||||
absl::is_trivially_relocatable<AttrKey<2, ElementSymmetry<0, 1>>>());
|
||||
|
||||
TEST(AttrKeyTest, CtorAndIteration) {
|
||||
EXPECT_THAT(AttrKey(), ElementsAre());
|
||||
EXPECT_THAT(AttrKey(1), ElementsAre(1));
|
||||
EXPECT_THAT(AttrKey(1, 2), ElementsAre(1, 2));
|
||||
}
|
||||
|
||||
TEST(AttrKeyTest, ElementIdCtor) {
|
||||
EXPECT_THAT(AttrKey(ElementId<ElementType::kVariable>(1)), ElementsAre(1));
|
||||
EXPECT_THAT(AttrKey(ElementId<ElementType::kVariable>(1),
|
||||
ElementId<ElementType::kLinearConstraint>(2)),
|
||||
ElementsAre(1, 2));
|
||||
}
|
||||
|
||||
TEST(AttrKeyTest, ElementAccess) {
|
||||
const AttrKey key(1, 2);
|
||||
EXPECT_EQ(key[0], 1);
|
||||
EXPECT_EQ(key[1], 2);
|
||||
|
||||
AttrKey mutable_key(1, 2);
|
||||
EXPECT_EQ(mutable_key[0], 1);
|
||||
EXPECT_EQ(mutable_key[1], 2);
|
||||
}
|
||||
|
||||
TEST(AttrKeyTest, SupportsAbslHash1) {
|
||||
EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({
|
||||
AttrKey(1),
|
||||
AttrKey(2),
|
||||
AttrKey(0),
|
||||
}));
|
||||
}
|
||||
|
||||
TEST(AttrKeyTest, SupportsAbslHash2) {
|
||||
EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({
|
||||
AttrKey(1, 2),
|
||||
AttrKey(2, 3),
|
||||
AttrKey(0, 0),
|
||||
}));
|
||||
}
|
||||
|
||||
TEST(AttrKeyTest, Stringify) {
|
||||
EXPECT_EQ(absl::StrCat(AttrKey(1, 2, 3)), "AttrKey(1, 2, 3)");
|
||||
EXPECT_EQ(StreamToString(AttrKey(1, 2, 3)), "AttrKey(1, 2, 3)");
|
||||
}
|
||||
|
||||
TEST(AttrKeyTest, AddRemove) {
|
||||
const AttrKey key0;
|
||||
EXPECT_THAT(key0, ElementsAre());
|
||||
const AttrKey key1 = key0.AddElement<0, NoSymmetry>(3);
|
||||
EXPECT_THAT(key1, ElementsAre(3));
|
||||
const AttrKey key2 = key1.AddElement<0, NoSymmetry>(1);
|
||||
EXPECT_THAT(key2, ElementsAre(1, 3));
|
||||
const AttrKey key3 = key2.AddElement<1, NoSymmetry>(2);
|
||||
EXPECT_THAT(key3, ElementsAre(1, 2, 3));
|
||||
const AttrKey key4 = key3.AddElement<3, NoSymmetry>(4);
|
||||
EXPECT_THAT(key4, ElementsAre(1, 2, 3, 4));
|
||||
}
|
||||
|
||||
TEST(AttrKeyTest, AddRemoveNotSymmetric) {
|
||||
using NoSym = NoSymmetry;
|
||||
EXPECT_THAT((AttrKey(0, 2).AddElement<1, NoSym>(1)), ElementsAre(0, 1, 2));
|
||||
EXPECT_THAT((AttrKey(0, 1).AddElement<2, NoSym>(2)), ElementsAre(0, 1, 2));
|
||||
EXPECT_THAT((AttrKey(0, 1).AddElement<1, NoSym>(2)), ElementsAre(0, 2, 1));
|
||||
EXPECT_THAT((AttrKey(0, 2).AddElement<2, NoSym>(1)), ElementsAre(0, 2, 1));
|
||||
}
|
||||
|
||||
TEST(AttrKeyDeathTest, AddRemoveSymmetric) {
|
||||
using Sym01 = ElementSymmetry<1, 2>;
|
||||
EXPECT_THAT((AttrKey(0, 2).AddElement<1, Sym01>(1)), ElementsAre(0, 1, 2));
|
||||
EXPECT_THAT((AttrKey(0, 1).AddElement<2, Sym01>(2)), ElementsAre(0, 1, 2));
|
||||
#ifndef NDEBUG
|
||||
EXPECT_DEATH(
|
||||
(AttrKey(0, 1).AddElement<1, Sym01>(2)),
|
||||
HasSubstr(
|
||||
"AttrKey(0, 2, 1) does not have `ElementSymmetry<1, 2>` symmetry"));
|
||||
EXPECT_DEATH(
|
||||
(AttrKey(0, 2).AddElement<2, Sym01>(1)),
|
||||
HasSubstr(
|
||||
"AttrKey(0, 2, 1) does not have `ElementSymmetry<1, 2>` symmetry"));
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(AttrKeyTest, ComparisonOperators) {
|
||||
// a[0] < a[1] < a[2] < a[3] < a[4]
|
||||
const std::vector<AttrKey<4>> a = {AttrKey(1, 0, 0, 0), AttrKey(2, 5, 1, 12),
|
||||
AttrKey(2, 5, 3, 10), AttrKey(2, 5, 3, 11),
|
||||
AttrKey(3, 0, 0, 0)};
|
||||
|
||||
// Now test each of the operators
|
||||
for (int i = 0; i < a.size(); ++i) {
|
||||
SCOPED_TRACE(absl::StrCat(i));
|
||||
for (int j = 0; j < a.size(); ++j) {
|
||||
SCOPED_TRACE(absl::StrCat(j));
|
||||
if (i == j) {
|
||||
EXPECT_FALSE(a[i] < a[j]);
|
||||
EXPECT_TRUE(a[i] <= a[j]);
|
||||
EXPECT_TRUE(a[i] == a[j]);
|
||||
EXPECT_TRUE(a[i] >= a[j]);
|
||||
EXPECT_FALSE(a[i] > a[j]);
|
||||
} else if (i < j) {
|
||||
EXPECT_TRUE(a[i] < a[j]);
|
||||
EXPECT_TRUE(a[i] <= a[j]);
|
||||
EXPECT_FALSE(a[i] == a[j]);
|
||||
EXPECT_FALSE(a[i] >= a[j]);
|
||||
EXPECT_FALSE(a[i] > a[j]);
|
||||
} else {
|
||||
EXPECT_FALSE(a[i] < a[j]);
|
||||
EXPECT_FALSE(a[i] <= a[j]);
|
||||
EXPECT_FALSE(a[i] == a[j]);
|
||||
EXPECT_TRUE(a[i] >= a[j]);
|
||||
EXPECT_TRUE(a[i] > a[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(AttrKey0SetTest, Works) {
|
||||
AttrKeyHashSet<AttrKey<0>> set;
|
||||
|
||||
EXPECT_THAT(set, IsEmpty());
|
||||
EXPECT_THAT(set, SizeIs(0));
|
||||
EXPECT_THAT(set, UnorderedElementsAre());
|
||||
EXPECT_FALSE(set.contains(AttrKey()));
|
||||
EXPECT_TRUE(set.find(AttrKey()) == set.end());
|
||||
EXPECT_EQ(set.erase(AttrKey()), 0);
|
||||
|
||||
set.insert(AttrKey());
|
||||
|
||||
EXPECT_THAT(set, Not(IsEmpty()));
|
||||
EXPECT_THAT(set, SizeIs(1));
|
||||
EXPECT_THAT(set, UnorderedElementsAre(AttrKey()));
|
||||
EXPECT_TRUE(set.contains(AttrKey()));
|
||||
EXPECT_TRUE(set.find(AttrKey()) == set.begin());
|
||||
EXPECT_EQ(set.erase(AttrKey()), 1);
|
||||
EXPECT_THAT(set, IsEmpty());
|
||||
|
||||
set.insert(AttrKey());
|
||||
set.clear();
|
||||
EXPECT_THAT(set, IsEmpty());
|
||||
|
||||
set.insert(AttrKey());
|
||||
set.erase(AttrKey());
|
||||
EXPECT_THAT(set, IsEmpty());
|
||||
}
|
||||
|
||||
TEST(AttrKey0MapTest, Works) {
|
||||
AttrKeyHashMap<AttrKey<0>, int> map;
|
||||
|
||||
EXPECT_THAT(map, IsEmpty());
|
||||
EXPECT_THAT(map, SizeIs(0));
|
||||
EXPECT_THAT(map, UnorderedElementsAre());
|
||||
EXPECT_FALSE(map.contains(AttrKey()));
|
||||
EXPECT_TRUE(map.find(AttrKey()) == map.end());
|
||||
EXPECT_EQ(map.erase(AttrKey()), 0);
|
||||
|
||||
map.try_emplace(AttrKey(), 42);
|
||||
|
||||
EXPECT_THAT(map, Not(IsEmpty()));
|
||||
EXPECT_THAT(map, SizeIs(1));
|
||||
EXPECT_THAT(map, UnorderedElementsAre(Pair(AttrKey(), 42)));
|
||||
EXPECT_EQ(map.begin()->first, AttrKey());
|
||||
EXPECT_EQ(map.begin()->second, 42);
|
||||
EXPECT_TRUE(map.contains(AttrKey()));
|
||||
EXPECT_TRUE(map.find(AttrKey()) == map.begin());
|
||||
EXPECT_EQ(map.erase(AttrKey()), 1);
|
||||
EXPECT_THAT(map, IsEmpty());
|
||||
|
||||
map.insert({AttrKey(), 43});
|
||||
map.clear();
|
||||
EXPECT_THAT(map, IsEmpty());
|
||||
|
||||
map.try_emplace(AttrKey(), 43);
|
||||
map.erase(AttrKey());
|
||||
EXPECT_THAT(map, IsEmpty());
|
||||
|
||||
map.try_emplace(AttrKey(), 43);
|
||||
map.erase(map.begin());
|
||||
EXPECT_THAT(map, IsEmpty());
|
||||
}
|
||||
|
||||
TEST(AttrKeyTest, FromRange) {
|
||||
EXPECT_THAT((AttrKey<0>::FromRange({})), IsOkAndHolds(AttrKey()));
|
||||
EXPECT_THAT((AttrKey<1>::FromRange({1})), IsOkAndHolds(AttrKey(1)));
|
||||
EXPECT_THAT((AttrKey<2>::FromRange({1, 2})), IsOkAndHolds(AttrKey(1, 2)));
|
||||
|
||||
EXPECT_THAT((AttrKey<0>::FromRange({1})),
|
||||
StatusIs(absl::StatusCode::kInvalidArgument));
|
||||
EXPECT_THAT((AttrKey<1>::FromRange({})),
|
||||
StatusIs(absl::StatusCode::kInvalidArgument));
|
||||
EXPECT_THAT((AttrKey<2>::FromRange({1})),
|
||||
StatusIs(absl::StatusCode::kInvalidArgument));
|
||||
}
|
||||
|
||||
TEST(AttrKeyTest, FromRangeSymmetric) {
|
||||
using Key = AttrKey<3, ElementSymmetry<1, 2>>;
|
||||
EXPECT_THAT((Key::FromRange({0, 1, 2})), IsOkAndHolds(Key(0, 1, 2)));
|
||||
EXPECT_THAT((Key::FromRange({0, 2, 1})), IsOkAndHolds(Key(0, 1, 2)));
|
||||
EXPECT_THAT((Key::FromRange({3, 1, 2})), IsOkAndHolds(Key(3, 1, 2)));
|
||||
EXPECT_THAT((Key::FromRange({3, 2, 1})), IsOkAndHolds(Key(3, 1, 2)));
|
||||
}
|
||||
|
||||
TEST(AttrKeyTest, IsAttrKey) {
|
||||
EXPECT_TRUE(is_attr_key_v<AttrKey<0>>);
|
||||
EXPECT_TRUE(is_attr_key_v<AttrKey<1>>);
|
||||
EXPECT_FALSE(is_attr_key_v<int>);
|
||||
}
|
||||
|
||||
constexpr int kBenchmarkSize = 30;
|
||||
|
||||
template <typename SetT>
|
||||
void BM_HashSet0(benchmark::State& state) {
|
||||
SetT set;
|
||||
for (const auto s : state) {
|
||||
auto it = set.find(AttrKey());
|
||||
benchmark::DoNotOptimize(it);
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_HashSet0<AttrKeyHashSet<AttrKey<0>>>);
|
||||
BENCHMARK(BM_HashSet0<absl::flat_hash_set<AttrKey<0>>>);
|
||||
|
||||
template <typename T>
|
||||
void BM_HashMap1(benchmark::State& state) {
|
||||
absl::flat_hash_map<T, int> map;
|
||||
for (int i = 0; i < kBenchmarkSize * kBenchmarkSize; ++i) {
|
||||
if (i % 2 > 0) { // Half of the lookups are hits.
|
||||
map[T(i)] = i;
|
||||
}
|
||||
}
|
||||
for (const auto s : state) {
|
||||
for (int i = 0; i < kBenchmarkSize * kBenchmarkSize; ++i) {
|
||||
auto it = map.find(T(i));
|
||||
benchmark::DoNotOptimize(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_HashMap1<AttrKey<1>>);
|
||||
BENCHMARK(BM_HashMap1<int64_t>);
|
||||
|
||||
template <typename T>
|
||||
void BM_HashMap2(benchmark::State& state) {
|
||||
absl::flat_hash_map<T, int> map;
|
||||
for (int i = 0; i < kBenchmarkSize; ++i) {
|
||||
for (int j = 0; j < kBenchmarkSize; ++j) {
|
||||
if ((i * kBenchmarkSize + j) % 2 > 0) { // Half of the lookups are hits.
|
||||
map[T(i, j)] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const auto s : state) {
|
||||
for (int i = 0; i < kBenchmarkSize; ++i) {
|
||||
for (int j = 0; j < kBenchmarkSize; ++j) {
|
||||
auto it = map.find(T(i, j));
|
||||
benchmark::DoNotOptimize(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_HashMap2<AttrKey<2>>);
|
||||
BENCHMARK(BM_HashMap2<std::pair<int64_t, int64_t>>);
|
||||
|
||||
template <int n>
|
||||
void BM_SortAttrKeys(benchmark::State& state) {
|
||||
const std::vector<AttrKey<n>> keys =
|
||||
MakeRandomAttrKeys<n, NoSymmetry>(state.range(0), state.range(0));
|
||||
|
||||
for (const auto s : state) {
|
||||
auto copy = keys;
|
||||
absl::c_sort(copy);
|
||||
benchmark::DoNotOptimize(copy);
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_SortAttrKeys<1>)->Arg(100)->Arg(10000);
|
||||
BENCHMARK(BM_SortAttrKeys<2>)->Arg(100)->Arg(10000);
|
||||
|
||||
} // namespace
|
||||
} // namespace operations_research::math_opt
|
||||
429
ortools/math_opt/elemental/attr_storage.h
Normal file
429
ortools/math_opt/elemental/attr_storage.h
Normal file
@@ -0,0 +1,429 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_STORAGE_H_
|
||||
#define OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_STORAGE_H_
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "ortools/base/map_util.h"
|
||||
#include "ortools/math_opt/elemental/attr_key.h"
|
||||
#include "ortools/math_opt/elemental/symmetry.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// A non-default key set based on a vector. This is very efficient for
|
||||
// insertions, reads, and slicing, but does not support deletions.
|
||||
template <int n>
|
||||
class DenseKeySet {
|
||||
public:
|
||||
// {Dense,Sparse}KeySet stores symmetric keys, symmetry is handled by
|
||||
// `SlicingStorage`.
|
||||
using Key = AttrKey<n, NoSymmetry>;
|
||||
|
||||
DenseKeySet() = default;
|
||||
|
||||
size_t size() const { return key_set_.size(); }
|
||||
|
||||
template <typename F>
|
||||
// requires std::invocable<F, const Key&>
|
||||
void ForEach(F f) const {
|
||||
for (const Key& key : key_set_) {
|
||||
f(key);
|
||||
}
|
||||
}
|
||||
|
||||
// Note: this does not check for duplicates. This is fine because inserting
|
||||
// into this map is gated on inserting into the AttrStorage, which does check
|
||||
// for duplicates.
|
||||
void Insert(const Key& key) { key_set_.push_back(key); }
|
||||
|
||||
auto begin() const { return key_set_.begin(); }
|
||||
auto end() const { return key_set_.end(); }
|
||||
|
||||
private:
|
||||
std::vector<Key> key_set_;
|
||||
};
|
||||
|
||||
// A non-default key set based on a hash set. Simple, but requires a hash lookup
|
||||
// for each insertion and deletion.
|
||||
template <int n>
|
||||
class SparseKeySet {
|
||||
public:
|
||||
// {Dense,Sparse}KeySet stores symmetric keys, symmetry is handled by
|
||||
// `SlicingStorage`.
|
||||
using Key = AttrKey<n, NoSymmetry>;
|
||||
|
||||
explicit SparseKeySet(const DenseKeySet<n>& dense_set)
|
||||
: key_set_(dense_set.begin(), dense_set.end()) {}
|
||||
|
||||
size_t size() const { return key_set_.size(); }
|
||||
|
||||
template <typename F>
|
||||
// requires std::invocable<F, const Key&>
|
||||
void ForEach(F f) const {
|
||||
for (const Key& key : key_set_) {
|
||||
f(key);
|
||||
}
|
||||
}
|
||||
|
||||
void Erase(const Key& key) { key_set_.erase(key); }
|
||||
void Insert(const Key& key) { key_set_.insert(key); }
|
||||
|
||||
private:
|
||||
absl::flat_hash_set<Key> key_set_;
|
||||
};
|
||||
|
||||
// A non-default key set that switches between implementations
|
||||
// opportunistically: It starts dense, and switches to sparse if there are
|
||||
// deletions.
|
||||
template <int n>
|
||||
class KeySet {
|
||||
public:
|
||||
using Key = AttrKey<n, NoSymmetry>;
|
||||
|
||||
size_t size() const {
|
||||
return std::visit([](const auto& impl) { return impl.size(); }, impl_);
|
||||
}
|
||||
|
||||
// We can't do begin/end because the iterator types are not the same.
|
||||
template <typename F>
|
||||
// requires std::invocable<F, const Key&>
|
||||
void ForEach(F f) const {
|
||||
return std::visit(
|
||||
[f = std::move(f)](const auto& impl) {
|
||||
return impl.ForEach(std::move(f));
|
||||
},
|
||||
impl_);
|
||||
}
|
||||
|
||||
auto Erase(const Key& key) { return AsSparse().Erase(key); }
|
||||
|
||||
void Insert(const Key& key) {
|
||||
std::visit([&](auto& impl) { impl.Insert(key); }, impl_);
|
||||
}
|
||||
|
||||
private:
|
||||
SparseKeySet<n>& AsSparse() {
|
||||
if (auto* sparse = std::get_if<SparseKeySet<n>>(&impl_)) {
|
||||
return *sparse;
|
||||
}
|
||||
// Switch to a sparse representation.
|
||||
impl_ = SparseKeySet<n>(std::get<DenseKeySet<n>>(impl_));
|
||||
return std::get<SparseKeySet<n>>(impl_);
|
||||
}
|
||||
|
||||
std::variant<DenseKeySet<n>, SparseKeySet<n>> impl_;
|
||||
};
|
||||
|
||||
// When we have two or more dimensions, we need to store the nondefaults for
|
||||
// each dimension to support slicing.
|
||||
template <int n, typename Symmetry, typename = void>
|
||||
class SlicingSupport {
|
||||
public:
|
||||
using Key = AttrKey<n, Symmetry>;
|
||||
|
||||
void AddRowsAndColumns(const Key key) {
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
ForEachDimension([this, key]<int i>() {
|
||||
if (MustInsertNondefault<i>(key, Symmetry{})) {
|
||||
key_nondefaults_[i][key[i]].Insert(key.template RemoveElement<i>());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Requires key is currently stored with a non-default value.
|
||||
void ClearRowsAndColumns(Key key) {
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
ForEachDimension([this, key]<int i>() {
|
||||
const auto& key_elem = key[i];
|
||||
auto& nondefaults = key_nondefaults_[i];
|
||||
if (nondefaults[key_elem].size() == 1) {
|
||||
nondefaults.erase(key_elem);
|
||||
} else {
|
||||
nondefaults[key_elem].Erase(key.template RemoveElement<i>());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void Clear() {
|
||||
for (auto& key_nondefaults : key_nondefaults_) {
|
||||
key_nondefaults.clear();
|
||||
}
|
||||
}
|
||||
|
||||
template <int i>
|
||||
std::vector<Key> Slice(const int64_t key_elem) const {
|
||||
return SliceImpl<i>(
|
||||
key_elem, Symmetry{},
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
[key_elem]<int... is>(KeySetExpansion<is>... expansions) {
|
||||
std::vector<Key> slice((expansions.key_set.size() + ...));
|
||||
Key* out = slice.data();
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
const auto append = [key_elem, &out]<int j>(
|
||||
const KeySetExpansion<j>& expansion) {
|
||||
expansion.key_set.ForEach(
|
||||
[key_elem, &out](const AttrKey<n - 1> other_elems) {
|
||||
*out = other_elems.template AddElement<j, Symmetry>(key_elem);
|
||||
++out;
|
||||
});
|
||||
};
|
||||
(append(expansions), ...);
|
||||
return slice;
|
||||
});
|
||||
}
|
||||
|
||||
template <int i>
|
||||
int64_t GetSliceSize(const int64_t key_elem) const {
|
||||
return SliceImpl<i>(key_elem, Symmetry{}, [](const auto... expansions) {
|
||||
return (expansions.key_set.size() + ...);
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
// We store the nondefaults for a given id along a given dimension as a set of
|
||||
// `AttrKey<n-1, NoSymmetry>` (the current dimension is not stored).
|
||||
using NonDefaultKeySet = KeySet<n - 1>;
|
||||
// For each dimension, we store the nondefaults for each id.
|
||||
using NonDefaultKeySetById = absl::flat_hash_map<int64_t, NonDefaultKeySet>;
|
||||
// We need one such set per dimension.
|
||||
using NonDefaultsPerDimension = std::array<NonDefaultKeySetById, n>;
|
||||
|
||||
// Represents a NonDefaultKeySet to be expanded by inserting an element on
|
||||
// dimension `i`.
|
||||
template <int i>
|
||||
struct KeySetExpansion {
|
||||
KeySetExpansion(const NonDefaultsPerDimension& key_nondefaults,
|
||||
int64_t key_elem)
|
||||
: key_set(gtl::FindWithDefault(key_nondefaults[i], key_elem)) {}
|
||||
const NonDefaultKeySet& key_set;
|
||||
};
|
||||
|
||||
// A helper function that calls `F` `i` times with template arguments `n-1` to
|
||||
// `0`.
|
||||
template <typename F, int i = n - 1>
|
||||
static void ForEachDimension(const F& f) {
|
||||
f.template operator()<i>();
|
||||
if constexpr (i > 0) {
|
||||
ForEachDimension<F, i - 1>(f);
|
||||
}
|
||||
}
|
||||
|
||||
template <int i>
|
||||
static bool MustInsertNondefault(const Key&, NoSymmetry) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <int i, int k, int l>
|
||||
static bool MustInsertNondefault(const Key& key, ElementSymmetry<k, l>) {
|
||||
// For attributes that are symmetric on `k` and `l`, elements on the
|
||||
// diagonal need to be in only one of the nondefaults for `k` or `l`
|
||||
// (otherwise they would be counted twice in `Slice()`). We arbitrarily pick
|
||||
// `k`.
|
||||
if constexpr (i == l) {
|
||||
const bool is_diagonal = key[k] == key[l];
|
||||
return !is_diagonal;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// `Fn` should be a functor that takes any number of `KeySetExpansion`
|
||||
// arguments.
|
||||
template <int i, typename Fn>
|
||||
auto SliceImpl(const int64_t key_elem, NoSymmetry, const Fn& fn) const {
|
||||
static_assert(n > 1);
|
||||
return fn(KeySetExpansion<i>(key_nondefaults_, key_elem));
|
||||
}
|
||||
|
||||
template <int i, int k, int l, typename Fn>
|
||||
auto SliceImpl(const int64_t key_elem, ElementSymmetry<k, l>,
|
||||
const Fn& fn) const {
|
||||
static_assert(n > 1);
|
||||
if constexpr (i != k && i != l) {
|
||||
// This is a normal dimension, not a symmetric one.
|
||||
return SliceImpl<i>(key_elem, NoSymmetry(), fn);
|
||||
} else {
|
||||
// For symmetric dimensions, we need to look up the keys on both
|
||||
// dimensions `l` and `k`.
|
||||
return fn(KeySetExpansion<k>(key_nondefaults_, key_elem),
|
||||
KeySetExpansion<l>(key_nondefaults_, key_elem));
|
||||
}
|
||||
}
|
||||
|
||||
NonDefaultsPerDimension key_nondefaults_;
|
||||
};
|
||||
|
||||
// Without slicing we don't need to track anything.
|
||||
template <int n, typename Symmetry>
|
||||
struct SlicingSupport<n, Symmetry, std::enable_if_t<(n < 2), std::void_t<>>> {
|
||||
using Key = AttrKey<n, Symmetry>;
|
||||
|
||||
void AddRowsAndColumns(Key) {}
|
||||
void ClearRowsAndColumns(Key) {}
|
||||
void Clear() {}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Stores the value of an attribute keyed on n elements (e.g.
|
||||
// linear_constraint_coefficient is a double valued attribute keyed first on
|
||||
// LinearConstraint and then on Variable).
|
||||
//
|
||||
// Memory usage:
|
||||
// Storing `k` elements with non-default values in a `AttrStorage<V, n>` uses
|
||||
// `sizeof(V) * (n^2 + 1) * k / load_factor` (where load_factor is the absl
|
||||
// hash map load factor, typically 0.8), plus a small allocation overhead of
|
||||
// `O(k)`.
|
||||
template <typename V, int n, typename Symmetry>
|
||||
class AttrStorage {
|
||||
public:
|
||||
using Key = AttrKey<n, Symmetry>;
|
||||
// If this no longer holds, we should sprinkle the code with `move`s and
|
||||
// return `V`s by ref.
|
||||
static_assert(std::is_trivially_copyable_v<V>);
|
||||
|
||||
// Generally avoid, provided to make working with std::array easier.
|
||||
explicit AttrStorage() : AttrStorage({}) {}
|
||||
|
||||
// The default value of the attribute is its value when the model is created
|
||||
// (e.g. for linear_constraint_coefficient, 0.0).
|
||||
explicit AttrStorage(const V default_value) : default_value_(default_value) {}
|
||||
|
||||
AttrStorage(const AttrStorage&) = default;
|
||||
AttrStorage& operator=(const AttrStorage&) = default;
|
||||
AttrStorage(AttrStorage&&) = default;
|
||||
AttrStorage& operator=(AttrStorage&&) = default;
|
||||
|
||||
// Returns true if the attribute for `key` has a value different from its
|
||||
// default.
|
||||
bool IsNonDefault(const Key key) const {
|
||||
return non_default_values_.contains(key);
|
||||
}
|
||||
|
||||
// Returns the previous value if value has changed, otherwise returns
|
||||
// `std::nullopt`.
|
||||
std::optional<V> Set(const Key key, const V value) {
|
||||
bool is_default = value == default_value_;
|
||||
if (is_default) {
|
||||
const auto it = non_default_values_.find(key);
|
||||
if (it == non_default_values_.end()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
const V prev_value = it->second;
|
||||
non_default_values_.erase(it);
|
||||
slicing_support_.ClearRowsAndColumns(key);
|
||||
return prev_value;
|
||||
}
|
||||
const auto [it, inserted] = non_default_values_.try_emplace(key, value);
|
||||
if (inserted) {
|
||||
slicing_support_.AddRowsAndColumns(key);
|
||||
return default_value_;
|
||||
}
|
||||
// !is_default and !inserted
|
||||
if (value == it->second) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return std::exchange(it->second, value);
|
||||
}
|
||||
|
||||
// Returns the value of the attribute for `key` (return the default value if
|
||||
// the attribute value for `key` is unset).
|
||||
V Get(const Key key) const {
|
||||
return GetIfNonDefault(key).value_or(default_value_);
|
||||
}
|
||||
|
||||
// Returns the value of the attribute for `key`, or nullopt.
|
||||
std::optional<V> GetIfNonDefault(const Key key) const {
|
||||
auto it = non_default_values_.find(key);
|
||||
if (it == non_default_values_.end()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Sets the value of the attribute for `key` to the default value.
|
||||
void Erase(const Key key) {
|
||||
if (non_default_values_.erase(key)) {
|
||||
slicing_support_.ClearRowsAndColumns(key);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the keys (ids pairs) the of the elements with a non-default value
|
||||
// for this attribute.
|
||||
std::vector<Key> NonDefaults() const {
|
||||
std::vector<Key> result;
|
||||
result.reserve(non_default_values_.size());
|
||||
for (const auto& [key, unused] : non_default_values_) {
|
||||
result.push_back(key);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the set of all keys `K` such that:
|
||||
// - There exists `k_{0}..k_{n-1}` such that
|
||||
// `K == AttrKey(k_{0}, ..., k_{i-1}, key_elem, k_{i+1}, ..., k_{n-1})`, and
|
||||
// - `K` has a non-default value for this attribute.
|
||||
template <int i>
|
||||
std::vector<Key> Slice(const int64_t key_elem) const {
|
||||
static_assert(n >= 1);
|
||||
if constexpr (n == 1) {
|
||||
return non_default_values_.contains(Key(key_elem))
|
||||
? std::vector<Key>({Key(key_elem)})
|
||||
: std::vector<Key>();
|
||||
} else {
|
||||
return slicing_support_.template Slice<i>(key_elem);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the size of the given slice: This is equivalent to
|
||||
// `Slice(key_elem).size()`, but `O(1)`.
|
||||
template <int i>
|
||||
int64_t GetSliceSize(const int64_t key_elem) const {
|
||||
static_assert(n >= 1);
|
||||
if constexpr (n == 1) {
|
||||
return non_default_values_.count(Key(key_elem));
|
||||
} else {
|
||||
return slicing_support_.template GetSliceSize<i>(key_elem);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the number of keys (element pairs) with non-default values for this
|
||||
// attribute.
|
||||
int64_t num_non_defaults() const { return non_default_values_.size(); }
|
||||
|
||||
// Restore all elements to their default value for this attribute.
|
||||
void Clear() {
|
||||
non_default_values_.clear();
|
||||
slicing_support_.Clear();
|
||||
}
|
||||
|
||||
private:
|
||||
V default_value_;
|
||||
AttrKeyHashMap<Key, V> non_default_values_;
|
||||
detail::SlicingSupport<n, Symmetry> slicing_support_;
|
||||
};
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ATTR_STORAGE_H_
|
||||
574
ortools/math_opt/elemental/attr_storage_test.cc
Normal file
574
ortools/math_opt/elemental/attr_storage_test.cc
Normal file
@@ -0,0 +1,574 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/attr_storage.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "benchmark/benchmark.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/base/gmock.h"
|
||||
#include "ortools/math_opt/elemental/attr_key.h"
|
||||
#include "ortools/math_opt/elemental/symmetry.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
namespace {
|
||||
|
||||
using ::testing::IsEmpty;
|
||||
using ::testing::Optional;
|
||||
using ::testing::UnorderedElementsAre;
|
||||
|
||||
TEST(Attr0StorageTest, EmptyGetters) {
|
||||
const AttrStorage<double, 0, NoSymmetry> attr_storage(1.0);
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey()), 1.0);
|
||||
EXPECT_FALSE(attr_storage.IsNonDefault(AttrKey()));
|
||||
}
|
||||
|
||||
TEST(Attr0StorageTest, SetDefaultToDefault) {
|
||||
AttrStorage<double, 0, NoSymmetry> attr_storage(1.0);
|
||||
|
||||
EXPECT_FALSE(attr_storage.Set(AttrKey(), 1.0));
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey()), 1.0);
|
||||
EXPECT_FALSE(attr_storage.IsNonDefault(AttrKey()));
|
||||
}
|
||||
|
||||
TEST(Attr0StorageTest, SetDefaultToNonDefault) {
|
||||
AttrStorage<double, 0, NoSymmetry> attr_storage(1.0);
|
||||
|
||||
EXPECT_TRUE(attr_storage.Set(AttrKey(), 10.0));
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey()), 10.0);
|
||||
EXPECT_TRUE(attr_storage.IsNonDefault(AttrKey()));
|
||||
}
|
||||
|
||||
TEST(Attr0StorageTest, SetNonDefaultToDefault) {
|
||||
AttrStorage<double, 0, NoSymmetry> attr_storage(1.0);
|
||||
EXPECT_TRUE(attr_storage.Set(AttrKey(), 10.0));
|
||||
|
||||
EXPECT_TRUE(attr_storage.Set(AttrKey(), 1.0));
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey()), 1.0);
|
||||
EXPECT_FALSE(attr_storage.IsNonDefault(AttrKey()));
|
||||
}
|
||||
|
||||
TEST(Attr0StorageTest, SetNonDefaultToNonDefaultDifferent) {
|
||||
AttrStorage<double, 0, NoSymmetry> attr_storage(1.0);
|
||||
attr_storage.Set(AttrKey(), 10.0);
|
||||
|
||||
EXPECT_TRUE(attr_storage.Set(AttrKey(), 20.0));
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey()), 20.0);
|
||||
EXPECT_TRUE(attr_storage.IsNonDefault(AttrKey()));
|
||||
}
|
||||
|
||||
TEST(Attr0StorageTest, SetNonDefaultToNonDefaultSame) {
|
||||
AttrStorage<double, 0, NoSymmetry> attr_storage(1.0);
|
||||
attr_storage.Set(AttrKey(), 10.0);
|
||||
|
||||
EXPECT_FALSE(attr_storage.Set(AttrKey(), 10.0));
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey()), 10.0);
|
||||
EXPECT_TRUE(attr_storage.IsNonDefault(AttrKey()));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Attr1Storage
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Attr1StorageTest, EmptyGetters) {
|
||||
const AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(0)), 1.0);
|
||||
EXPECT_FALSE(attr_storage.IsNonDefault(AttrKey(0)));
|
||||
EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty());
|
||||
EXPECT_EQ(attr_storage.num_non_defaults(), 0);
|
||||
EXPECT_THAT(attr_storage.Slice<0>(0), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(Attr1StorageTest, GettersNonEmpty) {
|
||||
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
|
||||
attr_storage.Set(AttrKey(2), 10.0);
|
||||
attr_storage.Set(AttrKey(3), 11.0);
|
||||
attr_storage.Set(AttrKey(5), 12.0);
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 10.0);
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(3)), 11.0);
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(4)), 1.0);
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(5)), 12.0);
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(6)), 1.0);
|
||||
|
||||
EXPECT_THAT(attr_storage.NonDefaults(),
|
||||
UnorderedElementsAre(AttrKey(2), AttrKey(3), AttrKey(5)));
|
||||
EXPECT_EQ(attr_storage.num_non_defaults(), 3);
|
||||
EXPECT_THAT(attr_storage.Slice<0>(3), UnorderedElementsAre(AttrKey(3)));
|
||||
}
|
||||
|
||||
TEST(Attr1StorageTest, SetDefaultToDefault) {
|
||||
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
|
||||
|
||||
EXPECT_FALSE(attr_storage.Set(AttrKey(2), 1.0));
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 1.0);
|
||||
EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(Attr1StorageTest, SetDefaultToNonDefault) {
|
||||
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
|
||||
|
||||
EXPECT_TRUE(attr_storage.Set(AttrKey(2), 10.0));
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 10.0);
|
||||
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(2)));
|
||||
}
|
||||
|
||||
TEST(Attr1StorageTest, SetNonDefaultToDefault) {
|
||||
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
|
||||
attr_storage.Set(AttrKey(2), 10.0);
|
||||
|
||||
EXPECT_TRUE(attr_storage.Set(AttrKey(2), 1.0));
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 1.0);
|
||||
EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(Attr1StorageTest, SetNonDefaultToNonDefaultDifferent) {
|
||||
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
|
||||
attr_storage.Set(AttrKey(2), 5.0);
|
||||
|
||||
EXPECT_TRUE(attr_storage.Set(AttrKey(2), 10.0));
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 10.0);
|
||||
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(2)));
|
||||
}
|
||||
|
||||
TEST(Attr1StorageTest, SetNonDefaultToNonDefaultSame) {
|
||||
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
|
||||
attr_storage.Set(AttrKey(2), 10.0);
|
||||
|
||||
EXPECT_FALSE(attr_storage.Set(AttrKey(2), 10.0));
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 10.0);
|
||||
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(2)));
|
||||
}
|
||||
|
||||
TEST(Attr1StorageTest, Clear) {
|
||||
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
|
||||
attr_storage.Set(AttrKey(2), 10.0);
|
||||
attr_storage.Set(AttrKey(3), 11.0);
|
||||
|
||||
EXPECT_THAT(attr_storage.NonDefaults(),
|
||||
UnorderedElementsAre(AttrKey(2), AttrKey(3)));
|
||||
EXPECT_EQ(attr_storage.num_non_defaults(), 2);
|
||||
|
||||
attr_storage.Clear();
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 1.0);
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(3)), 1.0);
|
||||
EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty());
|
||||
EXPECT_EQ(attr_storage.num_non_defaults(), 0);
|
||||
}
|
||||
|
||||
TEST(Attr1StorageTest, Erase) {
|
||||
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
|
||||
attr_storage.Set(AttrKey(2), 10.0);
|
||||
attr_storage.Set(AttrKey(3), 11.0);
|
||||
|
||||
attr_storage.Erase(AttrKey(2));
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2)), 1.0);
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(3)), 11.0);
|
||||
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(3)));
|
||||
EXPECT_EQ(attr_storage.num_non_defaults(), 1);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Attr2Storage
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(Attr2StorageTest, EmptyGetters) {
|
||||
const AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(0, 0)), 1.0);
|
||||
EXPECT_FALSE(attr_storage.IsNonDefault(AttrKey(0, 0)));
|
||||
EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty());
|
||||
EXPECT_EQ(attr_storage.num_non_defaults(), 0);
|
||||
EXPECT_THAT(attr_storage.Slice<1>(0), IsEmpty());
|
||||
EXPECT_THAT(attr_storage.GetSliceSize<1>(0), 0);
|
||||
EXPECT_THAT(attr_storage.Slice<0>(0), IsEmpty());
|
||||
EXPECT_THAT(attr_storage.GetSliceSize<0>(0), 0);
|
||||
}
|
||||
|
||||
TEST(Attr2StorageTest, GettersNonEmpty) {
|
||||
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
|
||||
attr_storage.Set(AttrKey(2, 3), 10.0);
|
||||
attr_storage.Set(AttrKey(2, 5), 11.0);
|
||||
attr_storage.Set(AttrKey(5, 5), 12.0);
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 10.0);
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 5)), 11.0);
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(5, 5)), 12.0);
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(5, 2)), 1.0);
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 2)), 1.0);
|
||||
|
||||
EXPECT_THAT(
|
||||
attr_storage.NonDefaults(),
|
||||
UnorderedElementsAre(AttrKey(2, 3), AttrKey(2, 5), AttrKey(5, 5)));
|
||||
EXPECT_EQ(attr_storage.num_non_defaults(), 3);
|
||||
EXPECT_THAT(attr_storage.Slice<0>(2),
|
||||
UnorderedElementsAre(AttrKey(2, 3), AttrKey(2, 5)));
|
||||
EXPECT_THAT(attr_storage.GetSliceSize<0>(2), 2);
|
||||
EXPECT_THAT(attr_storage.Slice<0>(3), IsEmpty());
|
||||
EXPECT_THAT(attr_storage.GetSliceSize<0>(3), 0);
|
||||
EXPECT_THAT(attr_storage.Slice<0>(5), UnorderedElementsAre(AttrKey(5, 5)));
|
||||
EXPECT_THAT(attr_storage.GetSliceSize<0>(5), 1);
|
||||
|
||||
EXPECT_THAT(attr_storage.Slice<1>(2), IsEmpty());
|
||||
EXPECT_THAT(attr_storage.GetSliceSize<1>(2), 0);
|
||||
EXPECT_THAT(attr_storage.Slice<1>(3), UnorderedElementsAre(AttrKey(2, 3)));
|
||||
EXPECT_THAT(attr_storage.GetSliceSize<1>(3), 1);
|
||||
EXPECT_THAT(attr_storage.Slice<1>(5),
|
||||
UnorderedElementsAre(AttrKey(2, 5), AttrKey(5, 5)));
|
||||
EXPECT_THAT(attr_storage.GetSliceSize<1>(5), 2);
|
||||
}
|
||||
|
||||
TEST(Attr2StorageTest, GettersNonEmptySymmetric) {
|
||||
// Dim 0
|
||||
// | 0 1 2 3 4 5
|
||||
// --+------------------------
|
||||
// 0 | 0
|
||||
// D 1 | 0 0
|
||||
// i 2 | 0 0 0
|
||||
// m 3 | 0 10 0 0
|
||||
// 1 4 | 0 0 0 0 0
|
||||
// 5 | 0 11 0 0 0 12
|
||||
//
|
||||
using Storage = AttrStorage<double, 2, ElementSymmetry<0, 1>>;
|
||||
using Key = Storage::Key;
|
||||
Storage attr_storage(1.0);
|
||||
attr_storage.Set(Key(2, 3), 10.0);
|
||||
attr_storage.Set(Key(2, 5), 123.0);
|
||||
attr_storage.Set(Key(5, 2), 11.0); // Overwrites 123.0.
|
||||
attr_storage.Set(Key(5, 5), 12.0);
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(Key(2, 3)), 10.0);
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(Key(2, 5)), 11.0);
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(Key(5, 5)), 12.0);
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(Key(3, 2)), 10.0);
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(Key(5, 2)), 11.0);
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(Key(2, 2)), 1.0);
|
||||
|
||||
EXPECT_THAT(attr_storage.NonDefaults(),
|
||||
UnorderedElementsAre(Key(2, 3), Key(2, 5), Key(5, 5)));
|
||||
EXPECT_EQ(attr_storage.num_non_defaults(), 3);
|
||||
EXPECT_THAT(attr_storage.Slice<0>(2),
|
||||
UnorderedElementsAre(Key(2, 3), Key(2, 5)));
|
||||
EXPECT_THAT(attr_storage.GetSliceSize<0>(2), 2);
|
||||
EXPECT_THAT(attr_storage.Slice<0>(3), UnorderedElementsAre(Key(2, 3)));
|
||||
EXPECT_THAT(attr_storage.GetSliceSize<0>(3), 1);
|
||||
EXPECT_THAT(attr_storage.Slice<0>(4), IsEmpty());
|
||||
EXPECT_THAT(attr_storage.GetSliceSize<0>(4), 0);
|
||||
EXPECT_THAT(attr_storage.Slice<0>(5),
|
||||
UnorderedElementsAre(Key(2, 5), Key(5, 5)));
|
||||
EXPECT_THAT(attr_storage.GetSliceSize<0>(5), 2);
|
||||
|
||||
EXPECT_THAT(attr_storage.Slice<1>(2),
|
||||
UnorderedElementsAre(Key(2, 3), Key(2, 5)));
|
||||
EXPECT_THAT(attr_storage.GetSliceSize<1>(2), 2);
|
||||
EXPECT_THAT(attr_storage.Slice<1>(3), UnorderedElementsAre(Key(2, 3)));
|
||||
EXPECT_THAT(attr_storage.GetSliceSize<1>(3), 1);
|
||||
EXPECT_THAT(attr_storage.Slice<1>(4), IsEmpty());
|
||||
EXPECT_THAT(attr_storage.GetSliceSize<1>(4), 0);
|
||||
EXPECT_THAT(attr_storage.Slice<1>(5),
|
||||
UnorderedElementsAre(Key(2, 5), Key(5, 5)));
|
||||
EXPECT_THAT(attr_storage.GetSliceSize<1>(5), 2);
|
||||
}
|
||||
|
||||
TEST(Attr2StorageTest, SetDefaultToDefault) {
|
||||
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
|
||||
|
||||
EXPECT_FALSE(attr_storage.Set(AttrKey(2, 3), 1.0).has_value());
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 1.0);
|
||||
EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(Attr2StorageTest, SetDefaultToNonDefault) {
|
||||
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
|
||||
|
||||
EXPECT_THAT(attr_storage.Set(AttrKey(2, 3), 10.0), Optional(1.0));
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 10.0);
|
||||
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(2, 3)));
|
||||
}
|
||||
|
||||
TEST(Attr2StorageTest, SetNonDefaultToDefault) {
|
||||
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
|
||||
attr_storage.Set(AttrKey(2, 3), 10.0);
|
||||
|
||||
EXPECT_THAT(attr_storage.Set(AttrKey(2, 3), 1.0), Optional(10.0));
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 1.0);
|
||||
EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(Attr2StorageTest, SetNonDefaultToNonDefaultDifferent) {
|
||||
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
|
||||
attr_storage.Set(AttrKey(2, 3), 5.0);
|
||||
|
||||
EXPECT_THAT(attr_storage.Set(AttrKey(2, 3), 10.0), Optional(5.0));
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 10.0);
|
||||
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(2, 3)));
|
||||
}
|
||||
|
||||
TEST(Attr2StorageTest, SetNonDefaultToNonDefaultSame) {
|
||||
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
|
||||
attr_storage.Set(AttrKey(2, 3), 10.0);
|
||||
|
||||
EXPECT_FALSE(attr_storage.Set(AttrKey(2, 3), 10.0));
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 10.0);
|
||||
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(2, 3)));
|
||||
}
|
||||
|
||||
TEST(Attr2StorageTest, Clear) {
|
||||
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
|
||||
attr_storage.Set(AttrKey(2, 3), 10.0);
|
||||
attr_storage.Set(AttrKey(3, 4), 11.0);
|
||||
|
||||
EXPECT_THAT(attr_storage.NonDefaults(),
|
||||
UnorderedElementsAre(AttrKey(2, 3), AttrKey(3, 4)));
|
||||
EXPECT_EQ(attr_storage.num_non_defaults(), 2);
|
||||
|
||||
attr_storage.Clear();
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 1.0);
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(3, 4)), 1.0);
|
||||
EXPECT_THAT(attr_storage.NonDefaults(), IsEmpty());
|
||||
EXPECT_EQ(attr_storage.num_non_defaults(), 0);
|
||||
}
|
||||
|
||||
TEST(Attr2StorageTest, Erase) {
|
||||
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
|
||||
attr_storage.Set(AttrKey(2, 3), 10.0);
|
||||
attr_storage.Set(AttrKey(3, 4), 11.0);
|
||||
|
||||
attr_storage.Erase(AttrKey(2, 3));
|
||||
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(2, 3)), 1.0);
|
||||
EXPECT_DOUBLE_EQ(attr_storage.Get(AttrKey(3, 4)), 11.0);
|
||||
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(3, 4)));
|
||||
EXPECT_EQ(attr_storage.num_non_defaults(), 1);
|
||||
}
|
||||
|
||||
TEST(Attr2StorageTest, EraseColumnLives) {
|
||||
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
|
||||
attr_storage.Set(AttrKey(2, 3), 10.0);
|
||||
attr_storage.Set(AttrKey(5, 3), 11.0);
|
||||
|
||||
attr_storage.Erase(AttrKey(2, 3));
|
||||
|
||||
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(5, 3)));
|
||||
EXPECT_THAT(attr_storage.Slice<0>(5), UnorderedElementsAre(AttrKey(5, 3)));
|
||||
EXPECT_THAT(attr_storage.Slice<1>(3), UnorderedElementsAre(AttrKey(5, 3)));
|
||||
|
||||
// Insert again.
|
||||
attr_storage.Set(AttrKey(2, 3), 12.0);
|
||||
EXPECT_THAT(attr_storage.NonDefaults(),
|
||||
UnorderedElementsAre(AttrKey(2, 3), AttrKey(5, 3)));
|
||||
EXPECT_THAT(attr_storage.Slice<0>(5), UnorderedElementsAre(AttrKey(5, 3)));
|
||||
EXPECT_THAT(attr_storage.Slice<1>(3),
|
||||
UnorderedElementsAre(AttrKey(2, 3), AttrKey(5, 3)));
|
||||
}
|
||||
|
||||
TEST(Attr2StorageTest, EraseRowLives) {
|
||||
AttrStorage<double, 2, NoSymmetry> attr_storage(1.0);
|
||||
attr_storage.Set(AttrKey(3, 2), 10.0);
|
||||
attr_storage.Set(AttrKey(3, 5), 11.0);
|
||||
|
||||
attr_storage.Erase(AttrKey(3, 2));
|
||||
|
||||
EXPECT_THAT(attr_storage.NonDefaults(), UnorderedElementsAre(AttrKey(3, 5)));
|
||||
EXPECT_THAT(attr_storage.Slice<0>(3), UnorderedElementsAre(AttrKey(3, 5)));
|
||||
EXPECT_THAT(attr_storage.Slice<1>(5), UnorderedElementsAre(AttrKey(3, 5)));
|
||||
}
|
||||
|
||||
// Makes a set of `n` 1-dimensional keys.
|
||||
std::vector<AttrKey<1>> Make1DKeys(int n) {
|
||||
std::vector<AttrKey<1>> keys;
|
||||
for (int64_t i = 0; i < n; ++i) {
|
||||
keys.emplace_back(i);
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
// Makes a set of `n^2` 2-dimensional keys.
|
||||
// NOTE: depending in `Symmetry` this might create duplicate keys. This is
|
||||
// intentional, as we want to have the same number of keys to be able to compare
|
||||
// the performance of different symmetries.
|
||||
template <typename Symmetry>
|
||||
std::vector<AttrKey<2, Symmetry>> Make2DKeys(int n) {
|
||||
std::vector<AttrKey<2, Symmetry>> keys;
|
||||
for (int64_t i = 0; i < n; ++i) {
|
||||
for (int64_t j = 0; j < n; ++j) {
|
||||
keys.emplace_back(i, j);
|
||||
}
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
// A functor that returns true every N calls, false otherwise.
|
||||
template <int N>
|
||||
struct TrueEvery {
|
||||
int n = 0;
|
||||
bool operator()() {
|
||||
if (n == N) {
|
||||
n = 0;
|
||||
return true;
|
||||
}
|
||||
++n;
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
void BM_Attr0StorageSet(benchmark::State& state) {
|
||||
AttrStorage<double, 0, NoSymmetry> attr_storage(1.0);
|
||||
|
||||
for (const auto s : state) {
|
||||
attr_storage.Set(AttrKey(), 10.0);
|
||||
benchmark::DoNotOptimize(attr_storage);
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_Attr0StorageSet);
|
||||
|
||||
void BM_Attr1StorageSet(benchmark::State& state) {
|
||||
const int n = state.range(0);
|
||||
|
||||
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
|
||||
const auto keys = Make1DKeys(n);
|
||||
|
||||
for (const auto s : state) {
|
||||
for (const auto& key : keys) {
|
||||
attr_storage.Set(key, 10.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_Attr1StorageSet)->Arg(900);
|
||||
|
||||
template <typename Symmetry>
|
||||
void BM_Attr2StorageSet(benchmark::State& state) {
|
||||
const int n = state.range(0);
|
||||
|
||||
const auto keys = Make2DKeys<Symmetry>(n);
|
||||
|
||||
std::optional<AttrStorage<double, 2, Symmetry>> attr_storage(1.0);
|
||||
for (const auto s : state) {
|
||||
for (const auto& key : keys) {
|
||||
attr_storage->Set(key, 10.0);
|
||||
}
|
||||
state.PauseTiming();
|
||||
attr_storage.emplace(1.0);
|
||||
state.ResumeTiming();
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_Attr2StorageSet<NoSymmetry>)->Arg(30);
|
||||
BENCHMARK(BM_Attr2StorageSet<ElementSymmetry<0, 1>>)->Arg(30);
|
||||
|
||||
void BM_Attr0StorageGet(benchmark::State& state) {
|
||||
AttrStorage<double, 0, NoSymmetry> attr_storage(1.0);
|
||||
|
||||
for (const auto s : state) {
|
||||
double v = attr_storage.Get(AttrKey());
|
||||
benchmark::DoNotOptimize(v);
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_Attr0StorageGet);
|
||||
|
||||
void BM_Attr1StorageGet(benchmark::State& state) {
|
||||
const int n = state.range(0);
|
||||
|
||||
AttrStorage<double, 1, NoSymmetry> attr_storage(1.0);
|
||||
const auto keys = Make1DKeys(n);
|
||||
// Insert half the keys.
|
||||
TrueEvery<2> sample;
|
||||
for (const auto& key : keys) {
|
||||
if (sample()) {
|
||||
attr_storage.Set(key, 10.0);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto s : state) {
|
||||
for (const auto& key : keys) {
|
||||
double v = attr_storage.Get(key);
|
||||
benchmark::DoNotOptimize(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_Attr1StorageGet)->Arg(900);
|
||||
|
||||
template <typename Symmetry>
|
||||
void BM_Attr2StorageGet(benchmark::State& state) {
|
||||
const int n = state.range(0);
|
||||
|
||||
AttrStorage<double, 2, Symmetry> attr_storage(1.0);
|
||||
const auto keys = Make2DKeys<Symmetry>(n);
|
||||
// Insert half the keys.
|
||||
TrueEvery<2> sample;
|
||||
for (const auto& key : keys) {
|
||||
if (sample()) {
|
||||
attr_storage.Set(key, 10.0);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto s : state) {
|
||||
for (const auto& key : keys) {
|
||||
double v = attr_storage.Get(key);
|
||||
benchmark::DoNotOptimize(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_Attr2StorageGet<NoSymmetry>)->Arg(30);
|
||||
BENCHMARK(BM_Attr2StorageGet<ElementSymmetry<0, 1>>)->Arg(30);
|
||||
|
||||
template <typename Symmetry>
|
||||
void BM_Attr2StorageSlice(benchmark::State& state) {
|
||||
const int n = state.range(0);
|
||||
|
||||
AttrStorage<double, 2, Symmetry> attr_storage(1.0);
|
||||
const auto keys = Make2DKeys<Symmetry>(n);
|
||||
// Insert 5% of the keys.
|
||||
TrueEvery<20> sample;
|
||||
for (const auto& key : keys) {
|
||||
if (sample()) {
|
||||
attr_storage.Set(key, 10.0);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto s : state) {
|
||||
for (int key_id = 0; key_id < n; ++key_id) {
|
||||
auto slice0 = attr_storage.template Slice<0>(key_id);
|
||||
auto slice1 = attr_storage.template Slice<1>(key_id);
|
||||
benchmark::DoNotOptimize(slice0);
|
||||
benchmark::DoNotOptimize(slice1);
|
||||
}
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_Attr2StorageSlice<NoSymmetry>)->Arg(30);
|
||||
BENCHMARK(BM_Attr2StorageSlice<ElementSymmetry<0, 1>>)->Arg(30);
|
||||
|
||||
} // namespace
|
||||
} // namespace operations_research::math_opt
|
||||
349
ortools/math_opt/elemental/attributes.h
Normal file
349
ortools/math_opt/elemental/attributes.h
Normal file
@@ -0,0 +1,349 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ATTRIBUTES_H_
|
||||
#define OR_TOOLS_MATH_OPT_ELEMENTAL_ATTRIBUTES_H_
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <ostream>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "ortools/base/array.h"
|
||||
#include "ortools/math_opt/elemental/arrays.h"
|
||||
#include "ortools/math_opt/elemental/elements.h"
|
||||
#include "ortools/math_opt/elemental/symmetry.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
// A base class for all attribute type descriptors.
|
||||
// `ValueTypeT` is the attribute value type, and `n` is the number of key
|
||||
// elements (e.g. `Double2` attribute has `ValueType` == `double` and `n` == 2).
|
||||
// This uses
|
||||
// [CRTP](https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern) in
|
||||
// `Impl` to deduce common descriptor properties from `Impl`. `Impl` must
|
||||
// inherit from `AttrTypeDescriptor` and define the following entities:
|
||||
// - `static constexpr absl::string_view kName`: The name of the attribute
|
||||
// type.
|
||||
// - `enum class AttrType`: The attribute type, with `k` enumerators
|
||||
// corresponding to attributes for this type. Enumerators must be numbered
|
||||
// `0..(k-1)` (a good way to do this is to leave them unnumbered).
|
||||
// - `std::array<AttrDescriptor, k> kAttrDescriptors`: A descriptor for each
|
||||
// of the `k` attributes for this type.
|
||||
template <typename ValueTypeT, int n, typename SymmetryT, typename Impl>
|
||||
struct AttrTypeDescriptor {
|
||||
// The type of attribute values (e.g. `bool`, `int64_t`, `double`).
|
||||
using ValueType = ValueTypeT;
|
||||
|
||||
// The number of key elements.
|
||||
static constexpr int kNumKeyElements = n;
|
||||
|
||||
// The key symmetry. For example, this can be used to enforce that
|
||||
// quadratic objective coefficients are the same for `(i, j)` and `(j, i)`
|
||||
// (see `kObjQuadCoef` below).
|
||||
using Symmetry = SymmetryT;
|
||||
|
||||
// A descriptor of an attribute of this attribute type.
|
||||
// E.g., this could describe the attribute `DoubleAttr1::kVarLb`.
|
||||
struct AttrDescriptor {
|
||||
// The name of the attribute value.
|
||||
absl::string_view name;
|
||||
// The default value.
|
||||
ValueType default_value;
|
||||
// The types of the `n` key elements.
|
||||
std::array<ElementType, n> key_types;
|
||||
};
|
||||
|
||||
// Returns the number of attributes of this attribute type.
|
||||
static constexpr int NumAttrs() { return Impl::kAttrDescriptors.size(); }
|
||||
|
||||
// Returns an array with all attributes of this attribute type.
|
||||
static constexpr auto Enumerate() {
|
||||
std::array<typename Impl::AttrType, NumAttrs()> result;
|
||||
for (int i = 0; i < NumAttrs(); ++i) {
|
||||
result[i] = {static_cast<typename Impl::AttrType>(i)};
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
struct BoolAttr0TypeDescriptor
|
||||
: public AttrTypeDescriptor<bool, 0, NoSymmetry, BoolAttr0TypeDescriptor> {
|
||||
static constexpr absl::string_view kName = "BoolAttr0";
|
||||
|
||||
enum class AttrType { kMaximize };
|
||||
|
||||
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>(
|
||||
{{.name = "maximize", .default_value = false, .key_types = {}}});
|
||||
};
|
||||
|
||||
struct BoolAttr1TypeDescriptor
|
||||
: public AttrTypeDescriptor<bool, 1, NoSymmetry, BoolAttr1TypeDescriptor> {
|
||||
static constexpr absl::string_view kName = "BoolAttr1";
|
||||
|
||||
enum class AttrType {
|
||||
kVarInteger,
|
||||
kAuxObjMaximize,
|
||||
kIndConActivateOnZero,
|
||||
};
|
||||
|
||||
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>(
|
||||
{{.name = "variable_integer",
|
||||
.default_value = false,
|
||||
.key_types = {ElementType::kVariable}},
|
||||
{.name = "auxiliary_objective_maximize",
|
||||
.default_value = false,
|
||||
.key_types = {ElementType::kAuxiliaryObjective}},
|
||||
{.name = "indicator_constraint_activate_on_zero",
|
||||
.default_value = false,
|
||||
.key_types = {ElementType::kIndicatorConstraint}}});
|
||||
};
|
||||
|
||||
struct IntAttr0TypeDescriptor
|
||||
: public AttrTypeDescriptor<int64_t, 0, NoSymmetry,
|
||||
IntAttr0TypeDescriptor> {
|
||||
static constexpr absl::string_view kName = "IntAttr0";
|
||||
|
||||
enum class AttrType {
|
||||
kObjPriority,
|
||||
};
|
||||
|
||||
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>({
|
||||
{.name = "objective_priority", .default_value = 0, .key_types = {}},
|
||||
});
|
||||
};
|
||||
|
||||
struct IntAttr1TypeDescriptor
|
||||
: public AttrTypeDescriptor<int64_t, 1, NoSymmetry,
|
||||
IntAttr1TypeDescriptor> {
|
||||
static constexpr absl::string_view kName = "IntAttr1";
|
||||
|
||||
enum class AttrType {
|
||||
kAuxObjPriority,
|
||||
};
|
||||
|
||||
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>({
|
||||
{.name = "auxiliary_objective_priority",
|
||||
.default_value = 0,
|
||||
.key_types = {ElementType::kAuxiliaryObjective}},
|
||||
});
|
||||
};
|
||||
|
||||
struct DoubleAttr0TypeDescriptor
|
||||
: public AttrTypeDescriptor<double, 0, NoSymmetry,
|
||||
DoubleAttr0TypeDescriptor> {
|
||||
static constexpr absl::string_view kName = "DoubleAttr0";
|
||||
|
||||
enum class AttrType { kObjOffset };
|
||||
|
||||
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>(
|
||||
{{.name = "objective_offset", .default_value = 0.0, .key_types = {}}});
|
||||
};
|
||||
|
||||
struct DoubleAttr1TypeDescriptor
|
||||
: public AttrTypeDescriptor<double, 1, NoSymmetry,
|
||||
DoubleAttr1TypeDescriptor> {
|
||||
static constexpr absl::string_view kName = "DoubleAttr1";
|
||||
|
||||
enum class AttrType {
|
||||
kVarLb,
|
||||
kVarUb,
|
||||
kObjLinCoef,
|
||||
kLinConLb,
|
||||
kLinConUb,
|
||||
kAuxObjOffset,
|
||||
kQuadConLb,
|
||||
kQuadConUb,
|
||||
kIndConLb,
|
||||
kIndConUb,
|
||||
};
|
||||
|
||||
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>({
|
||||
{.name = "variable_lower_bound",
|
||||
.default_value = -std::numeric_limits<double>::infinity(),
|
||||
.key_types = {ElementType::kVariable}},
|
||||
{.name = "variable_upper_bound",
|
||||
.default_value = std::numeric_limits<double>::infinity(),
|
||||
.key_types = {ElementType::kVariable}},
|
||||
{.name = "objective_linear_coefficient",
|
||||
.default_value = 0.0,
|
||||
.key_types = {ElementType::kVariable}},
|
||||
{.name = "linear_constraint_lower_bound",
|
||||
.default_value = -std::numeric_limits<double>::infinity(),
|
||||
.key_types = {ElementType::kLinearConstraint}},
|
||||
{.name = "linear_constraint_upper_bound",
|
||||
.default_value = std::numeric_limits<double>::infinity(),
|
||||
.key_types = {ElementType::kLinearConstraint}},
|
||||
{.name = "auxiliary_objective_offset",
|
||||
.default_value = 0.0,
|
||||
.key_types = {ElementType::kAuxiliaryObjective}},
|
||||
{.name = "quadratic_constraint_lower_bound",
|
||||
.default_value = -std::numeric_limits<double>::infinity(),
|
||||
.key_types = {ElementType::kQuadraticConstraint}},
|
||||
{.name = "quadratic_constraint_upper_bound",
|
||||
.default_value = std::numeric_limits<double>::infinity(),
|
||||
.key_types = {ElementType::kQuadraticConstraint}},
|
||||
{.name = "indicator_constraint_lower_bound",
|
||||
.default_value = -std::numeric_limits<double>::infinity(),
|
||||
.key_types = {ElementType::kIndicatorConstraint}},
|
||||
{.name = "indicator_constraint_upper_bound",
|
||||
.default_value = std::numeric_limits<double>::infinity(),
|
||||
.key_types = {ElementType::kIndicatorConstraint}},
|
||||
});
|
||||
};
|
||||
|
||||
struct DoubleAttr2TypeDescriptor
|
||||
: public AttrTypeDescriptor<double, 2, NoSymmetry,
|
||||
DoubleAttr2TypeDescriptor> {
|
||||
static constexpr absl::string_view kName = "DoubleAttr2";
|
||||
|
||||
enum class AttrType {
|
||||
kLinConCoef,
|
||||
kAuxObjLinCoef,
|
||||
kQuadConLinCoef,
|
||||
kIndConLinCoef
|
||||
};
|
||||
|
||||
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>({
|
||||
{.name = "linear_constraint_coefficient",
|
||||
.default_value = 0.0,
|
||||
.key_types = {ElementType::kLinearConstraint, ElementType::kVariable}},
|
||||
{.name = "auxiliary_objective_linear_coefficient",
|
||||
.default_value = 0.0,
|
||||
.key_types = {ElementType::kAuxiliaryObjective, ElementType::kVariable}},
|
||||
{.name = "quadratic_constraint_linear_coefficient",
|
||||
.default_value = 0.0,
|
||||
.key_types = {ElementType::kQuadraticConstraint,
|
||||
ElementType::kVariable}},
|
||||
{.name = "indicator_constraint_linear_coefficient",
|
||||
.default_value = 0.0,
|
||||
.key_types = {ElementType::kIndicatorConstraint,
|
||||
ElementType::kVariable}},
|
||||
});
|
||||
};
|
||||
|
||||
struct SymmetricDoubleAttr2TypeDescriptor
|
||||
: public AttrTypeDescriptor<double, 2, ElementSymmetry<0, 1>,
|
||||
SymmetricDoubleAttr2TypeDescriptor> {
|
||||
static constexpr absl::string_view kName = "SymmetricDoubleAttr2";
|
||||
|
||||
enum class AttrType {
|
||||
kObjQuadCoef,
|
||||
};
|
||||
|
||||
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>({
|
||||
{.name = "objective_quadratic_coefficient",
|
||||
.default_value = 0.0,
|
||||
.key_types = {ElementType::kVariable, ElementType::kVariable}},
|
||||
});
|
||||
};
|
||||
|
||||
// Note: For this type, we pick the symmetric elements to be the last 2 elements
|
||||
// of the key (index 1 and 2).
|
||||
struct SymmetricDoubleAttr3TypeDescriptor
|
||||
: public AttrTypeDescriptor<double, 3, ElementSymmetry<1, 2>,
|
||||
SymmetricDoubleAttr3TypeDescriptor> {
|
||||
static constexpr absl::string_view kName = "SymmetricDoubleAttr3";
|
||||
|
||||
enum class AttrType {
|
||||
kQuadConQuadCoef,
|
||||
};
|
||||
|
||||
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>({
|
||||
{.name = "quadratic_constraint_quadratic_coefficient",
|
||||
.default_value = 0.0,
|
||||
.key_types = {ElementType::kQuadraticConstraint, ElementType::kVariable,
|
||||
ElementType::kVariable}},
|
||||
});
|
||||
};
|
||||
|
||||
struct VariableAttr1TypeDescriptor
|
||||
: public AttrTypeDescriptor<VariableId, 1, NoSymmetry,
|
||||
VariableAttr1TypeDescriptor> {
|
||||
static constexpr absl::string_view kName = "VariableAttr1";
|
||||
|
||||
enum class AttrType {
|
||||
kIndConIndicator,
|
||||
};
|
||||
|
||||
static constexpr auto kAttrDescriptors = gtl::to_array<AttrDescriptor>({
|
||||
{.name = "indicator_constraint_indicator",
|
||||
.default_value = VariableId(),
|
||||
.key_types = {ElementType::kIndicatorConstraint}},
|
||||
});
|
||||
};
|
||||
|
||||
// The list of all available attribute descriptors. This is typically
|
||||
// manipulated using the `AllAttrs` helper in `derived_data.h`.
|
||||
using AllAttrTypeDescriptors =
|
||||
std::tuple<BoolAttr0TypeDescriptor, BoolAttr1TypeDescriptor,
|
||||
IntAttr0TypeDescriptor, IntAttr1TypeDescriptor,
|
||||
DoubleAttr0TypeDescriptor, DoubleAttr1TypeDescriptor,
|
||||
DoubleAttr2TypeDescriptor, SymmetricDoubleAttr2TypeDescriptor,
|
||||
SymmetricDoubleAttr3TypeDescriptor, VariableAttr1TypeDescriptor>;
|
||||
|
||||
// Aliases for types.
|
||||
using BoolAttr0 = BoolAttr0TypeDescriptor::AttrType;
|
||||
using BoolAttr1 = BoolAttr1TypeDescriptor::AttrType;
|
||||
using IntAttr0 = IntAttr0TypeDescriptor::AttrType;
|
||||
using IntAttr1 = IntAttr1TypeDescriptor::AttrType;
|
||||
using DoubleAttr0 = DoubleAttr0TypeDescriptor::AttrType;
|
||||
using DoubleAttr1 = DoubleAttr1TypeDescriptor::AttrType;
|
||||
using DoubleAttr2 = DoubleAttr2TypeDescriptor::AttrType;
|
||||
using SymmetricDoubleAttr2 = SymmetricDoubleAttr2TypeDescriptor::AttrType;
|
||||
using SymmetricDoubleAttr3 = SymmetricDoubleAttr3TypeDescriptor::AttrType;
|
||||
using VariableAttr1 = VariableAttr1TypeDescriptor::AttrType;
|
||||
|
||||
// Returns the index of `AttrT` in `AllAttrTypes` if `AttrT` is an attribute
|
||||
// type, -1 otherwise.
|
||||
template <typename AttrT>
|
||||
static constexpr int GetIndexIfAttr() {
|
||||
using Tuple = AllAttrTypeDescriptors;
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
return ApplyOnIndexRange<std::tuple_size_v<Tuple>>([]<int... i>() {
|
||||
return ((std::is_same_v<std::remove_cv_t<std::remove_reference_t<AttrT>>,
|
||||
typename std::tuple_element_t<i, Tuple>::AttrType>
|
||||
? (i + 1)
|
||||
: 0) +
|
||||
... + -1);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename AttrT,
|
||||
typename = std::enable_if_t<(GetIndexIfAttr<AttrT>() >= 0)>>
|
||||
absl::string_view ToString(const AttrT attr) {
|
||||
using Descriptor =
|
||||
std::tuple_element_t<GetIndexIfAttr<AttrT>(), AllAttrTypeDescriptors>;
|
||||
const int attr_index = static_cast<int>(attr);
|
||||
return Descriptor::kAttrDescriptors[attr_index].name;
|
||||
}
|
||||
|
||||
template <typename Sink, typename AttrT,
|
||||
typename = std::enable_if_t<(GetIndexIfAttr<AttrT>() >= 0)>>
|
||||
void AbslStringify(Sink& sink, const AttrT attr_type) {
|
||||
sink.Append(ToString(attr_type));
|
||||
}
|
||||
|
||||
template <typename AttrT>
|
||||
std::enable_if_t<(GetIndexIfAttr<AttrT>() >= 0), std::ostream&> operator<<(
|
||||
std::ostream& ostr, AttrT attr) {
|
||||
ostr << ToString(attr);
|
||||
return ostr;
|
||||
}
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ATTRIBUTES_H_
|
||||
58
ortools/math_opt/elemental/attributes_test.cc
Normal file
58
ortools/math_opt/elemental/attributes_test.cc
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/attributes.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/math_opt/elemental/arrays.h"
|
||||
#include "ortools/math_opt/testing/stream.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
namespace {
|
||||
|
||||
TEST(ToStringTests, EachTypeCanConvert) {
|
||||
EXPECT_EQ(ToString(BoolAttr0::kMaximize), "maximize");
|
||||
EXPECT_EQ(ToString(BoolAttr1::kVarInteger), "variable_integer");
|
||||
EXPECT_EQ(ToString(IntAttr0::kObjPriority), "objective_priority");
|
||||
EXPECT_EQ(ToString(IntAttr1::kAuxObjPriority),
|
||||
"auxiliary_objective_priority");
|
||||
EXPECT_EQ(ToString(DoubleAttr0::kObjOffset), "objective_offset");
|
||||
EXPECT_EQ(ToString(DoubleAttr1::kVarLb), "variable_lower_bound");
|
||||
EXPECT_EQ(ToString(DoubleAttr2::kLinConCoef),
|
||||
"linear_constraint_coefficient");
|
||||
EXPECT_EQ(ToString(SymmetricDoubleAttr2::kObjQuadCoef),
|
||||
"objective_quadratic_coefficient");
|
||||
EXPECT_EQ(ToString(SymmetricDoubleAttr3::kQuadConQuadCoef),
|
||||
"quadratic_constraint_quadratic_coefficient");
|
||||
// Now check that absl::Stringify wraps ToString()
|
||||
EXPECT_EQ(absl::StrCat(BoolAttr0::kMaximize), "maximize");
|
||||
// Now check that << wraps ToString()
|
||||
EXPECT_EQ(StreamToString(BoolAttr0::kMaximize), "maximize");
|
||||
}
|
||||
|
||||
// Validate that for all symmetric attribute types, the symmetry is consistent
|
||||
// with element types.
|
||||
TEST(SymmetryTest, AllSymmetricTypesAreCorrect) {
|
||||
ForEach(
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
[]<typename Descriptor>(const Descriptor&) {
|
||||
for (const auto& attr : Descriptor::kAttrDescriptors) {
|
||||
Descriptor::Symmetry::CheckElementTypes(attr.key_types);
|
||||
}
|
||||
},
|
||||
AllAttrTypeDescriptors{});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace operations_research::math_opt
|
||||
81
ortools/math_opt/elemental/codegen/BUILD.bazel
Normal file
81
ortools/math_opt/elemental/codegen/BUILD.bazel
Normal file
@@ -0,0 +1,81 @@
|
||||
# Copyright 2010-2025 Google LLC
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@rules_cc//cc:cc_binary.bzl", "cc_binary")
|
||||
load("@rules_cc//cc:cc_library.bzl", "cc_library")
|
||||
|
||||
cc_library(
|
||||
name = "gen",
|
||||
srcs = ["gen.cc"],
|
||||
hdrs = ["gen.h"],
|
||||
deps = [
|
||||
"//ortools/math_opt/elemental:arrays",
|
||||
"//ortools/math_opt/elemental:attributes",
|
||||
"//ortools/math_opt/elemental:elements",
|
||||
"@abseil-cpp//absl/strings",
|
||||
"@abseil-cpp//absl/strings:string_view",
|
||||
"@abseil-cpp//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gen_c",
|
||||
srcs = ["gen_c.cc"],
|
||||
hdrs = ["gen_c.h"],
|
||||
deps = [
|
||||
":gen",
|
||||
"@abseil-cpp//absl/log:check",
|
||||
"@abseil-cpp//absl/strings",
|
||||
"@abseil-cpp//absl/strings:str_format",
|
||||
"@abseil-cpp//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "testing",
|
||||
testonly = 1,
|
||||
hdrs = ["testing.h"],
|
||||
deps = [":gen"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gen_python",
|
||||
srcs = ["gen_python.cc"],
|
||||
hdrs = ["gen_python.h"],
|
||||
deps = [
|
||||
":gen",
|
||||
"@abseil-cpp//absl/strings",
|
||||
"@abseil-cpp//absl/strings:str_format",
|
||||
"@abseil-cpp//absl/strings:string_view",
|
||||
"@abseil-cpp//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "codegen",
|
||||
srcs = ["codegen.cc"],
|
||||
visibility = [
|
||||
"//ortools/math_opt/elemental/c_api:__subpackages__",
|
||||
"//ortools/math_opt/elemental/python:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
":gen",
|
||||
":gen_c",
|
||||
":gen_python",
|
||||
"//ortools/base",
|
||||
"@abseil-cpp//absl/flags:flag",
|
||||
"@abseil-cpp//absl/log",
|
||||
"@abseil-cpp//absl/log:check",
|
||||
"@abseil-cpp//absl/strings:string_view",
|
||||
],
|
||||
)
|
||||
52
ortools/math_opt/elemental/codegen/codegen.cc
Normal file
52
ortools/math_opt/elemental/codegen/codegen.cc
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "ortools/base/init_google.h"
|
||||
#include "ortools/math_opt/elemental/codegen/gen.h"
|
||||
#include "ortools/math_opt/elemental/codegen/gen_c.h"
|
||||
#include "ortools/math_opt/elemental/codegen/gen_python.h"
|
||||
|
||||
ABSL_FLAG(std::string, binding_type, "", "The binding type to generate.");
|
||||
|
||||
namespace operations_research::math_opt::codegen {
|
||||
|
||||
namespace {
|
||||
void Main() {
|
||||
const std::string binding_type = absl::GetFlag(FLAGS_binding_type);
|
||||
if (binding_type == "c99_h") {
|
||||
std::cout << C99Declarations()->GenerateCode();
|
||||
} else if (binding_type == "c99_cc") {
|
||||
std::cout << C99Definitions()->GenerateCode();
|
||||
} else if (binding_type == "python_enums") {
|
||||
std::cout << PythonEnums()->GenerateCode();
|
||||
} else {
|
||||
LOG(FATAL) << "unknown binding type: '" << binding_type << "'";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace operations_research::math_opt::codegen
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
InitGoogle(argv[0], &argc, &argv, /*remove_flags=*/true);
|
||||
operations_research::math_opt::codegen::Main();
|
||||
return 0;
|
||||
}
|
||||
158
ortools/math_opt/elemental/codegen/gen.cc
Normal file
158
ortools/math_opt/elemental/codegen/gen.cc
Normal file
@@ -0,0 +1,158 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/codegen/gen.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "ortools/math_opt/elemental/arrays.h"
|
||||
#include "ortools/math_opt/elemental/attributes.h"
|
||||
#include "ortools/math_opt/elemental/elements.h"
|
||||
|
||||
namespace operations_research::math_opt::codegen {
|
||||
|
||||
namespace {
|
||||
|
||||
class NamedType : public Type {
|
||||
public:
|
||||
explicit NamedType(std::string name) : name_(std::move(name)) {}
|
||||
|
||||
void Print(absl::string_view, std::string* out) const final {
|
||||
absl::StrAppend(out, name_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
class PointerType : public Type {
|
||||
public:
|
||||
explicit PointerType(std::shared_ptr<Type> pointee)
|
||||
: pointee_(std::move(pointee)) {}
|
||||
|
||||
void Print(absl::string_view attr_value_type, std::string* out) const final {
|
||||
pointee_->Print(attr_value_type, out);
|
||||
absl::StrAppend(out, "*");
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Type> pointee_;
|
||||
};
|
||||
|
||||
class AttrValueTypeType : public Type {
|
||||
public:
|
||||
void Print(absl::string_view attr_value_type, std::string* out) const final {
|
||||
absl::StrAppend(out, attr_value_type);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::shared_ptr<Type> Type::Named(std::string name) {
|
||||
return std::make_shared<NamedType>(std::move(name));
|
||||
}
|
||||
|
||||
std::shared_ptr<Type> Type::Pointer(std::shared_ptr<Type> pointee) {
|
||||
return std::make_shared<PointerType>(std::move(pointee));
|
||||
}
|
||||
|
||||
std::shared_ptr<Type> Type::AttrValueType() {
|
||||
return std::make_shared<AttrValueTypeType>();
|
||||
}
|
||||
|
||||
Type::~Type() = default;
|
||||
|
||||
CodegenAttrTypeDescriptor::ValueType GetValueType(bool) {
|
||||
return CodegenAttrTypeDescriptor::ValueType::kBool;
|
||||
}
|
||||
|
||||
CodegenAttrTypeDescriptor::ValueType GetValueType(int64_t) {
|
||||
return CodegenAttrTypeDescriptor::ValueType::kInt64;
|
||||
}
|
||||
|
||||
CodegenAttrTypeDescriptor::ValueType GetValueType(double) {
|
||||
return CodegenAttrTypeDescriptor::ValueType::kDouble;
|
||||
}
|
||||
|
||||
template <ElementType element_type>
|
||||
CodegenAttrTypeDescriptor::ValueType GetValueType(ElementId<element_type>) {
|
||||
// Element ids are untyped in wrapped APIs.
|
||||
return CodegenAttrTypeDescriptor::ValueType::kInt64;
|
||||
}
|
||||
|
||||
template <typename Descriptor>
|
||||
CodegenAttrTypeDescriptor MakeAttrTypeDescriptor() {
|
||||
CodegenAttrTypeDescriptor descriptor;
|
||||
descriptor.value_type = GetValueType(typename Descriptor::ValueType{});
|
||||
descriptor.name = Descriptor::kName;
|
||||
descriptor.num_key_elements = Descriptor::kNumKeyElements;
|
||||
descriptor.symmetry = Descriptor::Symmetry::GetName();
|
||||
|
||||
descriptor.attribute_names.reserve(Descriptor::NumAttrs());
|
||||
for (const auto& attr_descriptor : Descriptor::kAttrDescriptors) {
|
||||
descriptor.attribute_names.push_back(attr_descriptor.name);
|
||||
}
|
||||
return descriptor;
|
||||
}
|
||||
|
||||
constexpr absl::string_view kOpNames[static_cast<int>(AttrOp::kNumOps)] = {
|
||||
"Get", "Set", "IsNonDefault", "NumNonDefaults", "GetNonDefaults"};
|
||||
|
||||
void CodeGenerator::EmitAttrType(const CodegenAttrTypeDescriptor& descriptor,
|
||||
std::string* out) const {
|
||||
StartAttrType(descriptor, out);
|
||||
for (int op = 0; op < kNumAttrOps; ++op) {
|
||||
const AttrOpFunctionInfo& op_info = attr_op_function_infos_[op];
|
||||
EmitAttrOp(kOpNames[op], descriptor, op_info, out);
|
||||
}
|
||||
}
|
||||
|
||||
void CodeGenerator::EmitAttributes(
|
||||
absl::Span<const CodegenAttrTypeDescriptor> descriptors,
|
||||
std::string* out) const {
|
||||
for (const auto& descriptor : descriptors) {
|
||||
StartAttrType(descriptor, out);
|
||||
for (int i = 0; i < kNumAttrOps; ++i) {
|
||||
EmitAttrOp(kOpNames[i], descriptor, attr_op_function_infos_[i], out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string CodeGenerator::GenerateCode() const {
|
||||
std::string out;
|
||||
EmitHeader(&out);
|
||||
|
||||
// Generate elements.
|
||||
EmitElements(kElementNames, &out);
|
||||
|
||||
// Generate attributes.
|
||||
std::vector<CodegenAttrTypeDescriptor> attr_type_descriptors;
|
||||
ForEach(
|
||||
[&attr_type_descriptors](auto type_descriptor) {
|
||||
attr_type_descriptors.push_back(
|
||||
MakeAttrTypeDescriptor<decltype(type_descriptor)>());
|
||||
},
|
||||
AllAttrTypeDescriptors{});
|
||||
EmitAttributes(attr_type_descriptors, &out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace operations_research::math_opt::codegen
|
||||
143
ortools/math_opt/elemental/codegen/gen.h
Normal file
143
ortools/math_opt/elemental/codegen/gen.h
Normal file
@@ -0,0 +1,143 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Language-agnostic utilities for `Elemental` codegen.
|
||||
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_H_
|
||||
#define OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_H_
|
||||
|
||||
#include <array>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
|
||||
namespace operations_research::math_opt::codegen {
|
||||
|
||||
// The list of attribute operations supported by `Elemental`.
|
||||
enum class AttrOp {
|
||||
kGet,
|
||||
kSet,
|
||||
kIsNonDefault,
|
||||
kNumNonDefaults,
|
||||
kGetNonDefaults,
|
||||
// Do not use.
|
||||
kNumOps,
|
||||
};
|
||||
|
||||
static constexpr int kNumAttrOps = static_cast<int>(AttrOp::kNumOps);
|
||||
|
||||
// A struct to represent an attribute type descriptor during codegen.
|
||||
struct CodegenAttrTypeDescriptor {
|
||||
// The attribute type name.
|
||||
absl::string_view name;
|
||||
// The value type of the attribute.
|
||||
enum class ValueType {
|
||||
kBool,
|
||||
kInt64,
|
||||
kDouble,
|
||||
};
|
||||
ValueType value_type;
|
||||
// The number of key elements.
|
||||
int num_key_elements;
|
||||
// The key symmetry.
|
||||
std::string symmetry;
|
||||
|
||||
// The names of the attributes of this type.
|
||||
std::vector<absl::string_view> attribute_names;
|
||||
};
|
||||
|
||||
// Representations for types.
|
||||
class Type {
|
||||
public:
|
||||
// A named type, e.g. "double".
|
||||
static std::shared_ptr<Type> Named(std::string name);
|
||||
// A pointer type.
|
||||
static std::shared_ptr<Type> Pointer(std::shared_ptr<Type> pointee);
|
||||
// A placeholder for the attribute value type, which is yet unknown when types
|
||||
// are defined. This gets replaced by `attr_value_type` when calling `Print`.
|
||||
static std::shared_ptr<Type> AttrValueType();
|
||||
|
||||
virtual ~Type();
|
||||
|
||||
// Prints the type to `out`, replacing `AttrValueType` placeholders with
|
||||
// `attr_value_type`.
|
||||
virtual void Print(absl::string_view attr_value_type,
|
||||
std::string* out) const = 0;
|
||||
};
|
||||
|
||||
// Information about how to codegen a given `AttrOp` in a given language.
|
||||
struct AttrOpFunctionInfo {
|
||||
// The return type of the function.
|
||||
std::shared_ptr<Type> return_type;
|
||||
|
||||
// If true, the function has an `AttrKey` parameter.
|
||||
bool has_key_parameter;
|
||||
|
||||
// Extra parameters (e.g. {"double", "value"} for `Set` operations).
|
||||
struct ExtraParameter {
|
||||
std::shared_ptr<Type> type;
|
||||
std::string name;
|
||||
};
|
||||
std::vector<ExtraParameter> extra_parameters;
|
||||
};
|
||||
|
||||
using AttrOpFunctionInfos = std::array<AttrOpFunctionInfo, kNumAttrOps>;
|
||||
|
||||
// The code generator interface.
|
||||
class CodeGenerator {
|
||||
public:
|
||||
explicit CodeGenerator(const AttrOpFunctionInfos* attr_op_function_infos)
|
||||
: attr_op_function_infos_(*attr_op_function_infos) {}
|
||||
|
||||
virtual ~CodeGenerator() = default;
|
||||
|
||||
// Generates code.
|
||||
std::string GenerateCode() const;
|
||||
|
||||
// Emits the header for the generated code.
|
||||
virtual void EmitHeader(std::string* out) const {}
|
||||
|
||||
// Emits code for elements.
|
||||
virtual void EmitElements(absl::Span<const absl::string_view> elements,
|
||||
std::string* out) const {}
|
||||
|
||||
// Emits code for attributes. By default, this iterates attributes and for
|
||||
// each attribute:
|
||||
// - calls `StartAttrType`, and
|
||||
// - calls `EmitAttrOp` for each operation.
|
||||
virtual void EmitAttributes(
|
||||
absl::Span<const CodegenAttrTypeDescriptor> descriptors,
|
||||
std::string* out) const;
|
||||
|
||||
// Called before generating code for an attribute type.
|
||||
virtual void StartAttrType(const CodegenAttrTypeDescriptor& descriptor,
|
||||
std::string* out) const {}
|
||||
|
||||
// Emits code for operation `info` for attribute described by `descriptor`.
|
||||
virtual void EmitAttrOp(absl::string_view op_name,
|
||||
const CodegenAttrTypeDescriptor& descriptor,
|
||||
const AttrOpFunctionInfo& info,
|
||||
std::string* out) const {}
|
||||
|
||||
private:
|
||||
void EmitAttrType(const CodegenAttrTypeDescriptor& descriptor,
|
||||
std::string* out) const;
|
||||
|
||||
const AttrOpFunctionInfos& attr_op_function_infos_;
|
||||
};
|
||||
|
||||
} // namespace operations_research::math_opt::codegen
|
||||
|
||||
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_H_
|
||||
245
ortools/math_opt/elemental/codegen/gen_c.cc
Normal file
245
ortools/math_opt/elemental/codegen/gen_c.cc
Normal file
@@ -0,0 +1,245 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/codegen/gen_c.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "ortools/math_opt/elemental/codegen/gen.h"
|
||||
|
||||
namespace operations_research::math_opt::codegen {
|
||||
|
||||
namespace {
|
||||
|
||||
// A helper to generate parameters to pass `n` key element indices, e.g:
|
||||
// ", int64_t key_0, int64_t key_1" (parameters)
|
||||
void AddKeyParams(int n, std::string* out) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
absl::StrAppend(out, ", int64_t key_", i);
|
||||
}
|
||||
}
|
||||
|
||||
// A helper to generate an AttrKey argument to pass `n` key element indices,
|
||||
// e.g: "AttrKey<2, NoSymmetry>(key_0, key_1)".
|
||||
void AddAttrKeyArg(int n, absl::string_view symmetry, std::string* out) {
|
||||
absl::StrAppendFormat(out, ", AttrKey<%i, %s>(", n, symmetry);
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (i != 0) {
|
||||
absl::StrAppend(out, ", ");
|
||||
}
|
||||
absl::StrAppend(out, "key_", i);
|
||||
}
|
||||
absl::StrAppend(out, ")");
|
||||
}
|
||||
|
||||
// Returns the C99 name for the given type.
|
||||
absl::string_view GetCTypeName(
|
||||
CodegenAttrTypeDescriptor::ValueType value_type) {
|
||||
switch (value_type) {
|
||||
case CodegenAttrTypeDescriptor::ValueType::kBool:
|
||||
return "_Bool";
|
||||
case CodegenAttrTypeDescriptor::ValueType::kInt64:
|
||||
return "int64_t";
|
||||
case CodegenAttrTypeDescriptor::ValueType::kDouble:
|
||||
return "double";
|
||||
}
|
||||
}
|
||||
|
||||
// Turns an element/attribute name (e.g. "some_name") into a camel case name
|
||||
// (e.g. "SomeName").
|
||||
std::string NameToCamelCase(absl::string_view attr_name) {
|
||||
std::string result;
|
||||
result.reserve(attr_name.size());
|
||||
CHECK(!attr_name.empty());
|
||||
const char first = attr_name[0];
|
||||
CHECK(absl::ascii_isalpha(first) && absl::ascii_islower(first))
|
||||
<< "invalid attr name: " << attr_name;
|
||||
result.push_back(absl::ascii_toupper(first));
|
||||
for (int i = 1; i < attr_name.size(); ++i) {
|
||||
const char c = attr_name[i];
|
||||
if (c == '_') {
|
||||
++i;
|
||||
CHECK(i < attr_name.size()) << "invalid attr name: " << attr_name;
|
||||
const char next_c = attr_name[i];
|
||||
CHECK(absl::ascii_isalnum(next_c)) << "invalid attr name: " << attr_name;
|
||||
result.push_back(absl::ascii_toupper(next_c));
|
||||
} else {
|
||||
CHECK(absl::ascii_isalnum(c)) << "invalid attr name: " << attr_name;
|
||||
CHECK(absl::ascii_islower(c)) << "invalid attr name: " << attr_name;
|
||||
result.push_back(c);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns the type of the C status.
|
||||
std::shared_ptr<Type> GetStatusType() { return Type::Named("int"); }
|
||||
|
||||
const AttrOpFunctionInfos* GetC99FunctionInfos() {
|
||||
static const auto* const kResult = new AttrOpFunctionInfos({
|
||||
// Get.
|
||||
AttrOpFunctionInfo{
|
||||
.return_type = GetStatusType(),
|
||||
.has_key_parameter = true,
|
||||
.extra_parameters = {{.type = Type::Pointer(Type::AttrValueType()),
|
||||
.name = "value"}}},
|
||||
// Set.
|
||||
AttrOpFunctionInfo{.return_type = GetStatusType(),
|
||||
.has_key_parameter = true,
|
||||
.extra_parameters = {{.type = Type::AttrValueType(),
|
||||
.name = "value"}}},
|
||||
// IsNonDefault.
|
||||
AttrOpFunctionInfo{
|
||||
.return_type = GetStatusType(),
|
||||
.has_key_parameter = true,
|
||||
.extra_parameters = {{.type = Type::Pointer(Type::Named("_Bool")),
|
||||
.name = "out_is_non_default"}}},
|
||||
// NumNonDefaults.
|
||||
AttrOpFunctionInfo{
|
||||
.return_type = GetStatusType(),
|
||||
.has_key_parameter = false,
|
||||
.extra_parameters = {{.type = Type::Pointer(Type::Named("int64_t")),
|
||||
.name = "out_num_non_defaults"}}},
|
||||
// GetNonDefaults.
|
||||
AttrOpFunctionInfo{
|
||||
.return_type = GetStatusType(),
|
||||
.has_key_parameter = false,
|
||||
.extra_parameters =
|
||||
{
|
||||
{.type = Type::Pointer(Type::Named("int64_t")),
|
||||
.name = "out_num_non_defaults"},
|
||||
{.type = Type::Pointer(Type::Pointer(Type::Named("int64_t"))),
|
||||
.name = "out_non_defaults"},
|
||||
}},
|
||||
});
|
||||
return kResult;
|
||||
}
|
||||
|
||||
class C99CodeGeneratorBase : public CodeGenerator {
|
||||
public:
|
||||
using CodeGenerator::CodeGenerator;
|
||||
|
||||
void EmitHeader(std::string* out) const final {
|
||||
absl::StrAppend(out, R"(
|
||||
/* DO NOT EDIT: This file is autogenerated. */
|
||||
#ifndef MATHOPTH_GENERATED
|
||||
#error "this file is intended to be included, do not use directly"
|
||||
#endif
|
||||
)");
|
||||
}
|
||||
};
|
||||
|
||||
// Emits the prototype for a function.
|
||||
void EmitPrototype(absl::string_view op_name,
|
||||
const CodegenAttrTypeDescriptor& descriptor,
|
||||
const AttrOpFunctionInfo& info, std::string* out) {
|
||||
absl::string_view attr_value_type = GetCTypeName(descriptor.value_type);
|
||||
// Adds the return type, function name and common parameters.
|
||||
info.return_type->Print(attr_value_type, out);
|
||||
absl::StrAppendFormat(out,
|
||||
" MathOpt%s%s(struct "
|
||||
"MathOptElemental* e, int attr",
|
||||
descriptor.name, op_name);
|
||||
// Add the key.
|
||||
if (info.has_key_parameter) {
|
||||
AddKeyParams(descriptor.num_key_elements, out);
|
||||
}
|
||||
// Add extra parameters.
|
||||
for (const auto& extra_param : info.extra_parameters) {
|
||||
absl::StrAppend(out, ", ");
|
||||
extra_param.type->Print(attr_value_type, out);
|
||||
absl::StrAppend(out, " ", extra_param.name);
|
||||
}
|
||||
// Finish prototype.
|
||||
absl::StrAppend(out, ")");
|
||||
}
|
||||
|
||||
class C99DeclarationsGenerator : public C99CodeGeneratorBase {
|
||||
public:
|
||||
C99DeclarationsGenerator() : C99CodeGeneratorBase(GetC99FunctionInfos()) {}
|
||||
|
||||
void EmitElements(absl::Span<const absl::string_view> elements,
|
||||
std::string* out) const override {
|
||||
// Generate an enum for the elements.
|
||||
absl::StrAppend(out,
|
||||
"// The type of an element in the model.\n"
|
||||
"enum MathOptElementType {\n");
|
||||
for (const auto& element_name : elements) {
|
||||
absl::StrAppendFormat(out, " kMathOpt%s,\n",
|
||||
NameToCamelCase(element_name));
|
||||
}
|
||||
absl::StrAppend(out, "};\n\n");
|
||||
}
|
||||
|
||||
void EmitAttrOp(absl::string_view op_name,
|
||||
const CodegenAttrTypeDescriptor& descriptor,
|
||||
const AttrOpFunctionInfo& info,
|
||||
std::string* out) const override {
|
||||
// Just emit a prototype.
|
||||
EmitPrototype(op_name, descriptor, info, out);
|
||||
absl::StrAppend(out, ";\n");
|
||||
}
|
||||
|
||||
void StartAttrType(const CodegenAttrTypeDescriptor& descriptor,
|
||||
std::string* out) const override {
|
||||
// Generate an enum for the attribute type.
|
||||
absl::StrAppendFormat(out, "typedef enum {\n");
|
||||
for (absl::string_view attr_name : descriptor.attribute_names) {
|
||||
absl::StrAppendFormat(out, " kMathOpt%s%s,\n", descriptor.name,
|
||||
NameToCamelCase(attr_name));
|
||||
}
|
||||
absl::StrAppendFormat(out, "} MathOpt%s;\n", descriptor.name);
|
||||
}
|
||||
};
|
||||
|
||||
class C99DefinitionsGenerator : public C99CodeGeneratorBase {
|
||||
public:
|
||||
C99DefinitionsGenerator() : C99CodeGeneratorBase(GetC99FunctionInfos()) {}
|
||||
|
||||
void EmitAttrOp(absl::string_view op_name,
|
||||
const CodegenAttrTypeDescriptor& descriptor,
|
||||
const AttrOpFunctionInfo& info,
|
||||
std::string* out) const override {
|
||||
EmitPrototype(op_name, descriptor, info, out);
|
||||
// Emit a call to the wrapper (e.g. `CAttrOp<Descriptor>::Op`).
|
||||
absl::StrAppendFormat(out, " {\n return CAttrOp<%s>::%s(e, attr",
|
||||
descriptor.name, op_name);
|
||||
// Add the key argument.
|
||||
if (info.has_key_parameter) {
|
||||
AddAttrKeyArg(descriptor.num_key_elements, descriptor.symmetry, out);
|
||||
}
|
||||
// Add extra parameter arguments.
|
||||
for (const auto& extra_param : info.extra_parameters) {
|
||||
absl::StrAppend(out, ", ", extra_param.name);
|
||||
}
|
||||
absl::StrAppend(out, ");\n}\n");
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<CodeGenerator> C99Declarations() {
|
||||
return std::make_unique<C99DeclarationsGenerator>();
|
||||
}
|
||||
|
||||
std::unique_ptr<CodeGenerator> C99Definitions() {
|
||||
return std::make_unique<C99DefinitionsGenerator>();
|
||||
}
|
||||
|
||||
} // namespace operations_research::math_opt::codegen
|
||||
32
ortools/math_opt/elemental/codegen/gen_c.h
Normal file
32
ortools/math_opt/elemental/codegen/gen_c.h
Normal file
@@ -0,0 +1,32 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// The C99 code generator.
|
||||
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_C_H_
|
||||
#define OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_C_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ortools/math_opt/elemental/codegen/gen.h"
|
||||
|
||||
namespace operations_research::math_opt::codegen {
|
||||
|
||||
// Returns a generator for C99 declarations.
|
||||
std::unique_ptr<CodeGenerator> C99Declarations();
|
||||
|
||||
// Returns a generator for C99 definitions.
|
||||
std::unique_ptr<CodeGenerator> C99Definitions();
|
||||
|
||||
} // namespace operations_research::math_opt::codegen
|
||||
|
||||
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_C_H_
|
||||
100
ortools/math_opt/elemental/codegen/gen_c_test.cc
Normal file
100
ortools/math_opt/elemental/codegen/gen_c_test.cc
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/codegen/gen_c.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/math_opt/elemental/codegen/gen.h"
|
||||
#include "ortools/math_opt/elemental/codegen/testing.h"
|
||||
|
||||
namespace operations_research::math_opt::codegen {
|
||||
namespace {
|
||||
|
||||
TEST(GenC99DeclarationsTest, EmitElements) {
|
||||
std::string code;
|
||||
C99Declarations()->EmitElements({"some_name", "other_name"}, &code);
|
||||
EXPECT_EQ(code,
|
||||
R"(// The type of an element in the model.
|
||||
enum MathOptElementType {
|
||||
kMathOptSomeName,
|
||||
kMathOptOtherName,
|
||||
};
|
||||
|
||||
)");
|
||||
}
|
||||
|
||||
TEST(GenC99DeclarationsTest, StartAttrType) {
|
||||
std::string code;
|
||||
C99Declarations()->StartAttrType(GetTestDescriptor(), &code);
|
||||
EXPECT_EQ(code,
|
||||
R"(typedef enum {
|
||||
kMathOptTestAttr2AName,
|
||||
kMathOptTestAttr2BName,
|
||||
} MathOptTestAttr2;
|
||||
)");
|
||||
}
|
||||
|
||||
TEST(GenC99DeclarationsTest, WithoutKey) {
|
||||
std::string code;
|
||||
C99Declarations()->EmitAttrOp("Op", GetTestDescriptor(),
|
||||
GetTestFunctionInfo(false), &code);
|
||||
EXPECT_EQ(
|
||||
code,
|
||||
R"(ReturnType MathOptTestAttr2Op(struct MathOptElemental* e, int attr, ExtraParam extra_param);
|
||||
)");
|
||||
}
|
||||
|
||||
TEST(GenC99DeclarationsTest, WithKey) {
|
||||
std::string code;
|
||||
C99Declarations()->EmitAttrOp("Op", GetTestDescriptor(),
|
||||
GetTestFunctionInfo(true), &code);
|
||||
EXPECT_EQ(
|
||||
code,
|
||||
R"(ReturnType MathOptTestAttr2Op(struct MathOptElemental* e, int attr, int64_t key_0, int64_t key_1, ExtraParam extra_param);
|
||||
)");
|
||||
}
|
||||
|
||||
TEST(GenC99DefinitionsTest, WithoutKey) {
|
||||
std::string code;
|
||||
C99Definitions()->EmitAttrOp("Op", GetTestDescriptor(),
|
||||
GetTestFunctionInfo(false), &code);
|
||||
EXPECT_EQ(
|
||||
code,
|
||||
R"(ReturnType MathOptTestAttr2Op(struct MathOptElemental* e, int attr, ExtraParam extra_param) {
|
||||
return CAttrOp<TestAttr2>::Op(e, attr, extra_param);
|
||||
}
|
||||
)");
|
||||
}
|
||||
|
||||
TEST(GenC99DefinitionsTest, WithKey) {
|
||||
std::string code;
|
||||
C99Definitions()->EmitAttrOp("Op", GetTestDescriptor(),
|
||||
GetTestFunctionInfo(true), &code);
|
||||
EXPECT_EQ(
|
||||
code,
|
||||
R"(ReturnType MathOptTestAttr2Op(struct MathOptElemental* e, int attr, int64_t key_0, int64_t key_1, ExtraParam extra_param) {
|
||||
return CAttrOp<TestAttr2>::Op(e, attr, AttrKey<2, SomeSymmetry>(key_0, key_1), extra_param);
|
||||
}
|
||||
)");
|
||||
}
|
||||
|
||||
TEST(GenC99DefinitionsTest, StartAttrType) {
|
||||
std::string code;
|
||||
C99Definitions()->StartAttrType(GetTestDescriptor(), &code);
|
||||
EXPECT_EQ(code, "");
|
||||
}
|
||||
} // namespace
|
||||
} // namespace operations_research::math_opt::codegen
|
||||
161
ortools/math_opt/elemental/codegen/gen_python.cc
Normal file
161
ortools/math_opt/elemental/codegen/gen_python.cc
Normal file
@@ -0,0 +1,161 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/codegen/gen_python.h"
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "ortools/math_opt/elemental/codegen/gen.h"
|
||||
|
||||
namespace operations_research::math_opt::codegen {
|
||||
|
||||
namespace {
|
||||
|
||||
const AttrOpFunctionInfos* GetPythonFunctionInfos() {
|
||||
// We're not generating functions for python, only enums.
|
||||
static const auto* const kResult = new AttrOpFunctionInfos();
|
||||
return kResult;
|
||||
}
|
||||
|
||||
// Emits a set of numbered python enumerators for the given range.
|
||||
void EmitEnumerators(const absl::Span<const absl::string_view> names,
|
||||
std::string* out) {
|
||||
for (int i = 0; i < names.size(); ++i) {
|
||||
absl::StrAppendFormat(out, " %s = %i\n", absl::AsciiStrToUpper(names[i]),
|
||||
i);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the python type for the given value type.
|
||||
absl::string_view GetAttrPyValueType(
|
||||
const CodegenAttrTypeDescriptor::ValueType& value_type) {
|
||||
switch (value_type) {
|
||||
case CodegenAttrTypeDescriptor::ValueType::kBool:
|
||||
return "bool";
|
||||
case CodegenAttrTypeDescriptor::ValueType::kInt64:
|
||||
return "int";
|
||||
case CodegenAttrTypeDescriptor::ValueType::kDouble:
|
||||
return "float";
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the python type for the given value type.
|
||||
absl::string_view GetAttrNumpyValueType(
|
||||
const CodegenAttrTypeDescriptor::ValueType& value_type) {
|
||||
switch (value_type) {
|
||||
case CodegenAttrTypeDescriptor::ValueType::kBool:
|
||||
return "np.bool_";
|
||||
case CodegenAttrTypeDescriptor::ValueType::kInt64:
|
||||
return "np.int64";
|
||||
case CodegenAttrTypeDescriptor::ValueType::kDouble:
|
||||
return "np.float64";
|
||||
}
|
||||
}
|
||||
|
||||
class PythonEnumsGenerator : public CodeGenerator {
|
||||
public:
|
||||
PythonEnumsGenerator() : CodeGenerator(GetPythonFunctionInfos()) {}
|
||||
|
||||
void EmitHeader(std::string* out) const override {
|
||||
absl::StrAppend(out, R"(
|
||||
'''DO NOT EDIT: This file is autogenerated.'''
|
||||
|
||||
import enum
|
||||
from typing import Generic, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
)");
|
||||
}
|
||||
|
||||
void EmitElements(absl::Span<const absl::string_view> elements,
|
||||
std::string* out) const override {
|
||||
// Generate an enum for the elements.
|
||||
absl::StrAppend(out, "class ElementType(enum.Enum):\n");
|
||||
EmitEnumerators(elements, out);
|
||||
absl::StrAppend(out, "\n");
|
||||
}
|
||||
|
||||
void EmitAttributes(absl::Span<const CodegenAttrTypeDescriptor> descriptors,
|
||||
std::string* out) const override {
|
||||
absl::StrAppend(out, "\n");
|
||||
|
||||
{
|
||||
// Collect the list of unique types:
|
||||
std::set<absl::string_view> value_types;
|
||||
for (const auto& descriptor : descriptors) {
|
||||
value_types.insert(GetAttrNumpyValueType(descriptor.value_type));
|
||||
}
|
||||
|
||||
// Emit `AttrValueType`, a type variable for all attribute value types.
|
||||
absl::StrAppend(out, "AttrValueType = TypeVar('AttrValueType', ",
|
||||
absl::StrJoin(value_types, ", "), ")\n");
|
||||
}
|
||||
absl::StrAppend(out, "\n");
|
||||
{
|
||||
std::set<absl::string_view> py_value_types;
|
||||
for (const auto& descriptor : descriptors) {
|
||||
py_value_types.insert(GetAttrPyValueType(descriptor.value_type));
|
||||
}
|
||||
absl::StrAppend(out, "AttrPyValueType = TypeVar('AttrPyValueType', ",
|
||||
absl::StrJoin(py_value_types, ", "), ")\n");
|
||||
}
|
||||
|
||||
// `Attr` is an attribute with any value type.
|
||||
absl::StrAppend(out, R"(
|
||||
class Attr(Generic[AttrValueType]):
|
||||
pass
|
||||
)");
|
||||
|
||||
// `PyAttr` is an attribute with any value type.
|
||||
absl::StrAppend(out, R"(
|
||||
class PyAttr(Generic[AttrPyValueType]):
|
||||
pass
|
||||
)");
|
||||
|
||||
// Generate an enum for the attribute type.
|
||||
for (const auto& descriptor : descriptors) {
|
||||
absl::StrAppendFormat(
|
||||
out, "\nclass %s(Attr[%s], PyAttr[%s], int, enum.Enum):\n",
|
||||
descriptor.name, GetAttrNumpyValueType(descriptor.value_type),
|
||||
GetAttrPyValueType(descriptor.value_type));
|
||||
EmitEnumerators(descriptor.attribute_names, out);
|
||||
absl::StrAppend(out, "\n");
|
||||
}
|
||||
|
||||
// Add a type alias for the union of all attribute types.
|
||||
absl::StrAppend(
|
||||
out, "AnyAttr = Union[",
|
||||
absl::StrJoin(
|
||||
descriptors, ", ",
|
||||
[](std::string* out, const CodegenAttrTypeDescriptor& descriptor) {
|
||||
absl::StrAppend(out, descriptor.name);
|
||||
}),
|
||||
"]\n");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<CodeGenerator> PythonEnums() {
|
||||
return std::make_unique<PythonEnumsGenerator>();
|
||||
}
|
||||
|
||||
} // namespace operations_research::math_opt::codegen
|
||||
30
ortools/math_opt/elemental/codegen/gen_python.h
Normal file
30
ortools/math_opt/elemental/codegen/gen_python.h
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// The python code generator.
|
||||
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_PYTHON_H_
|
||||
#define OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_PYTHON_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ortools/math_opt/elemental/codegen/gen.h"
|
||||
|
||||
namespace operations_research::math_opt::codegen {
|
||||
|
||||
// Returns a generator for python enums, independent of the actual
|
||||
// implementation. These are used by the protocol.
|
||||
std::unique_ptr<CodeGenerator> PythonEnums();
|
||||
|
||||
} // namespace operations_research::math_opt::codegen
|
||||
|
||||
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_GEN_PYTHON_H_
|
||||
60
ortools/math_opt/elemental/codegen/gen_python_test.cc
Normal file
60
ortools/math_opt/elemental/codegen/gen_python_test.cc
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/codegen/gen_python.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/math_opt/elemental/codegen/testing.h"
|
||||
|
||||
namespace operations_research::math_opt::codegen {
|
||||
namespace {
|
||||
|
||||
TEST(GenPythonEnumsTest, EmitElements) {
|
||||
std::string code;
|
||||
PythonEnums()->EmitElements({"some_name", "other_name"}, &code);
|
||||
EXPECT_EQ(code,
|
||||
R"(class ElementType(enum.Enum):
|
||||
SOME_NAME = 0
|
||||
OTHER_NAME = 1
|
||||
|
||||
)");
|
||||
}
|
||||
|
||||
TEST(GenPythonEnumsTest, EmitAttributes) {
|
||||
std::string code;
|
||||
PythonEnums()->EmitAttributes({GetTestDescriptor()}, &code);
|
||||
EXPECT_EQ(code,
|
||||
R"(
|
||||
AttrValueType = TypeVar('AttrValueType', np.float64)
|
||||
|
||||
AttrPyValueType = TypeVar('AttrPyValueType', float)
|
||||
|
||||
class Attr(Generic[AttrValueType]):
|
||||
pass
|
||||
|
||||
class PyAttr(Generic[AttrPyValueType]):
|
||||
pass
|
||||
|
||||
class TestAttr2(Attr[np.float64], PyAttr[float], int, enum.Enum):
|
||||
A_NAME = 0
|
||||
B_NAME = 1
|
||||
|
||||
AnyAttr = Union[TestAttr2]
|
||||
)");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace operations_research::math_opt::codegen
|
||||
96
ortools/math_opt/elemental/codegen/gen_test.cc
Normal file
96
ortools/math_opt/elemental/codegen/gen_test.cc
Normal file
@@ -0,0 +1,96 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/codegen/gen.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/base/gmock.h"
|
||||
|
||||
namespace operations_research::math_opt::codegen {
|
||||
namespace {
|
||||
using testing::HasSubstr;
|
||||
using testing::StartsWith;
|
||||
|
||||
const AttrOpFunctionInfos* GetFunctionInfos() {
|
||||
static const auto* const kResult = new AttrOpFunctionInfos({
|
||||
{.return_type = Type::Named("TypeForGet"),
|
||||
.has_key_parameter = false,
|
||||
.extra_parameters = {}},
|
||||
{.return_type = Type::Pointer(Type::AttrValueType()),
|
||||
.has_key_parameter = true,
|
||||
.extra_parameters = {}},
|
||||
{.return_type = Type::Named("T"),
|
||||
.has_key_parameter = false,
|
||||
.extra_parameters = {}},
|
||||
{.return_type = Type::Named("T"),
|
||||
.has_key_parameter = false,
|
||||
.extra_parameters = {}},
|
||||
{.return_type = Type::Named("T"),
|
||||
.has_key_parameter = false,
|
||||
.extra_parameters = {}},
|
||||
});
|
||||
return kResult;
|
||||
}
|
||||
|
||||
class TestCodeGenerator : public CodeGenerator {
|
||||
public:
|
||||
TestCodeGenerator() : CodeGenerator(GetFunctionInfos()) {}
|
||||
|
||||
void EmitHeader(std::string* out) const override {
|
||||
absl::StrAppend(out, "# DO NOT EDIT: Test\n");
|
||||
}
|
||||
|
||||
void EmitElements(absl::Span<const absl::string_view> elements,
|
||||
std::string* out) const override {
|
||||
absl::StrAppend(out, "Elements: ", absl::StrJoin(elements, ", "), "\n");
|
||||
}
|
||||
|
||||
void StartAttrType(const CodegenAttrTypeDescriptor&,
|
||||
std::string* out) const override {
|
||||
absl::StrAppend(out, "\n");
|
||||
}
|
||||
|
||||
void EmitAttrOp(absl::string_view op_name,
|
||||
const CodegenAttrTypeDescriptor& descriptor,
|
||||
const AttrOpFunctionInfo& info,
|
||||
std::string* out) const override {
|
||||
info.return_type->Print("fake_type", out);
|
||||
absl::StrAppend(out, " ", descriptor.name, op_name, "\n");
|
||||
}
|
||||
};
|
||||
|
||||
TEST(GenerateCodeTest, Attrs) {
|
||||
const std::string code = TestCodeGenerator().GenerateCode();
|
||||
EXPECT_THAT(code, StartsWith("# DO NOT EDIT: Test\n"));
|
||||
EXPECT_THAT(code, HasSubstr("Elements: variable, linear_constraint, "));
|
||||
EXPECT_THAT(code, HasSubstr("TypeForGet BoolAttr0Get\n"
|
||||
"fake_type* BoolAttr0Set\n"
|
||||
"T BoolAttr0IsNonDefault\n"
|
||||
"T BoolAttr0NumNonDefaults\n"
|
||||
"T BoolAttr0GetNonDefaults\n"));
|
||||
EXPECT_THAT(code, HasSubstr("TypeForGet DoubleAttr1Get\n"
|
||||
"fake_type* DoubleAttr1Set\n"
|
||||
"T DoubleAttr1IsNonDefault\n"
|
||||
"T DoubleAttr1NumNonDefaults\n"
|
||||
"T DoubleAttr1GetNonDefaults\n"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace operations_research::math_opt::codegen
|
||||
40
ortools/math_opt/elemental/codegen/testing.h
Normal file
40
ortools/math_opt/elemental/codegen/testing.h
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Test descriptors. This avoids depending on attributes from `attributes.h`
|
||||
// in the tests to decouple the codegen tests from `attributes.h`.
|
||||
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_TESTING_H_
|
||||
#define OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_TESTING_H_
|
||||
|
||||
#include "ortools/math_opt/elemental/codegen/gen.h"
|
||||
|
||||
namespace operations_research::math_opt::codegen {
|
||||
|
||||
inline CodegenAttrTypeDescriptor GetTestDescriptor() {
|
||||
return {.name = "TestAttr2",
|
||||
.value_type = CodegenAttrTypeDescriptor::ValueType::kDouble,
|
||||
.num_key_elements = 2,
|
||||
.symmetry = "SomeSymmetry",
|
||||
.attribute_names = {"a_name", "b_name"}};
|
||||
}
|
||||
|
||||
inline AttrOpFunctionInfo GetTestFunctionInfo(bool with_key_parameter) {
|
||||
return {.return_type = Type::Named("ReturnType"),
|
||||
.has_key_parameter = with_key_parameter,
|
||||
.extra_parameters = {
|
||||
{{.type = Type::Named("ExtraParam"), .name = "extra_param"}}}};
|
||||
}
|
||||
|
||||
} // namespace operations_research::math_opt::codegen
|
||||
|
||||
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_CODEGEN_TESTING_H_
|
||||
218
ortools/math_opt/elemental/derived_data.h
Normal file
218
ortools/math_opt/elemental/derived_data.h
Normal file
@@ -0,0 +1,218 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_DERIVED_DATA_H_
|
||||
#define OR_TOOLS_MATH_OPT_ELEMENTAL_DERIVED_DATA_H_
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "ortools/math_opt/elemental/arrays.h"
|
||||
#include "ortools/math_opt/elemental/attr_key.h"
|
||||
#include "ortools/math_opt/elemental/attributes.h"
|
||||
#include "ortools/math_opt/elemental/elements.h"
|
||||
#include "ortools/util/fp_roundtrip_conv.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
// A helper to manipulate the list of attributes.
|
||||
struct AllAttrs {
|
||||
// The number of available attribute types.
|
||||
static constexpr int kNumAttrTypes =
|
||||
std::tuple_size_v<AllAttrTypeDescriptors>;
|
||||
|
||||
// Returns the descriptor of the `i-th` attribute type in the list.
|
||||
template <int i>
|
||||
using TypeDescriptor = std::tuple_element_t<i, AllAttrTypeDescriptors>;
|
||||
|
||||
// Returns the `i-th` attribute type in the list.
|
||||
template <int i>
|
||||
using Type = typename TypeDescriptor<i>::AttrType;
|
||||
|
||||
// Returns the index of attribute type `AttrT`.
|
||||
// Fails to compile if `AttrT` is not an attribute.
|
||||
template <typename AttrT>
|
||||
static constexpr int GetIndex() {
|
||||
constexpr int index = GetIndexIfAttr<AttrT>();
|
||||
// This weird construct is to show `AttrT` explicitly instead of letting
|
||||
// the user fish it out of the stack trace when the static_assert fails.
|
||||
static_assert(
|
||||
std::is_const_v<std::conditional_t<(index >= 0), const AttrT, AttrT>>,
|
||||
"no such attribute");
|
||||
return index;
|
||||
}
|
||||
|
||||
// Applies `fn` on each value for each attribute type. `fn` must have a
|
||||
// overload set of `operator(AttrType)` that accepts a `AttrType` for
|
||||
// each attribute type.
|
||||
template <typename Fn>
|
||||
static void ForEachAttr(Fn&& fn) {
|
||||
ForEach(
|
||||
[&fn](const auto& descriptor) {
|
||||
for (auto attr : descriptor.Enumerate()) {
|
||||
fn(attr);
|
||||
}
|
||||
},
|
||||
AllAttrTypeDescriptors{});
|
||||
}
|
||||
};
|
||||
|
||||
// Returns the descriptor for attribute `AttrT`.
|
||||
template <typename AttrT>
|
||||
using AttrTypeDescriptorT =
|
||||
AllAttrs::TypeDescriptor<AllAttrs::GetIndex<AttrT>()>;
|
||||
|
||||
// Returns the default value for the attribute type `attr`.
|
||||
//
|
||||
// For example GetAttrDefaultValue<DoubleAttr2::kLinConCoef>() returns 0.0.
|
||||
template <auto attr>
|
||||
constexpr typename AttrTypeDescriptorT<decltype(attr)>::ValueType
|
||||
GetAttrDefaultValue() {
|
||||
return AttrTypeDescriptorT<decltype(attr)>::kAttrDescriptors[static_cast<int>(
|
||||
attr)]
|
||||
.default_value;
|
||||
}
|
||||
|
||||
// Returns the number of elements in a key for the attribute type `AttrType`.
|
||||
//
|
||||
// For example `GetAttrKeySize<DoubleAttr2>()` returns 2.
|
||||
template <typename AttrType>
|
||||
constexpr int GetAttrKeySize() {
|
||||
return AttrTypeDescriptorT<AttrType>::kNumKeyElements;
|
||||
}
|
||||
template <auto attr>
|
||||
constexpr int GetAttrKeySize() {
|
||||
return GetAttrKeySize<decltype(attr)>();
|
||||
}
|
||||
|
||||
// The type of the `AttrKey` for attribute type `AttrType`.
|
||||
template <typename AttrType>
|
||||
using AttrKeyFor = AttrKey<AttrTypeDescriptorT<AttrType>::kNumKeyElements,
|
||||
typename AttrTypeDescriptorT<AttrType>::Symmetry>;
|
||||
|
||||
// The value type for attribute type `AttrType`.
|
||||
template <typename AttrType>
|
||||
using ValueTypeFor = typename AttrTypeDescriptorT<AttrType>::ValueType;
|
||||
|
||||
// Returns the array of elements for the key for the attribute type `attr`.
|
||||
//
|
||||
// For example, GetElementTypes<DoubleAttr2>() returns the array
|
||||
// {ElementType::kLinearConstraint, ElementType::kVariable}.
|
||||
template <typename AttrType>
|
||||
constexpr std::array<ElementType, GetAttrKeySize<AttrType>()> GetElementTypes(
|
||||
const AttrType attr) {
|
||||
return AttrTypeDescriptorT<AttrType>::kAttrDescriptors[static_cast<int>(attr)]
|
||||
.key_types;
|
||||
}
|
||||
template <auto attr>
|
||||
constexpr std::array<ElementType, GetAttrKeySize<attr>()> GetElementTypes() {
|
||||
return GetElementTypes(attr);
|
||||
}
|
||||
|
||||
// After C++20, this can be replaced by a lambda. C++17 does not allow lambdas
|
||||
// in unevaluated contexts.
|
||||
template <template <int i> typename ValueType>
|
||||
struct EnumeratedTupleCpp17Helper {
|
||||
template <int... i>
|
||||
auto operator()() const {
|
||||
return std::make_tuple(ValueType<i>()...);
|
||||
}
|
||||
};
|
||||
|
||||
// A tuple of `ValueType<i>` for `i` in `0..n`.
|
||||
template <int n, template <int i> typename ValueType>
|
||||
using EnumeratedTuple =
|
||||
decltype(ApplyOnIndexRange<n>(EnumeratedTupleCpp17Helper<ValueType>{}));
|
||||
|
||||
// A map of attribute to `ValueType<i>`, where `i` is the index of the attribute
|
||||
// type.
|
||||
// See `AttrMapTest` for example usage.
|
||||
// NOTE: this is formally a map (it maps attributes to values), but internally
|
||||
// uses dense storage.
|
||||
template <template <int i> typename ValueType>
|
||||
class AttrMap {
|
||||
public:
|
||||
template <typename AttrT>
|
||||
ValueType<AllAttrs::GetIndex<AttrT>()>& operator[](AttrT a) {
|
||||
// TODO(b/365997645): post C++ 23, prefer `std::to_underlying(a)`.
|
||||
return std::get<AllAttrs::GetIndex<AttrT>()>(
|
||||
array_tuple_)[static_cast<int>(a)];
|
||||
}
|
||||
|
||||
template <typename AttrT>
|
||||
const ValueType<AllAttrs::GetIndex<AttrT>()>& operator[](AttrT a) const {
|
||||
// The `const_cast` is fine because non-const `operator[]` does not mutate
|
||||
// anything and we're casting the return type back to const.
|
||||
return (*const_cast<AttrMap*>(this))[a];
|
||||
}
|
||||
|
||||
// Applies `fn` on each value for each attribute type. `fn` must have a
|
||||
// overload set of `operator()` that accepts a `ValueType<i>` for `i` in
|
||||
// `0..AllAttrs::kSize`.
|
||||
// This cannot be an iterator because value types are not homogeneous.
|
||||
template <typename Fn>
|
||||
void ForEachAttrValue(Fn&& fn) {
|
||||
ForEach(
|
||||
[&fn](auto& array) {
|
||||
for (auto& value : array) {
|
||||
fn(value);
|
||||
}
|
||||
},
|
||||
array_tuple_);
|
||||
}
|
||||
|
||||
private:
|
||||
template <int i>
|
||||
using ArrayType =
|
||||
std::array<ValueType<i>, AllAttrs::TypeDescriptor<i>::NumAttrs()>;
|
||||
EnumeratedTuple<AllAttrs::kNumAttrTypes, ArrayType> array_tuple_;
|
||||
};
|
||||
|
||||
// Calls `fn<attr>()`.
|
||||
template <typename AttrType, typename Fn, int n = 0>
|
||||
decltype(auto) CallForAttr(AttrType attr, Fn&& fn) {
|
||||
using Descriptor = AttrTypeDescriptorT<AttrType>;
|
||||
if constexpr (n < Descriptor::NumAttrs()) {
|
||||
constexpr AttrType a = static_cast<AttrType>(n);
|
||||
if (a == attr) {
|
||||
return fn.template operator()<a>();
|
||||
}
|
||||
return CallForAttr<AttrType, Fn, n + 1>(attr, std::forward<Fn>(fn));
|
||||
} else {
|
||||
LOG(FATAL) << "impossible";
|
||||
return decltype(fn.template operator()<AttrType{}>()) {};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ValueType>
|
||||
std::string FormatAttrValue(const ValueType v) {
|
||||
return absl::StrCat(v);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string FormatAttrValue(const double v) {
|
||||
return RoundTripDoubleFormat::ToString(v);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline std::string FormatAttrValue(const bool v) {
|
||||
return v ? "true" : "false";
|
||||
}
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_DERIVED_DATA_H_
|
||||
172
ortools/math_opt/elemental/derived_data_test.cc
Normal file
172
ortools/math_opt/elemental/derived_data_test.cc
Normal file
@@ -0,0 +1,172 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/derived_data.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/base/gmock.h"
|
||||
#include "ortools/math_opt/elemental/arrays.h"
|
||||
#include "ortools/math_opt/elemental/attr_key.h"
|
||||
#include "ortools/math_opt/elemental/attributes.h"
|
||||
#include "ortools/math_opt/elemental/elements.h"
|
||||
#include "ortools/math_opt/elemental/symmetry.h"
|
||||
#include "ortools/math_opt/testing/stream.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
using ::testing::IsSupersetOf;
|
||||
|
||||
TEST(GetAttrDefaultValueTest, HasRightDefault) {
|
||||
EXPECT_EQ(GetAttrDefaultValue<DoubleAttr0::kObjOffset>(), 0.0);
|
||||
EXPECT_EQ(GetAttrDefaultValue<BoolAttr1::kVarInteger>(), false);
|
||||
EXPECT_EQ(GetAttrDefaultValue<DoubleAttr1::kVarUb>(),
|
||||
std::numeric_limits<double>::infinity());
|
||||
EXPECT_EQ(GetAttrDefaultValue<DoubleAttr2::kLinConCoef>(), 0.0);
|
||||
}
|
||||
|
||||
TEST(AttrKeyForTest, Works) {
|
||||
EXPECT_TRUE((std::is_same_v<AttrKeyFor<BoolAttr0>, AttrKey<0>>));
|
||||
EXPECT_TRUE((std::is_same_v<AttrKeyFor<DoubleAttr0>, AttrKey<0>>));
|
||||
EXPECT_TRUE((std::is_same_v<AttrKeyFor<DoubleAttr1>, AttrKey<1>>));
|
||||
EXPECT_TRUE((std::is_same_v<AttrKeyFor<DoubleAttr2>, AttrKey<2>>));
|
||||
EXPECT_TRUE((std::is_same_v<AttrKeyFor<SymmetricDoubleAttr2>,
|
||||
AttrKey<2, ElementSymmetry<0, 1>>>));
|
||||
}
|
||||
|
||||
TEST(ValueTypeForTest, Works) {
|
||||
EXPECT_TRUE((std::is_same_v<ValueTypeFor<BoolAttr0>, bool>));
|
||||
EXPECT_TRUE((std::is_same_v<ValueTypeFor<DoubleAttr0>, double>));
|
||||
EXPECT_TRUE((std::is_same_v<ValueTypeFor<DoubleAttr1>, double>));
|
||||
EXPECT_TRUE((std::is_same_v<ValueTypeFor<DoubleAttr2>, double>));
|
||||
}
|
||||
|
||||
TEST(GetAttrKeySizeTest, IsRightSize) {
|
||||
EXPECT_EQ(GetAttrKeySize<DoubleAttr0>(), 0);
|
||||
EXPECT_EQ(GetAttrKeySize<BoolAttr1>(), 1);
|
||||
EXPECT_EQ(GetAttrKeySize<DoubleAttr2>(), 2);
|
||||
EXPECT_EQ(GetAttrKeySize<DoubleAttr0::kObjOffset>(), 0);
|
||||
EXPECT_EQ(GetAttrKeySize<BoolAttr1::kVarInteger>(), 1);
|
||||
EXPECT_EQ(GetAttrKeySize<DoubleAttr2::kLinConCoef>(), 2);
|
||||
}
|
||||
|
||||
TEST(GetElementTypesTest, Attr1HasElement) {
|
||||
EXPECT_EQ(GetElementTypes<BoolAttr1::kVarInteger>()[0],
|
||||
ElementType::kVariable);
|
||||
}
|
||||
|
||||
TEST(GetElementTypesTest, Attr2HasElements) {
|
||||
EXPECT_EQ(GetElementTypes<DoubleAttr2::kLinConCoef>()[0],
|
||||
ElementType::kLinearConstraint);
|
||||
EXPECT_EQ(GetElementTypes<DoubleAttr2::kLinConCoef>()[1],
|
||||
ElementType::kVariable);
|
||||
}
|
||||
|
||||
TEST(AllAttrsTest, Indexing) {
|
||||
ForEachIndex<AllAttrs::kNumAttrTypes>(
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
[]<int i>() { EXPECT_EQ((AllAttrs::GetIndex<AllAttrs::Type<i>>()), i); });
|
||||
}
|
||||
|
||||
TEST(AllAttrsTest, ForEachAttribute) {
|
||||
std::vector<std::string> invoked;
|
||||
AllAttrs::ForEachAttr(
|
||||
[&invoked](auto attr) { invoked.push_back(StreamToString(attr)); });
|
||||
EXPECT_THAT(invoked, IsSupersetOf({"objective_offset", "maximize",
|
||||
"variable_integer", "variable_lower_bound",
|
||||
"linear_constraint_coefficient"}));
|
||||
}
|
||||
|
||||
template <int i>
|
||||
struct Value {
|
||||
Value() = default;
|
||||
explicit Value(int v) : value(v) {}
|
||||
|
||||
int value = i;
|
||||
};
|
||||
|
||||
TEST(AttrMapTest, GetSet) {
|
||||
AttrMap<Value> attr_map;
|
||||
|
||||
constexpr int kBoolAttr0Index = AllAttrs::GetIndex<BoolAttr0>();
|
||||
constexpr int kBoolAttr1Index = AllAttrs::GetIndex<BoolAttr1>();
|
||||
constexpr int kDoubleAttr1Index = AllAttrs::GetIndex<DoubleAttr1>();
|
||||
constexpr int kDoubleAttr2Index = AllAttrs::GetIndex<DoubleAttr2>();
|
||||
|
||||
// Default initialization.
|
||||
EXPECT_EQ(attr_map[BoolAttr0::kMaximize].value, kBoolAttr0Index);
|
||||
EXPECT_EQ(attr_map[BoolAttr1::kVarInteger].value, kBoolAttr1Index);
|
||||
EXPECT_EQ(attr_map[DoubleAttr1::kVarLb].value, kDoubleAttr1Index);
|
||||
EXPECT_EQ(attr_map[DoubleAttr1::kVarUb].value, kDoubleAttr1Index);
|
||||
EXPECT_EQ(attr_map[DoubleAttr2::kLinConCoef].value, kDoubleAttr2Index);
|
||||
|
||||
// Mutation (typed).
|
||||
attr_map[BoolAttr0::kMaximize] = Value<kBoolAttr0Index>(42);
|
||||
attr_map[BoolAttr1::kVarInteger] = Value<kBoolAttr1Index>(43);
|
||||
attr_map[DoubleAttr1::kVarLb] = Value<kDoubleAttr1Index>(44);
|
||||
attr_map[DoubleAttr1::kVarUb] = Value<kDoubleAttr1Index>(45);
|
||||
attr_map[DoubleAttr2::kLinConCoef] = Value<kDoubleAttr2Index>(46);
|
||||
EXPECT_EQ(attr_map[BoolAttr0::kMaximize].value, 42);
|
||||
EXPECT_EQ(attr_map[BoolAttr1::kVarInteger].value, 43);
|
||||
EXPECT_EQ(attr_map[DoubleAttr1::kVarLb].value, 44);
|
||||
EXPECT_EQ(attr_map[DoubleAttr1::kVarUb].value, 45);
|
||||
EXPECT_EQ(attr_map[DoubleAttr2::kLinConCoef].value, 46);
|
||||
}
|
||||
|
||||
TEST(AttrMapTest, Iteration) {
|
||||
AttrMap<Value> attr_map;
|
||||
|
||||
// Collect all values in the default-initialized map.s
|
||||
std::vector<int> values;
|
||||
attr_map.ForEachAttrValue(
|
||||
[&values](const auto& v) { values.emplace_back(v.value); });
|
||||
|
||||
// We should have `NumAttrs()` values `i` per attribute.
|
||||
std::vector<int> expected_values;
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
ForEachIndex<AllAttrs::kNumAttrTypes>([&expected_values]<int i>() {
|
||||
for (int k = 0; k < AllAttrs::TypeDescriptor<i>::NumAttrs(); ++k) {
|
||||
expected_values.push_back(i);
|
||||
}
|
||||
});
|
||||
EXPECT_THAT(values, ElementsAreArray(expected_values));
|
||||
}
|
||||
|
||||
TEST(CallForAttrTest, Works) {
|
||||
EXPECT_EQ(CallForAttr(DoubleAttr1::kVarUb,
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
[]<DoubleAttr1 a>() { return static_cast<int>(a); }),
|
||||
static_cast<int>(DoubleAttr1::kVarUb));
|
||||
}
|
||||
|
||||
TEST(FormatAttrValueTest, FormatsBool) {
|
||||
EXPECT_EQ(FormatAttrValue(true), "true");
|
||||
}
|
||||
|
||||
TEST(FormatAttrValueTest, FormatsInt64) {
|
||||
EXPECT_EQ(FormatAttrValue(int64_t{12}), "12");
|
||||
}
|
||||
|
||||
TEST(FormatAttrValueTest, FormatsDouble) {
|
||||
EXPECT_EQ(FormatAttrValue(4.2), "4.2");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace operations_research::math_opt
|
||||
30
ortools/math_opt/elemental/diff.cc
Normal file
30
ortools/math_opt/elemental/diff.cc
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/diff.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
|
||||
#include "ortools/math_opt/elemental/elements.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
void Diff::Advance(const std::array<int64_t, kNumElements>& checkpoints) {
|
||||
for (int i = 0; i < kNumElements; ++i) {
|
||||
element_diffs_[i].Advance(checkpoints[i]);
|
||||
}
|
||||
attr_diffs_.ForEachAttrValue([](auto& diff) { diff.Advance(); });
|
||||
}
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
157
ortools/math_opt/elemental/diff.h
Normal file
157
ortools/math_opt/elemental/diff.h
Normal file
@@ -0,0 +1,157 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_DIFF_H_
|
||||
#define OR_TOOLS_MATH_OPT_ELEMENTAL_DIFF_H_
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "ortools/math_opt/elemental/attr_diff.h"
|
||||
#include "ortools/math_opt/elemental/attr_key.h"
|
||||
#include "ortools/math_opt/elemental/attributes.h"
|
||||
#include "ortools/math_opt/elemental/derived_data.h"
|
||||
#include "ortools/math_opt/elemental/element_diff.h"
|
||||
#include "ortools/math_opt/elemental/elements.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
// Stores the modifications to the model since the previous checkpoint (or since
|
||||
// creation of the Diff if Advance() has never been called).
|
||||
//
|
||||
// Only the following modifications are tracked explicitly:
|
||||
// * elements before the checkpoint
|
||||
// * attributes with all elements in the key before the checkpoint
|
||||
// as all changes involving an element after the checkpoint are implied to be in
|
||||
// the difference.
|
||||
//
|
||||
// Note: users of ElementalImpl can only access a const Diff.
|
||||
//
|
||||
// When a element is deleted from the model, the creator of the Diff is
|
||||
// responsible both for:
|
||||
// 1. Calling Diff::DeleteElement() on the element,
|
||||
// 2. For each Attr with a key element on the element type, calling
|
||||
// Diff::EraseKeysForAttr()
|
||||
// We cannot do this all at once for the user, as we do not have access to the
|
||||
// relevant related keys in steps 3/4 above.
|
||||
class Diff {
|
||||
public:
|
||||
Diff() = default;
|
||||
|
||||
// Discards all tracked modifications, and in the future, track only
|
||||
// modifications where all elements are at most checkpoint.
|
||||
//
|
||||
// Generally, checkpoints should be component-wise non-decreasing with each
|
||||
// invocation of Advance(), but this is not checked here.
|
||||
void Advance(const std::array<int64_t, kNumElements>& checkpoints);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Elements
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// The current checkpoint for the element type `e`.
|
||||
//
|
||||
// This equals the next element id for the element type `e` when Advance() was
|
||||
// last called (or at creation time if advance was never called).
|
||||
int64_t checkpoint(const ElementType e) const {
|
||||
return element_diff(e).checkpoint();
|
||||
}
|
||||
|
||||
// The elements of element type `e` that have been deleted since the last call
|
||||
// to Advance() with id less than the checkpoint.
|
||||
const absl::flat_hash_set<int64_t>& deleted_elements(
|
||||
const ElementType e) const {
|
||||
return element_diff(e).deleted();
|
||||
}
|
||||
|
||||
// Tracks the element `id` of element type `e` as deleted if it is less than
|
||||
// the checkpoint.
|
||||
//
|
||||
// WARNING: this does not update any related attributes.
|
||||
void DeleteElement(const ElementType e, int64_t id) {
|
||||
mutable_element_diff(e).Delete(id);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Attributes
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Returns the keys with all elements below the checkpoint where the Attr2 `a`
|
||||
// was modified since the last call to Advance().
|
||||
template <typename AttrType>
|
||||
const AttrKeyHashSet<AttrKeyFor<AttrType>>& modified_keys(
|
||||
const AttrType a) const {
|
||||
return attr_diffs_[a].modified_keys();
|
||||
}
|
||||
|
||||
// Marks that the attribute `a` has been modified for `attr_key`.
|
||||
template <typename AttrType>
|
||||
void SetModified(const AttrType a, const AttrKeyFor<AttrType> attr_key) {
|
||||
if (IsBeforeCheckpoint(a, attr_key)) {
|
||||
attr_diffs_[a].SetModified(attr_key);
|
||||
}
|
||||
}
|
||||
|
||||
// Discard any tracked modifications for attribute `a` on `keys`.
|
||||
//
|
||||
// Typically invoke when the element with id `keys[i]` is deleted from the
|
||||
// model, and where `keys` are the keys `k` for all elements `e` where the `e`
|
||||
// has a non-default value for `a`.
|
||||
template <typename AttrType>
|
||||
void EraseKeysForAttr(const AttrType a,
|
||||
absl::Span<const AttrKeyFor<AttrType>> keys) {
|
||||
if (!attr_diffs_[a].has_modified_keys()) {
|
||||
return;
|
||||
}
|
||||
for (const auto& attr_key : keys) {
|
||||
if (IsBeforeCheckpoint(a, attr_key)) {
|
||||
attr_diffs_[a].Erase(attr_key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const ElementDiff& element_diff(const ElementType e) const {
|
||||
// TODO(b/365997645): post C++ 23, prefer std::to_underlying(e).
|
||||
return element_diffs_[static_cast<int>(e)];
|
||||
}
|
||||
|
||||
ElementDiff& mutable_element_diff(const ElementType e) {
|
||||
return element_diffs_[static_cast<int>(e)];
|
||||
}
|
||||
|
||||
// Returns true if all elements if `key` are before their respective
|
||||
// checkpoints.
|
||||
template <typename AttrType>
|
||||
bool IsBeforeCheckpoint(const AttrType a, const AttrKeyFor<AttrType> key) {
|
||||
for (int i = 0; i < GetAttrKeySize<AttrType>(); ++i) {
|
||||
if (key[i] >= element_diff(GetElementTypes(a)[i]).checkpoint()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::array<ElementDiff, kNumElements> element_diffs_;
|
||||
|
||||
template <int i>
|
||||
using DiffForAttr = AttrDiff<AllAttrs::TypeDescriptor<i>::kNumKeyElements,
|
||||
typename AllAttrs::TypeDescriptor<i>::Symmetry>;
|
||||
AttrMap<DiffForAttr> attr_diffs_;
|
||||
};
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_DIFF_H_
|
||||
289
ortools/math_opt/elemental/diff_test.cc
Normal file
289
ortools/math_opt/elemental/diff_test.cc
Normal file
@@ -0,0 +1,289 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/diff.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/base/gmock.h"
|
||||
#include "ortools/math_opt/elemental/attr_key.h"
|
||||
#include "ortools/math_opt/elemental/attributes.h"
|
||||
#include "ortools/math_opt/elemental/elements.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
namespace {
|
||||
|
||||
using ::testing::IsEmpty;
|
||||
using ::testing::UnorderedElementsAre;
|
||||
|
||||
std::array<int64_t, kNumElements> MakeUniformCheckpoint(int64_t id) {
|
||||
std::array<int64_t, kNumElements> result;
|
||||
result.fill(id);
|
||||
return result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Element tests
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(DiffTest, InitDiffElementsEmpty) {
|
||||
Diff diff;
|
||||
EXPECT_EQ(diff.checkpoint(ElementType::kVariable), 0);
|
||||
EXPECT_THAT(diff.deleted_elements(ElementType::kVariable), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(DiffTest, DeleteElementAfterCheckpointNoEffect) {
|
||||
Diff diff;
|
||||
diff.DeleteElement(ElementType::kVariable, 2);
|
||||
EXPECT_THAT(diff.deleted_elements(ElementType::kVariable), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(DiffTest, DeletesTrackedBelowCheckpoint) {
|
||||
Diff diff;
|
||||
diff.Advance(MakeUniformCheckpoint(5));
|
||||
EXPECT_EQ(diff.checkpoint(ElementType::kVariable), 5);
|
||||
diff.DeleteElement(ElementType::kVariable, 3);
|
||||
diff.DeleteElement(ElementType::kVariable, 1);
|
||||
diff.DeleteElement(ElementType::kVariable, 8);
|
||||
diff.DeleteElement(ElementType::kVariable, 5);
|
||||
EXPECT_THAT(diff.deleted_elements(ElementType::kVariable),
|
||||
UnorderedElementsAre(3, 1));
|
||||
}
|
||||
|
||||
TEST(DiffTest, AdvanceClearsDeletedElements) {
|
||||
Diff diff;
|
||||
diff.Advance(MakeUniformCheckpoint(5));
|
||||
diff.DeleteElement(ElementType::kVariable, 3);
|
||||
EXPECT_THAT(diff.deleted_elements(ElementType::kVariable),
|
||||
UnorderedElementsAre(3));
|
||||
diff.Advance(MakeUniformCheckpoint(5));
|
||||
EXPECT_THAT(diff.deleted_elements(ElementType::kVariable), IsEmpty());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Attr0 Tests
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(DiffTest, InitBoolAttr0Empty) {
|
||||
Diff diff;
|
||||
EXPECT_THAT(diff.modified_keys(BoolAttr0::kMaximize), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(DiffTest, SetBoolAttr0ModifiedIsModified) {
|
||||
Diff diff;
|
||||
diff.SetModified(BoolAttr0::kMaximize, AttrKey());
|
||||
EXPECT_THAT(diff.modified_keys(BoolAttr0::kMaximize),
|
||||
UnorderedElementsAre(AttrKey()));
|
||||
}
|
||||
|
||||
TEST(DiffTest, BoolAttr0AdvanceClearsModification) {
|
||||
Diff diff;
|
||||
diff.SetModified(BoolAttr0::kMaximize, AttrKey());
|
||||
diff.Advance(MakeUniformCheckpoint(0));
|
||||
EXPECT_THAT(diff.modified_keys(BoolAttr0::kMaximize), IsEmpty());
|
||||
}
|
||||
|
||||
// Repeat all tests for double, they are short enough.
|
||||
|
||||
TEST(DiffTest, InitDoubleAttr0Empty) {
|
||||
Diff diff;
|
||||
EXPECT_THAT(diff.modified_keys(DoubleAttr0::kObjOffset), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(DiffTest, SetDoubleAttr0ModifiedIsModified) {
|
||||
Diff diff;
|
||||
diff.SetModified(DoubleAttr0::kObjOffset, AttrKey());
|
||||
EXPECT_THAT(diff.modified_keys(DoubleAttr0::kObjOffset),
|
||||
UnorderedElementsAre(AttrKey()));
|
||||
}
|
||||
|
||||
TEST(DiffTest, DoubleAttr0AdvanceClearsModification) {
|
||||
Diff diff;
|
||||
diff.SetModified(DoubleAttr0::kObjOffset, AttrKey());
|
||||
diff.Advance(MakeUniformCheckpoint(0));
|
||||
EXPECT_THAT(diff.modified_keys(DoubleAttr0::kObjOffset), IsEmpty());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Attr1 Tests
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(DiffTest, InitBoolAttr1Empty) {
|
||||
Diff diff;
|
||||
EXPECT_THAT(diff.modified_keys(BoolAttr1::kVarInteger), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(DiffTest, SetBoolAttr1ModifiedBeforeCheckpointIsModified) {
|
||||
Diff diff;
|
||||
diff.Advance(MakeUniformCheckpoint(1));
|
||||
diff.SetModified(BoolAttr1::kVarInteger, AttrKey(0));
|
||||
EXPECT_THAT(diff.modified_keys(BoolAttr1::kVarInteger),
|
||||
UnorderedElementsAre(AttrKey(0)));
|
||||
}
|
||||
|
||||
TEST(DiffTest, SetBoolAttr1ModifiedAtleastCheckpointNotTracked) {
|
||||
Diff diff;
|
||||
diff.SetModified(BoolAttr1::kVarInteger, AttrKey(0));
|
||||
EXPECT_THAT(diff.modified_keys(BoolAttr1::kVarInteger), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(DiffTest, BoolAttr1AdvanceClearsModification) {
|
||||
Diff diff;
|
||||
diff.Advance(MakeUniformCheckpoint(1));
|
||||
diff.SetModified(BoolAttr1::kVarInteger, AttrKey(0));
|
||||
diff.Advance(MakeUniformCheckpoint(1));
|
||||
EXPECT_THAT(diff.modified_keys(BoolAttr1::kVarInteger), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(DiffTest, EraseElementForBoolAttr1IsNoLongerTracked) {
|
||||
Diff diff;
|
||||
diff.Advance(MakeUniformCheckpoint(1));
|
||||
diff.SetModified(BoolAttr1::kVarInteger, AttrKey(0));
|
||||
|
||||
EXPECT_THAT(diff.modified_keys(BoolAttr1::kVarInteger),
|
||||
UnorderedElementsAre(AttrKey(0)));
|
||||
|
||||
diff.EraseKeysForAttr(BoolAttr1::kVarInteger, {AttrKey(0)});
|
||||
|
||||
EXPECT_THAT(diff.modified_keys(BoolAttr1::kVarInteger), IsEmpty());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Repeat all tests for DoubleAttr1, not ideal
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(DiffTest, InitDoubleAttr1Empty) {
|
||||
Diff diff;
|
||||
EXPECT_THAT(diff.modified_keys(DoubleAttr1::kLinConUb), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(DiffTest, SetDoubleAttr1ModifiedBeforeCheckpointIsModified) {
|
||||
Diff diff;
|
||||
diff.Advance(MakeUniformCheckpoint(1));
|
||||
diff.SetModified(DoubleAttr1::kLinConUb, AttrKey(0));
|
||||
EXPECT_THAT(diff.modified_keys(DoubleAttr1::kLinConUb),
|
||||
UnorderedElementsAre(AttrKey(0)));
|
||||
}
|
||||
|
||||
TEST(DiffTest, SetDoubleAttr1ModifiedAtleastCheckpointNotTracked) {
|
||||
Diff diff;
|
||||
diff.SetModified(DoubleAttr1::kLinConUb, AttrKey(0));
|
||||
EXPECT_THAT(diff.modified_keys(DoubleAttr1::kLinConUb), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(DiffTest, DoubleAttr1AdvanceClearsModification) {
|
||||
Diff diff;
|
||||
diff.Advance(MakeUniformCheckpoint(1));
|
||||
diff.SetModified(DoubleAttr1::kLinConUb, AttrKey(0));
|
||||
diff.Advance(MakeUniformCheckpoint(1));
|
||||
EXPECT_THAT(diff.modified_keys(DoubleAttr1::kLinConUb), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(DiffTest, EraseElementForDoubleAttr1IsNoLongerTracked) {
|
||||
Diff diff;
|
||||
diff.Advance(MakeUniformCheckpoint(1));
|
||||
diff.SetModified(DoubleAttr1::kLinConUb, AttrKey(0));
|
||||
|
||||
EXPECT_THAT(diff.modified_keys(DoubleAttr1::kLinConUb),
|
||||
UnorderedElementsAre(AttrKey(0)));
|
||||
|
||||
diff.EraseKeysForAttr(DoubleAttr1::kLinConUb, {AttrKey(0)});
|
||||
|
||||
EXPECT_THAT(diff.modified_keys(DoubleAttr1::kLinConUb), IsEmpty());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Attr2 Tests
|
||||
//
|
||||
// UpdateAttr2OnFirstElementDeleted and UpdateAttr2OnSecondElementDeleted are a
|
||||
// bit under tested.
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(DiffTest, InitDoubleAttr2Empty) {
|
||||
Diff diff;
|
||||
EXPECT_THAT(diff.modified_keys(DoubleAttr2::kLinConCoef), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(DiffTest, SetDoubleAttr2ModifiedBothKeysBeforeCheckpointIsModified) {
|
||||
Diff diff;
|
||||
diff.Advance(MakeUniformCheckpoint(2));
|
||||
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(1, 0));
|
||||
EXPECT_THAT(diff.modified_keys(DoubleAttr2::kLinConCoef),
|
||||
UnorderedElementsAre(AttrKey(1, 0)));
|
||||
}
|
||||
|
||||
TEST(DiffTest, SetDoubleAttr2ModifiedFirstKeyAtleastCheckpointNotTracked) {
|
||||
Diff diff;
|
||||
diff.Advance(MakeUniformCheckpoint(2));
|
||||
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(4, 0));
|
||||
EXPECT_THAT(diff.modified_keys(DoubleAttr2::kLinConCoef), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(DiffTest, SetDoubleAttr2ModifiedSecondtKeyAtleastCheckpointNotTracked) {
|
||||
Diff diff;
|
||||
diff.Advance(MakeUniformCheckpoint(2));
|
||||
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(0, 4));
|
||||
EXPECT_THAT(diff.modified_keys(DoubleAttr2::kLinConCoef), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(DiffTest, DoubleAttr2AdvanceClearsModification) {
|
||||
Diff diff;
|
||||
diff.Advance(MakeUniformCheckpoint(1));
|
||||
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(0, 0));
|
||||
diff.Advance(MakeUniformCheckpoint(1));
|
||||
EXPECT_THAT(diff.modified_keys(DoubleAttr2::kLinConCoef), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(DiffTest, EraseFirstElementForDoubleAttr2IsNoLongerTracked) {
|
||||
Diff diff;
|
||||
diff.Advance(MakeUniformCheckpoint(5));
|
||||
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(1, 0));
|
||||
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(1, 2));
|
||||
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(1, 4));
|
||||
|
||||
EXPECT_THAT(
|
||||
diff.modified_keys(DoubleAttr2::kLinConCoef),
|
||||
UnorderedElementsAre(AttrKey(1, 0), AttrKey(1, 2), AttrKey(1, 4)));
|
||||
|
||||
diff.EraseKeysForAttr(
|
||||
DoubleAttr2::kLinConCoef,
|
||||
absl::Span<const AttrKey<2>>({AttrKey(1, 0), AttrKey(1, 4)}));
|
||||
|
||||
EXPECT_THAT(diff.modified_keys(DoubleAttr2::kLinConCoef),
|
||||
UnorderedElementsAre(AttrKey(1, 2)));
|
||||
}
|
||||
|
||||
TEST(DiffTest, EraseSecondElementForDoubleAttr2IsNoLongerTracked) {
|
||||
Diff diff;
|
||||
diff.Advance(MakeUniformCheckpoint(5));
|
||||
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(0, 1));
|
||||
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(2, 1));
|
||||
diff.SetModified(DoubleAttr2::kLinConCoef, AttrKey(4, 1));
|
||||
|
||||
EXPECT_THAT(
|
||||
diff.modified_keys(DoubleAttr2::kLinConCoef),
|
||||
UnorderedElementsAre(AttrKey(0, 1), AttrKey(2, 1), AttrKey(4, 1)));
|
||||
|
||||
diff.EraseKeysForAttr(
|
||||
DoubleAttr2::kLinConCoef,
|
||||
absl::Span<const AttrKey<2>>({AttrKey(0, 1), AttrKey(4, 1)}));
|
||||
|
||||
EXPECT_THAT(diff.modified_keys(DoubleAttr2::kLinConCoef),
|
||||
UnorderedElementsAre(AttrKey(2, 1)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace operations_research::math_opt
|
||||
64
ortools/math_opt/elemental/element_diff.h
Normal file
64
ortools/math_opt/elemental/element_diff.h
Normal file
@@ -0,0 +1,64 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENT_DIFF_H_
|
||||
#define OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENT_DIFF_H_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
// Tracks the ids of the elements in a model that:
|
||||
// 1. Are less than the checkpoint for this element.
|
||||
// 2. Have been deleted since the most recent time the checkpoint was advanced
|
||||
// (or creation of the ElementDiff if advance was never called).
|
||||
//
|
||||
// Generally:
|
||||
// * Element ids should be nonnegative.
|
||||
// * Each element should be deleted at most once.
|
||||
// * Sequential calls to Advance() should be called on non-decreasing
|
||||
// checkpoints.
|
||||
// However, these are enforced higher up the stack, not in this class.
|
||||
class ElementDiff {
|
||||
public:
|
||||
// The current checkpoint for this element, generally the next_id for this
|
||||
// element when Advance() was last called (or at creation time if advance was
|
||||
// never called).
|
||||
int64_t checkpoint() const { return checkpoint_; }
|
||||
|
||||
// The elements that have been deleted before the checkpoint.
|
||||
const absl::flat_hash_set<int64_t>& deleted() const { return deleted_; }
|
||||
|
||||
// Tracks the element `id` as deleted if it is less than the checkpoint.
|
||||
void Delete(int64_t id) {
|
||||
if (id < checkpoint_) {
|
||||
deleted_.insert(id);
|
||||
}
|
||||
}
|
||||
|
||||
// Update the checkpoint and clears all tracked deletions.
|
||||
void Advance(int64_t checkpoint) {
|
||||
checkpoint_ = checkpoint;
|
||||
deleted_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t checkpoint_ = 0;
|
||||
absl::flat_hash_set<int64_t> deleted_;
|
||||
};
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENT_DIFF_H_
|
||||
59
ortools/math_opt/elemental/element_diff_test.cc
Normal file
59
ortools/math_opt/elemental/element_diff_test.cc
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/element_diff.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/base/gmock.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
namespace {
|
||||
|
||||
using ::testing::IsEmpty;
|
||||
using ::testing::UnorderedElementsAre;
|
||||
|
||||
TEST(ElementDiffTest, EmptyDiff) {
|
||||
ElementDiff diff;
|
||||
EXPECT_EQ(diff.checkpoint(), 0);
|
||||
EXPECT_THAT(diff.deleted(), IsEmpty());
|
||||
|
||||
diff.Delete(4);
|
||||
EXPECT_THAT(diff.deleted(), IsEmpty());
|
||||
}
|
||||
|
||||
TEST(ElementDiffTest, AddsPointsBelowCheckpoint) {
|
||||
ElementDiff diff;
|
||||
diff.Advance(4);
|
||||
EXPECT_EQ(diff.checkpoint(), 4);
|
||||
|
||||
diff.Delete(1);
|
||||
diff.Delete(3);
|
||||
diff.Delete(4);
|
||||
diff.Delete(5);
|
||||
EXPECT_THAT(diff.deleted(), UnorderedElementsAre(1, 3));
|
||||
}
|
||||
|
||||
TEST(ElementDiffTest, AdvanceClearsDiff) {
|
||||
ElementDiff diff;
|
||||
diff.Advance(4);
|
||||
|
||||
diff.Delete(1);
|
||||
diff.Delete(3);
|
||||
|
||||
diff.Advance(5);
|
||||
EXPECT_THAT(diff.deleted(), IsEmpty());
|
||||
EXPECT_EQ(diff.checkpoint(), 5);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace operations_research::math_opt
|
||||
81
ortools/math_opt/elemental/element_ref_tracker.h
Normal file
81
ortools/math_opt/elemental/element_ref_tracker.h
Normal file
@@ -0,0 +1,81 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENT_REF_TRACKER_H_
|
||||
#define OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENT_REF_TRACKER_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "ortools/base/map_util.h"
|
||||
#include "ortools/math_opt/elemental/attr_key.h"
|
||||
#include "ortools/math_opt/elemental/elements.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
template <typename ValueType, int n, typename Symmetry>
|
||||
class ElementRefTracker;
|
||||
|
||||
// The `ElementRefTracker` for a given attribute type descriptor.
|
||||
template <typename Descriptor>
|
||||
using ElementRefTrackerForAttrTypeDescriptor =
|
||||
ElementRefTracker<typename Descriptor::ValueType,
|
||||
Descriptor::kNumKeyElements,
|
||||
typename Descriptor::Symmetry>;
|
||||
|
||||
// A tracker for values that reference elements.
|
||||
//
|
||||
// This is used to delete attributes when the elements they reference are
|
||||
// deleted.
|
||||
template <ElementType element_type, int n, typename Symmetry>
|
||||
class ElementRefTracker<ElementId<element_type>, n, Symmetry> {
|
||||
public:
|
||||
using ElemId = ElementId<element_type>;
|
||||
using Key = AttrKey<n, Symmetry>;
|
||||
|
||||
// Returns the set of keys that reference element `id`.
|
||||
const absl::flat_hash_set<Key>& GetKeysReferencing(const ElemId id) const {
|
||||
return gtl::FindWithDefault(element_id_to_attr_keys_, id);
|
||||
}
|
||||
|
||||
// Tracks the fact that attribute with key `key` has a value that references
|
||||
// elements.
|
||||
void Track(const Key key, const ElemId id) {
|
||||
element_id_to_attr_keys_[id].insert(key);
|
||||
}
|
||||
|
||||
void Untrack(const Key key, const ElemId id) {
|
||||
const auto it = element_id_to_attr_keys_.find(id);
|
||||
if (it != element_id_to_attr_keys_.end()) {
|
||||
it->second.erase(key);
|
||||
if (it->second.empty()) {
|
||||
element_id_to_attr_keys_.erase(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Clear() { element_id_to_attr_keys_.clear(); }
|
||||
|
||||
private:
|
||||
// A map of element id to the list of attribute keys that have a non-default
|
||||
// value for this element.
|
||||
absl::flat_hash_map<ElemId, absl::flat_hash_set<Key>>
|
||||
element_id_to_attr_keys_;
|
||||
};
|
||||
|
||||
// Other value types do not need tracking.
|
||||
template <typename ValueType, int n, typename Symmetry>
|
||||
class ElementRefTracker {};
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENT_REF_TRACKER_H_
|
||||
54
ortools/math_opt/elemental/element_ref_tracker_test.cc
Normal file
54
ortools/math_opt/elemental/element_ref_tracker_test.cc
Normal file
@@ -0,0 +1,54 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/element_ref_tracker.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/base/gmock.h"
|
||||
#include "ortools/math_opt/elemental/attr_key.h"
|
||||
#include "ortools/math_opt/elemental/elements.h"
|
||||
#include "ortools/math_opt/elemental/symmetry.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
namespace {
|
||||
using ::testing::IsEmpty;
|
||||
using ::testing::UnorderedElementsAre;
|
||||
|
||||
TEST(ElementRefTrackerTest, ElementValued) {
|
||||
const VariableId x(0);
|
||||
const VariableId y(1);
|
||||
|
||||
ElementRefTracker<VariableId, 1, NoSymmetry> tracker;
|
||||
tracker.Track(AttrKey(1), x);
|
||||
tracker.Track(AttrKey(2), x);
|
||||
tracker.Track(AttrKey(3), y);
|
||||
EXPECT_THAT(tracker.GetKeysReferencing(x),
|
||||
UnorderedElementsAre(AttrKey(1), AttrKey(2)));
|
||||
EXPECT_THAT(tracker.GetKeysReferencing(y), UnorderedElementsAre(AttrKey(3)));
|
||||
|
||||
tracker.Untrack(AttrKey(1), x);
|
||||
EXPECT_THAT(tracker.GetKeysReferencing(x), UnorderedElementsAre(AttrKey(2)));
|
||||
EXPECT_THAT(tracker.GetKeysReferencing(y), UnorderedElementsAre(AttrKey(3)));
|
||||
|
||||
tracker.Untrack(AttrKey(2), x);
|
||||
EXPECT_THAT(tracker.GetKeysReferencing(x), IsEmpty());
|
||||
EXPECT_THAT(tracker.GetKeysReferencing(y), UnorderedElementsAre(AttrKey(3)));
|
||||
|
||||
tracker.Untrack(AttrKey(3), y);
|
||||
EXPECT_THAT(tracker.GetKeysReferencing(x), IsEmpty());
|
||||
EXPECT_THAT(tracker.GetKeysReferencing(y), IsEmpty());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace operations_research::math_opt
|
||||
56
ortools/math_opt/elemental/element_storage.cc
Normal file
56
ortools/math_opt/elemental/element_storage.cc
Normal file
@@ -0,0 +1,56 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/element_storage.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
namespace detail {
|
||||
|
||||
std::vector<int64_t> SparseElementStorage::AllIds() const {
|
||||
std::vector<int64_t> result;
|
||||
result.reserve(elements_.size());
|
||||
for (const auto& [id, unused] : elements_) {
|
||||
result.push_back(id);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int64_t> DenseElementStorage::AllIds() const {
|
||||
std::vector<int64_t> result(elements_.size());
|
||||
absl::c_iota(result, 0);
|
||||
return result;
|
||||
}
|
||||
|
||||
SparseElementStorage::SparseElementStorage(DenseElementStorage&& dense)
|
||||
: next_id_(dense.size()) {
|
||||
elements_.reserve(next_id_);
|
||||
for (int i = 0; i < next_id_; ++i) {
|
||||
elements_.emplace(i, std::move(dense.elements_[i]));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
std::vector<int64_t> ElementStorage::AllIds() const {
|
||||
return std::visit([](const auto& impl) { return impl.AllIds(); }, impl_);
|
||||
}
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
192
ortools/math_opt/elemental/element_storage.h
Normal file
192
ortools/math_opt/elemental/element_storage.h
Normal file
@@ -0,0 +1,192 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENT_STORAGE_H_
|
||||
#define OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENT_STORAGE_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "ortools/base/status_builder.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// A dense element storage, for use when no elements have been erased.
|
||||
// Same API as `ElementStorage`, but no deletion.
|
||||
// TODO(b/369972336): We should stay in dense mode if we have a small percentage
|
||||
// of deletions.
|
||||
class DenseElementStorage {
|
||||
public:
|
||||
int64_t Add(const absl::string_view name) {
|
||||
const int64_t id = elements_.size();
|
||||
elements_.emplace_back(name);
|
||||
return id;
|
||||
}
|
||||
|
||||
bool exists(const int64_t id) const {
|
||||
return 0 <= id && id < elements_.size();
|
||||
}
|
||||
|
||||
absl::StatusOr<absl::string_view> GetName(const int64_t id) const {
|
||||
if (exists(id)) {
|
||||
return elements_[id];
|
||||
}
|
||||
return util::InvalidArgumentErrorBuilder() << "no element with id " << id;
|
||||
}
|
||||
|
||||
int64_t next_id() const { return size(); }
|
||||
|
||||
std::vector<int64_t> AllIds() const;
|
||||
|
||||
int64_t size() const { return elements_.size(); }
|
||||
|
||||
private:
|
||||
friend class SparseElementStorage;
|
||||
std::vector<std::string> elements_;
|
||||
};
|
||||
|
||||
// A sparse element storage, which supports deletion.
|
||||
class SparseElementStorage {
|
||||
public:
|
||||
explicit SparseElementStorage(DenseElementStorage&& dense);
|
||||
|
||||
int64_t Add(const absl::string_view name) {
|
||||
const int64_t id = next_id_;
|
||||
elements_.try_emplace(id, name);
|
||||
++next_id_;
|
||||
return id;
|
||||
}
|
||||
|
||||
bool Erase(int64_t id) { return elements_.erase(id) > 0; }
|
||||
|
||||
bool exists(int64_t id) const { return elements_.contains(id); }
|
||||
|
||||
absl::StatusOr<absl::string_view> GetName(int64_t id) const {
|
||||
if (const auto it = elements_.find(id); it != elements_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return util::InvalidArgumentErrorBuilder() << "no element with id " << id;
|
||||
}
|
||||
|
||||
int64_t next_id() const { return next_id_; }
|
||||
|
||||
std::vector<int64_t> AllIds() const;
|
||||
|
||||
int64_t size() const { return elements_.size(); }
|
||||
|
||||
void ensure_next_id_at_least(int64_t id) {
|
||||
next_id_ = std::max(next_id_, id);
|
||||
}
|
||||
|
||||
private:
|
||||
absl::flat_hash_map<int64_t, std::string> elements_;
|
||||
int64_t next_id_ = 0;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
class ElementStorage {
|
||||
// Functions with deduced return must be defined before they are used.
|
||||
private:
|
||||
// std::visit is very slow, see
|
||||
// 5253596299885805568
|
||||
//
|
||||
// This function is static, taking Self as template argument, to avoid
|
||||
// having const and non-const versions. Post C++ 23, prefer:
|
||||
// https://en.cppreference.com/w/cpp/language/member_functions#Explicit_object_member_functions
|
||||
template <typename Self, typename Fn>
|
||||
static auto Visit(Self& self, Fn fn) {
|
||||
if (std::holds_alternative<detail::DenseElementStorage>(self.impl_)) {
|
||||
return fn(std::get<detail::DenseElementStorage>(self.impl_));
|
||||
} else {
|
||||
return fn(std::get<detail::SparseElementStorage>(self.impl_));
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// We start with a dense storage, which is more efficient, and switch to a
|
||||
// sparse storage when an element is erased.
|
||||
ElementStorage() : impl_(detail::DenseElementStorage()) {}
|
||||
|
||||
// Creates a new element and returns its id.
|
||||
int64_t Add(const absl::string_view name) {
|
||||
return Visit(*this, [name](auto& impl) { return impl.Add(name); });
|
||||
}
|
||||
|
||||
// Deletes an element by id, returning true on success and false if no element
|
||||
// was deleted (it was already deleted or the id was not from any existing
|
||||
// element).
|
||||
bool Erase(const int64_t id) { return AsSparse().Erase(id); }
|
||||
|
||||
// Returns true an element with this id was created and not yet erased.
|
||||
bool Exists(const int64_t id) const {
|
||||
return Visit(*this, [id](auto& impl) { return impl.exists(id); });
|
||||
}
|
||||
|
||||
// Returns the name of this element, or CHECK fails if no element with this id
|
||||
// exists.
|
||||
absl::StatusOr<absl::string_view> GetName(const int64_t id) const {
|
||||
return Visit(*this, [id](auto& impl) { return impl.GetName(id); });
|
||||
}
|
||||
|
||||
// Returns the id that will be used for the next element added.
|
||||
//
|
||||
// NOTE: when no elements have been erased, this equals size().
|
||||
int64_t next_id() const {
|
||||
return Visit(*this, [](auto& impl) { return impl.next_id(); });
|
||||
}
|
||||
|
||||
// Returns all ids of all elements in the model in an unsorted,
|
||||
// non-deterministic order.
|
||||
std::vector<int64_t> AllIds() const;
|
||||
|
||||
// Returns the number of elements added and not erased.
|
||||
int64_t size() const {
|
||||
return Visit(*this, [](auto& impl) { return impl.size(); });
|
||||
}
|
||||
|
||||
// Increases next_id() to `id` if it is currently less than `id`.
|
||||
//
|
||||
// Useful for reading a model back from proto, most users should not need to
|
||||
// call this directly.
|
||||
void EnsureNextIdAtLeast(const int64_t id) {
|
||||
if (id > next_id()) {
|
||||
AsSparse().ensure_next_id_at_least(id);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
detail::SparseElementStorage& AsSparse() {
|
||||
if (auto* sparse = std::get_if<detail::SparseElementStorage>(&impl_)) {
|
||||
return *sparse;
|
||||
}
|
||||
impl_ = detail::SparseElementStorage(
|
||||
std::move(std::get<detail::DenseElementStorage>(impl_)));
|
||||
return std::get<detail::SparseElementStorage>(impl_);
|
||||
}
|
||||
|
||||
std::variant<detail::DenseElementStorage, detail::SparseElementStorage> impl_;
|
||||
};
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENT_STORAGE_H_
|
||||
170
ortools/math_opt/elemental/element_storage_test.cc
Normal file
170
ortools/math_opt/elemental/element_storage_test.cc
Normal file
@@ -0,0 +1,170 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/element_storage.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "benchmark/benchmark.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ortools/base/gmock.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
namespace {
|
||||
|
||||
using ::testing::IsEmpty;
|
||||
using ::testing::UnorderedElementsAre;
|
||||
using ::testing::status::IsOkAndHolds;
|
||||
using ::testing::status::StatusIs;
|
||||
|
||||
TEST(ElementStorageTest, EmptyModelGetters) {
|
||||
ElementStorage elements;
|
||||
EXPECT_EQ(elements.size(), 0);
|
||||
EXPECT_EQ(elements.next_id(), 0);
|
||||
EXPECT_THAT(elements.AllIds(), IsEmpty());
|
||||
EXPECT_FALSE(elements.Exists(0));
|
||||
EXPECT_FALSE(elements.Exists(1));
|
||||
EXPECT_FALSE(elements.Exists(-1));
|
||||
}
|
||||
|
||||
TEST(ElementStorageTest, AddElement) {
|
||||
ElementStorage elements;
|
||||
EXPECT_EQ(elements.Add("x"), 0);
|
||||
EXPECT_EQ(elements.Add("y"), 1);
|
||||
EXPECT_EQ(elements.size(), 2);
|
||||
EXPECT_EQ(elements.next_id(), 2);
|
||||
EXPECT_THAT(elements.AllIds(), UnorderedElementsAre(0, 1));
|
||||
EXPECT_FALSE(elements.Exists(2));
|
||||
EXPECT_FALSE(elements.Exists(-1));
|
||||
|
||||
ASSERT_TRUE(elements.Exists(0));
|
||||
ASSERT_TRUE(elements.Exists(1));
|
||||
EXPECT_THAT(elements.GetName(0), IsOkAndHolds("x"));
|
||||
EXPECT_THAT(elements.GetName(1), IsOkAndHolds("y"));
|
||||
EXPECT_THAT(elements.GetName(-42),
|
||||
StatusIs(absl::StatusCode::kInvalidArgument));
|
||||
}
|
||||
|
||||
TEST(ElementStorageTest, AddElementDuplicateName) {
|
||||
ElementStorage elements;
|
||||
EXPECT_EQ(elements.Add("xyx"), 0);
|
||||
EXPECT_EQ(elements.Add("xyx"), 1);
|
||||
EXPECT_EQ(elements.size(), 2);
|
||||
EXPECT_EQ(elements.next_id(), 2);
|
||||
EXPECT_THAT(elements.AllIds(), UnorderedElementsAre(0, 1));
|
||||
|
||||
ASSERT_TRUE(elements.Exists(0));
|
||||
ASSERT_TRUE(elements.Exists(1));
|
||||
EXPECT_THAT(elements.GetName(0), IsOkAndHolds("xyx"));
|
||||
EXPECT_THAT(elements.GetName(1), IsOkAndHolds("xyx"));
|
||||
}
|
||||
|
||||
TEST(ElementStorageTest, DeleteElement) {
|
||||
ElementStorage elements;
|
||||
const int64_t x = elements.Add("x");
|
||||
const int64_t y = elements.Add("y");
|
||||
const int64_t z = elements.Add("z");
|
||||
|
||||
EXPECT_TRUE(elements.Erase(x));
|
||||
EXPECT_TRUE(elements.Erase(z));
|
||||
EXPECT_EQ(elements.size(), 1);
|
||||
EXPECT_EQ(elements.next_id(), 3);
|
||||
EXPECT_THAT(elements.AllIds(), UnorderedElementsAre(y));
|
||||
EXPECT_FALSE(elements.Exists(x));
|
||||
EXPECT_FALSE(elements.Exists(z));
|
||||
EXPECT_FALSE(elements.Exists(3));
|
||||
EXPECT_FALSE(elements.Exists(-1));
|
||||
|
||||
ASSERT_TRUE(elements.Exists(y));
|
||||
EXPECT_THAT(elements.GetName(y), IsOkAndHolds("y"));
|
||||
EXPECT_THAT(elements.GetName(x),
|
||||
StatusIs(absl::StatusCode::kInvalidArgument));
|
||||
}
|
||||
|
||||
TEST(ElementStorageTest, DeleteInvalidIdNoEffect) {
|
||||
ElementStorage elements;
|
||||
const int64_t x = elements.Add("x");
|
||||
const int64_t y = elements.Add("y");
|
||||
const int64_t z = elements.Add("z");
|
||||
|
||||
EXPECT_FALSE(elements.Erase(-2));
|
||||
EXPECT_FALSE(elements.Erase(5));
|
||||
|
||||
EXPECT_EQ(elements.size(), 3);
|
||||
EXPECT_EQ(elements.next_id(), 3);
|
||||
EXPECT_THAT(elements.AllIds(), UnorderedElementsAre(x, y, z));
|
||||
}
|
||||
|
||||
TEST(ElementStorageTest, DeleteTwiceNoAdditionalEffect) {
|
||||
ElementStorage elements;
|
||||
const int64_t x = elements.Add("x");
|
||||
const int64_t y = elements.Add("y");
|
||||
const int64_t z = elements.Add("z");
|
||||
|
||||
EXPECT_TRUE(elements.Erase(y));
|
||||
EXPECT_FALSE(elements.Erase(y));
|
||||
|
||||
EXPECT_EQ(elements.size(), 2);
|
||||
EXPECT_EQ(elements.next_id(), 3);
|
||||
EXPECT_THAT(elements.AllIds(), UnorderedElementsAre(x, z));
|
||||
}
|
||||
|
||||
TEST(ElementStorageTest, EnsureNextIdAtLeastIncreasesNextId) {
|
||||
ElementStorage elements;
|
||||
elements.EnsureNextIdAtLeast(5);
|
||||
EXPECT_EQ(elements.next_id(), 5);
|
||||
EXPECT_EQ(elements.Add("x"), 5);
|
||||
EXPECT_THAT(elements.AllIds(), UnorderedElementsAre(5));
|
||||
}
|
||||
|
||||
TEST(ElementStorageTest, EnsureNextIdAtLeastNoEffect) {
|
||||
ElementStorage elements;
|
||||
elements.Add("x");
|
||||
elements.EnsureNextIdAtLeast(0);
|
||||
EXPECT_EQ(elements.next_id(), 1);
|
||||
EXPECT_EQ(elements.Add("y"), 1);
|
||||
EXPECT_THAT(elements.AllIds(), UnorderedElementsAre(0, 1));
|
||||
}
|
||||
|
||||
void BM_AddElements(benchmark::State& state) {
|
||||
const int n = state.range(0);
|
||||
for (auto s : state) {
|
||||
ElementStorage storage;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
storage.Add("");
|
||||
}
|
||||
benchmark::DoNotOptimize(storage);
|
||||
}
|
||||
}
|
||||
|
||||
BENCHMARK(BM_AddElements)->Arg(100)->Arg(10000);
|
||||
|
||||
void BM_Exists(benchmark::State& state) {
|
||||
const int n = state.range(0);
|
||||
ElementStorage storage;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
storage.Add("");
|
||||
}
|
||||
for (auto s : state) {
|
||||
for (int i = 0; i < 2 * n; ++i) {
|
||||
bool e = storage.Exists(i);
|
||||
benchmark::DoNotOptimize(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BENCHMARK(BM_Exists)->Arg(100)->Arg(10000);
|
||||
|
||||
} // namespace
|
||||
} // namespace operations_research::math_opt
|
||||
169
ortools/math_opt/elemental/elemental.cc
Normal file
169
ortools/math_opt/elemental/elemental.cc
Normal file
@@ -0,0 +1,169 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ortools/math_opt/elemental/elemental.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "ortools/math_opt/elemental/arrays.h"
|
||||
#include "ortools/math_opt/elemental/attr_key.h"
|
||||
#include "ortools/math_opt/elemental/derived_data.h"
|
||||
#include "ortools/math_opt/elemental/diff.h"
|
||||
#include "ortools/math_opt/elemental/elements.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
Elemental::Elemental(std::string model_name, std::string primary_objective_name)
|
||||
: model_name_(std::move(model_name)),
|
||||
primary_objective_name_(std::move(primary_objective_name)) {
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
ForEachIndex<AllAttrs::kNumAttrTypes>([this]<int attr_type_index>() {
|
||||
using Descriptor = AllAttrs::TypeDescriptor<attr_type_index>;
|
||||
for (const auto a : Descriptor::Enumerate()) {
|
||||
const int index = static_cast<int>(a);
|
||||
attrs_[a] = StorageForAttrType<attr_type_index>(
|
||||
Descriptor::kAttrDescriptors[index].default_value);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
std::optional<Elemental::DiffHandle> Elemental::GetDiffHandle(
|
||||
const int64_t id) const {
|
||||
if (diffs_->Get(id) == nullptr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return DiffHandle(id, diffs_.get());
|
||||
}
|
||||
|
||||
Elemental::DiffHandle Elemental::AddDiff() {
|
||||
auto diff = std::make_unique<Diff>();
|
||||
diff->Advance(CurrentCheckpoint());
|
||||
const int64_t diff_id = diffs_->Insert(std::move(diff));
|
||||
return DiffHandle(diff_id, diffs_.get());
|
||||
}
|
||||
|
||||
bool Elemental::DeleteDiff(const DiffHandle diff) {
|
||||
if (&diff.diffs_ != diffs_.get()) {
|
||||
return false;
|
||||
}
|
||||
return diffs_->Erase(diff.diff_id_);
|
||||
}
|
||||
|
||||
bool Elemental::Advance(const DiffHandle diff) {
|
||||
if (diffs_.get() != &diff.diffs_) {
|
||||
return false;
|
||||
}
|
||||
Diff* d = diffs_->UpdateAndGet(diff.diff_id_);
|
||||
if (d == nullptr) {
|
||||
return false;
|
||||
}
|
||||
d->Advance(CurrentCheckpoint());
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Elemental::DeleteElementUntyped(const ElementType e, int64_t id) {
|
||||
if (!mutable_element_storage(e).Erase(id)) {
|
||||
return false;
|
||||
}
|
||||
for (auto& [unused, diff] : diffs_->UpdateAndGetAll()) {
|
||||
diff->DeleteElement(e, id);
|
||||
}
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
AllAttrs::ForEachAttr([this, e, id]<typename AttrType>(AttrType a) {
|
||||
ForEachIndex<GetAttrKeySize<AttrType>()>(
|
||||
// NOLINTNEXTLINE(clang-diagnostic-pre-c++20-compat)
|
||||
[&]<int i>() {
|
||||
if (GetElementTypes(a)[i] == e) {
|
||||
UpdateAttrOnElementDeleted<AttrType, i>(a, id);
|
||||
}
|
||||
});
|
||||
// If `a` is element-valued, we need to remove all keys that refer to the
|
||||
// deleted element.
|
||||
if constexpr (is_element_id_v<ValueTypeFor<AttrType>>) {
|
||||
if (e == ValueTypeFor<AttrType>::type()) {
|
||||
const auto keys = element_ref_trackers_[a].GetKeysReferencing(
|
||||
ValueTypeFor<AttrType>(id));
|
||||
for (const auto key : keys) {
|
||||
SetAttr(a, key, ValueTypeFor<AttrType>());
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename AttrType, int i>
|
||||
void Elemental::UpdateAttrOnElementDeleted(const AttrType a, const int64_t id) {
|
||||
auto& attr_storage = attrs_[a];
|
||||
// We consider the case of n == 1 separately so that we can ensure that
|
||||
// for any attribute with a key size of one, the AttrDiff has no deleted
|
||||
// elements. (If we did not specialize this code, we would need to check for
|
||||
// deleted elements when building our ModelUpdateProto, see
|
||||
// README.md#checkpoints-and-model-updates for an explanation.)
|
||||
if constexpr (GetAttrKeySize<AttrType>() == 1) {
|
||||
for (auto& [unused, diff] : diffs_->UpdateAndGetAll()) {
|
||||
diff->EraseKeysForAttr(a, {AttrKey(id)});
|
||||
}
|
||||
attr_storage.Erase(AttrKey(id));
|
||||
} else {
|
||||
// NOTE: We explicitly spell out the type here, so that if `Slice` ever
|
||||
// returns a reference in `attr_storage` instead of a copy, we are forced
|
||||
// to update this code to make a copy of the slice (otherwise the slice
|
||||
// would be invalidated by calls to `Erase()` below).
|
||||
const std::vector<AttrKeyFor<AttrType>> keys =
|
||||
attr_storage.template Slice<i>(id);
|
||||
for (auto& [unused, diff] : diffs_->UpdateAndGetAll()) {
|
||||
diff->EraseKeysForAttr(a, keys);
|
||||
}
|
||||
for (const auto& key : keys) {
|
||||
attr_storage.Erase(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::array<int64_t, kNumElements> Elemental::CurrentCheckpoint() const {
|
||||
std::array<int64_t, kNumElements> result;
|
||||
for (int i = 0; i < kNumElements; ++i) {
|
||||
result[i] = elements_[i].next_id();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Elemental Elemental::Clone(
|
||||
std::optional<absl::string_view> new_model_name) const {
|
||||
Elemental result(std::string(new_model_name.value_or(model_name_)),
|
||||
primary_objective_name_);
|
||||
result.elements_ = elements_;
|
||||
result.attrs_ = attrs_;
|
||||
return result;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& ostr, const Elemental& elemental) {
|
||||
ostr << elemental.DebugString();
|
||||
return ostr;
|
||||
}
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
545
ortools/math_opt/elemental/elemental.h
Normal file
545
ortools/math_opt/elemental/elemental.h
Normal file
@@ -0,0 +1,545 @@
|
||||
// Copyright 2010-2025 Google LLC
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENTAL_H_
|
||||
#define OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENTAL_H_
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/die_if_null.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "ortools/base/status_builder.h"
|
||||
#include "ortools/math_opt/elemental/attr_key.h"
|
||||
#include "ortools/math_opt/elemental/attr_storage.h"
|
||||
#include "ortools/math_opt/elemental/derived_data.h"
|
||||
#include "ortools/math_opt/elemental/diff.h"
|
||||
#include "ortools/math_opt/elemental/element_ref_tracker.h"
|
||||
#include "ortools/math_opt/elemental/element_storage.h"
|
||||
#include "ortools/math_opt/elemental/elements.h"
|
||||
#include "ortools/math_opt/elemental/thread_safe_id_map.h"
|
||||
#include "ortools/math_opt/model.pb.h"
|
||||
#include "ortools/math_opt/model_update.pb.h"
|
||||
|
||||
namespace operations_research::math_opt {
|
||||
|
||||
// A MathOpt optimization model and modification trackers.
|
||||
//
|
||||
// Holds the elements, the attribute values, and tracks modifications to the
|
||||
// model by `Diff` objects, and keeps them all in sync. See README.md for
|
||||
// details.
|
||||
class Elemental {
|
||||
public:
|
||||
// An opaque value type for a reference to an underlying `Diff` (change
|
||||
// tracker).
|
||||
class DiffHandle {
|
||||
public:
|
||||
int64_t id() const { return diff_id_; }
|
||||
|
||||
private:
|
||||
explicit DiffHandle(int64_t diff_id, ThreadSafeIdMap<Diff>* diffs)
|
||||
: diff_id_(diff_id), diffs_(*ABSL_DIE_IF_NULL(diffs)) {}
|
||||
|
||||
int64_t diff_id_;
|
||||
ThreadSafeIdMap<Diff>& diffs_;
|
||||
|
||||
friend class Elemental;
|
||||
friend class ElementalTestPeer;
|
||||
};
|
||||
|
||||
explicit Elemental(std::string model_name = "",
|
||||
std::string primary_objective_name = "");
|
||||
|
||||
// The name of this optimization model.
|
||||
const std::string& model_name() const { return model_name_; }
|
||||
|
||||
// The name of the primary objective of this optimization model.
|
||||
const std::string& primary_objective_name() const {
|
||||
return primary_objective_name_;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Elements
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Creates and returns the id of a new element for the element type `e`.
|
||||
template <ElementType e>
|
||||
ElementId<e> AddElement(const absl::string_view name) {
|
||||
return ElementId<e>(AddElementUntyped(e, name));
|
||||
}
|
||||
|
||||
// Type-erased version of `AddElement`. Prefer the latter.
|
||||
int64_t AddElementUntyped(const ElementType e, const absl::string_view name) {
|
||||
return mutable_element_storage(e).Add(name);
|
||||
}
|
||||
|
||||
// Deletes the element with `id` for element type `e`, returning true on
|
||||
// success and false if no element was deleted (it was already deleted or the
|
||||
// id was not from any existing element).
|
||||
template <ElementType e>
|
||||
bool DeleteElement(const ElementId<e> id) {
|
||||
return DeleteElementUntyped(e, id.value());
|
||||
}
|
||||
|
||||
// Type-erased version of `DeleteElement`. Prefer the latter.
|
||||
bool DeleteElementUntyped(ElementType e, int64_t id);
|
||||
|
||||
// Returns true the element with `id` for element type `e` exists (it was
|
||||
// created and not yet deleted).
|
||||
template <ElementType e>
|
||||
bool ElementExists(const ElementId<e> id) const {
|
||||
return ElementExistsUntyped(e, id.value());
|
||||
}
|
||||
|
||||
// Type-erased version of `ElementExists`. Prefer the latter.
|
||||
bool ElementExistsUntyped(const ElementType e, const int64_t id) const {
|
||||
return element_storage(e).Exists(id);
|
||||
}
|
||||
|
||||
// Returns the name of the element with `id` for element type `e`, or an error
|
||||
// if this element does not exist.
|
||||
template <ElementType e>
|
||||
absl::StatusOr<absl::string_view> GetElementName(
|
||||
const ElementId<e> id) const {
|
||||
return GetElementNameUntyped(e, id.value());
|
||||
}
|
||||
|
||||
// Type-erased version of `GetElementName`. Prefer the latter.
|
||||
absl::StatusOr<absl::string_view> GetElementNameUntyped(
|
||||
const ElementType e, const int64_t id) const {
|
||||
return element_storage(e).GetName(id);
|
||||
}
|
||||
|
||||
// Returns the ids of all elements of element type `e` in the model in an
|
||||
// unsorted, non-deterministic order.
|
||||
template <ElementType e>
|
||||
ElementIdsVector<e> AllElements() const {
|
||||
return ElementIdsVector<e>(AllElementsUntyped(e));
|
||||
}
|
||||
|
||||
// Type-erased version of `AllElements`. Prefer the latter.
|
||||
std::vector<int64_t> AllElementsUntyped(const ElementType e) const {
|
||||
return element_storage(e).AllIds();
|
||||
}
|
||||
|
||||
// Returns the id of next element created for element type `e`.
|
||||
//
|
||||
// Equal to one plus the number of elements that were created for element type
|
||||
// `e`. When no elements have been deleted, this equals num_elements(e).
|
||||
int64_t NextElementId(const ElementType e) const {
|
||||
return element_storage(e).next_id();
|
||||
}
|
||||
|
||||
// Returns the number of elements in the model for element type `e`.
|
||||
//
|
||||
// Equal to the number of elements that were created minus the number
|
||||
// deleted for element type `e`.
|
||||
int64_t NumElements(const ElementType e) const {
|
||||
return element_storage(e).size();
|
||||
}
|
||||
|
||||
// Increases next_element_id(e) to `id` if it is currently less than `id`.
|
||||
//
|
||||
// Useful for reading a model back from proto, most users should not need to
|
||||
// call this directly.
|
||||
template <ElementType e>
|
||||
void EnsureNextElementIdAtLeast(const ElementId<e> id) {
|
||||
EnsureNextElementIdAtLeastUntyped(e, id.value());
|
||||
}
|
||||
|
||||
// Type-erased version of `EnsureNextElementIdAtLeast`. Prefer the latter.
|
||||
void EnsureNextElementIdAtLeastUntyped(const ElementType e, int64_t id) {
|
||||
mutable_element_storage(e).EnsureNextIdAtLeast(id);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Attributes
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
struct AlwaysOk {};
|
||||
|
||||
// In all the following functions, `key` must be a valid key for attribute `a`
|
||||
// (i.e. elements must exists for all elements ids of `key`). When this is not
|
||||
// the case, the behavior is defined by one of the following policies:
|
||||
// - Checks whether each element of the key exists, and dies if not.
|
||||
struct DiePolicy {
|
||||
using CheckResultT = AlwaysOk;
|
||||
template <typename T>
|
||||
using Wrapped = T;
|
||||
};
|
||||
// - Checks whether each element of the key exists, and sets `*status` if
|
||||
// not. When `*status` is not OK, the model is not modified and is still
|
||||
// valid.
|
||||
struct StatusPolicy {
|
||||
using CheckResultT = absl::Status;
|
||||
template <typename T>
|
||||
using Wrapped = absl::StatusOr<T>;
|
||||
};
|
||||
// - Does not check whether key elements exists. UB if the key does not exist
|
||||
// (DCHECK-fails in debug mode). Use if you know that the key exists and
|
||||
// you care about performance.
|
||||
struct UBPolicy {
|
||||
using CheckResultT = AlwaysOk;
|
||||
template <typename T>
|
||||
using Wrapped = T;
|
||||
};
|
||||
|
||||
// Restores the attribute `a` to its default value for all AttrKeys (or for
|
||||
// an Attr0, its only value).
|
||||
template <typename AttrType>
|
||||
void AttrClear(AttrType a);
|
||||
|
||||
// Return the vector of attribute keys where a is non-default.
|
||||
template <typename AttrType>
|
||||
std::vector<AttrKeyFor<AttrType>> AttrNonDefaults(const AttrType a) const {
|
||||
return attrs_[a].NonDefaults();
|
||||
}
|
||||
|
||||
// Returns the number of keys where `a` is non-default.
|
||||
template <typename AttrType>
|
||||
int64_t AttrNumNonDefaults(const AttrType a) const {
|
||||
return attrs_[a].num_non_defaults();
|
||||
}
|
||||
|
||||
// Returns the value of the attr `a` for `key`:
|
||||
// - `get_attr(DoubleAttr1::kVarUb, AttrKey(x))` returns a `double` value if
|
||||
// element id `x` exists, and crashes otherwise. The returned value is
|
||||
// the default value if the attribute has not been set for `x`.
|
||||
// - `get_attr<StatusPolicy>(DoubleAttr1::kVarUb, AttrKey(x))` returns a
|
||||
// valid `StatusOr<double>` if element id `x` exists, and an error
|
||||
// otherwise.
|
||||
template <typename Policy = DiePolicy, typename AttrType>
|
||||
typename Policy::template Wrapped<ValueTypeFor<AttrType>> GetAttr(
|
||||
AttrType a, AttrKeyFor<AttrType> key) const;
|
||||
|
||||
// Returns true if the attr `a` for `key` has a value different from its
|
||||
// default.
|
||||
template <typename Policy = DiePolicy, typename AttrType>
|
||||
typename Policy::template Wrapped<bool> AttrIsNonDefault(
|
||||
AttrType a, AttrKeyFor<AttrType> key) const;
|
||||
|
||||
// Sets the value of the attr `a` for the element `key` to `value`, and
|
||||
// returns true if the value of the attribute has changed.
|
||||
template <typename Policy = DiePolicy, typename AttrType>
|
||||
typename Policy::CheckResultT SetAttr(AttrType a, AttrKeyFor<AttrType> key,
|
||||
ValueTypeFor<AttrType> value);
|
||||
|
||||
// Returns the set of all keys `k` such that `k[i] == key_elem` and `k` has a
|
||||
// non-default value for the attribute `a`.
|
||||
template <int i, typename Policy = DiePolicy, typename AttrType = void>
|
||||
typename Policy::template Wrapped<std::vector<AttrKeyFor<AttrType>>> Slice(
|
||||
AttrType a, int64_t key_elem) const;
|
||||
|
||||
// Returns the size of the given slice: This is equivalent to `Slice(a,
|
||||
// key_elem).size()`, but `O(1)`.
|
||||
template <int i, typename Policy = DiePolicy, typename AttrType = void>
|
||||
typename Policy::template Wrapped<int64_t> GetSliceSize(
|
||||
AttrType a, int64_t key_elem) const;
|
||||
|
||||
// Returns a copy of this, but with no diffs. The name of the model can
|
||||
// optionally be replaced by `new_model_name`.
|
||||
Elemental Clone(
|
||||
std::optional<absl::string_view> new_model_name = std::nullopt) const;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Working with proto
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Returns an equivalent protocol buffer. Fails if the model is too big
|
||||
// to fit in the in-memory representation of the proto (it has more than
|
||||
// 2**31-1 elements of a type or non-defaults for an attribute).
|
||||
absl::StatusOr<ModelProto> ExportModel(bool remove_names = false) const;
|
||||
|
||||
// Creates an equivalent Elemental to `proto`.
|
||||
static absl::StatusOr<Elemental> FromModelProto(const ModelProto& proto);
|
||||
|
||||
// Applies the changes to the model in `update_proto`.
|
||||
absl::Status ApplyUpdateProto(const ModelUpdateProto& update_proto);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Diffs
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Returns the DiffHandle for `id`, if one exists, or nullopt otherwise.
|
||||
std::optional<DiffHandle> GetDiffHandle(int64_t id) const;
|
||||
|
||||
// Returned handle is valid until passed `DeleteDiff` or `*this` is
|
||||
// destructed.
|
||||
DiffHandle AddDiff();
|
||||
|
||||
// Deletes `diff` & invalidates it. Returns false if the handle was invalid or
|
||||
// from the wrong elemental). On success, invalidates `diff`.
|
||||
bool DeleteDiff(DiffHandle diff);
|
||||
|
||||
// The number of diffs currently tracking this.
|
||||
int64_t NumDiffs() const { return diffs_->Size(); }
|
||||
|
||||
// Returns true on success (fails if diff was null, deleted or from the wrong
|
||||
// elemental). Warning: diff is modified (owned by this).
|
||||
bool Advance(DiffHandle diff);
|
||||
|
||||
// Internal use only (users of Elemental cannot access Diff directly), but
|
||||
// prefer to invoking Diff::modified_keys() directly..
|
||||
//
|
||||
// Returns the modified keys in a Diff for an attribute, filtering out the
|
||||
// keys referring to an element that has been deleted.
|
||||
//
|
||||
// This is needed because in some situations where a variable is deleted
|
||||
// we cannot clean up the diff, see README.md.
|
||||
template <typename AttrType>
|
||||
std::vector<AttrKeyFor<AttrType>> ModifiedKeysThatExist(
|
||||
AttrType attr, const Diff& diff) const;
|
||||
|
||||
// Returns a proto describing all changes to the model for `diff` since the
|
||||
// most recent call to `Advance(diff)` (or the creation of `diff` if
|
||||
// `Advance()` was never called).
|
||||
//
|
||||
// Returns std::nullopt the resulting ModelUpdateProto would be the empty
|
||||
// message (there have been no changes to the model to report).
|
||||
//
|
||||
// Fails if the update is too big to fit in the in-memory representation of
|
||||
// the proto (it has more than 2**31-1 elements in a RepeatedField).
|
||||
absl::StatusOr<std::optional<ModelUpdateProto>> ExportModelUpdate(
|
||||
DiffHandle diff, bool remove_names = false) const;
|
||||
|
||||
// Prints out the model by element and attribute. If print_diffs is true, also
|
||||
// prints out the deleted elements and modified keys for each attribute for
|
||||
// each DiffHandle tracked.
|
||||
//
|
||||
// This is a debug format. Do not assume the output is consistent across CLs
|
||||
// and do not parse this format.
|
||||
std::string DebugString(bool print_diffs = true) const;
|
||||
|
||||
private:
|
||||
DiePolicy::CheckResultT CheckElementExists(ElementType elem_type,
|
||||
int64_t elem_id, DiePolicy) const;
|
||||
StatusPolicy::CheckResultT CheckElementExists(ElementType elem_type,
|
||||
int64_t elem_id,
|
||||
StatusPolicy) const;
|
||||
UBPolicy::CheckResultT CheckElementExists(ElementType elem_type,
|
||||
int64_t elem_id, UBPolicy) const;
|
||||
|
||||
template <typename AttrType>
|
||||
bool AttrKeyExists(AttrType attr, AttrKeyFor<AttrType> key) const;
|
||||
|
||||
template <typename Policy, typename AttrType>
|
||||
typename Policy::CheckResultT CheckAttrKeyExists(
|
||||
AttrType a, AttrKeyFor<AttrType> key) const;
|
||||
|
||||
template <typename AttrType, int i>
|
||||
void UpdateAttrOnElementDeleted(AttrType a, int64_t id);
|
||||
|
||||
std::array<int64_t, kNumElements> CurrentCheckpoint() const;
|
||||
|
||||
const ElementStorage& element_storage(const ElementType e) const {
|
||||
// TODO(b/365997645): post C++ 23, prefer std::to_underlying(e).
|
||||
return elements_[static_cast<int>(e)];
|
||||
}
|
||||
|
||||
ElementStorage& mutable_element_storage(const ElementType e) {
|
||||
return elements_[static_cast<int>(e)];
|
||||
}
|
||||
|
||||
std::string model_name_;
|
||||
std::string primary_objective_name_;
|
||||
std::array<ElementStorage, kNumElements> elements_;
|
||||
template <int i>
|
||||
using StorageForAttrType =
|
||||
AttrStorage<typename AllAttrs::TypeDescriptor<i>::ValueType,
|
||||
AllAttrs::TypeDescriptor<i>::kNumKeyElements,
|
||||
typename AllAttrs::TypeDescriptor<i>::Symmetry>;
|
||||
AttrMap<StorageForAttrType> attrs_;
|
||||
// For each attribute whose value is an element, we need to keep a map of
|
||||
// element to the set of keys whose value refers to that element. This
|
||||
// is used to erase the attribute when the element is deleted.
|
||||
// This is kept outside of `attrs_` so that we can update the diffs when
|
||||
// element deletions trigger attribute deletions.
|
||||
template <int i>
|
||||
using ElementRefTrackerForAttrType =
|
||||
ElementRefTrackerForAttrTypeDescriptor<AllAttrs::TypeDescriptor<i>>;
|
||||
AttrMap<ElementRefTrackerForAttrType> element_ref_trackers_;
|
||||
// Note: it is important that this is a unique_ptr for two reasons:
|
||||
// 1. We need a stable memory address for diffs_ to refer to in DiffHandle,
|
||||
// and the Elemental type is moveable.
|
||||
// 2. We want Elemental to be moveable, but ThreadSafeIdMap<Diff> is not.
|
||||
std::unique_ptr<ThreadSafeIdMap<Diff>> diffs_ =
|
||||
std::make_unique<ThreadSafeIdMap<Diff>>();
|
||||
};
|
||||
|
||||
template <typename Sink>
|
||||
void AbslStringify(Sink& sink, const Elemental& elemental) {
|
||||
sink.Append(elemental.DebugString());
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& ostr, const Elemental& elemental);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Inline and template implementation
|
||||
|
||||
template <typename AttrType>
|
||||
void Elemental::AttrClear(const AttrType a) {
|
||||
// Note: this is slightly faster than setting each non-default back to the
|
||||
// default value.
|
||||
const std::vector<AttrKeyFor<AttrType>> non_defaults = AttrNonDefaults(a);
|
||||
if (!non_defaults.empty()) {
|
||||
for (auto& [unused, diff] : diffs_->UpdateAndGetAll()) {
|
||||
for (const auto key : non_defaults) {
|
||||
diff->SetModified(a, key);
|
||||
}
|
||||
}
|
||||
}
|
||||
attrs_[a].Clear();
|
||||
if constexpr (is_element_id_v<ValueTypeFor<AttrType>>) {
|
||||
element_ref_trackers_[a].Clear();
|
||||
}
|
||||
}
|
||||
|
||||
#define ELEMENTAL_RETURN_IF_ERROR(expr) \
|
||||
{ \
|
||||
auto error = (expr); \
|
||||
if constexpr (!std::is_same_v<decltype(error), AlwaysOk>) { \
|
||||
if (!error.ok()) { \
|
||||
return error; \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
template <typename Policy, typename AttrType>
|
||||
typename Policy::template Wrapped<ValueTypeFor<AttrType>> Elemental::GetAttr(
|
||||
const AttrType a, const AttrKeyFor<AttrType> key) const {
|
||||
ELEMENTAL_RETURN_IF_ERROR(CheckAttrKeyExists<Policy>(a, key));
|
||||
return attrs_[a].Get(key);
|
||||
}
|
||||
|
||||
template <typename Policy, typename AttrType>
|
||||
typename Policy::template Wrapped<bool> Elemental::AttrIsNonDefault(
|
||||
const AttrType a, const AttrKeyFor<AttrType> key) const {
|
||||
ELEMENTAL_RETURN_IF_ERROR(CheckAttrKeyExists<Policy>(a, key));
|
||||
return attrs_[a].IsNonDefault(key);
|
||||
}
|
||||
|
||||
template <typename Policy, typename AttrType>
|
||||
typename Policy::CheckResultT Elemental::SetAttr(
|
||||
const AttrType a, const AttrKeyFor<AttrType> key,
|
||||
const ValueTypeFor<AttrType> value) {
|
||||
ELEMENTAL_RETURN_IF_ERROR(CheckAttrKeyExists<Policy>(a, key));
|
||||
const std::optional<ValueTypeFor<AttrType>> prev_value =
|
||||
attrs_[a].Set(key, value);
|
||||
if (prev_value.has_value()) {
|
||||
if constexpr (is_element_id_v<ValueTypeFor<AttrType>>) {
|
||||
element_ref_trackers_[a].Untrack(key, *prev_value);
|
||||
}
|
||||
for (auto& [unused, diff] : diffs_->UpdateAndGetAll()) {
|
||||
diff->SetModified(a, key);
|
||||
}
|
||||
if constexpr (is_element_id_v<ValueTypeFor<AttrType>>) {
|
||||
element_ref_trackers_[a].Track(key, value);
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
template <int i, typename Policy, typename AttrType>
|
||||
typename Policy::template Wrapped<std::vector<AttrKeyFor<AttrType>>>
|
||||
Elemental::Slice(const AttrType a, const int64_t key_elem) const {
|
||||
ELEMENTAL_RETURN_IF_ERROR(
|
||||
CheckElementExists(GetElementTypes<AttrType>(a)[i], key_elem, Policy{}));
|
||||
return attrs_[a].template Slice<i>(key_elem);
|
||||
}
|
||||
|
||||
template <int i, typename Policy, typename AttrType>
|
||||
typename Policy::template Wrapped<int64_t> Elemental::GetSliceSize(
|
||||
const AttrType a, const int64_t key_elem) const {
|
||||
ELEMENTAL_RETURN_IF_ERROR(
|
||||
CheckElementExists(GetElementTypes<AttrType>(a)[i], key_elem, Policy{}));
|
||||
return attrs_[a].template GetSliceSize<i>(key_elem);
|
||||
}
|
||||
|
||||
inline Elemental::DiePolicy::CheckResultT Elemental::CheckElementExists(
|
||||
const ElementType elem_type, const int64_t elem_id, DiePolicy) const {
|
||||
// This is ~30% faster than:
|
||||
// CHECK_OK(CheckElementExistsUntyped(elem_type, elem_id, StatusPolicy{}));
|
||||
CHECK(ElementExistsUntyped(elem_type, elem_id))
|
||||
<< "no element with id " << elem_id << " for element type " << elem_type;
|
||||
return {};
|
||||
}
|
||||
|
||||
inline Elemental::StatusPolicy::CheckResultT Elemental::CheckElementExists(
|
||||
const ElementType elem_type, const int64_t elem_id, StatusPolicy) const {
|
||||
if (!ElementExistsUntyped(elem_type, elem_id)) {
|
||||
return util::InvalidArgumentErrorBuilder()
|
||||
<< "no element with id " << elem_id << " for element type "
|
||||
<< elem_type;
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
inline Elemental::UBPolicy::CheckResultT Elemental::CheckElementExists(
|
||||
const ElementType elem_type, const int64_t elem_id, UBPolicy) const {
|
||||
// Try to be useful in debug mode.
|
||||
DCHECK_OK(CheckElementExists(elem_type, elem_id, StatusPolicy{}));
|
||||
return {};
|
||||
}
|
||||
|
||||
template <typename AttrType>
|
||||
bool Elemental::AttrKeyExists(const AttrType attr,
|
||||
const AttrKeyFor<AttrType> key) const {
|
||||
for (int i = 0; i < key.size(); ++i) {
|
||||
if (!ElementExistsUntyped(GetElementTypes<AttrType>(attr)[i], key[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Policy, typename AttrType>
|
||||
typename Policy::CheckResultT Elemental::CheckAttrKeyExists(
|
||||
const AttrType a, const AttrKeyFor<AttrType> key) const {
|
||||
for (int i = 0; i < key.size(); ++i) {
|
||||
ELEMENTAL_RETURN_IF_ERROR(
|
||||
CheckElementExists(GetElementTypes<AttrType>(a)[i], key[i], Policy{}));
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
template <typename AttrType>
|
||||
std::vector<AttrKeyFor<AttrType>> Elemental::ModifiedKeysThatExist(
|
||||
AttrType attr, const Diff& diff) const {
|
||||
using Key = AttrKeyFor<AttrType>;
|
||||
std::vector<Key> keys;
|
||||
// Can be a slight overestimate.
|
||||
keys.reserve(diff.modified_keys(attr).size());
|
||||
for (const Key key : diff.modified_keys(attr)) {
|
||||
if constexpr (Key::size() > 1) {
|
||||
if (AttrKeyExists(attr, key)) {
|
||||
keys.push_back(key);
|
||||
}
|
||||
} else {
|
||||
keys.push_back(key);
|
||||
}
|
||||
}
|
||||
return keys;
|
||||
}
|
||||
|
||||
#undef ELEMENTAL_RETURN_IF_ERROR
|
||||
|
||||
} // namespace operations_research::math_opt
|
||||
|
||||
#endif // OR_TOOLS_MATH_OPT_ELEMENTAL_ELEMENTAL_H_
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user